Hi, after some discussions elsewhere, I am of the opinion that the below example should work.
import pennylane as qml
import numpy as np
import jax
jax.config.update("jax_enable_x64", True)
jax.config.update('jax_platform_name', 'cpu')
@jax.jit
def sample_circuit(phi, theta, key):
dev = qml.device('default.qubit', wires=2, seed=key, shots=10) # Any finite number
@qml.qnode(dev, interface='jax')
def circuit(phi, theta):
qml.RX(phi[0], wires=0)
qml.RZ(phi[1], wires=1)
qml.CNOT(wires=[0, 1])
qml.RX(theta[0], wires=0)
return qml.expval(qml.PauliZ(0))
return circuit(phi, theta)
phi = np.array([0.2, 1.0])
theta = np.array([0.2])
key = jax.random.PRNGKey(10)
print(sample_circuit(phi, theta, key))
print(sample_circuit(phi, theta, key))
print(sample_circuit(phi, theta, key))
print(jax.grad(sample_circuit)(phi, theta, key))
This results in a jax.errors.UnexpectedTracerError
The full traceback:
jax.pure_callback failed
Traceback (most recent call last):
File "/home/.../lib/python3.12/site-packages/jax/_src/callback.py", line 94, in pure_callback_impl
return tree_util.tree_map(np.asarray, callback(*args))
^^^^^^^^^^^^^^^
File "/home/.../lib/python3.12/site-packages/jax/_src/callback.py", line 71, in __call__
return tree_util.tree_leaves(self.callback_func(*args, **kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.12/site-packages/pennylane/workflow/interfaces/jax_jit.py", line 168, in pure_callback_wrapper
return _to_jax(execute_fn(new_tapes))
^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.12/site-packages/pennylane/workflow/execution.py", line 202, in inner_execute
results = device.execute(transformed_tapes, execution_config=execution_config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.12/site-packages/pennylane/devices/modifiers/simulator_tracking.py", line 30, in execute
results = untracked_execute(self, circuits, execution_config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.12/site-packages/pennylane/devices/modifiers/single_tape_support.py", line 32, in execute
results = batch_execute(self, circuits, execution_config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.12/site-packages/pennylane/devices/default_qubit.py", line 639, in execute
prng_keys = [self.get_prng_keys()[0] for _ in range(len(circuits))]
^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.12/site-packages/pennylane/devices/default_qubit.py", line 451, in get_prng_keys
self._prng_key, *keys = jax_random_split(self._prng_key)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.12/site-packages/pennylane/devices/qubit/sampling.py", line 42, in jax_random_split
return split(prng_key, num=num)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.12/site-packages/jax/_src/random.py", line 292, in split
typed_key, wrapped = _check_prng_key("split", key)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.12/site-packages/jax/_src/random.py", line 79, in _check_prng_key
wrapped_key = prng.random_wrap(key, impl=default_prng_impl())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.12/site-packages/jax/_src/prng.py", line 696, in random_wrap
return random_wrap_p.bind(base_arr, impl=impl)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.12/site-packages/jax/_src/core.py", line 438, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.12/site-packages/jax/_src/core.py", line 1391, in find_top_trace
top_tracer._assert_live()
File "/home/.../lib/python3.12/site-packages/jax/_src/interpreters/partial_eval.py", line 1658, in _assert_live
raise core.escaped_tracer_error(self, None)
jax.errors.UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type uint32[2] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was sample_circuit at /home/.../temp.py:44 traced for jit.
------------------------------
The leaked intermediate value was created on line /home/.../temp.py:62:6 (<module>).
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
/snap/pycharm-professional/443/plugins/python-ce/helpers/pydev/_pydev_bundle/pydev_ipython_console_011.py:438:26 (_PyDevIPythonFrontEnd.add_exec)
<ipython-input-2-9a5282ec6b15>:1 (<module>)
/snap/pycharm-professional/443/plugins/python-ce/helpers/pydev/_pydev_bundle/pydev_umd.py:197:12 (runfile)
/snap/pycharm-professional/443/plugins/python-ce/helpers/pydev/_pydev_imps/_pydev_execfile.py:18:4 (execfile)
/home/.../temp.py:62:6 (<module>)
------------------------------
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError
Traceback (most recent call last):
File "/home/.../lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-2-9a5282ec6b15>", line 1, in <module>
runfile('/home/.../temp.py', wdir='/home/...')
File "/snap/pycharm-professional/443/plugins/python-ce/helpers/pydev/_pydev_bundle/pydev_umd.py", line 197, in runfile
pydev_imports.execfile(filename, global_vars, local_vars) # execute the script
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/snap/pycharm-professional/443/plugins/python-ce/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "/home/.../temp.py", line 62, in <module>
print(sample_circuit(phi, theta, key))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.12/site-packages/jax/_src/pjit.py", line 338, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.12/site-packages/jax/_src/pjit.py", line 188, in _python_pjit_helper
out_flat = pjit_p.bind(*args_flat, **p.params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.12/site-packages/jax/_src/core.py", line 2803, in bind
return self.bind_with_trace(top_trace, args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.12/site-packages/jax/_src/core.py", line 442, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.12/site-packages/jax/_src/core.py", line 955, in process_primitive
return primitive.impl(*tracers, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.12/site-packages/jax/_src/pjit.py", line 1738, in _pjit_call_impl
return xc._xla.pjit(
^^^^^^^^^^^^^
File "/home/.../lib/python3.12/site-packages/jax/_src/pjit.py", line 1714, in call_impl_cache_miss
out_flat, compiled = _pjit_call_impl_python(
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.12/site-packages/jax/_src/pjit.py", line 1668, in _pjit_call_impl_python
return compiled.unsafe_call(*args), compiled
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.12/site-packages/jax/_src/profiler.py", line 333, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 1278, in __call__
results = self.xla_executable.execute_sharded(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: CpuCallback error: Traceback (most recent call last):
File "/snap/pycharm-professional/443/plugins/python-ce/helpers/pydev/pydevconsole.py", line 570, in <module>
File "/snap/pycharm-professional/443/plugins/python-ce/helpers/pydev/pydevconsole.py", line 498, in start_client
File "/snap/pycharm-professional/443/plugins/python-ce/helpers/pydev/pydevconsole.py", line 287, in process_exec_queue
File "/snap/pycharm-professional/443/plugins/python-ce/helpers/pydev/_pydev_bundle/pydev_code_executor.py", line 109, in add_exec
File "/snap/pycharm-professional/443/plugins/python-ce/helpers/pydev/_pydev_bundle/pydev_ipython_console.py", line 34, in do_add_exec
File "/snap/pycharm-professional/443/plugins/python-ce/helpers/pydev/_pydev_bundle/pydev_ipython_console_011.py", line 438, in add_exec
File "/home/.../lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3075, in run_cell
File "/home/.../lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3130, in _run_cell
File "/home/.../lib/python3.12/site-packages/IPython/core/async_helpers.py", line 128, in _pseudo_sync_runner
File "/home/.../lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3334, in run_cell_async
File "/home/.../lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3517, in run_ast_nodes
File "/home/.../lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code
File "<ipython-input-2-9a5282ec6b15>", line 1, in <module>
File "/snap/pycharm-professional/443/plugins/python-ce/helpers/pydev/_pydev_bundle/pydev_umd.py", line 197, in runfile
File "/snap/pycharm-professional/443/plugins/python-ce/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
File "/home/.../temp.py", line 62, in <module>
File "/home/.../lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
File "/home/.../lib/python3.12/site-packages/jax/_src/pjit.py", line 338, in cache_miss
File "/home/.../lib/python3.12/site-packages/jax/_src/pjit.py", line 188, in _python_pjit_helper
File "/home/.../lib/python3.12/site-packages/jax/_src/core.py", line 2803, in bind
File "/home/.../lib/python3.12/site-packages/jax/_src/core.py", line 442, in bind_with_trace
File "/home/.../lib/python3.12/site-packages/jax/_src/core.py", line 955, in process_primitive
File "/home/.../lib/python3.12/site-packages/jax/_src/pjit.py", line 1738, in _pjit_call_impl
File "/home/.../lib/python3.12/site-packages/jax/_src/pjit.py", line 1714, in call_impl_cache_miss
File "/home/.../lib/python3.12/site-packages/jax/_src/pjit.py", line 1668, in _pjit_call_impl_python
File "/home/.../lib/python3.12/site-packages/jax/_src/profiler.py", line 333, in wrapper
File "/home/.../lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 1278, in __call__
File "/home/.../lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 2768, in _wrapped_callback
File "/home/.../lib/python3.12/site-packages/jax/_src/callback.py", line 269, in _callback
File "/home/.../lib/python3.12/site-packages/jax/_src/callback.py", line 97, in pure_callback_impl
File "/home/.../lib/python3.12/site-packages/jax/_src/callback.py", line 71, in __call__
File "/home/.../lib/python3.12/site-packages/pennylane/workflow/interfaces/jax_jit.py", line 168, in pure_callback_wrapper
File "/home/.../lib/python3.12/site-packages/pennylane/workflow/execution.py", line 202, in inner_execute
File "/home/.../lib/python3.12/site-packages/pennylane/devices/modifiers/simulator_tracking.py", line 30, in execute
File "/home/.../lib/python3.12/site-packages/pennylane/devices/modifiers/single_tape_support.py", line 32, in execute
File "/home/.../lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
File "/home/.../lib/python3.12/site-packages/pennylane/devices/default_qubit.py", line 639, in execute
File "/home/.../lib/python3.12/site-packages/pennylane/devices/default_qubit.py", line 451, in get_prng_keys
File "/home/.../lib/python3.12/site-packages/pennylane/devices/qubit/sampling.py", line 42, in jax_random_split
File "/home/.../lib/python3.12/site-packages/jax/_src/random.py", line 292, in split
File "/home/.../lib/python3.12/site-packages/jax/_src/random.py", line 79, in _check_prng_key
File "/home/.../lib/python3.12/site-packages/jax/_src/prng.py", line 696, in random_wrap
File "/home/.../lib/python3.12/site-packages/jax/_src/core.py", line 438, in bind
File "/home/.../lib/python3.12/site-packages/jax/_src/core.py", line 1391, in find_top_trace
File "/home/.../lib/python3.12/site-packages/jax/_src/interpreters/partial_eval.py", line 1658, in _assert_live
UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type uint32[2] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was sample_circuit at /home/.../temp.py:44 traced for jit.
------------------------------
The leaked intermediate value was created on line /home/.../temp.py:62:6 (<module>).
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
/snap/pycharm-professional/443/plugins/python-ce/helpers/pydev/_pydev_bundle/pydev_ipython_console_011.py:438:26 (_PyDevIPythonFrontEnd.add_exec)
<ipython-input-2-9a5282ec6b15>:1 (<module>)
/snap/pycharm-professional/443/plugins/python-ce/helpers/pydev/_pydev_bundle/pydev_umd.py:197:12 (runfile)
/snap/pycharm-professional/443/plugins/python-ce/helpers/pydev/_pydev_imps/_pydev_execfile.py:18:4 (execfile)
/home/.../temp.py:62:6 (<module>)
------------------------------
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError
The same code works if we replace shots=10
with shots=None
For redundancy:
import pennylane as qml
import numpy as np
import jax
jax.config.update("jax_enable_x64", True)
jax.config.update('jax_platform_name', 'cpu')
@jax.jit
def sample_circuit(phi, theta, key):
dev = qml.device('default.qubit', wires=2, seed=key, shots=None)
@qml.qnode(dev, interface='jax')
def circuit(phi, theta):
qml.RX(phi[0], wires=0)
qml.RZ(phi[1], wires=1)
qml.CNOT(wires=[0, 1])
qml.RX(theta[0], wires=0)
return qml.expval(qml.PauliZ(0))
return circuit(phi, theta)
phi = np.array([0.2, 1.0])
theta = np.array([0.2])
key = jax.random.PRNGKey(10)
print(sample_circuit(phi, theta, key))
print(sample_circuit(phi, theta, key))
print(sample_circuit(phi, theta, key))
print(jax.grad(sample_circuit)(phi, theta, key))
The output is:
0.9605304970014428
0.9605304970014428
0.9605304970014428
[-1.94709171e-01 -1.48426226e-18]
Info about the setup:
Name: PennyLane
Version: 0.39.0
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page: https://github.com/PennyLaneAI/pennylane
Author:
Author-email:
License: Apache License 2.0
Location: /home/.../lib/python3.12/site-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, toml, typing-extensions
Required-by: PennyLane_Lightning
Platform info: Linux-6.8.0-50-generic-x86_64-with-glibc2.39
Python version: 3.12.0
Numpy version: 2.0.2
Scipy version: 1.14.1
Installed devices:
- lightning.qubit (PennyLane_Lightning-0.39.0)
- default.clifford (PennyLane-0.39.0)
- default.gaussian (PennyLane-0.39.0)
- default.mixed (PennyLane-0.39.0)
- default.qubit (PennyLane-0.39.0)
- default.qutrit (PennyLane-0.39.0)
- default.qutrit.mixed (PennyLane-0.39.0)
- default.tensor (PennyLane-0.39.0)
- null.qubit (PennyLane-0.39.0)
- reference.qubit (PennyLane-0.39.0)
and
jax==0.4.35
jax-cuda12-pjrt==0.4.35
jax-cuda12-plugin==0.4.35
jaxlib==0.4.34