Extending Operation with QJIT

This is the contents of a file that I’m trying to import,

import pennylane as qml
from catalyst import qjit
from pennylane.numpy import pi, tensor
from pennylane.operation import Operation


class StronglyEntanglingLayer(Operation):
    num_params = 6
    num_wires = 2
    grad_method = None
    grad_recipe = ([[0.5, 1, pi / 2], [-0.5, 1, -pi / 2]],)

    @qjit
    def compute_decomposition(
        weight0: tensor,
        weight1: tensor,
        weight2: tensor,
        weight3: tensor,
        weight4: tensor,
        weight5: tensor,
        wires: list,
    ):
        op_list = [
            qml.RZ(weight0, wires=wires[0]),
            qml.RZ(weight1, wires=wires[1]),
            qml.RY(weight2, wires=wires[0]),
            qml.RY(weight3, wires=wires[1]),
            qml.RZ(weight4, wires=wires[0]),
            qml.RZ(weight5, wires=wires[1]),
            qml.CNOT(wires=[wires[0], wires[1]]),
            qml.CNOT(wires=[wires[1], wires[0]]),
        ]
        return op_list

This is the error message that I get shown when trying to import the above

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
File /project_title/.venv/lib/python3.11/site-packages/jax/_src/api_util.py:450, in shaped_abstractify(x)
    449 try:
--> 450   return _shaped_abstractify_handlers[type(x)](x)
    451 except KeyError:

KeyError: <class 'type'>

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
File /project_title/.venv/lib/python3.11/site-packages/jax/_src/dtypes.py:101, in _canonicalize_dtype(x64_enabled, allow_opaque_dtype, dtype)
    100 try:
--> 101   dtype_ = np.dtype(dtype)
    102 except TypeError as e:

TypeError: Cannot interpret '<attribute 'dtype' of 'numpy.ndarray' objects>' as a data type

The above exception was the direct cause of the following exception:

TypeError                                 Traceback (most recent call last)
File /project_title/.venv/lib/python3.11/site-packages/catalyst/compilation_pipelines.py:214, in CompiledFunction.get_runtime_signature(*args)
    213 for arg in args:
--> 214     r_sig.append(jax.api_util.shaped_abstractify(arg))
    215 return r_sig

File /project_title/.venv/lib/python3.11/site-packages/jax/_src/api_util.py:452, in shaped_abstractify(x)
    451 except KeyError:
--> 452   return _shaped_abstractify_slow(x)

File /project_title/.venv/lib/python3.11/site-packages/jax/_src/api_util.py:441, in _shaped_abstractify_slow(x)
    440 if hasattr(x, 'dtype'):
--> 441   dtype = dtypes.canonicalize_dtype(x.dtype, allow_opaque_dtype=True)
    442 else:

File /project_title/.venv/lib/python3.11/site-packages/jax/_src/dtypes.py:117, in canonicalize_dtype(dtype, allow_opaque_dtype)
    116 def canonicalize_dtype(dtype: Any, allow_opaque_dtype: bool = False) -> Union[DType, OpaqueDType]:
--> 117   return _canonicalize_dtype(config.x64_enabled, allow_opaque_dtype, dtype)

File /project_title/.venv/lib/python3.11/site-packages/jax/_src/dtypes.py:103, in _canonicalize_dtype(x64_enabled, allow_opaque_dtype, dtype)
    102 except TypeError as e:
--> 103   raise TypeError(f'dtype {dtype!r} not understood') from e
    105 if x64_enabled:

TypeError: dtype <attribute 'dtype' of 'numpy.ndarray' objects> not understood

The above exception was the direct cause of the following exception:

TypeError                                 Traceback (most recent call last)
Cell In[1], line 8
      5 import matplotlib.pyplot as plt
      7 import quantum.project_title.data.data_generation as data_gen
