Using QubitUnitary with qjit raises differentiability issues

Dear Catalina,
good to know. I’m switching to jax now.
Unfortunately I directly run into the next problem which I struggle to fix myself: Combining adjoint and control seems to doesn’t work under jax. Each of them individually works without problems (and with pytorch also the combination works well).

Here is the minimal example:

import matplotlib.pyplot as plt
from catalyst import qjit
from jax import numpy as jnp
import jax
import optax

device = "lightning.qubit"
shots = 100
n_wires_state = 3
lr=0.05
tsteps=1000

dev_state = qml.device(device, wires=n_wires_state, shots=shots)
unitary = jnp.array(scst.unitary_group.rvs(8))
weights = jnp.array([0.,0.,0.])

def func_circ(weights,nq):
    qml.RY(weights, wires=nq)

@jax.jit
@qml.qnode(device=dev_state, interface="jax")#diff_method="finite-diff")
def cost_circuit(weights):
    func_circ(weights[1],1)
    qml.ctrl(qml.adjoint(func_circ),control=(0))(weights[2],2)
    qml.QubitUnitary(unitary, wires=[0,1,2])
    qml.measure(wires=1)
    return [qml.sample(qml.PauliZ(0)), qml.sample(qml.PauliZ(1))]
    
optimizer = optax.adam(lr)

@jax.jit
def costfunc(weights): 
    res = jnp.sum(jnp.array(cost_circuit(weights)), axis=1)
    vals, counts = jnp.unique(res, return_counts=True, size=1)
    res = counts[0]/len(res)
    cost = 1 - res
    return cost  

@jax.jit
def update_step_jit(i, args):
    weights, opt_state = args
    loss_val, grads = jax.value_and_grad(costfunc)(weights)
    updates, opt_state = optimizer.update(grads, opt_state)
    weights = optax.apply_updates(weights, updates)
    def print_fn():
        jax.debug.print("Step Adam pressure {i}, cost: {loss_val}", i=i, loss_val=loss_val)
    jax.lax.cond((jnp.mod(i, 50) == 0), print_fn, lambda: None)
    return (weights, opt_state)
        
@jax.jit
def optimization_jit(params, tsteps):

    opt_state = optimizer.init(params)
    args = (params, opt_state)
    (params, opt_state) = jax.lax.fori_loop(0, tsteps, update_step_jit, args)

    return params

weights = optimization_jit(weights, tsteps)
type or paste code here

and this is the error message I get:

Traceback (most recent call last):
File “/…/minimalExample.py”, line 50, in
print(costfunc(weights)) # works
^^^^^^^^^^^^^^^^^
File “/…/minimalExample.py”, line 43, in costfunc
res = jnp.sum(jnp.array(cost_circuit(weights)), axis=1)
^^^^^^^^^^^^^^^^^^^^^
File “/…/lib/python3.12/site-packages/pennylane/workflow/qnode.py”, line 905, in call
return self._impl_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/…/lib/python3.12/site-packages/pennylane/workflow/qnode.py”, line 881, in _impl_call
res = qml.execute(
^^^^^^^^^^^^
File “/…/lib/python3.12/site-packages/pennylane/workflow/execution.py”, line 227, in execute
tapes, post_processing = transform_program(tapes)
^^^^^^^^^^^^^^^^^^^^^^^^
File “/…/lib/python3.12/site-packages/pennylane/transforms/core/transform_program.py”, line 580, in call
new_tapes, fn = transform(tape, *targs, **tkwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/…/lib/python3.12/site-packages/pennylane/gradients/parameter_shift.py”, line 765, in _expand_transform_param_shift
[new_tape], postprocessing = qml.devices.preprocess.decompose(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/…/lib/python3.12/site-packages/pennylane/transforms/core/transform_dispatcher.py”, line 153, in call
transformed_tapes, processing_fn = self._transform(obj, *targs, **tkwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/…/lib/python3.12/site-packages/pennylane/devices/preprocess.py”, line 408, in decompose
if all(stopping_condition(op) for op in tape.operations[len(prep_op) :]):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/…/lib/python3.12/site-packages/pennylane/devices/preprocess.py”, line 408, in
if all(stopping_condition(op) for op in tape.operations[len(prep_op) :]):
^^^^^^^^^^^^^^^^^^^^^^
File “/…/lib/python3.12/site-packages/pennylane/gradients/parameter_shift.py”, line 746, in _param_shift_stopping_condition
if not op.has_decomposition:
^^^^^^^^^^^^^^^^^^^^
File “/…/lib/python3.12/site-packages/pennylane/ops/op_math/controlled.py”, line 721, in has_decomposition
if _is_single_qubit_special_unitary(self.base):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool
The error occurred while tracing the function cost_circuit at /superscratch/psiegl/PhD/QCProject/minimalExample.py:26 for jit. This concrete value was not available in Python because it depends on the value of the argument weights.
See Errors — JAX documentation

Is there a way to allow this combination?
Many thanks and best regards,
Pia