Trouble Compiling Circuits with qml.Hermitian Using jax.jit

Hi,
I am implementing a quantum circuit similar to the one in the Data-reuploading Classifier tutorial. However, I am encountering difficulties when trying to use jax.jit to compile the circuit. For example:

import jax
import jax.numpy as jnp
import pennylane as qml

# jax.config.update("jax_enable_x64", True)
dev = qml.device("default.qubit", wires=2)
y=jnp.array([[1,0],[0,-1]])

@jax.jit
@qml.qnode(dev, interface="jax")
def circuit(param):
    qml.RX(param, wires=0)
    qml.CNOT(wires=[0, 1])
    return qml.expval(qml.Hermitian(y, wires=[1]))

print(f"Result: {repr(circuit(jnp.array([0.123,0.25])))}")

Without the decodrator @jax.jit, the circuit works as expected. However, when I try to compile it with @jax.jit, I encounter the following error message:

TracerBoolConversionError                 Traceback (most recent call last)
Cell In[2], line 16
     13     qml.CNOT(wires=[0, 1])
     14     return qml.expval(qml.Hermitian(y, wires=[1]))
---> 16 print(f"Result: {repr(circuit(jnp.array([0.123,0.25])))}")

    [... skipping hidden 11 frame]

File ~/anaconda3/envs/research2/lib/python3.13/site-packages/pennylane/workflow/qnode.py:987, in QNode.__call__(self, *args, **kwargs)
    985 if qml.capture.enabled():
    986     return qml.capture.qnode_call(self, *args, **kwargs)
--> 987 return self._impl_call(*args, **kwargs)

File ~/anaconda3/envs/research2/lib/python3.13/site-packages/pennylane/workflow/qnode.py:963, in QNode._impl_call(self, *args, **kwargs)
    960 def _impl_call(self, *args, **kwargs) -> qml.typing.Result:
    961 
    962     # construct the tape
--> 963     self.construct(args, kwargs)
    965     old_interface = self.interface
    966     if old_interface == "auto":

File ~/anaconda3/envs/research2/lib/python3.13/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/anaconda3/envs/research2/lib/python3.13/site-packages/pennylane/workflow/qnode.py:857, in QNode.construct(self, args, kwargs)
    855 with pldb_device_manager(self.device):
    856     with qml.queuing.AnnotatedQueue() as q:
--> 857         self._qfunc_output = self.func(*args, **kwargs)
    859 self._tape = QuantumScript.from_queue(q, shots)
    861 params = self.tape.get_parameters(trainable_only=False)

Cell In[2], line 14, in circuit(param)
     12 qml.RX(param, wires=0)
     13 qml.CNOT(wires=[0, 1])
---> 14 return qml.expval(qml.Hermitian(y, wires=[1]))

File ~/anaconda3/envs/research2/lib/python3.13/site-packages/pennylane/capture/capture_meta.py:89, in CaptureMeta.__call__(cls, *args, **kwargs)
     85 if enabled():
     86     # when tracing is enabled, we want to
     87     # use bind to construct the class if we want class construction to add it to the jaxpr
     88     return cls._primitive_bind_call(*args, **kwargs)
---> 89 return type.__call__(cls, *args, **kwargs)

File ~/anaconda3/envs/research2/lib/python3.13/site-packages/pennylane/ops/qubit/observables.py:87, in Hermitian.__init__(self, A, wires, id)
     83     else:
     84         # Assumably wires is an int; further validation checks are performed by calling super().__init__
     85         expected_mx_shape = self._num_basis_states
---> 87     Hermitian._validate_input(A, expected_mx_shape)
     89 super().__init__(A, wires=wires, id=id)

File ~/anaconda3/envs/research2/lib/python3.13/site-packages/pennylane/ops/qubit/observables.py:103, in Hermitian._validate_input(A, expected_mx_shape)
     97 if expected_mx_shape is not None and A.shape[0] != expected_mx_shape:
     98     raise ValueError(
     99         f"Expected input matrix to have shape {expected_mx_shape}x{expected_mx_shape}, but "
    100         f"a matrix with shape {A.shape[0]}x{A.shape[0]} was passed."
    101     )