----> 8 from quantum.project_title.circuit_components.ansatz import StronglyEntanglingLayer
      9 from quantum.project_title.circuit_components.data_loader import PriceLoader
     10 from quantum.project_title.utils.utils import (angle_encode_spot_price, 
     11                                                                normalize_option_price, 
     12                                                                decode, 
     13                                                                calculate_delta)

File /project_title/quantum/project_title/circuit_components/ansatz.py:7
      3 from pennylane.numpy import pi, tensor
      4 from pennylane.operation import Operation
----> 7 class StronglyEntanglingLayer(Operation):
      8     num_params = 6
      9     num_wires = 2

File /project_title/quantum/project_title/circuit_components/ansatz.py:13, in StronglyEntanglingLayer()
     10 grad_method = None
     11 grad_recipe = ([[0.5, 1, pi / 2], [-0.5, 1, -pi / 2]],)
---> 13 @qjit
     14 def compute_decomposition(
     15     weight0: tensor,
     16     weight1: tensor,
     17     weight2: tensor,
     18     weight3: tensor,
     19     weight4: tensor,
     20     weight5: tensor,
     21     wires: list,
     22 ):
     23     op_list = [
     24         qml.RZ(weight0, wires=wires[0]),
     25         qml.RZ(weight1, wires=wires[1]),
   (...)
     31         qml.CNOT(wires=[wires[1], wires[0]]),
     32     ]
     33     return op_list

File /project_title/.venv/lib/python3.11/site-packages/catalyst/compilation_pipelines.py:675, in qjit(fn, target, keep_intermediate, verbose, logfile)
    588 """A just-in-time decorator for PennyLane and JAX programs using Catalyst.
    589 
    590 This decorator enables both just-in-time and ahead-of-time compilation,
   (...)
    671         :class:`~.pennylane_extensions.QJITDevice`.
    672 """
    674 if fn is not None:
--> 675     return QJIT(fn, target, keep_intermediate, CompileOptions(verbose, logfile))
    677 def wrap_fn(fn):
    678     return QJIT(fn, target, keep_intermediate, CompileOptions(verbose, logfile))

File /project_title/.venv/lib/python3.11/site-packages/catalyst/compilation_pipelines.py:484, in QJIT.__init__(self, fn, target, keep_intermediate, compile_options)
    482 self.user_typed = True
    483 if target in ("mlir", "binary"):
--> 484     self.mlir_module = self.get_mlir(*parameter_types)
    485 if target == "binary":
    486     self.compiled_function = self.compile()

Hey @Nikhil_Narayanan!

I think there’s some of the error message missing. Is that all of the output?

Yes my bad, the bottom of the error was cut off:

File /project_title/.venv/lib/python3.11/site-packages/catalyst/compilation_pipelines.py:529, in QJIT.get_mlir(self, *args)
    520 def get_mlir(self, *args):
    521     """Trace self.qfunc
    522 
    523     Args:
   (...)
    527         an MLIR module
    528     """
--> 529     self.c_sig = CompiledFunction.get_runtime_signature(*args)
    531     with Patcher(
    532         (qml.QNode, "__call__", QFunc.__call__),
    533     ):
    534         mlir_module, ctx, jaxpr = tracer.get_mlir(self.qfunc, *self.c_sig)

File /project_title/.venv/lib/python3.11/site-packages/catalyst/compilation_pipelines.py:218, in CompiledFunction.get_runtime_signature(*args)
    216 except Exception as exc:
    217     arg_type = type(arg)
--> 218     raise TypeError(f"Unsupported argument type: {arg_type}") from exc

TypeError: Unsupported argument type: <class 'type'>

Thanks! So you’re just trying to import the above operation elsewhere and that’s giving you the error you’re seeing?

1 Like

Yes the error happens at the import

I think the error is probably related to the fact that your class is incomplete. Check this out: Adding new operators — PennyLane 0.30.0 documentation

It looks like you’re missing __init__!

