Getting circuit derivatives inside jax.jit

Hello everyone,

I am trying to solve differential equations using a VQA approach. For this, my loss function has to include the derivative of the quantum circuit. I now have a working version using jax and optax, however it is painfully slow.

import pennylane as qml
import jax
from jax import numpy as jnp
import optax
import catalyst

n_wires = 4
weights = jnp.ones((4,n_wires))
bias = jnp.array([1.0])
opt = optax.adam(learning_rate=0.1)
opt_state = opt.init(weights)

@qml.qnode(qml.device("lightning.qubit", wires=n_wires), diff_method="best")
def circuit(x, weights):

    # Embedding Ansatz
    for i in range(n_wires):
        qml.RY(2*jnp.arccos(x),wires = i)

    # Variational Ansatz   
    for i in range(n_wires):
        qml.RX(weights[i, 0], wires=i)
        qml.RY(weights[i, 1], wires=i)
        qml.RX(weights[i, 2], wires=i)
        qml.CNOT(wires=[i, (i + 1) % n_wires])

    # Total magnetization in z-direction as cost function
    return qml.expval(qml.sum(*[qml.PauliZ(i) for i in range(n_wires)]))

def loss_fnc(weights):
    # Loss function of: du/dx = 1, u(0) = 0
    x = jnp.linspace(0,0.99,21)
    _dudx = jax.grad(circuit, argnums=0)
    dudx = jnp.array([_dudx(i, weights) for i in x])

    loss_diff = jnp.mean((dudx - jnp.ones_like(dudx))**2)
    loss_initial = jnp.mean((circuit(jnp.array(0.0),weights))**2)
    
    return loss_diff + loss_initial

def optimize(weights, opt_state, n=20):
    loss_history = []

    for i in range(1,n+1):
        loss_val, grads = jax.value_and_grad(loss_fnc)(weights)
        updates, opt_state = opt.update(grads, opt_state)
        weights = optax.apply_updates(weights, updates)
        print(f"Step: {i}  Loss: {loss_val}")
        loss_history.append(loss_val)

    return weights, opt_state, loss_history
weights, opt_state, loss_history = optimize(weights, opt_state)

In order to speed the computation up, I tried to use jax.jit and/or catalyst, however it seems that is not possible to get a derivative inside precopiled code. See the following example where I just modified the loss function with the jax.jit function decorator:

import pennylane as qml
import jax
from jax import numpy as jnp
import optax
import catalyst

n_wires = 4
weights = jnp.ones((4,n_wires))
bias = jnp.array([1.0])
opt = optax.adam(learning_rate=0.1)
opt_state = opt.init(weights)

@qml.qnode(qml.device("lightning.qubit", wires=n_wires), diff_method="best")
def circuit(x, weights):

    # Embedding Ansatz
    for i in range(n_wires):
        qml.RY(2*jnp.arccos(x),wires = i)

    # Variational Ansatz   
    for i in range(n_wires):
        qml.RX(weights[i, 0], wires=i)
        qml.RY(weights[i, 1], wires=i)
        qml.RX(weights[i, 2], wires=i)
        qml.CNOT(wires=[i, (i + 1) % n_wires])

    # Total magnetization in z-direction as cost function
    return qml.expval(qml.sum(*[qml.PauliZ(i) for i in range(n_wires)]))

@jax.jit
def loss_fnc(weights):
    # Loss function of: du/dx = 1, u(0) = 0
    x = jnp.linspace(0,0.99,21)
    _dudx = jax.grad(circuit, argnums=0)
    dudx = jnp.array([_dudx(i, weights) for i in x])

    loss_diff = jnp.mean((dudx - jnp.ones_like(dudx))**2)
    loss_initial = jnp.mean((circuit(jnp.array(0.0),weights))**2)
    
    return loss_diff + loss_initial

def optimize(weights, opt_state, n=20):
    loss_history = []

    for i in range(1,n+1):
        loss_val, grads = jax.value_and_grad(loss_fnc)(weights)
        updates, opt_state = opt.update(grads, opt_state)
        weights = optax.apply_updates(weights, updates)
        print(f"Step: {i}  Loss: {loss_val}")
        loss_history.append(loss_val)

    return weights, opt_state, loss_history

weights, opt_state, loss_history = optimize(weights, opt_state)

