JAX TracerArrayConversionError when using qml.Hermitian with jax.vmap

Hi PennyLane community,

I’m encountering a jax.errors.TracerArrayConversionError when trying to use qml.Hermitian with jax.vmap. Here’s a minimal example of my code:

from icecream import ic
import pennylane as qml
import jax
import jax.numpy as jnp
import optax

jax.config.update("jax_enable_x64", True)

device = qml.device("default.qubit", wires=1)

dms = [
    jnp.array([[1, 0], [0, 0]]),
    jnp.array([[0, 0], [0, 1]]),
]

key = jax.random.PRNGKey(42)
ys = jnp.array([0, 1]).astype(int)
xs = jax.random.uniform(key, shape=(len(ys), ))

@qml.qnode(device, interface='jax')
def qnode(params, x: float, dm: jnp.ndarray) -> float:
    qml.RX(x, wires=0)
    qml.RX(params[0], wires=0)
    return qml.expval(qml.Hermitian(dm, wires=0))

qnn_vectorized = jax.jit(jax.vmap(qnode, in_axes=(None, 0, 0)))

def loss_function(params, x_set, dm):
    fidelities = qnn_vectorized(params, x_set, dm)
    loss = ((1 - fidelities) ** 2).sum() / len(fidelities)
    return loss

# Define optimizer
params = jnp.array([0.1])
opt = optax.adam(learning_rate=0.1)
opt_state = opt.init(params)
grad = jax.jit(jax.grad(loss_function))

for epoch in range(5):
    print(f'Epoch {epoch}')
    grad_value = grad(params, xs, dms)
    updates, opt_state = opt.update(grad_value, opt_state)
    params = optax.apply_updates(params, updates)

print("Optimized parameters:", params)

Error:

When running this code, I get the following error in line 24:
jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape int64[2]

Question:

How can I use qml.Hermitian with jax.vmap without encountering the TracerArrayConversionError? Is there a recommended way to handle this situation?
Any help or guidance would be greatly appreciated!

Thanks in advance.

Hi @Kernel123!

The issue is that dms is a list, where it should be a jnp.array instead.

So changing dms as follows should fix the issue.

dms = jnp.array([
    jnp.array([[1, 0], [0, 0]]),
    jnp.array([[0, 0], [0, 1]]),
])

If you keep having issues, can you please post the version of JAX you’re using? We don’t support JAX 0.5 yet.

I hope this solves your issue! Let me know if it does.

1 Like

Thank you so much, Catalina. That solved the problem. However, I still have a question related to this topic. When I change the device to default.mixed (I need to include some Quantum channels to simulate noise), I encounter a similar traceback error:
ValueError: Converting a JAX array to a NumPy array not supported when using the JAX JIT.
I haven’t made any changes to the provided code except for your suggestion (line 11, including dms in a jnp.array) and changing the device argument on line 9. Perhaps I should start a new topic in the PennyLane forum.

Thanks again for your support.

Hi @Kernel123,

I can replicate your issue.

I can see in the error traceback that there’s a function that converts to NumPy to calculate the eigendecomposition. We’d need to avoid that somehow.

I can also see that the issue is directly caused by the dm. If you remove it from everywhere in your code you can actually use default.mixed with no issues.

The only way I was able to avoid errors was to set a specific dm within the qnode, which I know is not what you want. I’ll add the code below in case it helps.

Error traceback

Epoch 0
---------------------------------------------------------------------------
TracerArrayConversionError                Traceback (most recent call last)
/usr/local/lib/python3.11/dist-packages/pennylane/math/single_dispatch.py in _to_numpy_jax(x)
    797     try:
--> 798         return np.array(getattr(x, "val", x))
    799     except TracerArrayConversionError as e:

20 frames
/usr/local/lib/python3.11/dist-packages/jax/_src/core.py in __array__(self, *args, **kw)
    714   def __array__(self, *args, **kw):
--> 715     raise TracerArrayConversionError(self)
    716 

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape int64[2,2,2]
The error occurred while tracing the function qnode at <ipython-input-8-d92340179049>:21 for jit. This concrete value was not available in Python because it depends on the value of the argument dm.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
<ipython-input-8-d92340179049> in <cell line: 0>()
     40 for epoch in range(5):
     41     print(f'Epoch {epoch}')
---> 42     grad_value = grad(params, xs, dms)
     43     updates, opt_state = opt.update(grad_value, opt_state)
     44     params = optax.apply_updates(params, updates)

    [... skipping hidden 21 frame]

