Error while trying to train with JAX-JIT and finite shots

Hi, after some discussions elsewhere, I am of the opinion that the below example should work.

import pennylane as qml
import numpy as np
import jax

jax.config.update("jax_enable_x64", True)
jax.config.update('jax_platform_name', 'cpu')


@jax.jit
def sample_circuit(phi, theta, key):
    dev = qml.device('default.qubit', wires=2, seed=key, shots=10) # Any finite number 

    @qml.qnode(dev, interface='jax')
    def circuit(phi, theta):
        qml.RX(phi[0], wires=0)
        qml.RZ(phi[1], wires=1)
        qml.CNOT(wires=[0, 1])
        qml.RX(theta[0], wires=0)
        return qml.expval(qml.PauliZ(0))

    return circuit(phi, theta)


phi = np.array([0.2, 1.0])
theta = np.array([0.2])
key = jax.random.PRNGKey(10)
print(sample_circuit(phi, theta, key))
print(sample_circuit(phi, theta, key))
print(sample_circuit(phi, theta, key))
print(jax.grad(sample_circuit)(phi, theta, key))

This results in a jax.errors.UnexpectedTracerError
The full traceback:

jax.pure_callback failed
Traceback (most recent call last):
  File "/home/.../lib/python3.12/site-packages/jax/_src/callback.py", line 94, in pure_callback_impl
    return tree_util.tree_map(np.asarray, callback(*args))
                                          ^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.12/site-packages/jax/_src/callback.py", line 71, in __call__
    return tree_util.tree_leaves(self.callback_func(*args, **kwargs))
                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.12/site-packages/pennylane/workflow/interfaces/jax_jit.py", line 168, in pure_callback_wrapper
    return _to_jax(execute_fn(new_tapes))
                   ^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.12/site-packages/pennylane/workflow/execution.py", line 202, in inner_execute
    results = device.execute(transformed_tapes, execution_config=execution_config)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.12/site-packages/pennylane/devices/modifiers/simulator_tracking.py", line 30, in execute
    results = untracked_execute(self, circuits, execution_config)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.12/site-packages/pennylane/devices/modifiers/single_tape_support.py", line 32, in execute
    results = batch_execute(self, circuits, execution_config)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.12/site-packages/pennylane/devices/default_qubit.py", line 639, in execute
    prng_keys = [self.get_prng_keys()[0] for _ in range(len(circuits))]
                 ^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.12/site-packages/pennylane/devices/default_qubit.py", line 451, in get_prng_keys
    self._prng_key, *keys = jax_random_split(self._prng_key)
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.12/site-packages/pennylane/devices/qubit/sampling.py", line 42, in jax_random_split
    return split(prng_key, num=num)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.12/site-packages/jax/_src/random.py", line 292, in split
    typed_key, wrapped = _check_prng_key("split", key)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.12/site-packages/jax/_src/random.py", line 79, in _check_prng_key
    wrapped_key = prng.random_wrap(key, impl=default_prng_impl())
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.12/site-packages/jax/_src/prng.py", line 696, in random_wrap
    return random_wrap_p.bind(base_arr, impl=impl)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.12/site-packages/jax/_src/core.py", line 438, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
                                ^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.12/site-packages/jax/_src/core.py", line 1391, in find_top_trace
    top_tracer._assert_live()
  File "/home/.../lib/python3.12/site-packages/jax/_src/interpreters/partial_eval.py", line 1658, in _assert_live
    raise core.escaped_tracer_error(self, None)