--> 103 if not qml.math.allclose(A, qml.math.T(qml.math.conj(A))):
    104     raise ValueError("Observable must be Hermitian.")

    [... skipping hidden 1 frame]

File ~/anaconda3/envs/research2/lib/python3.13/site-packages/jax/_src/core.py:1554, in concretization_function_error.<locals>.error(self, arg)
   1553 def error(self, arg):
-> 1554   raise TracerBoolConversionError(arg)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function circuit at /tmp/ipykernel_20529/364763606.py:9 for jit. This value became a tracer due to JAX operations on these lines:

  operation a:i32[2,2] = transpose[permutation=(1, 0)] b
    from line /tmp/ipykernel_20529/364763606.py:14:22 (circuit)

  operation a:bool[] = pjit[
  name=allclose
  jaxpr={ lambda ; b:i32[2,2] c:i32[2,2] d:f32[] e:f32[]. let
      f:bool[2,2] = pjit[
        name=isclose
        jaxpr={ lambda ; g:i32[2,2] h:i32[2,2] i:f32[] j:f32[]. let
            k:f32[2,2] = convert_element_type[new_dtype=float32 weak_type=False] g
            l:f32[2,2] = convert_element_type[new_dtype=float32 weak_type=False] h
            m:f32[] = convert_element_type[new_dtype=float32 weak_type=False] i
            n:f32[] = convert_element_type[new_dtype=float32 weak_type=False] j
            o:f32[2,2] = sub k l
            p:f32[2,2] = abs o
            q:f32[2,2] = abs l
            r:f32[2,2] = mul m q
            s:f32[2,2] = add n r
            t:bool[2,2] = le p s
            u:bool[2,2] = pjit[
              name=isinf
              jaxpr={ lambda ; v:f32[2,2]. let
                  w:f32[2,2] = abs v
                  x:bool[2,2] = eq w inf
                in (x,) }
            ] k
            y:bool[2,2] = pjit[
              name=isinf
              jaxpr={ lambda ; v:f32[2,2]. let
                  w:f32[2,2] = abs v
                  x:bool[2,2] = eq w inf
                in (x,) }
            ] l
            z:bool[2,2] = or u y
            ba:bool[2,2] = and u y
            bb:bool[2,2] = not z
            bc:bool[2,2] = and t bb
            bd:bool[2,2] = eq k l
            be:bool[2,2] = and ba bd
            bf:bool[2,2] = or bc be
            bg:bool[2,2] = ne k k
            bh:bool[2,2] = ne l l
            bi:bool[2,2] = or bg bh
            bj:bool[2,2] = not bi
            bk:bool[2,2] = and bf bj
          in (bk,) }
      ] b c e d
      bl:bool[] = reduce_and[axes=(0, 1)] f
    in (bl,) }
] bm bn bo bp
    from line /tmp/ipykernel_20529/364763606.py:14:22 (circuit)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

It seems that the issue arises because qml.Hermitian checks whether the input matrix y is Hermitian, which appears to be incompatible with JAX’s compilation process. Is there any way to bypass or resolve this error?

Alternatively, from my understanding, qml.expval(qml.Hermitian) is used to calculate the fidelity between the input matrix y and the final state of the circuit. Are there other methods to calculate the fidelity between y and the circuit’s output state that work with jax.jit compilation?

Version of Jax:
jax==0.4.35
jaxlib==0.4.35

Output of qml.about():
Name: PennyLane
Version: 0.39.0
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page: GitHub - PennyLaneAI/pennylane: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Author:
Author-email:
License: Apache License 2.0
Location: /home/ubuntu2022/anaconda3/envs/research2/lib/python3.13/site-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, toml, typing-extensions
Required-by: PennyLane-qiskit, PennyLane_Lightning

Platform info: Linux-5.15.133.1-microsoft-standard-WSL2-x86_64-with-glibc2.35
Python version: 3.13.0
Numpy version: 2.0.2
Scipy version: 1.14.1
Installed devices:

  • qiskit.aer (PennyLane-qiskit-0.39.0)
  • qiskit.basicaer (PennyLane-qiskit-0.39.0)
  • qiskit.basicsim (PennyLane-qiskit-0.39.0)
  • qiskit.remote (PennyLane-qiskit-0.39.0)
  • default.clifford (PennyLane-0.39.0)
  • default.gaussian (PennyLane-0.39.0)
  • default.mixed (PennyLane-0.39.0)
  • default.qubit (PennyLane-0.39.0)
  • default.qutrit (PennyLane-0.39.0)
  • default.qutrit.mixed (PennyLane-0.39.0)
  • default.tensor (PennyLane-0.39.0)
  • null.qubit (PennyLane-0.39.0)
  • reference.qubit (PennyLane-0.39.0)
  • lightning.qubit (PennyLane_Lightning-0.39.0)

