Hello everyone,

I am trying to solve differential equations using a VQA approach. For this, my loss function has to include the derivative of the quantum circuit. I now have a working version using jax and optax, however it is painfully slow.

```
import pennylane as qml
import jax
from jax import numpy as jnp
import optax
import catalyst
n_wires = 4
weights = jnp.ones((4,n_wires))
bias = jnp.array([1.0])
opt = optax.adam(learning_rate=0.1)
opt_state = opt.init(weights)
@qml.qnode(qml.device("lightning.qubit", wires=n_wires), diff_method="best")
def circuit(x, weights):
# Embedding Ansatz
for i in range(n_wires):
qml.RY(2*jnp.arccos(x),wires = i)
# Variational Ansatz
for i in range(n_wires):
qml.RX(weights[i, 0], wires=i)
qml.RY(weights[i, 1], wires=i)
qml.RX(weights[i, 2], wires=i)
qml.CNOT(wires=[i, (i + 1) % n_wires])
# Total magnetization in z-direction as cost function
return qml.expval(qml.sum(*[qml.PauliZ(i) for i in range(n_wires)]))
def loss_fnc(weights):
# Loss function of: du/dx = 1, u(0) = 0
x = jnp.linspace(0,0.99,21)
_dudx = jax.grad(circuit, argnums=0)
dudx = jnp.array([_dudx(i, weights) for i in x])
loss_diff = jnp.mean((dudx - jnp.ones_like(dudx))**2)
loss_initial = jnp.mean((circuit(jnp.array(0.0),weights))**2)
return loss_diff + loss_initial
def optimize(weights, opt_state, n=20):
loss_history = []
for i in range(1,n+1):
loss_val, grads = jax.value_and_grad(loss_fnc)(weights)
updates, opt_state = opt.update(grads, opt_state)
weights = optax.apply_updates(weights, updates)
print(f"Step: {i} Loss: {loss_val}")
loss_history.append(loss_val)
return weights, opt_state, loss_history
weights, opt_state, loss_history = optimize(weights, opt_state)
```

In order to speed the computation up, I tried to use jax.jit and/or catalyst, however it seems that is not possible to get a derivative inside precopiled code. See the following example where I just modified the loss function with the jax.jit function decorator:

```
import pennylane as qml
import jax
from jax import numpy as jnp
import optax
import catalyst
n_wires = 4
weights = jnp.ones((4,n_wires))
bias = jnp.array([1.0])
opt = optax.adam(learning_rate=0.1)
opt_state = opt.init(weights)
@qml.qnode(qml.device("lightning.qubit", wires=n_wires), diff_method="best")
def circuit(x, weights):
# Embedding Ansatz
for i in range(n_wires):
qml.RY(2*jnp.arccos(x),wires = i)
# Variational Ansatz
for i in range(n_wires):
qml.RX(weights[i, 0], wires=i)
qml.RY(weights[i, 1], wires=i)
qml.RX(weights[i, 2], wires=i)
qml.CNOT(wires=[i, (i + 1) % n_wires])
# Total magnetization in z-direction as cost function
return qml.expval(qml.sum(*[qml.PauliZ(i) for i in range(n_wires)]))
@jax.jit
def loss_fnc(weights):
# Loss function of: du/dx = 1, u(0) = 0
x = jnp.linspace(0,0.99,21)
_dudx = jax.grad(circuit, argnums=0)
dudx = jnp.array([_dudx(i, weights) for i in x])
loss_diff = jnp.mean((dudx - jnp.ones_like(dudx))**2)
loss_initial = jnp.mean((circuit(jnp.array(0.0),weights))**2)
return loss_diff + loss_initial
def optimize(weights, opt_state, n=20):
loss_history = []
for i in range(1,n+1):
loss_val, grads = jax.value_and_grad(loss_fnc)(weights)
updates, opt_state = opt.update(grads, opt_state)
weights = optax.apply_updates(weights, updates)
print(f"Step: {i} Loss: {loss_val}")
loss_history.append(loss_val)
return weights, opt_state, loss_history
weights, opt_state, loss_history = optimize(weights, opt_state)
```

This leads to the following error:

```
---------------------------------------------------------------------------
JaxStackTraceBeforeTransformation Traceback (most recent call last)
File <frozen runpy>:198, in _run_module_as_main()
File <frozen runpy>:88, in _run_code()
...
---> 74 raise ValueError(
75 "Pure callbacks do not support JVP. "
76 "Please use `jax.custom_jvp` to use callbacks while taking gradients.")
ValueError: Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients.
```

Does anyone have experience with the implementation of a custom callback using pennylane Qcircuits? Or does anyone know how to do so?

On a sidenote, I implemented the same example in PyTorch, and got a significant performance speedup.

This is my qml.about():

```
Name: PennyLane
Version: 0.35.1
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page: https://github.com/PennyLaneAI/pennylane
Author:
Author-email:
License: Apache License 2.0
Location: /home/stefan/.local/lib/python3.11/site-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, pennylane-lightning, requests, rustworkx, scipy, semantic-version, toml, typing-extensions
Required-by: PennyLane-Catalyst, PennyLane_Lightning
Platform info: Linux-6.5.13-7-MANJARO-x86_64-with-glibc2.39
Python version: 3.11.8
Numpy version: 1.26.4
Scipy version: 1.11.4
Installed devices:
- nvidia.custatevec (PennyLane-Catalyst-0.5.0)
- nvidia.cutensornet (PennyLane-Catalyst-0.5.0)
- softwareq.qpp (PennyLane-Catalyst-0.5.0)
- default.clifford (PennyLane-0.35.1)
- default.gaussian (PennyLane-0.35.1)
- default.mixed (PennyLane-0.35.1)
- default.qubit (PennyLane-0.35.1)
- default.qubit.autograd (PennyLane-0.35.1)
- default.qubit.jax (PennyLane-0.35.1)
- default.qubit.legacy (PennyLane-0.35.1)
- default.qubit.tf (PennyLane-0.35.1)
- default.qubit.torch (PennyLane-0.35.1)
- default.qutrit (PennyLane-0.35.1)
- null.qubit (PennyLane-0.35.1)
- lightning.qubit (PennyLane_Lightning-0.35.1)
```