jax.errors.UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type uint32[2] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was sample_circuit at /home/.../temp.py:44 traced for jit.
------------------------------
The leaked intermediate value was created on line /home/.../temp.py:62:6 (<module>).
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
/snap/pycharm-professional/443/plugins/python-ce/helpers/pydev/_pydev_bundle/pydev_ipython_console_011.py:438:26 (_PyDevIPythonFrontEnd.add_exec)
<ipython-input-2-9a5282ec6b15>:1 (<module>)
/snap/pycharm-professional/443/plugins/python-ce/helpers/pydev/_pydev_bundle/pydev_umd.py:197:12 (runfile)
/snap/pycharm-professional/443/plugins/python-ce/helpers/pydev/_pydev_imps/_pydev_execfile.py:18:4 (execfile)
/home/.../temp.py:62:6 (<module>)
------------------------------
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError
Traceback (most recent call last):
  File "/home/.../lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-2-9a5282ec6b15>", line 1, in <module>
    runfile('/home/.../temp.py', wdir='/home/...')
  File "/snap/pycharm-professional/443/plugins/python-ce/helpers/pydev/_pydev_bundle/pydev_umd.py", line 197, in runfile
    pydev_imports.execfile(filename, global_vars, local_vars)  # execute the script
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/snap/pycharm-professional/443/plugins/python-ce/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/home/.../temp.py", line 62, in <module>
    print(sample_circuit(phi, theta, key))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.12/site-packages/jax/_src/pjit.py", line 338, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
                                                                ^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.12/site-packages/jax/_src/pjit.py", line 188, in _python_pjit_helper
    out_flat = pjit_p.bind(*args_flat, **p.params)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.12/site-packages/jax/_src/core.py", line 2803, in bind
    return self.bind_with_trace(top_trace, args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.12/site-packages/jax/_src/core.py", line 442, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.12/site-packages/jax/_src/core.py", line 955, in process_primitive
    return primitive.impl(*tracers, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.12/site-packages/jax/_src/pjit.py", line 1738, in _pjit_call_impl
    return xc._xla.pjit(
           ^^^^^^^^^^^^^
  File "/home/.../lib/python3.12/site-packages/jax/_src/pjit.py", line 1714, in call_impl_cache_miss
    out_flat, compiled = _pjit_call_impl_python(
                         ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.12/site-packages/jax/_src/pjit.py", line 1668, in _pjit_call_impl_python
    return compiled.unsafe_call(*args), compiled
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.12/site-packages/jax/_src/profiler.py", line 333, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 1278, in __call__
    results = self.xla_executable.execute_sharded(
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: CpuCallback error: Traceback (most recent call last):
  File "/snap/pycharm-professional/443/plugins/python-ce/helpers/pydev/pydevconsole.py", line 570, in <module>
  File "/snap/pycharm-professional/443/plugins/python-ce/helpers/pydev/pydevconsole.py", line 498, in start_client
  File "/snap/pycharm-professional/443/plugins/python-ce/helpers/pydev/pydevconsole.py", line 287, in process_exec_queue
  File "/snap/pycharm-professional/443/plugins/python-ce/helpers/pydev/_pydev_bundle/pydev_code_executor.py", line 109, in add_exec
  File "/snap/pycharm-professional/443/plugins/python-ce/helpers/pydev/_pydev_bundle/pydev_ipython_console.py", line 34, in do_add_exec
  File "/snap/pycharm-professional/443/plugins/python-ce/helpers/pydev/_pydev_bundle/pydev_ipython_console_011.py", line 438, in add_exec
  File "/home/.../lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3075, in run_cell
  File "/home/.../lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3130, in _run_cell
  File "/home/.../lib/python3.12/site-packages/IPython/core/async_helpers.py", line 128, in _pseudo_sync_runner
  File "/home/.../lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3334, in run_cell_async
  File "/home/.../lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3517, in run_ast_nodes
  File "/home/.../lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code
  File "<ipython-input-2-9a5282ec6b15>", line 1, in <module>
  File "/snap/pycharm-professional/443/plugins/python-ce/helpers/pydev/_pydev_bundle/pydev_umd.py", line 197, in runfile
  File "/snap/pycharm-professional/443/plugins/python-ce/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
  File "/home/.../temp.py", line 62, in <module>
  File "/home/.../lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
  File "/home/.../lib/python3.12/site-packages/jax/_src/pjit.py", line 338, in cache_miss
  File "/home/.../lib/python3.12/site-packages/jax/_src/pjit.py", line 188, in _python_pjit_helper
  File "/home/.../lib/python3.12/site-packages/jax/_src/core.py", line 2803, in bind
  File "/home/.../lib/python3.12/site-packages/jax/_src/core.py", line 442, in bind_with_trace
  File "/home/.../lib/python3.12/site-packages/jax/_src/core.py", line 955, in process_primitive
  File "/home/.../lib/python3.12/site-packages/jax/_src/pjit.py", line 1738, in _pjit_call_impl
  File "/home/.../lib/python3.12/site-packages/jax/_src/pjit.py", line 1714, in call_impl_cache_miss
  File "/home/.../lib/python3.12/site-packages/jax/_src/pjit.py", line 1668, in _pjit_call_impl_python
  File "/home/.../lib/python3.12/site-packages/jax/_src/profiler.py", line 333, in wrapper
  File "/home/.../lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 1278, in __call__
  File "/home/.../lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 2768, in _wrapped_callback
  File "/home/.../lib/python3.12/site-packages/jax/_src/callback.py", line 269, in _callback
  File "/home/.../lib/python3.12/site-packages/jax/_src/callback.py", line 97, in pure_callback_impl
  File "/home/.../lib/python3.12/site-packages/jax/_src/callback.py", line 71, in __call__
  File "/home/.../lib/python3.12/site-packages/pennylane/workflow/interfaces/jax_jit.py", line 168, in pure_callback_wrapper
  File "/home/.../lib/python3.12/site-packages/pennylane/workflow/execution.py", line 202, in inner_execute
  File "/home/.../lib/python3.12/site-packages/pennylane/devices/modifiers/simulator_tracking.py", line 30, in execute
  File "/home/.../lib/python3.12/site-packages/pennylane/devices/modifiers/single_tape_support.py", line 32, in execute
  File "/home/.../lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
  File "/home/.../lib/python3.12/site-packages/pennylane/devices/default_qubit.py", line 639, in execute
  File "/home/.../lib/python3.12/site-packages/pennylane/devices/default_qubit.py", line 451, in get_prng_keys
  File "/home/.../lib/python3.12/site-packages/pennylane/devices/qubit/sampling.py", line 42, in jax_random_split
  File "/home/.../lib/python3.12/site-packages/jax/_src/random.py", line 292, in split
  File "/home/.../lib/python3.12/site-packages/jax/_src/random.py", line 79, in _check_prng_key
  File "/home/.../lib/python3.12/site-packages/jax/_src/prng.py", line 696, in random_wrap
  File "/home/.../lib/python3.12/site-packages/jax/_src/core.py", line 438, in bind
  File "/home/.../lib/python3.12/site-packages/jax/_src/core.py", line 1391, in find_top_trace
  File "/home/.../lib/python3.12/site-packages/jax/_src/interpreters/partial_eval.py", line 1658, in _assert_live
UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type uint32[2] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was sample_circuit at /home/.../temp.py:44 traced for jit.
------------------------------
The leaked intermediate value was created on line /home/.../temp.py:62:6 (<module>).
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
/snap/pycharm-professional/443/plugins/python-ce/helpers/pydev/_pydev_bundle/pydev_ipython_console_011.py:438:26 (_PyDevIPythonFrontEnd.add_exec)
<ipython-input-2-9a5282ec6b15>:1 (<module>)
/snap/pycharm-professional/443/plugins/python-ce/helpers/pydev/_pydev_bundle/pydev_umd.py:197:12 (runfile)
/snap/pycharm-professional/443/plugins/python-ce/helpers/pydev/_pydev_imps/_pydev_execfile.py:18:4 (execfile)
/home/.../temp.py:62:6 (<module>)
------------------------------
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError

The same code works if we replace shots=10 with shots=None
For redundancy:

import pennylane as qml
import numpy as np
import jax

jax.config.update("jax_enable_x64", True)
jax.config.update('jax_platform_name', 'cpu')


@jax.jit
def sample_circuit(phi, theta, key):
    dev = qml.device('default.qubit', wires=2, seed=key, shots=None)

    @qml.qnode(dev, interface='jax')
    def circuit(phi, theta):
        qml.RX(phi[0], wires=0)
        qml.RZ(phi[1], wires=1)
        qml.CNOT(wires=[0, 1])
        qml.RX(theta[0], wires=0)
        return qml.expval(qml.PauliZ(0))

    return circuit(phi, theta)


phi = np.array([0.2, 1.0])
theta = np.array([0.2])
key = jax.random.PRNGKey(10)
print(sample_circuit(phi, theta, key))
print(sample_circuit(phi, theta, key))
print(sample_circuit(phi, theta, key))
print(jax.grad(sample_circuit)(phi, theta, key))

The output is:

0.9605304970014428
0.9605304970014428
0.9605304970014428
[-1.94709171e-01 -1.48426226e-18]

Info about the setup:

Name: PennyLane
Version: 0.39.0
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page: https://github.com/PennyLaneAI/pennylane
Author: 
Author-email: 
License: Apache License 2.0
Location: /home/.../lib/python3.12/site-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, toml, typing-extensions
Required-by: PennyLane_Lightning
Platform info:           Linux-6.8.0-50-generic-x86_64-with-glibc2.39
Python version:          3.12.0
Numpy version:           2.0.2
Scipy version:           1.14.1
Installed devices:
- lightning.qubit (PennyLane_Lightning-0.39.0)
- default.clifford (PennyLane-0.39.0)
- default.gaussian (PennyLane-0.39.0)
- default.mixed (PennyLane-0.39.0)
- default.qubit (PennyLane-0.39.0)
- default.qutrit (PennyLane-0.39.0)
- default.qutrit.mixed (PennyLane-0.39.0)
- default.tensor (PennyLane-0.39.0)
- null.qubit (PennyLane-0.39.0)
- reference.qubit (PennyLane-0.39.0)

and

jax==0.4.35
jax-cuda12-pjrt==0.4.35
jax-cuda12-plugin==0.4.35
jaxlib==0.4.34

FYI @CatalinaAlbornoz

I have tried this example on two different machines with and without GPU and I get the same result.

Thanks for posting this here @ankit27kh .

The problem is resolved by using a fixed seed in the device (e.g. seed=42) instead of seed=key.

This works with finite shots but I agree that it should ideally work with a PRNGKey.

I’ll share this with the team. Thanks for bringing it up!

Is using a fixed seed enough for you right now or do you still need to use a key?

Yes, @CatalinaAlbornoz, trying with a fixed seed worked. I am not sure if a fixed seed will behave correctly for what I want to do. So I’ll test it out. But at least the code is working now.

Another error comes when you try to send just some integers as the key as part of the function call instead of the PRNGKey. So print(sample_circuit(phi, theta, jax.random.PRNGKey(10))) gives the leak error and print(sample_circuit(phi, theta, 10)) throws another error:

TypeError: JAX encountered invalid PRNG key data: expected key_data.ndim >= 1; got ndim=0

So, the current way to get it to work is to fix a seed in the device itself.

Clearly, a few things are wrong with the current implementation. It’ll be great if the team can look into this.

Hi @ankit27kh ,

Yes, you cannot pass an integer as the key.

I’ve forwarded the info to the team and also added a comment on this issue (which is similar).

So at the moment I think we can say differentiation is not supported with jax-jit when using shots unless you provide a numerical key directly to the device (or not seed at all if you don’t care about reproducibility).

Note that differentiation with samples isn’t allowed even without JAX since they give stochastic results and are thus non-differentiable.

Hi @ankit27kh ,

My colleague Mudit managed to find a solution by using static_argnums!

import pennylane as qml
import numpy as np
import jax
from functools import partial

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")


@partial(jax.jit, static_argnums=2)
def sample_circuit(phi, theta, key):
    dev = qml.device("default.qubit", wires=2, seed=key, shots=100)

    @qml.qnode(dev, interface="jax")
    def circuit(phi, theta):
        qml.RX(phi[0], wires=0)
        qml.RZ(phi[1], wires=1)
        qml.CNOT(wires=[0, 1])
        qml.RX(theta[0], wires=0)
        return qml.expval(qml.PauliZ(0))

    return circuit(phi, theta)


phi = np.array([0.2, 1.0])
theta = np.array([0.2])
key = 10
print(sample_circuit(phi, theta, key))
print(sample_circuit(phi, theta, key))
print(sample_circuit(phi, theta, key))
print(jax.grad(sample_circuit)(phi, theta, key))

Here’s Mudit’s explanation:

Here, I’ve set static_argnums=2 so that key is static, and the input key is just an integer, not a PRNGKey. Without static_argnums, the key will be a jax tracer during tracing, so it will be assumed to be a PRNGKey because it is “jax-like”. However, since it’s not actually a PRNGKey, the execution fails. With static_argnums=2 and an integer key, the key will have a concrete value, and since it’s not “jax-like”, it will be assumed to be a numpy seed, which is what then gets used for sampling.
The downside of using static_argnums is that every time a different value for key is provided, the circuit will get re-compiled, so it will introduce inefficiencies. But, that is the only way to get the workflow to work

1 Like

Thanks, @CatalinaAlbornoz, for the update. I’ll test out these two ways and report back on how they behave in a real application.

1 Like