# Trouble calculating two-time correlators with jax.jit

I want to be able to compute two-time correlators of the form \langle O_1(t_1)O_2(t_2)\rangle assuming t_1>t_2. One way to do this is by instead evolving the initial state |\psi_0\rangle under U(t_1) and then calculating the expectation value of the operator M=O_1U(t_1,t_2)O_2U^\dagger(t_1,t_2).

I want to do this where U is the evolution of a parametrized Hamiltonian, and ideally this should be jitted since I plan to do calculation many times in a for loop. My code looks like this (this Hamiltonian is an example from pennylane’s website):

import pennylane as qml
import jax
import jax.numpy as jnp

coeffs = [lambda p, t: p * t for _ in range(4)]
ops = [qml.PauliX(i) for i in range(4)]

# ParametrizedHamiltonian
H = qml.dot(coeffs, ops)

#Parameters
key = jax.random.PRNGKey(0)
param_vector = jax.random.uniform(key, shape = [4])

#Setting device
dev = qml.device("default.qubit.jax", wires=4)

#Stting times and operators
t1=1.0
O1 = qml.PauliY(0)
t2=0.5
O2 = qml.PauliZ(1)

@qml.qnode(dev, interface="jax")
def two_point_correlator(t1,O1,t2,O2):
#Setting the partial unitary evolution
u_partial= qml.evolve(H)(param_vector, t=[t2,t1])
#Evolving till t1
qml.evolve(H)(param_vector, t1)


This works well! But when I add the jit decorator:

@jax.jit
@qml.qnode(dev, interface="jax")
def two_point_correlator(t1,O1,t2,O2):
u_partial= qml.evolve(H)(param_vector, t=[t2,t1])
qml.evolve(H)(param_vector, t1)


I get the following error. Is there a work-around?

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape complex128[16,16].
The error occurred while tracing the function two_point_correlator at /Users/Rodrigo/Library/CloudStorage/OneDrive-HarvardUniversity/Landscapes/Covariances/untitled1.py:35 for jit. This concrete value was not available in Python because it depends on the values of the arguments t1 and t2.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):

File ~/anaconda3/lib/python3.10/site-packages/spyder_kernels/py3compat.py:356 in compat_exec
exec(code, globals, locals)

File ~/Library/CloudStorage/OneDrive-HarvardUniversity/Landscapes/Covariances/untitled1.py:41
print(two_point_correlator(t1,O1,t2,O2))

File ~/anaconda3/lib/python3.10/site-packages/jax/_src/traceback_util.py:177 in reraise_with_filtered_traceback
return fun(*args, **kwargs)