Okay, just spoke to some of our development team for Catalyst. It’s a relatively new project for me, so apologies that I didn’t understand what the issue was immediately! :sweat_smile:

This is what they said:

It appears that you are trying to return a list of operations. This is not currently allowed inside functions that will be JIT-compiled.

It would be good to know what you’re trying to do :thinking:. What’s your end goal?

I am trying to call the function as below. I’m trying to train a variational circuit that does regression while taking advantage of the JAX features in Catalyst

@qml.qnode(dev, diff_method = "parameter-shift")
def qnn(phi, *weights):
    """
    Input a numpy array feature (which encodes a single normalized angle encoded spot price)
    """
    qml.RY(phi, wires=0)
    qml.RY(phi, wires=1)
    qml.IsingXX(phi, wires=[0, 1])
    qml.RY(phi, wires=0)
    qml.RY(phi, wires=1)
    for i in range(3):
        StronglyEntanglingLayer(*weights[(i * 6):((i + 1) * 6)], wires=[0,1])
    return qml.expval(qml.PauliZ(0))

Moreover, my other functions are as below:

@qjit
def network_fn(angle_encoded_spot, *weights):
    quantum_out = qnn(angle_encoded_spot, *weights)
    return (quantum_out + 1) / 2

# Cost function using output of Q-Node
@qjit
def param_shift_cost(target, angle_encoded_spot, *weights):
    output = network_fn(angle_encoded_spot, *weights)
    return (output - target) ** 2 / target

Hey @Nikhil_Narayanan! I spoke to some catalyst developers and here’s what the solution is.

  1. remove @qjit from custom_decomposition.
  2. add @staticmethod to custom_decomposition

compute_decomposition is QJIT’d regardless, so no need to do it here :). One more thing to change is network_fn:

@qjit
def network_fn(angle_encoded_spot, *weights):
    quantum_out = qnn(angle_encoded_spot, *weights)
    return (quantum_out[0] + 1) / 2

Side note: in the next release of Catalyst, you’ll be able to do this:

@qjit
def network_fn(angle_encoded_spot, *weights):
    quantum_out = qnn(angle_encoded_spot, *weights)
    return (quantum_out + 1) / 2

Here’s a complete example:

from catalyst import qjit, grad
import pennylane as qml
import numpy as np
from pennylane.numpy import pi, tensor

from jax import numpy as jnp
from catalyst import qjit

class StronglyEntanglingLayer(qml.operation.Operation):
    num_params = 6
    num_wires = 2
    grad_method = None
    grad_recipe = ([[0.5, 1, np.pi / 2], [-0.5, 1, -np.pi / 2]],)

    def __init__(self, *weights, wires=None):
        super().__init__(*weights, wires=wires)

    @staticmethod
    def compute_decomposition(w0, w1, w2, w3, w4, w5, wires):
        op_list = [
            qml.RZ(w0, wires=wires[0]),
            qml.RZ(w1, wires=wires[1]),
            qml.RY(w2, wires=wires[0]),
            qml.RY(w3, wires=wires[1]),
            qml.RZ(w4, wires=wires[0]),
            qml.RZ(w5, wires=wires[1]),
            qml.CNOT(wires=[wires[0], wires[1]]),
            qml.CNOT(wires=[wires[1], wires[0]]),
        ]
        return op_list

@qml.qnode(qml.device("lightning.qubit", wires=2), diff_method = "parameter-shift")
def qnn(phi, weights):
    """
    Input a numpy array feature (which encodes a single normalized angle encoded spot price)
    """
    qml.RY(phi, wires=0)
    qml.RY(phi, wires=1)
    qml.IsingXX(phi, wires=[0, 1])
    qml.RY(phi, wires=0)
    qml.RY(phi, wires=1)
    for i in range(3):
        StronglyEntanglingLayer(*weights[(i * 6):((i + 1) * 6)], wires=[0,1])
    return qml.expval(qml.PauliZ(0))

