Using QubitUnitary with qjit raises differentiability issues

Dear Pennylane team,

I am currenrtly trying to speed up my code by moving from using torch as my training environment to jax+catalyst.

My quantum circuit contains a QubitUnitary (which is fixed and not trained).
While using lightning.qubit together with torch training of this circuit is possible, it raises an error as soon as I change to catalyst+jax, namely:

raise DifferentiableCompileError( catalyst.utils.exceptions.DifferentiableCompileError: QubitUnitary is non-differentiable on ‘lightning.qubit’ device

Can I change something in the settings of of qjit to make this work?

In the following you find a minimal example that reproduces the error. (1. using pytorch (working) and 2. using jax and catalyst (not working)). Both use the same backend “lightning.qubit”

import scipy.stats as scst
import numpy as np
import pennylane as qml
import torch

device = "lightning.qubit"
tau = 2*10**(-3) #length timestep [s]
shots = 100
n_wires_state = 3
lr=0.05
tsteps=1000

dev_state = qml.device(device, wires=n_wires_state, shots=shots)
unitary = scst.unitary_group.rvs(8)
weights = torch.tensor([0.,0.,0.], requires_grad=True)

@qml.qnode(device=dev_state)
def cost_circuit(weights, unitary):
    qml.RY(weights[0], wires=0)
    qml.RY(weights[1], wires=1)
    qml.RY(weights[2], wires=2)
    qml.QubitUnitary(unitary, wires=[0,1,2])
    return qml.expval(qml.PauliZ(0))

def costfunc(weights): 
    cost = 1 - cost_circuit(weights, unitary)
    return cost  
    
def closure():
    opt.zero_grad()
    loss = costfunc(weights)
    loss.backward(retain_graph=True)
    return loss

opt = torch.optim.Adam([weights], lr=lr)
for k in range(tsteps):
    if k % 50 == 0:
        print(f"Step Adam velocity {k}, cost: {costfunc(weights)}")
    opt.step(closure)
import scipy.stats as scst
import numpy as np
import pennylane as qml
import matplotlib.pyplot as plt
from catalyst import qjit, measure, for_loop, value_and_grad
from jax import numpy as jnp
import jax
import optax

'device = "lightning.qubit"
shots = 100
n_wires_state = 3
lr=0.05
tsteps=1000

dev_state = qml.device(device, wires=n_wires_state, shots=shots)
unitary = jnp.array(scst.unitary_group.rvs(8))
weights = jnp.array([0.,0.,0.])

@qjit()
@qml.qnode(device=dev_state)
def cost_circuit(weights):
    qml.RY(weights[0], wires=0)
    qml.RY(weights[1], wires=1)
    qml.RY(weights[2], wires=2)
    qml.QubitUnitary(unitary, wires=[0,1,2])
    return qml.expval(qml.PauliZ(0))
    
optimizer = optax.adam(lr)

@qjit
def costfunc(weights): 
    cost = 1 - cost_circuit(weights, unitary)
    return cost  
        
@qjit
def update_step_jit(i, args):
    weights, opt_state = args
    loss_val, grads = value_and_grad(costfunc)(weights)
    updates, opt_state = optimizer.update(grads, opt_state)
    weights = optax.apply_updates(weights, updates)
    def print_fn():
        jax.debug.print("Step Adam pressure {i}, cost: {loss_val}", i=i, loss_val=loss_val)

    # if print_training=True, print the loss every 5 steps
    jax.lax.cond((jnp.mod(i, 50) == 0), print_fn, lambda: None)
    return (weights, opt_state)
        
@qjit
def optimization_jit(params, tsteps):

    opt_state = optimizer.init(params)
    args = (params, opt_state)
    (params, opt_state) = for_loop(0, tsteps, 1)(update_step_jit)(args)

    return params

weights = optimization_jit(weights, tsteps)```

I tried to make ´´´unitary``` static, but that's apparently not possible for arrays.
I'm thankfull for any advice :)
Best reagrds,
Pia

Hi @Pia ,

There were several issues happening here.

  1. You had a minor error where you added the unitary in your costfunc but not in the signature for your cost_circuit.
  2. You used a JAX print statement where you needed to use a Catalyst print instead.
  3. Lightning’s adjoint-jacobian differentiation method doesn’t have support for QubitUnitary. In that case you need to change the differentiation method to parameter-shift or finite-diff.

The code below shows your code with these three changes, plus some print statements that I added just to check that everything was working as expected.

import scipy.stats as scst
import numpy as np
import pennylane as qml
import matplotlib.pyplot as plt
from catalyst import qjit, measure, for_loop, value_and_grad
from jax import numpy as jnp
import jax
import optax
import catalyst # Added

device = "lightning.qubit"
shots = 100
n_wires_state = 3
lr=0.05
tsteps=1000

dev_state = qml.device(device, wires=n_wires_state, shots=shots)
unitary = jnp.array(scst.unitary_group.rvs(8))
weights = jnp.array([0.,0.,0.])

@qjit()
@qml.qnode(device=dev_state, diff_method="parameter-shift") # Added diff_method="parameter-shift"
def cost_circuit(weights):
    qml.RY(weights[0], wires=0)
    qml.RY(weights[1], wires=1)
    qml.RY(weights[2], wires=2)
    qml.QubitUnitary(unitary, wires=[0,1,2])
    return qml.expval(qml.PauliZ(0))

optimizer = optax.adam(lr)

print('Init cost: ',cost_circuit(weights)) # Added

@qjit
def costfunc(weights):
    cost = 1 - cost_circuit(weights) # , unitary  # Removed unitary
    return cost

print('Init costfn: ',costfunc(weights)) # Added
print('Init grad ',qml.qjit(value_and_grad(costfunc))(weights)) # Added

@qjit
def update_step_jit(i, args):
    weights, opt_state = args 
    loss_val, grads = value_and_grad(costfunc)(weights)
    updates, opt_state = optimizer.update(grads, opt_state)
    weights = optax.apply_updates(weights, updates)
    
    def print_fn():
        # Removed this print statement because it's causing issues. Print with Catalyst instead
        #jax.debug.print("Step Adam pressure {i}, cost: {loss_val}", i=i, loss_val=loss_val)

        # Use this print statement instead
        catalyst.debug.print("Step Adam pressure {i}, cost: {loss_val}", i=i, loss_val=loss_val)
    
    # if print_training=True, print the loss every 50 steps
    jax.lax.cond((jnp.mod(i, 50) == 0), print_fn, lambda: None)
        
    return (weights, opt_state)

@qjit
def optimization_jit(params, tsteps):

    opt_state = optimizer.init(params)
    args = (params, opt_state)
    (params, opt_state) = for_loop(0, tsteps, 1)(update_step_jit)(args)

    return params

weights = optimization_jit(weights, tsteps)

print("Final weights: ",weights) # Added

I hope this helps!

Hey Catalina,
thank you! That solved the issue :slight_smile:
I have a follow up question. I also used a mid-circuit measurement in my circuit which is apparently not possible in a differentiable circuit? I get the error:

catalyst.utils.exceptions.DifferentiableCompileError: MidCircuitMeasure is not allowed in gradients

It there a way to make it work? Or is it just not suppoted with catalyst yet? Would it work if I solely use jax instead?

Here is the minimal example with additional mid-circuit measurement (it might seem useless here, but it is required in my real circuit):

import scipy.stats as scst

import numpy as np
import pennylane as qml
import matplotlib.pyplot as plt
from catalyst import qjit
import catalyst  #, measure, for_loop, value_and_grad, debug
from jax import numpy as jnp
import jax
import optax

device = "lightning.qubit"
tau = 2*10**(-3) #length timestep [s]
shots = 100
n_wires_state = 3
lr=0.05
tsteps=1000

dev_state = qml.device(device, wires=n_wires_state, shots=shots)
unitary = jnp.array(scst.unitary_group.rvs(8))
weights = jnp.array([0.,0.,0.])

@qjit()
@qml.qnode(device=dev_state, diff_method="parameter-shift")
def cost_circuit(weights):
    qml.RY(weights[0], wires=0)
    qml.RY(weights[1], wires=1)
    qml.RY(weights[2], wires=2)
    qml.QubitUnitary(unitary, wires=[0,1,2])
    catalyst.measure(wires=1)
    return qml.expval(qml.PauliZ(0))
    
optimizer = optax.adam(lr)

@qjit
def costfunc(weights): 
    cost = 1 - cost_circuit(weights)
    return cost  
        
@qjit
def update_step_jit(i, args):
    weights, opt_state = args
    loss_val, grads = catalyst.value_and_grad(costfunc)(weights)
    updates, opt_state = optimizer.update(grads, opt_state)
    weights = optax.apply_updates(weights, updates)
    def print_fn():
        #jax.debug.print("Step Adam pressure {i}, cost: {loss_val}", i=i, loss_val=loss_val)
        catalyst.debug.print("Step Adam pressure {i}, cost: {loss_val}", i=i, loss_val=loss_val)
    # if print_training=True, print the loss every 5 steps
    jax.lax.cond((jnp.mod(i, 50) == 0), print_fn, lambda: None)
    return (weights, opt_state)
        
@qjit
def optimization_jit(params, tsteps):

    opt_state = optimizer.init(params)
    args = (params, opt_state)
    (params, opt_state) = catalyst.for_loop(0, tsteps, 1)(update_step_jit)(args)

    return params

weights = optimization_jit(weights, tsteps)

print("Final weights: ",weights) # Added

type or paste code here

Hi @Pia ,

Unfortunately mid-circuit measurements aren’t yet supported in Catalyst.

However, my colleague Isaac shared this example on how this works with JAX :slightly_smiling_face:

dev = qml.device("lightning.qubit", wires=2)

@jax.jit
@qml.qnode(dev)
def circuit(x):
    qml.RY(x, wires=1)
    qml.CNOT(wires=[0, 1])
    qml.measure(0)
    return qml.expval(qml.X(1))

x = jnp.array(0.1)

jax.grad(circuit)(x)

Note that you should use jax.grad / optax for differentiation with jax.jit’d QNodes.

I hope this helps!

Dear Catalina,
good to know. I’m switching to jax now.
Unfortunately I directly run into the next problem which I struggle to fix myself: Combining adjoint and control seems to doesn’t work under jax. Each of them individually works without problems (and with pytorch also the combination works well).

Here is the minimal example:

import matplotlib.pyplot as plt
from catalyst import qjit
from jax import numpy as jnp
import jax
import optax

device = "lightning.qubit"
shots = 100
n_wires_state = 3
lr=0.05
tsteps=1000

dev_state = qml.device(device, wires=n_wires_state, shots=shots)
unitary = jnp.array(scst.unitary_group.rvs(8))
weights = jnp.array([0.,0.,0.])

def func_circ(weights,nq):
    qml.RY(weights, wires=nq)

@jax.jit
@qml.qnode(device=dev_state, interface="jax")#diff_method="finite-diff")
def cost_circuit(weights):
    func_circ(weights[1],1)
    qml.ctrl(qml.adjoint(func_circ),control=(0))(weights[2],2)
    qml.QubitUnitary(unitary, wires=[0,1,2])
    qml.measure(wires=1)
    return [qml.sample(qml.PauliZ(0)), qml.sample(qml.PauliZ(1))]
    
optimizer = optax.adam(lr)

@jax.jit
def costfunc(weights): 
    res = jnp.sum(jnp.array(cost_circuit(weights)), axis=1)
    vals, counts = jnp.unique(res, return_counts=True, size=1)
    res = counts[0]/len(res)
    cost = 1 - res
    return cost  

@jax.jit
def update_step_jit(i, args):
    weights, opt_state = args
    loss_val, grads = jax.value_and_grad(costfunc)(weights)
    updates, opt_state = optimizer.update(grads, opt_state)
    weights = optax.apply_updates(weights, updates)
    def print_fn():
        jax.debug.print("Step Adam pressure {i}, cost: {loss_val}", i=i, loss_val=loss_val)
    jax.lax.cond((jnp.mod(i, 50) == 0), print_fn, lambda: None)
    return (weights, opt_state)
        
@jax.jit
def optimization_jit(params, tsteps):

    opt_state = optimizer.init(params)
    args = (params, opt_state)
    (params, opt_state) = jax.lax.fori_loop(0, tsteps, update_step_jit, args)

    return params

weights = optimization_jit(weights, tsteps)
type or paste code here

and this is the error message I get:

Traceback (most recent call last):
File “/…/minimalExample.py”, line 50, in
print(costfunc(weights)) # works
^^^^^^^^^^^^^^^^^
File “/…/minimalExample.py”, line 43, in costfunc
res = jnp.sum(jnp.array(cost_circuit(weights)), axis=1)
^^^^^^^^^^^^^^^^^^^^^
File “/…/lib/python3.12/site-packages/pennylane/workflow/qnode.py”, line 905, in call
return self._impl_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/…/lib/python3.12/site-packages/pennylane/workflow/qnode.py”, line 881, in _impl_call
res = qml.execute(
^^^^^^^^^^^^
File “/…/lib/python3.12/site-packages/pennylane/workflow/execution.py”, line 227, in execute
tapes, post_processing = transform_program(tapes)
^^^^^^^^^^^^^^^^^^^^^^^^
File “/…/lib/python3.12/site-packages/pennylane/transforms/core/transform_program.py”, line 580, in call
new_tapes, fn = transform(tape, *targs, **tkwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/…/lib/python3.12/site-packages/pennylane/gradients/parameter_shift.py”, line 765, in _expand_transform_param_shift
[new_tape], postprocessing = qml.devices.preprocess.decompose(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/…/lib/python3.12/site-packages/pennylane/transforms/core/transform_dispatcher.py”, line 153, in call
transformed_tapes, processing_fn = self._transform(obj, *targs, **tkwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/…/lib/python3.12/site-packages/pennylane/devices/preprocess.py”, line 408, in decompose
if all(stopping_condition(op) for op in tape.operations[len(prep_op) :]):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/…/lib/python3.12/site-packages/pennylane/devices/preprocess.py”, line 408, in
if all(stopping_condition(op) for op in tape.operations[len(prep_op) :]):
^^^^^^^^^^^^^^^^^^^^^^
File “/…/lib/python3.12/site-packages/pennylane/gradients/parameter_shift.py”, line 746, in _param_shift_stopping_condition
if not op.has_decomposition:
^^^^^^^^^^^^^^^^^^^^
File “/…/lib/python3.12/site-packages/pennylane/ops/op_math/controlled.py”, line 721, in has_decomposition
if _is_single_qubit_special_unitary(self.base):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool
The error occurred while tracing the function cost_circuit at /superscratch/psiegl/PhD/QCProject/minimalExample.py:26 for jit. This concrete value was not available in Python because it depends on the value of the argument weights.
See Errors — JAX documentation

Is there a way to allow this combination?
Many thanks and best regards,
Pia

Hi @Pia ,

I can replicate your error. I’m looking into workarounds for this. I’ll let you know what I can find.

1 Like

Hi @Pia ,

I’ve made a bug report for this issue.

So the issue here is that using qml.control(qml.adjoint(...)...)() doesn’t work with jax-jit. For something like the code I added in the bug report you could technically use qjit. But as we saw already, you wouldn’t be able to find gradients if you also have mid-circuit measurements.

So at this point I think you need to choose whether you really need that mid-circuit measurement. From the last code that you shared it doesn’t look like the mid-circuit measurement is doing much. So if you can avoid using it then you can go back to using qjit and things should work nicely again.

I hope this helps.

Let me know if you run into other issues or if you can’t avoid using the mid-circuit measurement.

Dear Catalina,
thanks for looking into that and making a bug report. In my real code mid-circuit measurements are actually needed, so I have to wait for this bug fix.
Best regards,
Pia

Hi @Pia ,

The bug fix may not be obvious, so at this time it’s hard to predict how long it’s going to take, and hence when our team will have time for it.

Since the issue is caused by jitting, another option is to remove jitting. This will allow you to use mid-circuit measurements and gradients. Below is an example:

import pennylane as qml
from jax import numpy as jnp
import jax

# create the device
device = "lightning.qubit"
shots = 100
n_wires = 3
dev = qml.device(device, wires=n_wires, shots=shots)

# create the function
def func_circ(weights,qubit):
    qml.RY(weights, wires=qubit)

# find the adjoint
ad = qml.adjoint(func_circ)

# jitting causes issues
@qml.qnode(device=dev)
def cost_circuit(weights):
    # use a controlled operation on the adjoint
    qml.X(0)
    qml.ctrl(ad, control=(0))(weights[0], 2)
    m_0 = qml.measure(2)
    qml.cond(m_0, qml.RY)(0.1, wires=2)
    return qml.expval(qml.PauliZ(2))

# run the circuit for some weights
weights = jnp.array([0.5,0.3,0.1])
print(cost_circuit(weights))

jax.grad(cost_circuit)(weights)

Let me know if this works for you!