File ~/anaconda3/lib/python3.10/site-packages/jax/_src/pjit.py:256 in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(

File ~/anaconda3/lib/python3.10/site-packages/jax/_src/pjit.py:162 in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(

File ~/anaconda3/lib/python3.10/site-packages/jax/_src/api.py:314 in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)

File ~/anaconda3/lib/python3.10/site-packages/jax/_src/pjit.py:486 in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(

File ~/anaconda3/lib/python3.10/site-packages/jax/_src/pjit.py:963 in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(

File ~/anaconda3/lib/python3.10/site-packages/jax/_src/linear_util.py:349 in memoized_fun
ans = call(fun, *args)

File ~/anaconda3/lib/python3.10/site-packages/jax/_src/pjit.py:916 in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(

File ~/anaconda3/lib/python3.10/site-packages/jax/_src/profiler.py:340 in wrapper
return func(*args, **kwargs)

File ~/anaconda3/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py:2278 in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(

File ~/anaconda3/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py:2300 in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)

File ~/anaconda3/lib/python3.10/site-packages/jax/_src/linear_util.py:191 in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))

File ~/anaconda3/lib/python3.10/site-packages/pennylane/qnode.py:1027 in __call__
res = qml.execute(

File ~/anaconda3/lib/python3.10/site-packages/pennylane/interfaces/execution.py:616 in execute
results = inner_execute(tapes)

File ~/anaconda3/lib/python3.10/site-packages/pennylane/interfaces/execution.py:249 in inner_execute
return cached_device_execution(tapes)

File ~/anaconda3/lib/python3.10/site-packages/pennylane/interfaces/execution.py:371 in wrapper
res = list(fn(tuple(execution_tapes.values()), **kwargs))

File ~/anaconda3/lib/python3.10/contextlib.py:79 in inner
return func(*args, **kwds)

File ~/anaconda3/lib/python3.10/site-packages/pennylane/_qubit_device.py:460 in batch_execute
res = self.execute(circuit)

File ~/anaconda3/lib/python3.10/site-packages/pennylane/_qubit_device.py:279 in execute
self.apply(circuit.operations, rotations=self._get_diagonalizing_gates(circuit), **kwargs)

File ~/anaconda3/lib/python3.10/site-packages/pennylane/devices/default_qubit_legacy.py:1085 in _get_diagonalizing_gates
return super()._get_diagonalizing_gates(qml.tape.QuantumScript(measurements=meas_filtered))

File ~/anaconda3/lib/python3.10/site-packages/pennylane/_qubit_device.py:1720 in _get_diagonalizing_gates
return circuit.diagonalizing_gates

File ~/anaconda3/lib/python3.10/site-packages/pennylane/tape/qscript.py:387 in diagonalizing_gates
rotation_gates.extend(observable.diagonalizing_gates())

File ~/anaconda3/lib/python3.10/site-packages/pennylane/ops/op_math/composite.py:267 in diagonalizing_gates
eigvecs = tmp_sum.eigendecomposition["eigvec"]

File ~/anaconda3/lib/python3.10/site-packages/pennylane/ops/op_math/composite.py:231 in eigendecomposition
mat = math.to_numpy(mat)

File ~/anaconda3/lib/python3.10/site-packages/autoray/autoray.py:80 in do
return get_lib_fn(backend, fn)(*args, **kwargs)

File ~/anaconda3/lib/python3.10/site-packages/pennylane/math/single_dispatch.py:743 in _to_numpy_jax
raise ValueError(

ValueError: Converting a JAX array to a NumPy array not supported when using the JAX JIT.


A quick hot-fix solution would be to not use the device interface and just use matrices. It is a bit unsatisfying as a PennyLane developer but it gets the job done:

@jax.jit
def two_point_correlator(t1,O1,t2,O2):
psi0 = jnp.eye(2**4, dtype=complex)[0] # |0>
U_t1 = qml.matrix(qml.evolve(H)(param_vector, t1))
psi_t1 = U_t1 @ psi0

u_partial= qml.evolve(H)(param_vector, t=[t2,t1])
return psi_t1.conj().T @ qml.matrix(qml.prod(O1, u_partial, O2, qml.adjoint(u_partial))) @ psi_t1


I will try to find a solution for the original attempt and will let you know if I find any.

By the way, when you do u_partial= qml.evolve(H)(param_vector, t=[t2,t1]) inside the qnode, this operator is automatically queued, I guess this was not your intention? When you construct operators inside a qnode that you dont want queued, you need to put them in a stop_recording context, e.g.

@qml.qnode(dev, interface="jax")
def two_point_correlator(t1,O1,t2,O2):
#Setting the partial unitary evolution
with qml.QueuingManager.stop_recording():
u_partial= qml.evolve(H)(param_vector, t=[t2,t1])

#Evolving till t1
qml.evolve(H)(param_vector, t1)


This doesnt fix your problem, just in general good to know this peculiarity

The second sujestion to stop recording didn’t solve the issue, but as you said, it is a good thing to know!

Yes, the solution of using matrices works (and actually cuts down the computation time by about 1/2 for N=4). In this particular case, this is fine, although it would be nice to do this still using the pennylane approach. Let me know if you come up with another pennylane-laden solution.

We could identify this as a bug and have a fix now merged in the latest version on github (to be released in 3 weeks, but you can also just checkout the latest version.

With that you can run your original code, remember to put the construction of the u_partial inside the stop_recording() context for correct results:

@jax.jit
@qml.qnode(dev, interface="jax")
def two_point_correlator(t1,O1,t2,O2):
#Setting the partial unitary evolution
with qml.QueuingManager.stop_recording():
u_partial= qml.evolve(H)(param_vector, t=[t2,t1])

#Evolving till t1
qml.evolve(H)(param_vector, t1)


It is still not 100% optimal, but works.
You will get some warnings:

UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
return lax_numpy.astype(arr, dtype)


To remove the one about precission you can set up jax.config.update("jax_enable_x64", True) after importing jax.

UserWarning: Prod might not be hermitian.
warnings.warn(f"{op.name} might not be hermitian.")


This warning comes from the fact that PennyLane does not know whether or not the observable in the expval is hermitian. To save performance it is not attempting to check that, but instead warns the user to make sure what they are passing is valid.

ComplexWarning: Casting complex values to real discards the imaginary part
out_array: Array = lax_internal._convert_element_type(


This is related to the above. As expval assumes Hermitian operators, it is returning only the real parts of the results and there may be some machine precission imaginary part that gets truncated here.