I am implementing a quantum circuit similar to the one in the Data-reuploading Classifier tutorial. However, I am encountering difficulties when trying to use jax.jit to compile the circuit. For example:
import jax
import jax.numpy as jnp
import pennylane as qml
# jax.config.update("jax_enable_x64", True)
dev = qml.device("default.qubit", wires=2)
@qml.qnode(dev, interface="jax")
def circuit(param):
qml.RX(param, wires=0)
qml.CNOT(wires=[0, 1])
return qml.expval(qml.Hermitian(y, wires=[1]))
print(f"Result: {repr(circuit(jnp.array([0.123,0.25])))}")
Without the decodrator @jax.jit, the circuit works as expected. However, when I try to compile it with @jax.jit, I encounter the following error message:
TracerBoolConversionError Traceback (most recent call last)
Cell In[2], line 16
13 qml.CNOT(wires=[0, 1])
14 return qml.expval(qml.Hermitian(y, wires=[1]))
---> 16 print(f"Result: {repr(circuit(jnp.array([0.123,0.25])))}")
[... skipping hidden 11 frame]
File ~/anaconda3/envs/research2/lib/python3.13/site-packages/pennylane/workflow/qnode.py:987, in QNode.__call__(self, *args, **kwargs)
985 if qml.capture.enabled():
986 return qml.capture.qnode_call(self, *args, **kwargs)
--> 987 return self._impl_call(*args, **kwargs)
File ~/anaconda3/envs/research2/lib/python3.13/site-packages/pennylane/workflow/qnode.py:963, in QNode._impl_call(self, *args, **kwargs)
960 def _impl_call(self, *args, **kwargs) -> qml.typing.Result:
962 # construct the tape
--> 963 self.construct(args, kwargs)
965 old_interface = self.interface
966 if old_interface == "auto":
File ~/anaconda3/envs/research2/lib/python3.13/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
54 s_caller = "::L".join(
55 [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
56 )
57 lgr.debug(
58 f"Calling {f_string} from {s_caller}",
59 **_debug_log_kwargs,
60 )
---> 61 return func(*args, **kwargs)
File ~/anaconda3/envs/research2/lib/python3.13/site-packages/pennylane/workflow/qnode.py:857, in QNode.construct(self, args, kwargs)
855 with pldb_device_manager(self.device):
856 with qml.queuing.AnnotatedQueue() as q:
--> 857 self._qfunc_output = self.func(*args, **kwargs)
859 self._tape = QuantumScript.from_queue(q, shots)
861 params = self.tape.get_parameters(trainable_only=False)
Cell In[2], line 14, in circuit(param)
12 qml.RX(param, wires=0)
13 qml.CNOT(wires=[0, 1])
---> 14 return qml.expval(qml.Hermitian(y, wires=[1]))
File ~/anaconda3/envs/research2/lib/python3.13/site-packages/pennylane/capture/capture_meta.py:89, in CaptureMeta.__call__(cls, *args, **kwargs)
85 if enabled():
86 # when tracing is enabled, we want to
87 # use bind to construct the class if we want class construction to add it to the jaxpr
88 return cls._primitive_bind_call(*args, **kwargs)
---> 89 return type.__call__(cls, *args, **kwargs)
File ~/anaconda3/envs/research2/lib/python3.13/site-packages/pennylane/ops/qubit/observables.py:87, in Hermitian.__init__(self, A, wires, id)
83 else:
84 # Assumably wires is an int; further validation checks are performed by calling super().__init__
85 expected_mx_shape = self._num_basis_states
---> 87 Hermitian._validate_input(A, expected_mx_shape)
89 super().__init__(A, wires=wires, id=id)
File ~/anaconda3/envs/research2/lib/python3.13/site-packages/pennylane/ops/qubit/observables.py:103, in Hermitian._validate_input(A, expected_mx_shape)
97 if expected_mx_shape is not None and A.shape[0] != expected_mx_shape:
98 raise ValueError(
99 f"Expected input matrix to have shape {expected_mx_shape}x{expected_mx_shape}, but "
100 f"a matrix with shape {A.shape[0]}x{A.shape[0]} was passed."
101 )
--> 103 if not qml.math.allclose(A, qml.math.T(qml.math.conj(A))):
104 raise ValueError("Observable must be Hermitian.")
[... skipping hidden 1 frame]
File ~/anaconda3/envs/research2/lib/python3.13/site-packages/jax/_src/core.py:1554, in concretization_function_error.<locals>.error(self, arg)
1553 def error(self, arg):
-> 1554 raise TracerBoolConversionError(arg)
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function circuit at /tmp/ipykernel_20529/364763606.py:9 for jit. This value became a tracer due to JAX operations on these lines:
operation a:i32[2,2] = transpose[permutation=(1, 0)] b
from line /tmp/ipykernel_20529/364763606.py:14:22 (circuit)
operation a:bool[] = pjit[
jaxpr={ lambda ; b:i32[2,2] c:i32[2,2] d:f32[] e:f32[]. let
f:bool[2,2] = pjit[
jaxpr={ lambda ; g:i32[2,2] h:i32[2,2] i:f32[] j:f32[]. let
k:f32[2,2] = convert_element_type[new_dtype=float32 weak_type=False] g
l:f32[2,2] = convert_element_type[new_dtype=float32 weak_type=False] h
m:f32[] = convert_element_type[new_dtype=float32 weak_type=False] i
n:f32[] = convert_element_type[new_dtype=float32 weak_type=False] j
o:f32[2,2] = sub k l
p:f32[2,2] = abs o
q:f32[2,2] = abs l
r:f32[2,2] = mul m q
s:f32[2,2] = add n r
t:bool[2,2] = le p s
u:bool[2,2] = pjit[
jaxpr={ lambda ; v:f32[2,2]. let
w:f32[2,2] = abs v
x:bool[2,2] = eq w inf
in (x,) }
] k
y:bool[2,2] = pjit[
jaxpr={ lambda ; v:f32[2,2]. let
w:f32[2,2] = abs v
x:bool[2,2] = eq w inf
in (x,) }
] l
z:bool[2,2] = or u y
ba:bool[2,2] = and u y
bb:bool[2,2] = not z
bc:bool[2,2] = and t bb
bd:bool[2,2] = eq k l
be:bool[2,2] = and ba bd
bf:bool[2,2] = or bc be
bg:bool[2,2] = ne k k
bh:bool[2,2] = ne l l
bi:bool[2,2] = or bg bh
bj:bool[2,2] = not bi
bk:bool[2,2] = and bf bj
in (bk,) }
] b c e d
bl:bool[] = reduce_and[axes=(0, 1)] f
in (bl,) }
] bm bn bo bp
from line /tmp/ipykernel_20529/364763606.py:14:22 (circuit)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
It seems that the issue arises because qml.Hermitian
checks whether the input matrix y
is Hermitian, which appears to be incompatible with JAX’s compilation process. Is there any way to bypass or resolve this error?
Alternatively, from my understanding, qml.expval(qml.Hermitian)
is used to calculate the fidelity between the input matrix y
and the final state of the circuit. Are there other methods to calculate the fidelity between y
and the circuit’s output state that work with jax.jit compilation?
Version of Jax:
Output of qml.about()
Name: PennyLane
Version: 0.39.0
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page: GitHub - PennyLaneAI/pennylane: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
License: Apache License 2.0
Location: /home/ubuntu2022/anaconda3/envs/research2/lib/python3.13/site-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, toml, typing-extensions
Required-by: PennyLane-qiskit, PennyLane_Lightning
Platform info: Linux-
Python version: 3.13.0
Numpy version: 2.0.2
Scipy version: 1.14.1
Installed devices:
- qiskit.aer (PennyLane-qiskit-0.39.0)
- qiskit.basicaer (PennyLane-qiskit-0.39.0)
- qiskit.basicsim (PennyLane-qiskit-0.39.0)
- qiskit.remote (PennyLane-qiskit-0.39.0)
- default.clifford (PennyLane-0.39.0)
- default.gaussian (PennyLane-0.39.0)
- default.mixed (PennyLane-0.39.0)
- default.qubit (PennyLane-0.39.0)
- default.qutrit (PennyLane-0.39.0)
- default.qutrit.mixed (PennyLane-0.39.0)
- default.tensor (PennyLane-0.39.0)
- null.qubit (PennyLane-0.39.0)
- reference.qubit (PennyLane-0.39.0)
- lightning.qubit (PennyLane_Lightning-0.39.0)