I want to be able to compute two-time correlators of the form \langle O_1(t_1)O_2(t_2)\rangle assuming t_1>t_2. One way to do this is by instead evolving the initial state |\psi_0\rangle under U(t_1) and then calculating the expectation value of the operator M=O_1U(t_1,t_2)O_2U^\dagger(t_1,t_2).
I want to do this where U is the evolution of a parametrized Hamiltonian, and ideally this should be jitted since I plan to do calculation many times in a for loop. My code looks like this (this Hamiltonian is an example from pennylane’s website):
import pennylane as qml
import jax
import jax.numpy as jnp
coeffs = [lambda p, t: p * t for _ in range(4)]
ops = [qml.PauliX(i) for i in range(4)]
# ParametrizedHamiltonian
H = qml.dot(coeffs, ops)
#Parameters
key = jax.random.PRNGKey(0)
param_vector = jax.random.uniform(key, shape = [4])
#Setting device
dev = qml.device("default.qubit.jax", wires=4)
#Stting times and operators
t1=1.0
O1 = qml.PauliY(0)
t2=0.5
O2 = qml.PauliZ(1)
@qml.qnode(dev, interface="jax")
def two_point_correlator(t1,O1,t2,O2):
#Setting the partial unitary evolution
u_partial= qml.evolve(H)(param_vector, t=[t2,t1])
#Evolving till t1
qml.evolve(H)(param_vector, t1)
return qml.expval(O1@(u_partial)@(O2@qml.adjoint(u_partial)))
This works well! But when I add the jit decorator:
@jax.jit
@qml.qnode(dev, interface="jax")
def two_point_correlator(t1,O1,t2,O2):
u_partial= qml.evolve(H)(param_vector, t=[t2,t1])
qml.evolve(H)(param_vector, t1)
return qml.expval(O1@(u_partial)@(O2@qml.adjoint(u_partial)))
I get the following error. Is there a work-around?
TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape complex128[16,16].
The error occurred while tracing the function two_point_correlator at /Users/Rodrigo/Library/CloudStorage/OneDrive-HarvardUniversity/Landscapes/Covariances/untitled1.py:35 for jit. This concrete value was not available in Python because it depends on the values of the arguments t1 and t2.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File ~/anaconda3/lib/python3.10/site-packages/spyder_kernels/py3compat.py:356 in compat_exec
exec(code, globals, locals)
File ~/Library/CloudStorage/OneDrive-HarvardUniversity/Landscapes/Covariances/untitled1.py:41
print(two_point_correlator(t1,O1,t2,O2))
File ~/anaconda3/lib/python3.10/site-packages/jax/_src/traceback_util.py:177 in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File ~/anaconda3/lib/python3.10/site-packages/jax/_src/pjit.py:256 in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
File ~/anaconda3/lib/python3.10/site-packages/jax/_src/pjit.py:162 in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
File ~/anaconda3/lib/python3.10/site-packages/jax/_src/api.py:314 in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
File ~/anaconda3/lib/python3.10/site-packages/jax/_src/pjit.py:486 in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
File ~/anaconda3/lib/python3.10/site-packages/jax/_src/pjit.py:963 in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
File ~/anaconda3/lib/python3.10/site-packages/jax/_src/linear_util.py:349 in memoized_fun
ans = call(fun, *args)
File ~/anaconda3/lib/python3.10/site-packages/jax/_src/pjit.py:916 in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
File ~/anaconda3/lib/python3.10/site-packages/jax/_src/profiler.py:340 in wrapper
return func(*args, **kwargs)
File ~/anaconda3/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py:2278 in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File ~/anaconda3/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py:2300 in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File ~/anaconda3/lib/python3.10/site-packages/jax/_src/linear_util.py:191 in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File ~/anaconda3/lib/python3.10/site-packages/pennylane/qnode.py:1027 in __call__
res = qml.execute(
File ~/anaconda3/lib/python3.10/site-packages/pennylane/interfaces/execution.py:616 in execute
results = inner_execute(tapes)
File ~/anaconda3/lib/python3.10/site-packages/pennylane/interfaces/execution.py:249 in inner_execute
return cached_device_execution(tapes)
File ~/anaconda3/lib/python3.10/site-packages/pennylane/interfaces/execution.py:371 in wrapper
res = list(fn(tuple(execution_tapes.values()), **kwargs))
File ~/anaconda3/lib/python3.10/contextlib.py:79 in inner
return func(*args, **kwds)
File ~/anaconda3/lib/python3.10/site-packages/pennylane/_qubit_device.py:460 in batch_execute
res = self.execute(circuit)
File ~/anaconda3/lib/python3.10/site-packages/pennylane/_qubit_device.py:279 in execute
self.apply(circuit.operations, rotations=self._get_diagonalizing_gates(circuit), **kwargs)
File ~/anaconda3/lib/python3.10/site-packages/pennylane/devices/default_qubit_legacy.py:1085 in _get_diagonalizing_gates
return super()._get_diagonalizing_gates(qml.tape.QuantumScript(measurements=meas_filtered))
File ~/anaconda3/lib/python3.10/site-packages/pennylane/_qubit_device.py:1720 in _get_diagonalizing_gates
return circuit.diagonalizing_gates
File ~/anaconda3/lib/python3.10/site-packages/pennylane/tape/qscript.py:387 in diagonalizing_gates
rotation_gates.extend(observable.diagonalizing_gates())
File ~/anaconda3/lib/python3.10/site-packages/pennylane/ops/op_math/composite.py:267 in diagonalizing_gates
eigvecs = tmp_sum.eigendecomposition["eigvec"]
File ~/anaconda3/lib/python3.10/site-packages/pennylane/ops/op_math/composite.py:231 in eigendecomposition
mat = math.to_numpy(mat)
File ~/anaconda3/lib/python3.10/site-packages/autoray/autoray.py:80 in do
return get_lib_fn(backend, fn)(*args, **kwargs)
File ~/anaconda3/lib/python3.10/site-packages/pennylane/math/single_dispatch.py:743 in _to_numpy_jax
raise ValueError(
ValueError: Converting a JAX array to a NumPy array not supported when using the JAX JIT.