@qjit
def network_fn(angle_encoded_spot, weights):
    quantum_out = qnn(angle_encoded_spot, weights)
    return (quantum_out + 1) / 2

# Cost function using output of Q-Node
@qjit
def param_shift_cost(target, angle_encoded_spot, weights):
    output = network_fn(angle_encoded_spot, weights)
    return (output - target) ** 2 / target

print(param_shift_cost(1.0, 3.14, jnp.array(list(range(18)))))
print(param_shift_cost.mlir)

Let me know if this helps!

I want to try to train my network using Adam optimizer - I am planning on using optax as shown here: Stochastic optimization — JAXopt 0.7 documentation. I will try this tomorrow, but am unsure whether this would be supported to QJIT using catalyst?

@Nikhil_Narayanan did my previous response help? Let me know if you run into any more issues!

I think my main question currently is if jaxopt and optax are compatible with Catalyst currently?

Hi your previous response helped, but I’m having a lot of trouble in the training step; I imported jaxopt and optax and below is the code which I’m using to train

def data_iterator(rng_key):
     
    rng_key, subkey = jax.random.split(rng_key)
    perm = jax.random.permutation(rng_key, jnp.array(range(len(training_x))))
    for index in perm:
        yield (training_x[index], training_y[index])

@qjit
def optimisation():
    
    # initial weights
    init_weights = jnp.array([0.6191368085366578, 
                -0.22716372247007463, 
                0.39964323632317433, 
                1.0533182225902902, 
                1.1384541354087307, 
                -0.31927423082968687, 
                -0.38598253427438795, 
                1.567254004140948, 
                1.8395378005048666, 
                0.4180817001090079, 
                1.062053061058973, 
                0.5349723405892166, 
                0.36431094419401455, 
                1.629228622648432, 
                0.2990813054605557, 
                -0.6528184516978838, 
                0.6739291834199453, 
                0.27512818131726857])
    
    # define optimizer
    opt = optax.adam(0.01)
    solver = OptaxSolver(opt=opt, fun=param_shift_cost, maxiter=80)
    rng_key = jax.random.PRNGKey(42)

    iterator = data_iterator(rng_key)
    solver.run_iterator(init_weights, iterator)

However, I get the following errors when I run the above:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[27], line 8
      5     for index in perm:
      6         yield (training_x[index], training_y[index])
----> 8 @qjit
      9 def optimisation():
     10     
     11     # initial weights
     12     init_weights = jnp.array([0.6191368085366578, 
     13                 -0.22716372247007463, 
     14                 0.39964323632317433, 
   (...)
     28                 0.6739291834199453, 
     29                 0.27512818131726857])
     31     # define optimizer

File /quantum_risk_engine/.venv/lib/python3.11/site-packages/catalyst/compilation_pipelines.py:675, in qjit(fn, target, keep_intermediate, verbose, logfile)
    588 """A just-in-time decorator for PennyLane and JAX programs using Catalyst.
    589 
    590 This decorator enables both just-in-time and ahead-of-time compilation,
   (...)
    671         :class:`~.pennylane_extensions.QJITDevice`.
    672 """
    674 if fn is not None:
--> 675     return QJIT(fn, target, keep_intermediate, CompileOptions(verbose, logfile))
    677 def wrap_fn(fn):
    678     return QJIT(fn, target, keep_intermediate, CompileOptions(verbose, logfile))

File /quantum_risk_engine/.venv/lib/python3.11/site-packages/catalyst/compilation_pipelines.py:484, in QJIT.__init__(self, fn, target, keep_intermediate, compile_options)
    482 self.user_typed = True
    483 if target in ("mlir", "binary"):
--> 484     self.mlir_module = self.get_mlir(*parameter_types)
    485 if target == "binary":
    486     self.compiled_function = self.compile()