<ipython-input-8-d92340179049> in loss_function(params, x_set, dm)
     28 
     29 def loss_function(params, x_set, dm):
---> 30     fidelities = qnn_vectorized(params, x_set, dm)
     31     loss = ((1 - fidelities) ** 2).sum() / len(fidelities)
     32     return loss

    [... skipping hidden 14 frame]

/usr/local/lib/python3.11/dist-packages/pennylane/workflow/qnode.py in __call__(self, *args, **kwargs)
    903         if qml.capture.enabled():
    904             return capture_qnode(self, *args, **kwargs)
--> 905         return self._impl_call(*args, **kwargs)
    906 
    907 

/usr/local/lib/python3.11/dist-packages/pennylane/workflow/qnode.py in _impl_call(self, *args, **kwargs)
    879         self._transform_program.set_classical_component(self, args, kwargs)
    880 
--> 881         res = qml.execute(
    882             (tape,),
    883             device=self.device,

/usr/local/lib/python3.11/dist-packages/pennylane/workflow/execution.py in execute(tapes, device, diff_method, interface, transform_program, inner_transform, config, grad_on_execution, gradient_kwargs, cache, cachesize, max_diff, device_vjp, mcm_config, gradient_fn)
    230         return post_processing(tapes)
    231 
--> 232     results = run(tapes, device, config, inner_transform)
    233     return post_processing(results)

/usr/local/lib/python3.11/dist-packages/pennylane/workflow/run.py in run(tapes, device, config, inner_transform_program)
    285     )
    286     if no_interface_boundary_required:
--> 287         results = inner_execute(tapes)
    288         return results
    289 

/usr/local/lib/python3.11/dist-packages/pennylane/workflow/run.py in inner_execute(tapes)
    245 
    246         if transformed_tapes:
--> 247             results = device.execute(transformed_tapes, execution_config=execution_config)
    248         else:
    249             results = ()

/usr/local/lib/python3.11/dist-packages/pennylane/devices/modifiers/single_tape_support.py in execute(self, circuits, execution_config)
     30             is_single_circuit = True
     31             circuits = (circuits,)
---> 32         results = batch_execute(self, circuits, execution_config)
     33         return results[0] if is_single_circuit else results
     34 

/usr/local/lib/python3.11/dist-packages/pennylane/devices/legacy_facade.py in execute(self, circuits, execution_config)
    372         first_shot = circuits[0].shots
    373         if all(t.shots == first_shot for t in circuits):
--> 374             return _set_shots(dev, first_shot)(dev.batch_execute)(circuits, **kwargs)
    375         return tuple(
    376             _set_shots(dev, t.shots)(dev.batch_execute)((t,), **kwargs)[0] for t in circuits

/usr/lib/python3.11/contextlib.py in inner(*args, **kwds)
     79         def inner(*args, **kwds):
     80             with self._recreate_cm():
---> 81                 return func(*args, **kwds)
     82         return inner
     83 

/usr/local/lib/python3.11/dist-packages/pennylane/devices/_qubit_device.py in batch_execute(self, circuits, **kwargs)
    481             self.reset()
    482 
--> 483             res = self.execute(circuit, **kwargs)
    484             results.append(res)
    485 

/usr/local/lib/python3.11/dist-packages/pennylane/logging/decorators.py in wrapper_entry(*args, **kwargs)
     59                 **_debug_log_kwargs,
     60             )
---> 61         return func(*args, **kwargs)
     62 
     63     @wraps(func)

/usr/local/lib/python3.11/dist-packages/pennylane/devices/default_mixed.py in execute(self, circuit, **kwargs)
    851                 wires_list.append(m.wires)
    852             self.measured_wires = qml.wires.Wires.all_wires(wires_list)
--> 853         return super().execute(circuit, **kwargs)
    854 
    855     @debug_logger

/usr/local/lib/python3.11/dist-packages/pennylane/devices/_qubit_device.py in execute(self, circuit, **kwargs)
    271         self.apply(
    272             circuit.operations,
--> 273             rotations=self._get_diagonalizing_gates(circuit),
    274             **kwargs,
    275         )

/usr/local/lib/python3.11/dist-packages/pennylane/devices/_qubit_device.py in _get_diagonalizing_gates(self, circuit)
   1777         """
   1778         # pylint:disable=no-self-use
-> 1779         return circuit.diagonalizing_gates
   1780 
   1781     def _is_lightning_device(self):

/usr/local/lib/python3.11/dist-packages/pennylane/tape/qscript.py in diagonalizing_gates(self)
    404                 # in which case we just don't append any
    405                 with contextlib.suppress(qml.operation.DiagGatesUndefinedError):
--> 406                     rotation_gates.extend(observable.diagonalizing_gates())
    407         return rotation_gates
    408 

/usr/local/lib/python3.11/dist-packages/pennylane/ops/qubit/observables.py in diagonalizing_gates(self)
    257         """
    258         # note: compute_diagonalizing_gates has a custom signature, which is why we overwrite this method
--> 259         return self.compute_diagonalizing_gates(self.eigendecomposition["eigvec"], self.wires)
    260 
    261 

/usr/local/lib/python3.11/dist-packages/pennylane/ops/qubit/observables.py in eigendecomposition(self)
    152         """
    153         Hmat = self.matrix()
--> 154         Hmat = qml.math.to_numpy(Hmat)
    155         Hkey = tuple(Hmat.flatten().tolist())
    156         if Hkey not in Hermitian._eigs:

/usr/local/lib/python3.11/dist-packages/autoray/autoray.py in do(fn, like, *args, **kwargs)
     79     backend = _choose_backend(fn, args, kwargs, like=like)
     80     func = get_lib_fn(backend, fn)
---> 81     return func(*args, **kwargs)
     82 
     83 

/usr/local/lib/python3.11/dist-packages/pennylane/math/single_dispatch.py in _to_numpy_jax(x)
    798         return np.array(getattr(x, "val", x))
    799     except TracerArrayConversionError as e:
--> 800         raise ValueError(
    801             "Converting a JAX array to a NumPy array not supported when using the JAX JIT."
    802         ) from e

ValueError: Converting a JAX array to a NumPy array not supported when using the JAX JIT.

Removing dm almost everywhere and adding a single dm inside the qnode. The inline comments show where I’ve removed dm.

import pennylane as qml
import jax
import jax.numpy as jnp
import optax

jax.config.update("jax_enable_x64", True)

device = qml.device("default.mixed", wires=1)

dms = jnp.array([
    jnp.array([[1, 0], [0, 0]]),
    jnp.array([[0, 0], [0, 1]]),
])

key = jax.random.PRNGKey(42)
ys = jnp.array([0, 1]).astype(int)
xs = jax.random.uniform(key, shape=(len(ys), ))

@qml.qnode(device, interface='jax')
def qnode(params, x: float) -> float: #, dm: jnp.ndarray
    dm = [[1, 0], [0, 0]]
    qml.RX(x, wires=0)
    qml.RX(params[0], wires=0)
    return qml.expval(qml.Hermitian(dm, wires=0)) 

qnn_vectorized = jax.jit(jax.vmap(qnode, in_axes=(None, 0))) #, 0

def loss_function(params, x_set): #, dm
    fidelities = qnn_vectorized(params, x_set) #, dm
    loss = ((1 - fidelities) ** 2).sum() / len(fidelities)
    return loss

# Define optimizer
params = jnp.array([0.1])
opt = optax.adam(learning_rate=0.1)
opt_state = opt.init(params)
grad = jax.jit(jax.grad(loss_function))

for epoch in range(5):
    print(f'Epoch {epoch}')
    grad_value = grad(params, xs) #, dms
    updates, opt_state = opt.update(grad_value, opt_state)
    params = optax.apply_updates(params, updates)

print("Optimized parameters:", params)

Hello Catalina,

Thank you very much for your time and comments. I’ve also noticed that the issue lies within the source code of PennyLane, where the device attempts to compute the eigenstates of dm. However, since dm is a jax.array, it raises an error.

As you mentioned, I need to vary dm (I am creating a variational measurement). Therefore, this doesn’t resolve my issue (but thank you anyway! :smiley: ). I’ll attempt to resolve it using other approaches. If I find a solution, I’ll update this post in case it helps anyone else who might be interested.

1 Like

Hi @Kernel123 ,

I have some good news! It looks like some recent changes to default.mixed should avoid this issue. You can use this with the current master version of PennyLane. You can install this version with

pip install git+https://github.com/PennyLaneAI/pennylane.git#egg=pennylane

This will be available in the next stable release of PennyLane in about 1 month.

Let me know if this solves your current issue!