Trouble calculating two-time correlators with jax.jit

I want to be able to compute two-time correlators of the form \langle O_1(t_1)O_2(t_2)\rangle assuming t_1>t_2. One way to do this is by instead evolving the initial state |\psi_0\rangle under U(t_1) and then calculating the expectation value of the operator M=O_1U(t_1,t_2)O_2U^\dagger(t_1,t_2).

I want to do this where U is the evolution of a parametrized Hamiltonian, and ideally this should be jitted since I plan to do calculation many times in a for loop. My code looks like this (this Hamiltonian is an example from pennylane’s website):

import pennylane as qml
import jax
import jax.numpy as jnp

coeffs = [lambda p, t: p * t for _ in range(4)]
ops = [qml.PauliX(i) for i in range(4)]

# ParametrizedHamiltonian
H = qml.dot(coeffs, ops)

#Parameters 
key = jax.random.PRNGKey(0)
param_vector = jax.random.uniform(key, shape = [4])

#Setting device
dev = qml.device("default.qubit.jax", wires=4)

#Stting times and operators
t1=1.0
O1 = qml.PauliY(0)
t2=0.5
O2 = qml.PauliZ(1)

@qml.qnode(dev, interface="jax")
def two_point_correlator(t1,O1,t2,O2):
    #Setting the partial unitary evolution
    u_partial= qml.evolve(H)(param_vector, t=[t2,t1])
    #Evolving till t1
    qml.evolve(H)(param_vector, t1)
    return qml.expval(O1@(u_partial)@(O2@qml.adjoint(u_partial)))

This works well! But when I add the jit decorator:

@jax.jit
@qml.qnode(dev, interface="jax")
def two_point_correlator(t1,O1,t2,O2):
    u_partial= qml.evolve(H)(param_vector, t=[t2,t1])
    qml.evolve(H)(param_vector, t1)
    return qml.expval(O1@(u_partial)@(O2@qml.adjoint(u_partial)))

I get the following error. Is there a work-around?

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape complex128[16,16].
The error occurred while tracing the function two_point_correlator at /Users/Rodrigo/Library/CloudStorage/OneDrive-HarvardUniversity/Landscapes/Covariances/untitled1.py:35 for jit. This concrete value was not available in Python because it depends on the values of the arguments t1 and t2.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError


The above exception was the direct cause of the following exception:

Traceback (most recent call last):

  File ~/anaconda3/lib/python3.10/site-packages/spyder_kernels/py3compat.py:356 in compat_exec
    exec(code, globals, locals)

  File ~/Library/CloudStorage/OneDrive-HarvardUniversity/Landscapes/Covariances/untitled1.py:41
    print(two_point_correlator(t1,O1,t2,O2))

  File ~/anaconda3/lib/python3.10/site-packages/jax/_src/traceback_util.py:177 in reraise_with_filtered_traceback
    return fun(*args, **kwargs)

  File ~/anaconda3/lib/python3.10/site-packages/jax/_src/pjit.py:256 in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(

  File ~/anaconda3/lib/python3.10/site-packages/jax/_src/pjit.py:162 in _python_pjit_helper
    args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(

  File ~/anaconda3/lib/python3.10/site-packages/jax/_src/api.py:314 in infer_params
    return pjit.common_infer_params(pjit_info_args, *args, **kwargs)

  File ~/anaconda3/lib/python3.10/site-packages/jax/_src/pjit.py:486 in common_infer_params
    jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(

  File ~/anaconda3/lib/python3.10/site-packages/jax/_src/pjit.py:963 in _pjit_jaxpr
    jaxpr, final_consts, out_type = _create_pjit_jaxpr(

  File ~/anaconda3/lib/python3.10/site-packages/jax/_src/linear_util.py:349 in memoized_fun
    ans = call(fun, *args)

  File ~/anaconda3/lib/python3.10/site-packages/jax/_src/pjit.py:916 in _create_pjit_jaxpr
    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(

  File ~/anaconda3/lib/python3.10/site-packages/jax/_src/profiler.py:340 in wrapper
    return func(*args, **kwargs)

  File ~/anaconda3/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py:2278 in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(

  File ~/anaconda3/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py:2300 in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)

  File ~/anaconda3/lib/python3.10/site-packages/jax/_src/linear_util.py:191 in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

  File ~/anaconda3/lib/python3.10/site-packages/pennylane/qnode.py:1027 in __call__
    res = qml.execute(

  File ~/anaconda3/lib/python3.10/site-packages/pennylane/interfaces/execution.py:616 in execute
    results = inner_execute(tapes)

  File ~/anaconda3/lib/python3.10/site-packages/pennylane/interfaces/execution.py:249 in inner_execute
    return cached_device_execution(tapes)

  File ~/anaconda3/lib/python3.10/site-packages/pennylane/interfaces/execution.py:371 in wrapper
    res = list(fn(tuple(execution_tapes.values()), **kwargs))

  File ~/anaconda3/lib/python3.10/contextlib.py:79 in inner
    return func(*args, **kwds)

  File ~/anaconda3/lib/python3.10/site-packages/pennylane/_qubit_device.py:460 in batch_execute
    res = self.execute(circuit)

  File ~/anaconda3/lib/python3.10/site-packages/pennylane/_qubit_device.py:279 in execute
    self.apply(circuit.operations, rotations=self._get_diagonalizing_gates(circuit), **kwargs)

  File ~/anaconda3/lib/python3.10/site-packages/pennylane/devices/default_qubit_legacy.py:1085 in _get_diagonalizing_gates
    return super()._get_diagonalizing_gates(qml.tape.QuantumScript(measurements=meas_filtered))

  File ~/anaconda3/lib/python3.10/site-packages/pennylane/_qubit_device.py:1720 in _get_diagonalizing_gates
    return circuit.diagonalizing_gates

  File ~/anaconda3/lib/python3.10/site-packages/pennylane/tape/qscript.py:387 in diagonalizing_gates
    rotation_gates.extend(observable.diagonalizing_gates())

  File ~/anaconda3/lib/python3.10/site-packages/pennylane/ops/op_math/composite.py:267 in diagonalizing_gates
    eigvecs = tmp_sum.eigendecomposition["eigvec"]

  File ~/anaconda3/lib/python3.10/site-packages/pennylane/ops/op_math/composite.py:231 in eigendecomposition
    mat = math.to_numpy(mat)

  File ~/anaconda3/lib/python3.10/site-packages/autoray/autoray.py:80 in do
    return get_lib_fn(backend, fn)(*args, **kwargs)

  File ~/anaconda3/lib/python3.10/site-packages/pennylane/math/single_dispatch.py:743 in _to_numpy_jax
    raise ValueError(

ValueError: Converting a JAX array to a NumPy array not supported when using the JAX JIT.

A quick hot-fix solution would be to not use the device interface and just use matrices. It is a bit unsatisfying as a PennyLane developer but it gets the job done:

@jax.jit
def two_point_correlator(t1,O1,t2,O2):
    psi0 = jnp.eye(2**4, dtype=complex)[0] # |0>
    U_t1 = qml.matrix(qml.evolve(H)(param_vector, t1))
    psi_t1 = U_t1 @ psi0

    u_partial= qml.evolve(H)(param_vector, t=[t2,t1])
    return psi_t1.conj().T @ qml.matrix(qml.prod(O1, u_partial, O2, qml.adjoint(u_partial))) @ psi_t1

I will try to find a solution for the original attempt and will let you know if I find any.

By the way, when you do u_partial= qml.evolve(H)(param_vector, t=[t2,t1]) inside the qnode, this operator is automatically queued, I guess this was not your intention? When you construct operators inside a qnode that you dont want queued, you need to put them in a stop_recording context, e.g.

@qml.qnode(dev, interface="jax")
def two_point_correlator(t1,O1,t2,O2):
    #Setting the partial unitary evolution
    with qml.QueuingManager.stop_recording():
        u_partial= qml.evolve(H)(param_vector, t=[t2,t1])

    #Evolving till t1
    qml.evolve(H)(param_vector, t1)
    return qml.expval(O1@(u_partial)@(O2@qml.adjoint(u_partial)))

This doesnt fix your problem, just in general good to know this peculiarity :slight_smile:

The second sujestion to stop recording didn’t solve the issue, but as you said, it is a good thing to know!

Yes, the solution of using matrices works (and actually cuts down the computation time by about 1/2 for N=4). In this particular case, this is fine, although it would be nice to do this still using the pennylane approach. Let me know if you come up with another pennylane-laden solution.

Hi @Rodrigo_Araiza_Bravo

We could identify this as a bug and have a fix now merged in the latest version on github (to be released in 3 weeks, but you can also just checkout the latest version.

With that you can run your original code, remember to put the construction of the u_partial inside the stop_recording() context for correct results:

@jax.jit
@qml.qnode(dev, interface="jax")
def two_point_correlator(t1,O1,t2,O2):
    #Setting the partial unitary evolution
    with qml.QueuingManager.stop_recording():
        u_partial= qml.evolve(H)(param_vector, t=[t2,t1])

    #Evolving till t1
    qml.evolve(H)(param_vector, t1)
    return qml.expval(O1@(u_partial)@(O2@qml.adjoint(u_partial)))

It is still not 100% optimal, but works.
You will get some warnings:

UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return lax_numpy.astype(arr, dtype)

To remove the one about precission you can set up jax.config.update("jax_enable_x64", True) after importing jax.

UserWarning: Prod might not be hermitian.
  warnings.warn(f"{op.name} might not be hermitian.")

This warning comes from the fact that PennyLane does not know whether or not the observable in the expval is hermitian. To save performance it is not attempting to check that, but instead warns the user to make sure what they are passing is valid.

ComplexWarning: Casting complex values to real discards the imaginary part
  out_array: Array = lax_internal._convert_element_type(

This is related to the above. As expval assumes Hermitian operators, it is returning only the real parts of the results and there may be some machine precission imaginary part that gets truncated here.