File /quantum_risk_engine/.venv/lib/python3.11/site-packages/catalyst/compilation_pipelines.py:534, in QJIT.get_mlir(self, *args)
    529 self.c_sig = CompiledFunction.get_runtime_signature(*args)
    531 with Patcher(
    532     (qml.QNode, "__call__", QFunc.__call__),
    533 ):
--> 534     mlir_module, ctx, jaxpr = tracer.get_mlir(self.qfunc, *self.c_sig)
    536 mod = mlir_module.operation
    537 self._jaxpr = jaxpr

File /quantum_risk_engine/.venv/lib/python3.11/site-packages/catalyst/jax_tracer.py:64, in get_mlir(func, *args, **kwargs)
     61 jprim.mlir_fn_cache.clear()
     63 with TracingContext():
---> 64     jaxpr = jax.make_jaxpr(func)(*args, **kwargs)
     66 nrep = jaxpr_replicas(jaxpr)
     67 effects = [eff for eff in jaxpr.effects if eff in jax.core.ordered_effects]

    [... skipping hidden 6 frame]

Cell In[27], line 37, in optimisation()
     34 rng_key = jax.random.PRNGKey(42)
     36 iterator = data_iterator(rng_key)
---> 37 solver.run_iterator(init_weights, iterator)

File /quantum_risk_engine/.venv/lib/python3.11/site-packages/jaxopt/_src/base.py:399, in StochasticSolver.run_iterator(self, init_params, iterator, *args, **kwargs)
    396 # TODO(mblondel): try and benchmark lax.fori_loop with host_call for `next`.
    397 for data in itertools.islice(iterator, 0, self.maxiter):
--> 399   params, state = self.update(params, state, *args, **kwargs, data=data)
    401 return OptStep(params=params, state=state)

File /quantum_risk_engine/.venv/lib/python3.11/site-packages/jaxopt/_src/optax_wrapper.py:141, in OptaxSolver.update(self, params, state, *args, **kwargs)
    138 if self.pre_update:
    139   params, state = self.pre_update(params, state, *args, **kwargs)
--> 141 (value, aux), grad = self._value_and_grad_fun(params, *args, **kwargs)
    143 delta, opt_state = self.opt.update(grad, state.internal_state, params)
    144 params = self._apply_updates(params, delta)

    [... skipping hidden 8 frame]

File /quantum_risk_engine/.venv/lib/python3.11/site-packages/jaxopt/_src/base.py:70, in _add_aux_to_fun.<locals>.fun_with_aux(*a, **kw)
     69 def fun_with_aux(*a, **kw):
---> 70   return fun(*a, **kw), None

File /quantum_risk_engine/.venv/lib/python3.11/site-packages/catalyst/compilation_pipelines.py:561, in QJIT.__call__(self, *args, **kwargs)
    559 def __call__(self, *args, **kwargs):
    560     if TracingContext.is_tracing():
--> 561         return self.qfunc(*args, **kwargs)
    563     if any(isinstance(arg, jax.core.Tracer) for arg in args):
    564         raise ValueError(
    565             "Cannot use JAX to trace through a qjit compiled function. If you attempted "
    566             "to use jax.jit or jax.grad, please use their equivalent from Catalyst."
    567         )

TypeError: param_shift_cost() got multiple values for argument 'data'

Are you able to get your code working if you loop manually? I.e., what’s depicted here: Stochastic optimization — JAXopt 0.7 documentation

It would be good to see if your code works manually updating over one data point (no batching).

I’ve been trying the following with no luck (without batching) - I’ll continue trying to debug and let you know if I get it to work

def data_iterator(rng_key):  
    rng_key, subkey = jax.random.split(rng_key)
    perm = jax.random.permutation(rng_key, jnp.array(range(len(training_x))))
    for index in perm:
        yield (training_x[index], training_y[index])