This leads to the following error:

---------------------------------------------------------------------------
JaxStackTraceBeforeTransformation         Traceback (most recent call last)
File <frozen runpy>:198, in _run_module_as_main()

File <frozen runpy>:88, in _run_code()
...
---> 74   raise ValueError(
     75       "Pure callbacks do not support JVP. "
     76       "Please use `jax.custom_jvp` to use callbacks while taking gradients.")

ValueError: Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients.

Does anyone have experience with the implementation of a custom callback using pennylane Qcircuits? Or does anyone know how to do so?

On a sidenote, I implemented the same example in PyTorch, and got a significant performance speedup.

This is my qml.about():

Name: PennyLane
Version: 0.35.1
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page: https://github.com/PennyLaneAI/pennylane
Author: 
Author-email: 
License: Apache License 2.0
Location: /home/stefan/.local/lib/python3.11/site-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, pennylane-lightning, requests, rustworkx, scipy, semantic-version, toml, typing-extensions
Required-by: PennyLane-Catalyst, PennyLane_Lightning

Platform info:           Linux-6.5.13-7-MANJARO-x86_64-with-glibc2.39
Python version:          3.11.8
Numpy version:           1.26.4
Scipy version:           1.11.4
Installed devices:
- nvidia.custatevec (PennyLane-Catalyst-0.5.0)
- nvidia.cutensornet (PennyLane-Catalyst-0.5.0)
- softwareq.qpp (PennyLane-Catalyst-0.5.0)
- default.clifford (PennyLane-0.35.1)
- default.gaussian (PennyLane-0.35.1)
- default.mixed (PennyLane-0.35.1)
- default.qubit (PennyLane-0.35.1)
- default.qubit.autograd (PennyLane-0.35.1)
- default.qubit.jax (PennyLane-0.35.1)
- default.qubit.legacy (PennyLane-0.35.1)
- default.qubit.tf (PennyLane-0.35.1)
- default.qubit.torch (PennyLane-0.35.1)
- default.qutrit (PennyLane-0.35.1)
- null.qubit (PennyLane-0.35.1)
- lightning.qubit (PennyLane_Lightning-0.35.1)

Hi @Stefan_Berger,

There are two things you could try:

  1. Use lightning.qubit without jax.jit but set diff_method="adjoint" when instantiating your QNode. On my local machine, this runs in about 15 seconds for n=20.

  2. Use default.qubit with jax.jit (you can optionally set diff_method="backprop" but this should be the default anyway). On my local machine, this runs in about 3 seconds after a lengthy initial compilation.

Regarding the first point, Lightning Qubit with the "best" diff_method for PennyLane version 0.35 and previous was unaware of the “adjoint” mode, due to some new structure we needed to add to the backend implementation. So, when “best” is selected, everything falls back to using parameter-shift, which is likely why you see such a slow runtime. Adjoint will be the default diff_method for Lightning as of the upcoming 0.36 release.

If you have a large number of CPU cores, you can also enable the observable to be split across them, rather than combined, which may offer some speedups, depending on the workload. For this, you can create the device as:

dev = qml.device("lightning.qubit", wires=x, batch_obs=True)

For more info on this mode, you can check here.

1 Like

Hi @Tom_Bromley ,

Thanks for your tips! I benchmarked and compared the different approaches, and I think I found unexpected behavior when using the lightning.qubit.


When using your first tip, I noticed an increase in the loss, maybe the derivative was computed incorrectly. When switching to the default.qubit, the result was as expected, with or without jax.jit. This was the output of the lightning.qubit:

Step: 1  Loss: 22.767391204833984
Step: 2  Loss: 22.07586669921875
Step: 3  Loss: 24.740163803100586
Step: 4  Loss: 28.548276901245117
Step: 5  Loss: 32.848453521728516
Step: 6  Loss: 37.56929016113281
Step: 7  Loss: 42.778018951416016
Step: 8  Loss: 48.53341293334961
Step: 9  Loss: 54.82943344116211
Step: 10  Loss: 61.5770149230957
6.03 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

When comparing the default.qubitwith jax.jit and torch, I found that torch was always faster, even when ignoring the initial compilation time of jax.jit. Here are the results:

First jax.jit with the compilation time:

Step: 100  Loss: 0.28363654017448425
...
Step: 1000  Loss: 0.01202180702239275
2min 59s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

And then calling the same optimization function again:

Step: 100  Loss: 0.28363654017448425
...
Step: 1000  Loss: 0.01202180702239275
51.2 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

And finally the torchresults:

Step: 100  Loss: 0.2836303651903658
...
Step: 1000  Loss: 0.013612815638004542
37.1 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

I would have thought that jax with the compilation should have been faster than torch, but maybe my code is not written in an optimal way. I tried using vmapto speed up the performance of the my_model function, but without success. If anyone has any tips on how to write more performant code, feel free to let me know, I am always eager to speed up my code (even though this investigation may have already taken more time then the potential speedup :slight_smile: ).


Code:

PyTorch

import pennylane as qml
import torch

n_wires = 4

torch.manual_seed(42)
weights = torch.ones((n_wires,3), requires_grad=True)
bias = torch.zeros(1, requires_grad=True)
params = {"weights": weights, "bias": bias}
opt = torch.optim.Adam([ weights, bias], lr=0.1)
loss_history = []


@qml.qnode(qml.device("default.qubit", wires=n_wires), diff_method="best")
def circuit(x, weights):

    # Embedding Ansatz
    for i in range(n_wires):
        qml.RY(2*torch.arccos(x),wires = i)

    # Variational Ansatz   
    for i in range(n_wires):
        qml.RX(weights[i, 0], wires=i)
        qml.RY(weights[i, 1], wires=i)
        qml.RX(weights[i, 2], wires=i)
        qml.CNOT(wires=[i, (i + 1) % n_wires])

    # Total magnetization in z-direction as cost function
    return qml.expval(qml.sum(*[qml.PauliZ(i) for i in range(n_wires)]))

def my_model(data, weights, bias):
    return circuit(data, weights) + bias

def loss_fnc(params):
    # Loss function of: du/dx = 1, u(0) = 0
    x = torch.linspace(0,0.99,11,requires_grad=True)
    u_pred = my_model(x, params["weights"], params["bias"])
    grad_outputs = torch.ones_like(u_pred)
    dudx = torch.autograd.grad(u_pred, x, grad_outputs=grad_outputs, create_graph=True)[0]
    
    loss_diff = torch.mean((dudx - torch.ones_like(dudx))**2)
    loss_initial = torch.mean(my_model(torch.zeros_like(x),params["weights"], params["bias"])**2)
    
    return loss_diff + loss_initial

def optimize(params, n=1000):
    loss_history = []

    for i in range(1,n+1):
        opt.zero_grad()
        loss_val = loss_fnc(params)
        loss_val.backward()
        opt.step()
        if i%100 == 0: print(f"Step: {i}  Loss: {loss_val}")
        loss_history.append(loss_val)

    return params, loss_history

%timeit -r1 -n1 optimize(params)

default.qubit with jax.jit

import pennylane as qml
import jax
from jax import numpy as jnp
import optax

n_wires = 4
weights = jnp.ones((n_wires,3))
bias = jnp.array(0.)
opt = optax.adam(learning_rate=0.1)
params = {"weights": weights, "bias": bias}
opt_state = opt.init(params)

@qml.qnode(qml.device("default.qubit", wires=n_wires), diff_method="best")
def circuit(x, weights):

    # Embedding Ansatz
    for i in range(n_wires):
        qml.RY(2*jnp.arccos(x),wires = i)

    # Variational Ansatz   
    for i in range(n_wires):
        qml.RX(weights[i, 0], wires=i)
        qml.RY(weights[i, 1], wires=i)
        qml.RX(weights[i, 2], wires=i)
        qml.CNOT(wires=[i, (i + 1) % n_wires])

    # Total magnetization in z-direction as cost function
    return qml.expval(qml.sum(*[qml.PauliZ(i) for i in range(n_wires)]))

@jax.jit
def my_model(data, weights, bias):
    return circuit(data, weights) + bias

@jax.jit
def loss_fnc(params):
    # Loss function of: du/dx = 1, u(0) = 0
    x = jnp.linspace(0,0.99,11)
    _dudx = jax.grad(my_model, argnums=0)
    dudx = jnp.array([_dudx(i, params["weights"], params["bias"]) for i in x])
    
    loss_diff = jnp.mean((dudx - jnp.ones_like(dudx))**2)
    loss_initial = jnp.mean(my_model(jnp.zeros_like(x),params["weights"], params["bias"])**2)
    
    return loss_diff + loss_initial

def optimize(params, opt_state, n=1000):
    loss_history = []

    for i in range(1,n+1):
        loss_val, grads = jax.value_and_grad(loss_fnc)(params)
        updates, opt_state = opt.update(grads, opt_state)
        params = optax.apply_updates(params, updates)
        if i%100 == 0: jax.debug.print("Step: {i}  Loss: {loss_val}", i=i, loss_val=loss_val)
        loss_history.append(loss_val)

    return params, opt_state, loss_history

%timeit -r1 -n1 optimize(params, opt_state)

lightning.qubit without jax.jit

import pennylane as qml
import jax
from jax import numpy as jnp
import optax

n_wires = 4
weights = jnp.ones((n_wires,3))
bias = jnp.array(0.)
opt = optax.adam(learning_rate=0.1)
params = {"weights": weights, "bias": bias}
opt_state = opt.init(params)

@qml.qnode(qml.device("lightning.qubit", wires=n_wires), diff_method="adjoint")
def circuit(x, weights):

    # Embedding Ansatz
    for i in range(n_wires):
        qml.RY(2*jnp.arccos(x),wires = i)

    # Variational Ansatz   
    for i in range(n_wires):
        qml.RX(weights[i, 0], wires=i)
        qml.RY(weights[i, 1], wires=i)
        qml.RX(weights[i, 2], wires=i)
        qml.CNOT(wires=[i, (i + 1) % n_wires])

    # Total magnetization in z-direction as cost function
    return qml.expval(qml.sum(*[qml.PauliZ(i) for i in range(n_wires)]))

def my_model(data, weights, bias):
    return circuit(data, weights) + bias

def loss_fnc(params):
    # Loss function of: du/dx = 1, u(0) = 0
    x = jnp.linspace(0,0.99,11)
    _dudx = jax.grad(my_model, argnums=0)
    dudx = jnp.array([_dudx(i, params["weights"], params["bias"]) for i in x])
    
    loss_diff = jnp.mean((dudx - jnp.ones_like(dudx))**2)
    loss_initial = jnp.mean(my_model(jnp.zeros_like(x),params["weights"], params["bias"])**2)
    
    return loss_diff + loss_initial

def optimize(params, opt_state, n=10):
    loss_history = []

    for i in range(1,n+1):
        loss_val, grads = jax.value_and_grad(loss_fnc)(params)
        updates, opt_state = opt.update(grads, opt_state)
        params = optax.apply_updates(params, updates)
        if i%1 == 0: jax.debug.print("Step: {i}  Loss: {loss_val}", i=i, loss_val=loss_val)
        loss_history.append(loss_val)

    return params, opt_state, loss_history

%timeit -r1 -n1 optimize(params, opt_state)

Hi @Stefan_Berger, thank you for your detailed post! I’m checking with Tom and will get back to you soon.

Hi @Stefan_Berger, this week has been very busy for the team but we haven’t forgotten your comment! We’ll get back to you next week.

1 Like

Some insights from my side.

One is that your simulations involve only 4 qubits. lightning.qubit is our high performance simulator, and is optimized for 15+ qubits. Since it focuses on the larger circuits, it can sometimes have a bit more overhead on small problems.

The second insight is that it looks like you are taking higher order derivatives. While second order derivatives work out of the box with default.qubit and backprop, they need to explicitly requested with all other differentiation methods by setting max_diff=2 as a QNode argument. This way we don’t have to calculate higher order derivatives when users don’t need them.

As @Tom_Bromley mentioned lightning.qubit defaulted to parameter-shift (as of v0.35.1 and below), a hardware compatible differentiation method that requires 2-4 executions per trainable parameter and is compatible with higher order derivatives. Adjoint diff (default for v0.36.0 and above) requires far less simulation time, but is not compatible with higher order derivatives.

Basically, performance can get messy, involves a lot of compromises, and sometimes doesn’t really have a good explanation.

So with your problem, I would recommend using default.qubit, and which ever interface seems to be faster for this particular scenario.

2 Likes