Hi @Kernel123,
I can replicate your issue.
I can see in the error traceback that there’s a function that converts to NumPy to calculate the eigendecomposition. We’d need to avoid that somehow.
I can also see that the issue is directly caused by the dm. If you remove it from everywhere in your code you can actually use default.mixed
with no issues.
The only way I was able to avoid errors was to set a specific dm within the qnode, which I know is not what you want. I’ll add the code below in case it helps.
Error traceback
Epoch 0
---------------------------------------------------------------------------
TracerArrayConversionError Traceback (most recent call last)
/usr/local/lib/python3.11/dist-packages/pennylane/math/single_dispatch.py in _to_numpy_jax(x)
797 try:
--> 798 return np.array(getattr(x, "val", x))
799 except TracerArrayConversionError as e:
20 frames
/usr/local/lib/python3.11/dist-packages/jax/_src/core.py in __array__(self, *args, **kw)
714 def __array__(self, *args, **kw):
--> 715 raise TracerArrayConversionError(self)
716
TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape int64[2,2,2]
The error occurred while tracing the function qnode at <ipython-input-8-d92340179049>:21 for jit. This concrete value was not available in Python because it depends on the value of the argument dm.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
The above exception was the direct cause of the following exception:
ValueError Traceback (most recent call last)
<ipython-input-8-d92340179049> in <cell line: 0>()
40 for epoch in range(5):
41 print(f'Epoch {epoch}')
---> 42 grad_value = grad(params, xs, dms)
43 updates, opt_state = opt.update(grad_value, opt_state)
44 params = optax.apply_updates(params, updates)
[... skipping hidden 21 frame]
<ipython-input-8-d92340179049> in loss_function(params, x_set, dm)
28
29 def loss_function(params, x_set, dm):
---> 30 fidelities = qnn_vectorized(params, x_set, dm)
31 loss = ((1 - fidelities) ** 2).sum() / len(fidelities)
32 return loss
[... skipping hidden 14 frame]
/usr/local/lib/python3.11/dist-packages/pennylane/workflow/qnode.py in __call__(self, *args, **kwargs)
903 if qml.capture.enabled():
904 return capture_qnode(self, *args, **kwargs)
--> 905 return self._impl_call(*args, **kwargs)
906
907
/usr/local/lib/python3.11/dist-packages/pennylane/workflow/qnode.py in _impl_call(self, *args, **kwargs)
879 self._transform_program.set_classical_component(self, args, kwargs)
880
--> 881 res = qml.execute(
882 (tape,),
883 device=self.device,
/usr/local/lib/python3.11/dist-packages/pennylane/workflow/execution.py in execute(tapes, device, diff_method, interface, transform_program, inner_transform, config, grad_on_execution, gradient_kwargs, cache, cachesize, max_diff, device_vjp, mcm_config, gradient_fn)
230 return post_processing(tapes)
231
--> 232 results = run(tapes, device, config, inner_transform)
233 return post_processing(results)
/usr/local/lib/python3.11/dist-packages/pennylane/workflow/run.py in run(tapes, device, config, inner_transform_program)
285 )
286 if no_interface_boundary_required:
--> 287 results = inner_execute(tapes)
288 return results
289
/usr/local/lib/python3.11/dist-packages/pennylane/workflow/run.py in inner_execute(tapes)
245
246 if transformed_tapes:
--> 247 results = device.execute(transformed_tapes, execution_config=execution_config)
248 else:
249 results = ()
/usr/local/lib/python3.11/dist-packages/pennylane/devices/modifiers/single_tape_support.py in execute(self, circuits, execution_config)
30 is_single_circuit = True
31 circuits = (circuits,)
---> 32 results = batch_execute(self, circuits, execution_config)
33 return results[0] if is_single_circuit else results
34
/usr/local/lib/python3.11/dist-packages/pennylane/devices/legacy_facade.py in execute(self, circuits, execution_config)
372 first_shot = circuits[0].shots
373 if all(t.shots == first_shot for t in circuits):
--> 374 return _set_shots(dev, first_shot)(dev.batch_execute)(circuits, **kwargs)
375 return tuple(
376 _set_shots(dev, t.shots)(dev.batch_execute)((t,), **kwargs)[0] for t in circuits
/usr/lib/python3.11/contextlib.py in inner(*args, **kwds)
79 def inner(*args, **kwds):
80 with self._recreate_cm():
---> 81 return func(*args, **kwds)
82 return inner
83
/usr/local/lib/python3.11/dist-packages/pennylane/devices/_qubit_device.py in batch_execute(self, circuits, **kwargs)
481 self.reset()
482
--> 483 res = self.execute(circuit, **kwargs)
484 results.append(res)
485
/usr/local/lib/python3.11/dist-packages/pennylane/logging/decorators.py in wrapper_entry(*args, **kwargs)
59 **_debug_log_kwargs,
60 )
---> 61 return func(*args, **kwargs)
62
63 @wraps(func)
/usr/local/lib/python3.11/dist-packages/pennylane/devices/default_mixed.py in execute(self, circuit, **kwargs)
851 wires_list.append(m.wires)
852 self.measured_wires = qml.wires.Wires.all_wires(wires_list)
--> 853 return super().execute(circuit, **kwargs)
854
855 @debug_logger
/usr/local/lib/python3.11/dist-packages/pennylane/devices/_qubit_device.py in execute(self, circuit, **kwargs)
271 self.apply(
272 circuit.operations,
--> 273 rotations=self._get_diagonalizing_gates(circuit),
274 **kwargs,
275 )
/usr/local/lib/python3.11/dist-packages/pennylane/devices/_qubit_device.py in _get_diagonalizing_gates(self, circuit)
1777 """
1778 # pylint:disable=no-self-use
-> 1779 return circuit.diagonalizing_gates
1780
1781 def _is_lightning_device(self):
/usr/local/lib/python3.11/dist-packages/pennylane/tape/qscript.py in diagonalizing_gates(self)
404 # in which case we just don't append any
405 with contextlib.suppress(qml.operation.DiagGatesUndefinedError):
--> 406 rotation_gates.extend(observable.diagonalizing_gates())
407 return rotation_gates
408
/usr/local/lib/python3.11/dist-packages/pennylane/ops/qubit/observables.py in diagonalizing_gates(self)
257 """
258 # note: compute_diagonalizing_gates has a custom signature, which is why we overwrite this method
--> 259 return self.compute_diagonalizing_gates(self.eigendecomposition["eigvec"], self.wires)
260
261
/usr/local/lib/python3.11/dist-packages/pennylane/ops/qubit/observables.py in eigendecomposition(self)
152 """
153 Hmat = self.matrix()
--> 154 Hmat = qml.math.to_numpy(Hmat)
155 Hkey = tuple(Hmat.flatten().tolist())
156 if Hkey not in Hermitian._eigs:
/usr/local/lib/python3.11/dist-packages/autoray/autoray.py in do(fn, like, *args, **kwargs)
79 backend = _choose_backend(fn, args, kwargs, like=like)
80 func = get_lib_fn(backend, fn)
---> 81 return func(*args, **kwargs)
82
83
/usr/local/lib/python3.11/dist-packages/pennylane/math/single_dispatch.py in _to_numpy_jax(x)
798 return np.array(getattr(x, "val", x))
799 except TracerArrayConversionError as e:
--> 800 raise ValueError(
801 "Converting a JAX array to a NumPy array not supported when using the JAX JIT."
802 ) from e
ValueError: Converting a JAX array to a NumPy array not supported when using the JAX JIT.
Removing dm almost everywhere and adding a single dm inside the qnode. The inline comments show where I’ve removed dm.
import pennylane as qml
import jax
import jax.numpy as jnp
import optax
jax.config.update("jax_enable_x64", True)
device = qml.device("default.mixed", wires=1)
dms = jnp.array([
jnp.array([[1, 0], [0, 0]]),
jnp.array([[0, 0], [0, 1]]),
])
key = jax.random.PRNGKey(42)
ys = jnp.array([0, 1]).astype(int)
xs = jax.random.uniform(key, shape=(len(ys), ))
@qml.qnode(device, interface='jax')
def qnode(params, x: float) -> float: #, dm: jnp.ndarray
dm = [[1, 0], [0, 0]]
qml.RX(x, wires=0)
qml.RX(params[0], wires=0)
return qml.expval(qml.Hermitian(dm, wires=0))
qnn_vectorized = jax.jit(jax.vmap(qnode, in_axes=(None, 0))) #, 0
def loss_function(params, x_set): #, dm
fidelities = qnn_vectorized(params, x_set) #, dm
loss = ((1 - fidelities) ** 2).sum() / len(fidelities)
return loss
# Define optimizer
params = jnp.array([0.1])
opt = optax.adam(learning_rate=0.1)
opt_state = opt.init(params)
grad = jax.jit(jax.grad(loss_function))
for epoch in range(5):
print(f'Epoch {epoch}')
grad_value = grad(params, xs) #, dms
updates, opt_state = opt.update(grad_value, opt_state)
params = optax.apply_updates(params, updates)
print("Optimized parameters:", params)