@qjit
def optimisation():
    
    # initial weights
    init_weights = jnp.array([0.6191368085366578, 
                -0.22716372247007463, 
                0.39964323632317433, 
                1.0533182225902902, 
                1.1384541354087307, 
                -0.31927423082968687, 
                -0.38598253427438795, 
                1.567254004140948, 
                1.8395378005048666, 
                0.4180817001090079, 
                1.062053061058973, 
                0.5349723405892166, 
                0.36431094419401455, 
                1.629228622648432, 
                0.2990813054605557, 
                -0.6528184516978838, 
                0.6739291834199453, 
                0.27512818131726857])
    
    # define optimizer
    opt = optax.adam(0.01)
    solver = OptaxSolver(opt=opt, fun=param_shift_cost, maxiter=80)
    rng_key = jax.random.PRNGKey(42)

    iterator = data_iterator(rng_key)
    solver.run(init_weights, angle_encoded_spot = training_x[0], target = training_y[0])

Below is the error that I got:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[47], line 7
      4     for index in perm:
      5         yield (training_x[index], training_y[index])
----> 7 @qjit
      8 def optimisation():
      9     
     10     # initial weights
     11     init_weights = jnp.array([0.6191368085366578, 
     12                 -0.22716372247007463, 
     13                 0.39964323632317433, 
   (...)
     27                 0.6739291834199453, 
     28                 0.27512818131726857])
     30     # define optimizer

File /quantum_risk_engine/.venv/lib/python3.11/site-packages/catalyst/compilation_pipelines.py:675, in qjit(fn, target, keep_intermediate, verbose, logfile)
    588 """A just-in-time decorator for PennyLane and JAX programs using Catalyst.
    589 
    590 This decorator enables both just-in-time and ahead-of-time compilation,
   (...)
    671         :class:`~.pennylane_extensions.QJITDevice`.
    672 """
    674 if fn is not None:
--> 675     return QJIT(fn, target, keep_intermediate, CompileOptions(verbose, logfile))
    677 def wrap_fn(fn):
    678     return QJIT(fn, target, keep_intermediate, CompileOptions(verbose, logfile))

File /quantum_risk_engine/.venv/lib/python3.11/site-packages/catalyst/compilation_pipelines.py:484, in QJIT.__init__(self, fn, target, keep_intermediate, compile_options)
    482 self.user_typed = True
    483 if target in ("mlir", "binary"):
--> 484     self.mlir_module = self.get_mlir(*parameter_types)
    485 if target == "binary":
    486     self.compiled_function = self.compile()

File /quantum_risk_engine/.venv/lib/python3.11/site-packages/catalyst/compilation_pipelines.py:534, in QJIT.get_mlir(self, *args)
    529 self.c_sig = CompiledFunction.get_runtime_signature(*args)
    531 with Patcher(
    532     (qml.QNode, "__call__", QFunc.__call__),
    533 ):
--> 534     mlir_module, ctx, jaxpr = tracer.get_mlir(self.qfunc, *self.c_sig)
    536 mod = mlir_module.operation
    537 self._jaxpr = jaxpr

File /quantum_risk_engine/.venv/lib/python3.11/site-packages/catalyst/jax_tracer.py:64, in get_mlir(func, *args, **kwargs)
     61 jprim.mlir_fn_cache.clear()
     63 with TracingContext():
---> 64     jaxpr = jax.make_jaxpr(func)(*args, **kwargs)
     66 nrep = jaxpr_replicas(jaxpr)
     67 effects = [eff for eff in jaxpr.effects if eff in jax.core.ordered_effects]

    [... skipping hidden 6 frame]

Cell In[47], line 37, in optimisation()
     35 iterator = data_iterator(rng_key)
     36 print(training_x[0])
---> 37 solver.run(init_weights, angle_encoded_spot = training_x[0], target = training_y[0])

