# Getting circuit derivatives inside jax.jit

Hello everyone,

I am trying to solve differential equations using a VQA approach. For this, my loss function has to include the derivative of the quantum circuit. I now have a working version using jax and optax, however it is painfully slow.

``````import pennylane as qml
import jax
from jax import numpy as jnp
import optax
import catalyst

n_wires = 4
weights = jnp.ones((4,n_wires))
bias = jnp.array([1.0])
opt_state = opt.init(weights)

@qml.qnode(qml.device("lightning.qubit", wires=n_wires), diff_method="best")
def circuit(x, weights):

# Embedding Ansatz
for i in range(n_wires):
qml.RY(2*jnp.arccos(x),wires = i)

# Variational Ansatz
for i in range(n_wires):
qml.RX(weights[i, 0], wires=i)
qml.RY(weights[i, 1], wires=i)
qml.RX(weights[i, 2], wires=i)
qml.CNOT(wires=[i, (i + 1) % n_wires])

# Total magnetization in z-direction as cost function
return qml.expval(qml.sum(*[qml.PauliZ(i) for i in range(n_wires)]))

def loss_fnc(weights):
# Loss function of: du/dx = 1, u(0) = 0
x = jnp.linspace(0,0.99,21)
dudx = jnp.array([_dudx(i, weights) for i in x])

loss_diff = jnp.mean((dudx - jnp.ones_like(dudx))**2)
loss_initial = jnp.mean((circuit(jnp.array(0.0),weights))**2)

return loss_diff + loss_initial

def optimize(weights, opt_state, n=20):
loss_history = []

for i in range(1,n+1):
print(f"Step: {i}  Loss: {loss_val}")
loss_history.append(loss_val)

return weights, opt_state, loss_history
weights, opt_state, loss_history = optimize(weights, opt_state)
``````

In order to speed the computation up, I tried to use jax.jit and/or catalyst, however it seems that is not possible to get a derivative inside precopiled code. See the following example where I just modified the loss function with the jax.jit function decorator:

``````import pennylane as qml
import jax
from jax import numpy as jnp
import optax
import catalyst

n_wires = 4
weights = jnp.ones((4,n_wires))
bias = jnp.array([1.0])
opt_state = opt.init(weights)

@qml.qnode(qml.device("lightning.qubit", wires=n_wires), diff_method="best")
def circuit(x, weights):

# Embedding Ansatz
for i in range(n_wires):
qml.RY(2*jnp.arccos(x),wires = i)

# Variational Ansatz
for i in range(n_wires):
qml.RX(weights[i, 0], wires=i)
qml.RY(weights[i, 1], wires=i)
qml.RX(weights[i, 2], wires=i)
qml.CNOT(wires=[i, (i + 1) % n_wires])

# Total magnetization in z-direction as cost function
return qml.expval(qml.sum(*[qml.PauliZ(i) for i in range(n_wires)]))

@jax.jit
def loss_fnc(weights):
# Loss function of: du/dx = 1, u(0) = 0
x = jnp.linspace(0,0.99,21)
dudx = jnp.array([_dudx(i, weights) for i in x])

loss_diff = jnp.mean((dudx - jnp.ones_like(dudx))**2)
loss_initial = jnp.mean((circuit(jnp.array(0.0),weights))**2)

return loss_diff + loss_initial

def optimize(weights, opt_state, n=20):
loss_history = []

for i in range(1,n+1):
print(f"Step: {i}  Loss: {loss_val}")
loss_history.append(loss_val)

return weights, opt_state, loss_history

weights, opt_state, loss_history = optimize(weights, opt_state)
``````

This leads to the following error:

``````---------------------------------------------------------------------------
JaxStackTraceBeforeTransformation         Traceback (most recent call last)
File <frozen runpy>:198, in _run_module_as_main()

File <frozen runpy>:88, in _run_code()
...
---> 74   raise ValueError(
75       "Pure callbacks do not support JVP. "

ValueError: Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients.
``````

Does anyone have experience with the implementation of a custom callback using pennylane Qcircuits? Or does anyone know how to do so?

On a sidenote, I implemented the same example in PyTorch, and got a significant performance speedup.

``````Name: PennyLane
Version: 0.35.1
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page: https://github.com/PennyLaneAI/pennylane
Author:
Author-email:
Location: /home/stefan/.local/lib/python3.11/site-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, pennylane-lightning, requests, rustworkx, scipy, semantic-version, toml, typing-extensions
Required-by: PennyLane-Catalyst, PennyLane_Lightning

Platform info:           Linux-6.5.13-7-MANJARO-x86_64-with-glibc2.39
Python version:          3.11.8
Numpy version:           1.26.4
Scipy version:           1.11.4
Installed devices:
- nvidia.custatevec (PennyLane-Catalyst-0.5.0)
- nvidia.cutensornet (PennyLane-Catalyst-0.5.0)
- softwareq.qpp (PennyLane-Catalyst-0.5.0)
- default.clifford (PennyLane-0.35.1)
- default.gaussian (PennyLane-0.35.1)
- default.mixed (PennyLane-0.35.1)
- default.qubit (PennyLane-0.35.1)
- default.qubit.jax (PennyLane-0.35.1)
- default.qubit.legacy (PennyLane-0.35.1)
- default.qubit.tf (PennyLane-0.35.1)
- default.qubit.torch (PennyLane-0.35.1)
- default.qutrit (PennyLane-0.35.1)
- null.qubit (PennyLane-0.35.1)
- lightning.qubit (PennyLane_Lightning-0.35.1)
``````

There are two things you could try:

1. Use `lightning.qubit` without `jax.jit` but set `diff_method="adjoint"` when instantiating your QNode. On my local machine, this runs in about 15 seconds for `n=20`.

2. Use `default.qubit` with `jax.jit` (you can optionally set `diff_method="backprop"` but this should be the default anyway). On my local machine, this runs in about 3 seconds after a lengthy initial compilation.

Regarding the first point, Lightning Qubit with the `"best"` diff_method for PennyLane version 0.35 and previous was unaware of the “adjoint” mode, due to some new structure we needed to add to the backend implementation. So, when “best” is selected, everything falls back to using parameter-shift, which is likely why you see such a slow runtime. Adjoint will be the default diff_method for Lightning as of the upcoming 0.36 release.

If you have a large number of CPU cores, you can also enable the observable to be split across them, rather than combined, which may offer some speedups, depending on the workload. For this, you can create the device as:

``````dev = qml.device("lightning.qubit", wires=x, batch_obs=True)
``````

1 Like

Hi @Tom_Bromley ,

Thanks for your tips! I benchmarked and compared the different approaches, and I think I found unexpected behavior when using the `lightning.qubit`.

When using your first tip, I noticed an increase in the loss, maybe the derivative was computed incorrectly. When switching to the `default.qubit`, the result was as expected, with or without `jax.jit`. This was the output of the `lightning.qubit`:

``````Step: 1  Loss: 22.767391204833984
Step: 2  Loss: 22.07586669921875
Step: 3  Loss: 24.740163803100586
Step: 4  Loss: 28.548276901245117
Step: 5  Loss: 32.848453521728516
Step: 6  Loss: 37.56929016113281
Step: 7  Loss: 42.778018951416016
Step: 8  Loss: 48.53341293334961
Step: 9  Loss: 54.82943344116211
Step: 10  Loss: 61.5770149230957
6.03 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
``````

When comparing the `default.qubit`with `jax.jit` and `torch`, I found that torch was always faster, even when ignoring the initial compilation time of `jax.jit`. Here are the results:

First `jax.jit` with the compilation time:

``````Step: 100  Loss: 0.28363654017448425
...
Step: 1000  Loss: 0.01202180702239275
2min 59s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
``````

And then calling the same optimization function again:

``````Step: 100  Loss: 0.28363654017448425
...
Step: 1000  Loss: 0.01202180702239275
51.2 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

``````

And finally the `torch`results:

``````Step: 100  Loss: 0.2836303651903658
...
Step: 1000  Loss: 0.013612815638004542
37.1 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
``````

I would have thought that jax with the compilation should have been faster than torch, but maybe my code is not written in an optimal way. I tried using `vmap`to speed up the performance of the `my_model` function, but without success. If anyone has any tips on how to write more performant code, feel free to let me know, I am always eager to speed up my code (even though this investigation may have already taken more time then the potential speedup ).

### Code:

#### `PyTorch`

``````import pennylane as qml
import torch

n_wires = 4

torch.manual_seed(42)
params = {"weights": weights, "bias": bias}
opt = torch.optim.Adam([ weights, bias], lr=0.1)
loss_history = []

@qml.qnode(qml.device("default.qubit", wires=n_wires), diff_method="best")
def circuit(x, weights):

# Embedding Ansatz
for i in range(n_wires):
qml.RY(2*torch.arccos(x),wires = i)

# Variational Ansatz
for i in range(n_wires):
qml.RX(weights[i, 0], wires=i)
qml.RY(weights[i, 1], wires=i)
qml.RX(weights[i, 2], wires=i)
qml.CNOT(wires=[i, (i + 1) % n_wires])

# Total magnetization in z-direction as cost function
return qml.expval(qml.sum(*[qml.PauliZ(i) for i in range(n_wires)]))

def my_model(data, weights, bias):
return circuit(data, weights) + bias

def loss_fnc(params):
# Loss function of: du/dx = 1, u(0) = 0
u_pred = my_model(x, params["weights"], params["bias"])

loss_diff = torch.mean((dudx - torch.ones_like(dudx))**2)
loss_initial = torch.mean(my_model(torch.zeros_like(x),params["weights"], params["bias"])**2)

return loss_diff + loss_initial

def optimize(params, n=1000):
loss_history = []

for i in range(1,n+1):
loss_val = loss_fnc(params)
loss_val.backward()
opt.step()
if i%100 == 0: print(f"Step: {i}  Loss: {loss_val}")
loss_history.append(loss_val)

return params, loss_history

%timeit -r1 -n1 optimize(params)
``````

#### `default.qubit` with `jax.jit`

``````import pennylane as qml
import jax
from jax import numpy as jnp
import optax

n_wires = 4
weights = jnp.ones((n_wires,3))
bias = jnp.array(0.)
params = {"weights": weights, "bias": bias}
opt_state = opt.init(params)

@qml.qnode(qml.device("default.qubit", wires=n_wires), diff_method="best")
def circuit(x, weights):

# Embedding Ansatz
for i in range(n_wires):
qml.RY(2*jnp.arccos(x),wires = i)

# Variational Ansatz
for i in range(n_wires):
qml.RX(weights[i, 0], wires=i)
qml.RY(weights[i, 1], wires=i)
qml.RX(weights[i, 2], wires=i)
qml.CNOT(wires=[i, (i + 1) % n_wires])

# Total magnetization in z-direction as cost function
return qml.expval(qml.sum(*[qml.PauliZ(i) for i in range(n_wires)]))

@jax.jit
def my_model(data, weights, bias):
return circuit(data, weights) + bias

@jax.jit
def loss_fnc(params):
# Loss function of: du/dx = 1, u(0) = 0
x = jnp.linspace(0,0.99,11)
dudx = jnp.array([_dudx(i, params["weights"], params["bias"]) for i in x])

loss_diff = jnp.mean((dudx - jnp.ones_like(dudx))**2)
loss_initial = jnp.mean(my_model(jnp.zeros_like(x),params["weights"], params["bias"])**2)

return loss_diff + loss_initial

def optimize(params, opt_state, n=1000):
loss_history = []

for i in range(1,n+1):
if i%100 == 0: jax.debug.print("Step: {i}  Loss: {loss_val}", i=i, loss_val=loss_val)
loss_history.append(loss_val)

return params, opt_state, loss_history

%timeit -r1 -n1 optimize(params, opt_state)

``````

#### `lightning.qubit`without`jax.jit`

``````import pennylane as qml
import jax
from jax import numpy as jnp
import optax

n_wires = 4
weights = jnp.ones((n_wires,3))
bias = jnp.array(0.)
params = {"weights": weights, "bias": bias}
opt_state = opt.init(params)

def circuit(x, weights):

# Embedding Ansatz
for i in range(n_wires):
qml.RY(2*jnp.arccos(x),wires = i)

# Variational Ansatz
for i in range(n_wires):
qml.RX(weights[i, 0], wires=i)
qml.RY(weights[i, 1], wires=i)
qml.RX(weights[i, 2], wires=i)
qml.CNOT(wires=[i, (i + 1) % n_wires])

# Total magnetization in z-direction as cost function
return qml.expval(qml.sum(*[qml.PauliZ(i) for i in range(n_wires)]))

def my_model(data, weights, bias):
return circuit(data, weights) + bias

def loss_fnc(params):
# Loss function of: du/dx = 1, u(0) = 0
x = jnp.linspace(0,0.99,11)
dudx = jnp.array([_dudx(i, params["weights"], params["bias"]) for i in x])

loss_diff = jnp.mean((dudx - jnp.ones_like(dudx))**2)
loss_initial = jnp.mean(my_model(jnp.zeros_like(x),params["weights"], params["bias"])**2)

return loss_diff + loss_initial

def optimize(params, opt_state, n=10):
loss_history = []

for i in range(1,n+1):
if i%1 == 0: jax.debug.print("Step: {i}  Loss: {loss_val}", i=i, loss_val=loss_val)
loss_history.append(loss_val)

return params, opt_state, loss_history

%timeit -r1 -n1 optimize(params, opt_state)
``````

Hi @Stefan_Berger, thank you for your detailed post! I’m checking with Tom and will get back to you soon.

Hi @Stefan_Berger, this week has been very busy for the team but we haven’t forgotten your comment! We’ll get back to you next week.

1 Like

Some insights from my side.

One is that your simulations involve only 4 qubits. `lightning.qubit` is our high performance simulator, and is optimized for 15+ qubits. Since it focuses on the larger circuits, it can sometimes have a bit more overhead on small problems.

The second insight is that it looks like you are taking higher order derivatives. While second order derivatives work out of the box with `default.qubit` and backprop, they need to explicitly requested with all other differentiation methods by setting `max_diff=2` as a `QNode` argument. This way we don’t have to calculate higher order derivatives when users don’t need them.

As @Tom_Bromley mentioned `lightning.qubit` defaulted to `parameter-shift` (as of v0.35.1 and below), a hardware compatible differentiation method that requires 2-4 executions per trainable parameter and is compatible with higher order derivatives. Adjoint diff (default for v0.36.0 and above) requires far less simulation time, but is not compatible with higher order derivatives.

Basically, performance can get messy, involves a lot of compromises, and sometimes doesn’t really have a good explanation.

So with your problem, I would recommend using `default.qubit`, and which ever interface seems to be faster for this particular scenario.

2 Likes