Good evening.
I am trying to learn about simulating quantum circuits on the GPU. I was following this demo - Using JAX with PennyLane | PennyLane Demos and tried to modify the device the circuit is run. Specifically, the code I am running is:
import jax
import jax.numpy as jnp
import pennylane as qml
# Added to silence some warnings.
jax.config.update("jax_enable_x64", True)
dev = qml.device("lightning.gpu", wires=2)
@qml.qnode(dev, interface="jax")
def circuit(param):
# These two gates represent our QML model.
qml.RX(param, wires=0)
qml.CNOT(wires=[0, 1])
# The expval here will be the "cost function" we try to minimize.
# Usually, this would be defined by the problem we want to solve,
# but for this example we'll just use a single PauliZ.
return qml.expval(qml.PauliZ(0))
print("\n\nBatching and Evolutionary Strategies")
print("------------------------------------")
# Create a vectorized version of our original circuit.
vcircuit = jax.vmap(circuit)
# Now, we call the ``vcircuit`` with multiple parameters at once and get back a
# batch of expectations.
# This examples runs 3 quantum circuits in parallel.
batch_params = jnp.array([1.02, 0.123, -0.571])
batched_results = vcircuit(batch_params)
print(f"Batched result: {batched_results}")
and I see the following error trace:
---------------------------------------------------------------------------
ConcretizationTypeError Traceback (most recent call last)
File ~/Research-Code/nequa/.venv/lib/python3.10/site-packages/pennylane/math/single_dispatch.py:876, in _to_numpy_jax(x)
875 try:
--> 876 x = concrete_or_error(None, x)
877 return np.array(x)
File ~/Research-Code/nequa/.venv/lib/python3.10/site-packages/jax/_src/core.py:1603, in concrete_or_error(force, val, context)
1602 if maybe_concrete is None:
-> 1603 raise ConcretizationTypeError(val, context)
1604 else:
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape float64[]
This BatchTracer with object id 140399171717936 was created on line:
/home/ashutosh/Research-Code/nequa/.venv/lib/python3.10/site-packages/pennylane/workflow/interfaces/jax.py:230 (jax_jvp_execute)
See https://docs.jax.dev/en/latest/errors.html#jax.errors.ConcretizationTypeError
The above exception was the direct cause of the following exception:
ValueError Traceback (most recent call last)
Cell In[1], line 30
25 # Now, we call the ``vcircuit`` with multiple parameters at once and get back a
26 # batch of expectations.
27 # This examples runs 3 quantum circuits in parallel.
28 batch_params = jnp.array([1.02, 0.123, -0.571])
---> 30 batched_results = vcircuit(batch_params)
31 print(f"Batched result: {batched_results}")
[... skipping hidden 7 frame]
File ~/Research-Code/nequa/.venv/lib/python3.10/site-packages/pennylane/workflow/qnode.py:922, in QNode.__call__(self, *args, **kwargs)
919 from ._capture_qnode import capture_qnode # pylint: disable=import-outside-toplevel
921 return capture_qnode(self, *args, **kwargs)
--> 922 return self._impl_call(*args, **kwargs)
File ~/Research-Code/nequa/.venv/lib/python3.10/site-packages/pennylane/workflow/qnode.py:895, in QNode._impl_call(self, *args, **kwargs)
892 # Calculate the classical jacobians if necessary
893 self._transform_program.set_classical_component(self, args, kwargs)
--> 895 res = execute(
896 (tape,),
897 device=self.device,
898 diff_method=self.diff_method,
899 interface=self.interface,
900 transform_program=self._transform_program,
901 gradient_kwargs=self.gradient_kwargs,
902 **self.execute_kwargs,
903 )
904 res = res[0]
906 # convert result to the interface in case the qfunc has no parameters
File ~/Research-Code/nequa/.venv/lib/python3.10/site-packages/pennylane/workflow/execution.py:233, in execute(tapes, device, diff_method, interface, grad_on_execution, cache, cachesize, max_diff, device_vjp, postselect_mode, mcm_method, gradient_kwargs, transform_program, executor_backend)
229 tapes, outer_post_processing = outer_transform(tapes)
231 assert not outer_transform.is_informative, "should only contain device preprocessing"
--> 233 results = run(tapes, device, config, inner_transform)
234 return user_post_processing(outer_post_processing(results))
File ~/Research-Code/nequa/.venv/lib/python3.10/site-packages/pennylane/workflow/run.py:338, in run(tapes, device, config, inner_transform_program)
335 params = tape.get_parameters(trainable_only=False)
336 tape.trainable_params = qml.math.get_trainable_indices(params)
--> 338 results = ml_execute(tapes, execute_fn, jpc, device=device)
339 return results
File ~/Research-Code/nequa/.venv/lib/python3.10/site-packages/pennylane/workflow/interfaces/jax.py:230, in jax_jvp_execute(tapes, execute_fn, jpc, device)
226 logger.debug("Entry with (tapes=%s, execute_fn=%s, jpc=%s)", tapes, execute_fn, jpc)
228 parameters = tuple(tuple(t.get_parameters()) for t in tapes)
--> 230 return _execute_jvp(parameters, _NonPytreeWrapper(tuple(tapes)), execute_fn, jpc)
[... skipping hidden 13 frame]
File ~/Research-Code/nequa/.venv/lib/python3.10/site-packages/pennylane/workflow/interfaces/jax.py:195, in _execute_wrapper(params, tapes, execute_fn, jpc)
193 """Executes ``tapes`` with ``params`` via ``execute_fn``"""
194 new_tapes = set_parameters_on_copy_and_unwrap(tapes.vals, params, unwrap=False)
--> 195 return _to_jax(execute_fn(new_tapes))
File ~/Research-Code/nequa/.venv/lib/python3.10/site-packages/pennylane/workflow/jacobian_products.py:487, in DeviceDerivatives.execute_and_cache_jacobian(self, tapes)
485 if logger.isEnabledFor(logging.DEBUG): # pragma: no cover
486 logger.debug("Forward pass called with %s", tapes)
--> 487 results, jac = self._dev_execute_and_compute_derivatives(tapes)
488 self._results_cache[tapes] = results
489 self._jacs_cache[tapes] = jac
File ~/Research-Code/nequa/.venv/lib/python3.10/site-packages/pennylane/workflow/jacobian_products.py:451, in DeviceDerivatives._dev_execute_and_compute_derivatives(self, tapes)
445 def _dev_execute_and_compute_derivatives(self, tapes: QuantumScriptBatch):
446 """
447 Converts tapes to numpy before computing the the results and derivatives on the device.
448
449 Dispatches between the two different device interfaces.
450 """
--> 451 numpy_tapes, _ = qml.transforms.convert_to_numpy_parameters(tapes)
452 return self._device.execute_and_compute_derivatives(numpy_tapes, self._execution_config)
File ~/Research-Code/nequa/.venv/lib/python3.10/site-packages/pennylane/transforms/core/transform_dispatcher.py:281, in TransformDispatcher.__call__(self, *targs, **tkwargs)
278 return self._qfunc_transform(obj, targs, tkwargs)
280 if isinstance(obj, Sequence) and all(isinstance(q, qml.tape.QuantumScript) for q in obj):
--> 281 return self._batch_transform(obj, targs, tkwargs)
283 # Input is not a QNode nor a quantum tape nor a device.
284 # Assume Python decorator syntax:
285 #
286 # result = some_transform(*transform_args)(qnode)(*qnode_args)
288 raise TransformError(
289 "Decorating a QNode with @transform_fn(**transform_kwargs) has been "
290 "removed. Please decorate with @functools.partial(transform_fn, **transform_kwargs) "
(...)
293 "https://docs.pennylane.ai/en/stable/development/deprecations.html#completed-deprecation-cycles",
294 )
File ~/Research-Code/nequa/.venv/lib/python3.10/site-packages/pennylane/transforms/core/transform_dispatcher.py:489, in TransformDispatcher._batch_transform(self, original_batch, targs, tkwargs)
483 tape_counts = []
485 for t in original_batch:
486 # Preprocess the tapes by applying transforms
487 # to each tape, and storing corresponding tapes
488 # for execution, processing functions, and list of tape lengths.
--> 489 new_tapes, fn = self(t, *targs, **tkwargs)
490 execution_tapes.extend(new_tapes)
491 batch_fns.append(fn)
File ~/Research-Code/nequa/.venv/lib/python3.10/site-packages/pennylane/transforms/core/transform_dispatcher.py:253, in TransformDispatcher.__call__(self, *targs, **tkwargs)
250 return expand_processing(processed_results)
252 else:
--> 253 transformed_tapes, processing_fn = self._transform(obj, *targs, **tkwargs)
255 if self.is_informative:
256 return processing_fn(transformed_tapes)
File ~/Research-Code/nequa/.venv/lib/python3.10/site-packages/pennylane/transforms/convert_to_numpy_parameters.py:84, in convert_to_numpy_parameters(tape)
82 new_ops = (_convert_op_to_numpy_data(op) for op in tape.operations)
83 new_measurements = (_convert_measurement_to_numpy_data(m) for m in tape.measurements)
---> 84 new_circuit = tape.__class__(
85 new_ops, new_measurements, shots=tape.shots, trainable_params=tape.trainable_params
86 )
88 def null_postprocessing(results):
89 """A postprocesing function returned by a transform that only converts the batch of results
90 into a result for a single ``QuantumTape``.
91 """
File ~/Research-Code/nequa/.venv/lib/python3.10/site-packages/pennylane/tape/qscript.py:194, in QuantumScript.__init__(self, ops, measurements, shots, trainable_params)
187 def __init__(
188 self,
189 ops: Optional[Iterable[Operator]] = None,
(...)
192 trainable_params: Optional[Sequence[int]] = None,
193 ):
--> 194 self._ops = [] if ops is None else list(ops)
195 self._measurements = [] if measurements is None else list(measurements)
196 self._shots = Shots(shots)
File ~/Research-Code/nequa/.venv/lib/python3.10/site-packages/pennylane/transforms/convert_to_numpy_parameters.py:82, in <genexpr>(.0)
47 @transform
48 def convert_to_numpy_parameters(tape: QuantumScript) -> tuple[QuantumScriptBatch, PostprocessingFn]:
49 """Transforms a circuit to one with purely numpy parameters.
50
51 Args:
(...)
80
81 """
---> 82 new_ops = (_convert_op_to_numpy_data(op) for op in tape.operations)
83 new_measurements = (_convert_measurement_to_numpy_data(m) for m in tape.measurements)
84 new_circuit = tape.__class__(
85 new_ops, new_measurements, shots=tape.shots, trainable_params=tape.trainable_params
86 )
File ~/Research-Code/nequa/.venv/lib/python3.10/site-packages/pennylane/transforms/convert_to_numpy_parameters.py:30, in _convert_op_to_numpy_data(op)
28 return op
29 # Use operator method to change parameters when it become available
---> 30 return qml.ops.functions.bind_new_parameters(op, math.unwrap(op.data))
File ~/Research-Code/nequa/.venv/lib/python3.10/site-packages/pennylane/math/multi_dispatch.py:814, in unwrap(values, max_depth)
811 return new_val.tolist() if isinstance(new_val, ndarray) and not new_val.shape else new_val
813 if isinstance(values, (tuple, list)):
--> 814 return type(values)(convert(val) for val in values)
815 return (
816 np.to_numpy(values, max_depth=max_depth)
817 if isinstance(values, ArrayBox)
818 else np.to_numpy(values)
819 )
File ~/Research-Code/nequa/.venv/lib/python3.10/site-packages/pennylane/math/multi_dispatch.py:814, in <genexpr>(.0)
811 return new_val.tolist() if isinstance(new_val, ndarray) and not new_val.shape else new_val
813 if isinstance(values, (tuple, list)):
--> 814 return type(values)(convert(val) for val in values)
815 return (
816 np.to_numpy(values, max_depth=max_depth)
817 if isinstance(values, ArrayBox)
818 else np.to_numpy(values)
819 )
File ~/Research-Code/nequa/.venv/lib/python3.10/site-packages/pennylane/math/multi_dispatch.py:809, in unwrap.<locals>.convert(val)
806 if isinstance(val, (tuple, list)):
807 return unwrap(val)
808 new_val = (
--> 809 np.to_numpy(val, max_depth=max_depth) if isinstance(val, ArrayBox) else np.to_numpy(val)
810 )
811 return new_val.tolist() if isinstance(new_val, ndarray) and not new_val.shape else new_val
File ~/Research-Code/nequa/.venv/lib/python3.10/site-packages/autoray/autoray.py:81, in do(fn, like, *args, **kwargs)
79 backend = _choose_backend(fn, args, kwargs, like=like)
80 func = get_lib_fn(backend, fn)
---> 81 return func(*args, **kwargs)
File ~/Research-Code/nequa/.venv/lib/python3.10/site-packages/pennylane/math/single_dispatch.py:879, in _to_numpy_jax(x)
877 return np.array(x)
878 except (ConcretizationTypeError, TracerArrayConversionError) as e:
--> 879 raise ValueError(
880 "Converting a JAX array to a NumPy array not supported when using the JAX JIT."
881 ) from e
ValueError: Converting a JAX array to a NumPy array not supported when using the JAX JIT.
I am not sure where this error originates from, as I am sure I am passing jnp arrays to the circuit.
Note: I tried to run the code with lightning.qubit
and got the same error.
Thank you for your help!