File /quantum_risk_engine/.venv/lib/python3.11/site-packages/jaxopt/_src/base.py:354, in IterativeSolver.run(self, init_params, *args, **kwargs)
    347   decorator = idf.custom_root(
    348       self.optimality_fun,
    349       has_aux=True,
    350       solve=self.implicit_diff_solve,
    351       reference_signature=reference_signature)
    352   run = decorator(run)
--> 354 return run(init_params, *args, **kwargs)

File /quantum_risk_engine/.venv/lib/python3.11/site-packages/jaxopt/_src/base.py:316, in IterativeSolver._run(self, init_params, *args, **kwargs)
    298 # We unroll the very first iteration. This allows `init_val` and `body_fun`
    299 # below to have the same output type, which is a requirement of
    300 # lax.while_loop and lax.scan.
   (...)
    311 # of a `lax.cond` for now in order to avoid staging the initial
    312 # update and the run loop. They might not be staging compatible.
    314 zero_step = self._make_zero_step(init_params, state)
--> 316 opt_step = self.update(init_params, state, *args, **kwargs)
    317 init_val = (opt_step, (args, kwargs))
    319 jit, unroll = self._get_loop_options()

File /quantum_risk_engine/.venv/lib/python3.11/site-packages/jaxopt/_src/optax_wrapper.py:141, in OptaxSolver.update(self, params, state, *args, **kwargs)
    138 if self.pre_update:
    139   params, state = self.pre_update(params, state, *args, **kwargs)
--> 141 (value, aux), grad = self._value_and_grad_fun(params, *args, **kwargs)
    143 delta, opt_state = self.opt.update(grad, state.internal_state, params)
    144 params = self._apply_updates(params, delta)

    [... skipping hidden 8 frame]

File /quantum_risk_engine/.venv/lib/python3.11/site-packages/jaxopt/_src/base.py:70, in _add_aux_to_fun.<locals>.fun_with_aux(*a, **kw)
     69 def fun_with_aux(*a, **kw):
---> 70   return fun(*a, **kw), None

File /quantum_risk_engine/.venv/lib/python3.11/site-packages/catalyst/compilation_pipelines.py:561, in QJIT.__call__(self, *args, **kwargs)
    559 def __call__(self, *args, **kwargs):
    560     if TracingContext.is_tracing():
--> 561         return self.qfunc(*args, **kwargs)
    563     if any(isinstance(arg, jax.core.Tracer) for arg in args):
    564         raise ValueError(
    565             "Cannot use JAX to trace through a qjit compiled function. If you attempted "
    566             "to use jax.jit or jax.grad, please use their equivalent from Catalyst."
    567         )

TypeError: param_shift_cost() got multiple values for argument 'angle_encoded_spot'

Can you provide a small dummy dataset that we can use to try and replicate the error on our end?

Hey @Nikhil_Narayanan! There is a slightly subtlety to your question – when you say compatible, do you mean inside the qjit or outside the qjit?

  • Outside the qjit, the qjitted function will look like a regular Python function, so will work with either jaxopt or optax. However, JAX won’t know how to differentiate a qjit function natively, so you would need to pass the optimizer both the cost function and the gradient function. For example,

    def cost(params)
    
    @qjit
    def cost_and_grad(params):
        grad = catalyst.grad(cost, argnum=0)
        return cost(params), grad(params)[0]
    
    opt = jaxopt.GradientDescent(cost_and_grad, stepsize=0.4, value_and_grad=True)
    
  • Inside the qjit, it is a little more complicated – you need to make sure the optimizer you are working with is jax.jit compatible. This is the case for jaxopt (and you can see an example in the docs), but I haven’t tested optax. In either case, you still need to use the trick above.


Note: Catalyst version 0.2.0, coming out soon, has better JAX integration, and will make it easier to integrate with the optimizers :slight_smile:

1 Like

Thank you @isaacdevlugt and @josh! This clarifies it.

1 Like

Awesome! Glad we were able to help here. And thanks for taking Catalyst for a spin! :smile: