Dear Pennylane team,
I am currenrtly trying to speed up my code by moving from using torch as my training environment to jax+catalyst.
My quantum circuit contains a QubitUnitary (which is fixed and not trained).
While using lightning.qubit together with torch training of this circuit is possible, it raises an error as soon as I change to catalyst+jax, namely:
raise DifferentiableCompileError( catalyst.utils.exceptions.DifferentiableCompileError: QubitUnitary is non-differentiable on ‘lightning.qubit’ device
Can I change something in the settings of of qjit to make this work?
In the following you find a minimal example that reproduces the error. (1. using pytorch (working) and 2. using jax and catalyst (not working)). Both use the same backend “lightning.qubit”
import scipy.stats as scst
import numpy as np
import pennylane as qml
import torch
device = "lightning.qubit"
tau = 2*10**(-3) #length timestep [s]
shots = 100
n_wires_state = 3
lr=0.05
tsteps=1000
dev_state = qml.device(device, wires=n_wires_state, shots=shots)
unitary = scst.unitary_group.rvs(8)
weights = torch.tensor([0.,0.,0.], requires_grad=True)
@qml.qnode(device=dev_state)
def cost_circuit(weights, unitary):
qml.RY(weights[0], wires=0)
qml.RY(weights[1], wires=1)
qml.RY(weights[2], wires=2)
qml.QubitUnitary(unitary, wires=[0,1,2])
return qml.expval(qml.PauliZ(0))
def costfunc(weights):
cost = 1 - cost_circuit(weights, unitary)
return cost
def closure():
opt.zero_grad()
loss = costfunc(weights)
loss.backward(retain_graph=True)
return loss
opt = torch.optim.Adam([weights], lr=lr)
for k in range(tsteps):
if k % 50 == 0:
print(f"Step Adam velocity {k}, cost: {costfunc(weights)}")
opt.step(closure)
import scipy.stats as scst
import numpy as np
import pennylane as qml
import matplotlib.pyplot as plt
from catalyst import qjit, measure, for_loop, value_and_grad
from jax import numpy as jnp
import jax
import optax
'device = "lightning.qubit"
shots = 100
n_wires_state = 3
lr=0.05
tsteps=1000
dev_state = qml.device(device, wires=n_wires_state, shots=shots)
unitary = jnp.array(scst.unitary_group.rvs(8))
weights = jnp.array([0.,0.,0.])
@qjit()
@qml.qnode(device=dev_state)
def cost_circuit(weights):
qml.RY(weights[0], wires=0)
qml.RY(weights[1], wires=1)
qml.RY(weights[2], wires=2)
qml.QubitUnitary(unitary, wires=[0,1,2])
return qml.expval(qml.PauliZ(0))
optimizer = optax.adam(lr)
@qjit
def costfunc(weights):
cost = 1 - cost_circuit(weights, unitary)
return cost
@qjit
def update_step_jit(i, args):
weights, opt_state = args
loss_val, grads = value_and_grad(costfunc)(weights)
updates, opt_state = optimizer.update(grads, opt_state)
weights = optax.apply_updates(weights, updates)
def print_fn():
jax.debug.print("Step Adam pressure {i}, cost: {loss_val}", i=i, loss_val=loss_val)
# if print_training=True, print the loss every 5 steps
jax.lax.cond((jnp.mod(i, 50) == 0), print_fn, lambda: None)
return (weights, opt_state)
@qjit
def optimization_jit(params, tsteps):
opt_state = optimizer.init(params)
args = (params, opt_state)
(params, opt_state) = for_loop(0, tsteps, 1)(update_step_jit)(args)
return params
weights = optimization_jit(weights, tsteps)```
I tried to make ´´´unitary``` static, but that's apparently not possible for arrays.
I'm thankfull for any advice :)
Best reagrds,
Pia