Good morning Pennylane community. I have an issue I would like to discuss with you, hoping to get a solution. I am working on QAOA and I am trying to use JAX for faster simulation but I have a problem where I call the decorator @jax.jit
on a qnode
.
I have started from your tutorial on Maxcut and mi idea is to make a function where store the qnode since I will create multiple devices with different number of qubits each time to have a statistic behavior of the task depending on the number of qubits.
def mutable_qnode(device, new_params, graph, edge=None):
@jax.jit
@qml.qnode(device, interface="jax")
def qnode(new_params=new_params, graph=graph, edge=edge):
[qml.Hadamard(i) for i in range(qubits)]
for l in range(fixed_layers):
gamma_circuit(opt_params[l, 0], graph=graph)
beta_circuit(opt_params[l, 1])
# variational block
gamma_circuit(new_params[0], graph=graph)
beta_circuit(new_params[1])
'''if edge is None:
return qml.counts()'''
H = qml.PauliZ(edge[0]) @ qml.PauliZ(edge[1])
return qml.expval(H)
result = qnode(new_params, graph, edge=edge)
return result
I have never seem it before. My dream is to use jax jit to speed the simulation.
Many thanks in advance.
francescoaldoventurelli@AirdiFrancesco QAOA % /usr/local/bin/python3.10 /Users/francescoaldoventurelli/qml/QAOA/j
ax_unmodified.py
/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.
warnings.warn(
/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.
warnings.warn(
jax.pure_callback failed
Traceback (most recent call last):
File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/jax/_src/callback.py", line 86, in pure_callback_impl
return tree_util.tree_map(np.asarray, callback(*args))
File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/jax/_src/callback.py", line 64, in __call__
return tree_util.tree_leaves(self.callback_func(*args, **kwargs))
File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pennylane/workflow/interfaces/jax_jit.py", line 187, in wrapper
return _to_jax(jpc.execute_and_compute_jacobian(new_tapes))
File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pennylane/workflow/jacobian_products.py", line 312, in execute_and_compute_jacobian
jac_tapes, jac_postprocessing = self._gradient_transform(tapes, **self._gradient_kwargs)
File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pennylane/transforms/core/transform_dispatcher.py", line 135, in __call__
return self._batch_transform(obj, targs, tkwargs)
File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pennylane/transforms/core/transform_dispatcher.py", line 343, in _batch_transform
new_tapes, fn = self(t, *targs, **tkwargs)
File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pennylane/transforms/core/transform_dispatcher.py", line 100, in __call__
intermediate_tapes, post_processing_fn = self._transform(
File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pennylane/gradients/parameter_shift.py", line 1120, in param_shift
diff_methods = find_and_validate_gradient_methods(tape, method, trainable_params)
File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pennylane/gradients/gradient_transform.py", line 220, in find_and_validate_gradient_methods
diff_methods = _find_gradient_methods(tape, trainable_param_indices, use_graph=use_graph)
File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pennylane/gradients/gradient_transform.py", line 163, in _find_gradient_methods
return {
File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pennylane/gradients/gradient_transform.py", line 164, in <dictcomp>
idx: _try_zero_grad_from_graph_or_get_grad_method(
File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pennylane/gradients/gradient_transform.py", line 153, in _try_zero_grad_from_graph_or_get_grad_method
if not any(tape.graph.has_path(op_or_mp, mp) for mp in tape.measurements):
File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pennylane/gradients/gradient_transform.py", line 153, in <genexpr>
if not any(tape.graph.has_path(op_or_mp, mp) for mp in tape.measurements):
File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pennylane/tape/qscript.py", line 966, in graph
self._graph = qml.CircuitGraph(
File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pennylane/circuit_graph.py", line 127, in __init__
wire = wires.index(w)
File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pennylane/wires.py", line 250, in index
return self._labels.index(wire)
File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py", line 739, in op
return getattr(self.aval, f"_{name}")(self, *args)
File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py", line 265, in deferring_binary_op
return binary_op(*args)
File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/jax/_src/pjit.py", line 327, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/jax/_src/pjit.py", line 185, in _python_pjit_helper
out_flat = pjit_p.bind(*args_flat, **params)
File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/jax/_src/core.py", line 2829, in bind
top_trace = find_top_trace(args)
File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/jax/_src/core.py", line 1362, in find_top_trace
top_tracer._assert_live()
File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 1736, in _assert_live
raise core.escaped_tracer_error(self, None)
jax.errors.UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type int64[] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was qnode at /Users/francescoaldoventurelli/qml/QAOA/jax_unmodified.py:56 traced for jit.
------------------------------
The leaked intermediate value was created on line /Users/francescoaldoventurelli/qml/QAOA/jax_unmodified.py:73 (mutable_qnode).
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
------------------------------
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/Users/francescoaldoventurelli/qml/QAOA/jax_unmodified.py", line 188, in <module>
energy, counts, optimal_last_gamma_beta, ar = qaoa_execution(dev, graph1)
File "/Users/francescoaldoventurelli/qml/QAOA/jax_unmodified.py", line 156, in qaoa_execution
grads = jax.grad(obj_function)(params)
File "/Users/francescoaldoventurelli/qml/QAOA/jax_unmodified.py", line 147, in obj_function
cost -= 0.5 * (1 - mutable_qnode(device, new_params, graph, edge=edge))
File "/Users/francescoaldoventurelli/qml/QAOA/jax_unmodified.py", line 73, in mutable_qnode
result = qnode(new_params, graph, edge=edge)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Generated function failed: CpuCallback error: Traceback (most recent call last):
File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pennylane/gradients/gradient_transform.py", line 153, in <genexpr>
File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pennylane/tape/qscript.py", line 966, in graph
File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pennylane/circuit_graph.py", line 127, in __init__
File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pennylane/wires.py", line 251, in index
File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py", line 739, in op
File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py", line 265, in deferring_binary_op
File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/jax/_src/pjit.py", line 327, in cache_miss
File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/jax/_src/pjit.py", line 194, in _python_pjit_helper
File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/jax/_src/core.py", line 2829, in bind
File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/jax/_src/core.py", line 1362, in find_top_trace
File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 1736, in _assert_live
UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type int64[] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was qnode at /Users/francescoaldoventurelli/qml/QAOA/jax_unmodified.py:56 traced for jit.
------------------------------
The leaked intermediate value was created on line /Users/francescoaldoventurelli/qml/QAOA/jax_unmodified.py:73 (mutable_qnode).
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
------------------------------
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError
francescoaldoventurelli@AirdiFrancesco QAOA %
Pennylane versions:
`Name: PennyLane
Version: 0.37.0
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page: GitHub - PennyLaneAI/pennylane: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Author:
Author-email:
License: Apache License 2.0
Location: /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, semantic-version, toml, typing-extensions
Required-by: PennyLane_Lightning
Platform info: macOS-14.4.1-arm64-arm-64bit
Python version: 3.10.10
Numpy version: 1.26.4
Scipy version: 1.10.1
Installed devices:
- lightning.qubit (PennyLane_Lightning-0.37.0)
- default.clifford (PennyLane-0.37.0)
- default.gaussian (PennyLane-0.37.0)
- default.mixed (PennyLane-0.37.0)
- default.qubit (PennyLane-0.37.0)
- default.qubit.autograd (PennyLane-0.37.0)
- default.qubit.jax (PennyLane-0.37.0)
- default.qubit.legacy (PennyLane-0.37.0)
- default.qubit.tf (PennyLane-0.37.0)
- default.qubit.torch (PennyLane-0.37.0)
- default.qutrit (PennyLane-0.37.0)
- default.qutrit.mixed (PennyLane-0.37.0)
- default.tensor (PennyLane-0.37.0)
- null.qubit (PennyLane-0.37.0)
None`.