Help with Optimizers and speedups

Hi!
I am new to Pennylane, and have some questions regarding how to gain some speedups with some of the code that I have. My goal is to see how QAOA performs with both increasing graph sizes and depth. In order to get some data, I need to run the simulations multiple times, however, my code quickly becomes inefficient as I increase the number of qubits or the depth of the QAOA circuit. So I’m curious about 3 things:

Which qubit device is recommended to use to get increased speed? (I am currently using default.qubit)
Is there any efficient way to run the optimization? (possibly jit-ing the computations?)
I calculate the expectation-value of the hamiltonian using the ExpVal function. Is there an argument that can be passed to increase the speed of these evaluations?

Currently, running a 6 qubit system on a regular 3-degree graph requires around 15 minutes to run.

Here is some of the code that I have written. Note that the “edges” structure in my case is essentially just a list with tuples (i,j,w) that describe the edges with their corresponding weights, while graph is a networkx class. Out of these classes, it is QAOA complete in particular I think that is the slowest part of the code, so any help in speeding that part of the code would be incredibly helpful. Thanks in advance for any tips :slight_smile:

Here is the python file if it is easier to read off that: QAOA-MaxcutAttempt.py (10.5 KB)

def ProblemUnitary(edge1,edge2,t,weight):
    """Creates the problem unitary which is specific to MaxCut. Applies e^{-i t H_c}

    Args:
        edge1 ([int]): [The control qubit]
        edge2 ([int]): [The target qubit]
        t ([type]): [The t in the expression]
        weight ([float]): [weight of the edge between the nodes in the graph]
    """
    qml.CNOT(wires = [edge1,edge2])
    qml.RZ(2*t*weight,wires = edge2)
    qml.CNOT(wires = [edge1,edge2])

def MixerUnitary(edge1,t):
    """Creates the mixer unitary given by e^{-i t\sum_i X_i}. In this function, it is decomposed into a Rz rotation

    Args:
        edge1 ([int]): [The edge to apply the circuit on]
        t ([float]): [The t in the expression]
    """
    qml.Hadamard(wires = edge1)
    qml.RZ(2*t,wires = edge1)
    qml.Hadamard(wires = edge1)

def OneLayer(gamma,beta):
    """Create the one layer of the QAOA algorithm

    Args:
        gamma ([float]): [The k'th value of gamma]
        beta ([float]): [The k'th value of beta]
    """
    for i,j,w in edges:
        ProblemUnitary(i,j,gamma,w)
    for i in graph.nodes():
        MixerUnitary(i,beta)

def QAOAMaxCutAnsatz(parameters,**kwargs):
    #This creates the QAOA ansatz
    if len(np.shape(parameters))== 2:
        p = len(parameters[0])
    else:
        p = len(parameters)//2 
    for i in range(len(graph.nodes())):
        qml.Hadamard(wires = i)
    for i in range(p):
        if len(np.shape(parameters))== 2:
            OneLayer(parameters[0][i],parameters[1][i])
        else:
            OneLayer(parameters[0:p][i],parameters[p:][i])
    #return [qml.sample(qml.PauliZ(i)) for i in range(len(graph.nodes()))]

def QAOAcomplete(p,init_parameters, Hamiltonian, dev, graph, steps = 200):
    params = init_parameters

    cost_function = qml.ExpvalCost(QAOAMaxCutAnsatz, Hamiltonian,dev, graph = graph) #diff_method = autograd currently
    
    opt = qml.AdamOptimizer()

    for i in range(steps):
        params = opt.step(cost_function,params)
        if  (i+1)%steps == 0:
            print(f'objective after step {i+1}: {cost_function(params)}')
    
    return params,cost_function(params)

Hey @Viro, welcome :slight_smile:

I have a few tips that can help:

  • Prefer to use return qml.expval(H) within a QNode, rather than qml.ExpvalCost(). This is because, if you return a Hamiltonian expectation within a QNode, PennyLane can perform various optimizations.

  • Finally, a big speed advantage can be achieved by JITing the computation, as you suggest.

Below, I have modified your code to use JAX, and have applied the @jax.jit operator to the QNode:

import pennylane as qml
from networkx import Graph
import jax
from jax import numpy as np


def ProblemUnitary(edge1, edge2, t, weight):
    qml.CNOT(wires=[edge1, edge2])
    qml.RZ(2 * t * weight, wires=edge2)
    qml.CNOT(wires=[edge1, edge2])


def MixerUnitary(edge1, t):
    qml.Hadamard(wires=edge1)
    qml.RZ(2 * t, wires=edge1)
    qml.Hadamard(wires=edge1)


def OneLayer(gamma, beta, graph):
    for i, j in graph.edges():
        ProblemUnitary(i, j, gamma, 1)
    for i in graph.nodes():
        MixerUnitary(i, beta)


dev = qml.device("default.qubit", wires=3)
graph = Graph([(0, 1), (1, 2), (2, 0)])
cost_h, mixer_h = qml.qaoa.maxcut(graph)
init_parameters = np.ones([2, 2], dtype=np.float32)


@jax.jit
@qml.qnode(dev, interface="jax")
def cost_function(parameters):
    for i in range(len(graph.nodes())):
        qml.Hadamard(wires=i)

    for i in range(2):
        OneLayer(parameters[0][i], parameters[1][i], graph)

    return qml.expval(cost_h)


params = init_parameters
steps = 200

opt = qml.AdamOptimizer()

for i in range(steps):
    params, cost = opt.step_and_cost(
        cost_function, params, grad_fn=jax.grad(cost_function)
    )

    if (i + 1) % 5 == 0:
        print(f"objective after step {i+1}: {cost}")

Here, I made some assumptions as to your original graph and initial parameters.

You should find that this runs a lot faster, once the initial compilation is complete!

However, there are a couple of caveats to note with JIT support:

  • Currently, only default.qubit has JIT support, so applying @jax.jit will likely fail with other devices.

  • Using the JIT requires some restrictions in terms of classical processing. For example, if and for statements that depend on a tensor/array parameter are not allowed; note that I had to remove the if statements in your cost function.

1 Like

Some timing data for 200 steps of the above script on my laptop:

  • Using interface="autograd": 2.31s
  • Using interface="jax": 60.95s
  • Using interface="jax" and @jax.jit: 11.04s (note that this includes compilation time of 3.99s!)

While Autograd is faster for my small 3 qubit example, I imagine that as the number of qubits scales, the @jax.jit approach will likely surpass autograd in terms of speed :slight_smile:

Hi again, thanks alot for the input and the quick reply!
I tried your suggestion and used jax.jit as you mentioned, however, the code seems to be unable to compile the circuit for relatively small depths. I tried to run QAOA with depth p = 5 for a graph with 6 nodes and 3 edge pr. node, however, the code won’t compile. Any ideas as to what is causing the compilation to halt?
Autograd works fine though, runs relatively quickly :smiley:

Attaching the code here for convenience: PennyLaneHelp.py (4.3 KB)

Hi @Viro, what is the exact error you’re getting?

In any case I would suggest checking the following which might help you:

  • Make sure you’re using the latest version of PennyLane and Python
  • Try to break your problem into the smallest version of itself. This can help you debug and it can help us help you.
  • Try a different example, such as this demo, and see if you get the same problem. If you do get the same problem then it’s probably not a problem in your code.

Please let me know how it goes!

Hi again, sorry for the late reply.
It’s been a while since a worked with that code in particular, but I believe the reason for the non-compilation was because of the @vectorize decorator I used, however, I am not sure of this.

I have tried to work with the JAX workflow but find it a bit difficult. For instance, I tried to create a code that runs a circuit for several different initializations, however, whenever I try to optimize, I get the error

“Can’t differentiate w.r.t. type <class ‘jaxlib.xla_extension.DeviceArray’>”

QAOARandomSearch.py (6.1 KB)

Is there an obvious error here that makes it so that optimization is not possible?

Thanks in advance

Hi @Viro!
I get a different error message when I try to run your code. What versions of the different libraries and python are you using?

Also, could you please post the full error message? Thanks!

Here is an image of the full error message:

I think I am running python 3.8.8, pennylane 0.18, numpy 1.20.1 and jax 0.2.24

Hi @Viro, I have looked into this and haven’t been able to find an answer.

What the error is basically saying is that it can’t differentiate with respect to the output of your cost function.

I would suggest that you try to simplify the problem as much as possible. Find the simplest piece of code that reproduces your error. This will be useful in finding how to fix the problem.

In line with this, try to create a problem with a single parameter and set the trainable parameter as trainable with requires_grad=True and for any non-trainable parameters use requires_grad=False . Not defining this often causes errors.

If you haven’t been able to find the cause of the problem by Monday please let me know here and I will ask some PennyLane developers to take a look too.

If you do find the cause of the problem please also write it here so that others can benefit from the answer in the future.

I wish I could help more but hopefully you will be able to find the cause by creating a Minimum Working Example.

Hey @Viro!

Based on your error message, it looks like autograd is being used for differentiation, and not JAX! For example, the error message is coming from autograd/tracer.py:

This could mean one of two things:

  • You are using qml.grad() to compute the derivative. This only works with Autograd; when using JAX, use jax.grad instead.

  • The QNode needs to be created with interface="jax".

Let me know if that helps!

Hmm, that’s odd.
During optimization, I call

> params,cost = opt.step_and_cost(cost_function, params,grad_fun = jax.grad(cost_function))

and the QNode indeed uses the ‘jax’ interface. :confused:

Ah, that must be the issue; you want grad_fn=, not grad_fun!

You are indeed correct @josh, a small typo made all the difference :smiley:
Thanks!
I have one more question. When attempting to optimize, I want to set the parameters to different values at each iteration. However,

for i in range(len(gammaArray)):
    params = np.ones((2,p))
    print(gammaArray[i])
    params.at[0].set(gammaArray[i])
    params.at[1].set(betaArray[i])
    print(params)
    prevcost = 0
    cost = cost_function(params)
    j = 0
    opt = qml.AdamOptimizer()
    while (j < maxsteps):
        params,cost = opt.step_and_cost(cost_function, params,grad_fn = jax.grad(cost_function))

When printing out “params”, I only get a ones((2,p)) array instead of the altered versions (Jax numpy does not support the typical

params[0,:] = gammaArray[i]
params[1,:] = betaArray[I]

way of changing the arrays). How does one go about changing such “jax” arrays as one would do a regular numpy array so that I can use different initial points when optimizing?

EDIT: calling params = np.array([gammaArray[i],betaArray[I]]) instead of the convoluted version I mentioned earlier solved the issue :slight_smile:

Either way, I am curious as to how many qubits it is expected to be able to compile for. I tried compiling for graphs of 5 nodes, which equates to 5 wires, which compiled reasonably quickly (like a minute or two). Now I tried to do the same for 10 qubits, however, I get the following message:

At least from what I found, the compilation time Is heavily dependent on the number of wires. This message pops up multiple times as well. Is it a CPU hardware problem (I am running this on a low-spec 2017 Macbook pro) or some other problem that can be readily fixed?
Update: it does compile in the end, it just takes like around 20 minutes :sweat_smile:

@josh great catch on that typo!

@Viro thank you for posting your solutions to the other problems you mentioned.

I’m not sure if it works with jax but I suggest you try using lightning.qubit instead of default.qubit. This can increase your speed by 3X or more.

Let me know if this works!