Hi @PoJung-Lu,

I can replicate your issue.

I’m not sure what’s going on. Let me check and get back to you.

1 Like

Hi @PoJung-Lu ,

So it looks like there’s a bug. I’ve opened a bug report here. Part of the solution will be to indent one line of code but we still need to check if this fully solves the problem.

Would you like to try tackling this bug? No problem if not, I’m asking just in case. Edit: this bug seems more complicated than expected. It’s probably better for the PennyLane team to tackle this one.

On the other hand I just realized that you asked about fidelity. Maybe you can use qml.math.fidelity instead. Let me know if this works for you or if you have any questions about it!

Hi @CatalinaAlbornoz,

Thank you for opening the bug report and for the detailed follow-up! As you mentioned, tackling this problem might be beyond my current capability, so I’ll wait for updates from the team and hope for good news.

Regarding fidelity, from my understanding, in the ideal case (without noise), the final state of a quantum circuit can be represented as a pure state (ρ), and its fidelity with another pure state (σ) can be written as


image
This is why I think measuring the expectation value of an operator could be equivalent to calculating the fidelity between the final state (ρ) of a quantum circuit and a given state (σ).

However, I’m not sure qml.math.fidelity fits my use case, as I might need to run my code on a real quantum device in the future. In that context, performing full state tomography to obtain a complete understanding of the state could be too costly.

Hi @PoJung-Lu ,

I’m not sure about your equations. Take for example this excerpt of the Distance Measures module of the PennyLane Codebook.

Edit: I think the right hand side should be squared. \vert \langle \phi \vert \psi \rangle \vert ^2

As far as I understand calculating fidelity requires having both quantum states. Maybe you can look into other distance measures, or there may be tricks to get the fidelity that I’m not aware of.

HI @CatalinaAlbornoz,

Thank you for your reply.

I realized I made a mistake in my equations but couldn’t find the edit button to correct it. Could you let me know where it is?

Regarding the definitions of fidelity, my equations were derived from the Wikipedia page, which uses F = (F’)^2 as its definition of fidelity, whereas the PennyLane Codebook and Nielsen’s book (Eq. 9.53) use F′. Here’s an excerpt from the Wikipedia page for reference:

Despite the squaring difference, the derivation should only differ by a square root factor. Nielsen’s book provides a similar example:

I think this equation (F’ = \sqrt{<σ>}, or F= <σ>) might be the trick used in the Data-reuploading Classifier tutorial to design their circuit. What are your thoughts on this?

Lastly, I’d appreciate it if you could let me know when the JAX jitting issue is resolved. Thank you very much for your help!

Hi @PoJung-Lu ,

You’re right, you could get the expectation value of the Hermitian operator rho to obtain the fidelity between rho and the state (as explained in the data-reuploading classifier demo).

We have a Pull Request open for the jitting issue but I don’t know how long it will take to get it reviewed and merged. I’ll let you know once it’s finished.

Maybe you could try something like classical shadows in the meantime in order to obtain the state. Let me know if this works for you or if the jitting of the Hermitian is still needed and blocking work.

Hi @CatalinaAlbornoz ,

I see. I tried installing PennyLane from this branch in my virtual environment, and it fixed the jitting issue in my code. I think I’ll deploy this environment for now. Hopefully, this version will be merged soon. Thank you very much for your help!

Thanks for confirming that the fix worked for you @PoJung-Lu !

Our next stable release will be around mid January so most likely it will be merged by then.

I hope it gets merged into master sooner though!

1 Like

Hi @PoJung-Lu ,

I wanted to let you know that the PR with the fix has been merged into Master! :tada:
You can now use the Master branch if you prefer. Please let us know if you encounter any other issues!