Running quantum circuit in batches using jax.vmap on lighting.gpu device

Good evening.
I am trying to learn about simulating quantum circuits on the GPU. I was following this demo - Using JAX with PennyLane | PennyLane Demos and tried to modify the device the circuit is run. Specifically, the code I am running is:

import jax
import jax.numpy as jnp
import pennylane as qml

# Added to silence some warnings.
jax.config.update("jax_enable_x64", True)

dev = qml.device("lightning.gpu", wires=2)
@qml.qnode(dev, interface="jax")
def circuit(param):
    # These two gates represent our QML model.
    qml.RX(param, wires=0)
    qml.CNOT(wires=[0, 1])

    # The expval here will be the "cost function" we try to minimize.
    # Usually, this would be defined by the problem we want to solve,
    # but for this example we'll just use a single PauliZ.
    return qml.expval(qml.PauliZ(0))

print("\n\nBatching and Evolutionary Strategies")
print("------------------------------------")

# Create a vectorized version of our original circuit.
vcircuit = jax.vmap(circuit)

# Now, we call the ``vcircuit`` with multiple parameters at once and get back a
# batch of expectations.
# This examples runs 3 quantum circuits in parallel.
batch_params = jnp.array([1.02, 0.123, -0.571])

batched_results = vcircuit(batch_params)
print(f"Batched result: {batched_results}")

and I see the following error trace:


---------------------------------------------------------------------------
ConcretizationTypeError                   Traceback (most recent call last)
File ~/Research-Code/nequa/.venv/lib/python3.10/site-packages/pennylane/math/single_dispatch.py:876, in _to_numpy_jax(x)
    875 try:
--> 876     x = concrete_or_error(None, x)
    877     return np.array(x)

File ~/Research-Code/nequa/.venv/lib/python3.10/site-packages/jax/_src/core.py:1603, in concrete_or_error(force, val, context)
   1602 if maybe_concrete is None:
-> 1603   raise ConcretizationTypeError(val, context)
   1604 else:

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape float64[]

This BatchTracer with object id 140399171717936 was created on line:
  /home/ashutosh/Research-Code/nequa/.venv/lib/python3.10/site-packages/pennylane/workflow/interfaces/jax.py:230 (jax_jvp_execute)

See https://docs.jax.dev/en/latest/errors.html#jax.errors.ConcretizationTypeError

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
Cell In[1], line 30
     25 # Now, we call the ``vcircuit`` with multiple parameters at once and get back a
     26 # batch of expectations.
     27 # This examples runs 3 quantum circuits in parallel.
     28 batch_params = jnp.array([1.02, 0.123, -0.571])
---> 30 batched_results = vcircuit(batch_params)
     31 print(f"Batched result: {batched_results}")

    [... skipping hidden 7 frame]

File ~/Research-Code/nequa/.venv/lib/python3.10/site-packages/pennylane/workflow/qnode.py:922, in QNode.__call__(self, *args, **kwargs)
    919     from ._capture_qnode import capture_qnode  # pylint: disable=import-outside-toplevel
    921     return capture_qnode(self, *args, **kwargs)
--> 922 return self._impl_call(*args, **kwargs)

File ~/Research-Code/nequa/.venv/lib/python3.10/site-packages/pennylane/workflow/qnode.py:895, in QNode._impl_call(self, *args, **kwargs)
    892 # Calculate the classical jacobians if necessary
    893 self._transform_program.set_classical_component(self, args, kwargs)
--> 895 res = execute(
    896     (tape,),
    897     device=self.device,
    898     diff_method=self.diff_method,
    899     interface=self.interface,
    900     transform_program=self._transform_program,
    901     gradient_kwargs=self.gradient_kwargs,
    902     **self.execute_kwargs,
    903 )
    904 res = res[0]
    906 # convert result to the interface in case the qfunc has no parameters

File ~/Research-Code/nequa/.venv/lib/python3.10/site-packages/pennylane/workflow/execution.py:233, in execute(tapes, device, diff_method, interface, grad_on_execution, cache, cachesize, max_diff, device_vjp, postselect_mode, mcm_method, gradient_kwargs, transform_program, executor_backend)
    229 tapes, outer_post_processing = outer_transform(tapes)
    231 assert not outer_transform.is_informative, "should only contain device preprocessing"
--> 233 results = run(tapes, device, config, inner_transform)
    234 return user_post_processing(outer_post_processing(results))

File ~/Research-Code/nequa/.venv/lib/python3.10/site-packages/pennylane/workflow/run.py:338, in run(tapes, device, config, inner_transform_program)
    335         params = tape.get_parameters(trainable_only=False)
    336         tape.trainable_params = qml.math.get_trainable_indices(params)
--> 338 results = ml_execute(tapes, execute_fn, jpc, device=device)
    339 return results

File ~/Research-Code/nequa/.venv/lib/python3.10/site-packages/pennylane/workflow/interfaces/jax.py:230, in jax_jvp_execute(tapes, execute_fn, jpc, device)
    226     logger.debug("Entry with (tapes=%s, execute_fn=%s, jpc=%s)", tapes, execute_fn, jpc)
    228 parameters = tuple(tuple(t.get_parameters()) for t in tapes)
--> 230 return _execute_jvp(parameters, _NonPytreeWrapper(tuple(tapes)), execute_fn, jpc)

    [... skipping hidden 13 frame]

File ~/Research-Code/nequa/.venv/lib/python3.10/site-packages/pennylane/workflow/interfaces/jax.py:195, in _execute_wrapper(params, tapes, execute_fn, jpc)
    193 """Executes ``tapes`` with ``params`` via ``execute_fn``"""
    194 new_tapes = set_parameters_on_copy_and_unwrap(tapes.vals, params, unwrap=False)
--> 195 return _to_jax(execute_fn(new_tapes))

File ~/Research-Code/nequa/.venv/lib/python3.10/site-packages/pennylane/workflow/jacobian_products.py:487, in DeviceDerivatives.execute_and_cache_jacobian(self, tapes)
    485 if logger.isEnabledFor(logging.DEBUG):  # pragma: no cover
    486     logger.debug("Forward pass called with %s", tapes)
--> 487 results, jac = self._dev_execute_and_compute_derivatives(tapes)
    488 self._results_cache[tapes] = results
    489 self._jacs_cache[tapes] = jac

File ~/Research-Code/nequa/.venv/lib/python3.10/site-packages/pennylane/workflow/jacobian_products.py:451, in DeviceDerivatives._dev_execute_and_compute_derivatives(self, tapes)
    445 def _dev_execute_and_compute_derivatives(self, tapes: QuantumScriptBatch):
    446     """
    447     Converts tapes to numpy before computing the the results and derivatives on the device.
    448 
    449     Dispatches between the two different device interfaces.
    450     """
--> 451     numpy_tapes, _ = qml.transforms.convert_to_numpy_parameters(tapes)
    452     return self._device.execute_and_compute_derivatives(numpy_tapes, self._execution_config)

File ~/Research-Code/nequa/.venv/lib/python3.10/site-packages/pennylane/transforms/core/transform_dispatcher.py:281, in TransformDispatcher.__call__(self, *targs, **tkwargs)
    278     return self._qfunc_transform(obj, targs, tkwargs)
    280 if isinstance(obj, Sequence) and all(isinstance(q, qml.tape.QuantumScript) for q in obj):
--> 281     return self._batch_transform(obj, targs, tkwargs)
    283 # Input is not a QNode nor a quantum tape nor a device.
    284 # Assume Python decorator syntax:
    285 #
    286 # result = some_transform(*transform_args)(qnode)(*qnode_args)
    288 raise TransformError(
    289     "Decorating a QNode with @transform_fn(**transform_kwargs) has been "
    290     "removed. Please decorate with @functools.partial(transform_fn, **transform_kwargs) "
   (...)
    293     "https://docs.pennylane.ai/en/stable/development/deprecations.html#completed-deprecation-cycles",
    294 )

File ~/Research-Code/nequa/.venv/lib/python3.10/site-packages/pennylane/transforms/core/transform_dispatcher.py:489, in TransformDispatcher._batch_transform(self, original_batch, targs, tkwargs)
    483 tape_counts = []
    485 for t in original_batch:
    486     # Preprocess the tapes by applying transforms
    487     # to each tape, and storing corresponding tapes
    488     # for execution, processing functions, and list of tape lengths.
--> 489     new_tapes, fn = self(t, *targs, **tkwargs)
    490     execution_tapes.extend(new_tapes)
    491     batch_fns.append(fn)

File ~/Research-Code/nequa/.venv/lib/python3.10/site-packages/pennylane/transforms/core/transform_dispatcher.py:253, in TransformDispatcher.__call__(self, *targs, **tkwargs)
    250         return expand_processing(processed_results)
    252 else:
--> 253     transformed_tapes, processing_fn = self._transform(obj, *targs, **tkwargs)
    255 if self.is_informative:
    256     return processing_fn(transformed_tapes)

File ~/Research-Code/nequa/.venv/lib/python3.10/site-packages/pennylane/transforms/convert_to_numpy_parameters.py:84, in convert_to_numpy_parameters(tape)
     82 new_ops = (_convert_op_to_numpy_data(op) for op in tape.operations)
     83 new_measurements = (_convert_measurement_to_numpy_data(m) for m in tape.measurements)
---> 84 new_circuit = tape.__class__(
     85     new_ops, new_measurements, shots=tape.shots, trainable_params=tape.trainable_params
     86 )
     88 def null_postprocessing(results):
     89     """A postprocesing function returned by a transform that only converts the batch of results
     90     into a result for a single ``QuantumTape``.
     91     """

File ~/Research-Code/nequa/.venv/lib/python3.10/site-packages/pennylane/tape/qscript.py:194, in QuantumScript.__init__(self, ops, measurements, shots, trainable_params)
    187 def __init__(
    188     self,
    189     ops: Optional[Iterable[Operator]] = None,
   (...)
    192     trainable_params: Optional[Sequence[int]] = None,
    193 ):
--> 194     self._ops = [] if ops is None else list(ops)
    195     self._measurements = [] if measurements is None else list(measurements)
    196     self._shots = Shots(shots)

File ~/Research-Code/nequa/.venv/lib/python3.10/site-packages/pennylane/transforms/convert_to_numpy_parameters.py:82, in <genexpr>(.0)
     47 @transform
     48 def convert_to_numpy_parameters(tape: QuantumScript) -> tuple[QuantumScriptBatch, PostprocessingFn]:
     49     """Transforms a circuit to one with purely numpy parameters.
     50 
     51     Args:
   (...)
     80 
     81     """
---> 82     new_ops = (_convert_op_to_numpy_data(op) for op in tape.operations)
     83     new_measurements = (_convert_measurement_to_numpy_data(m) for m in tape.measurements)
     84     new_circuit = tape.__class__(
     85         new_ops, new_measurements, shots=tape.shots, trainable_params=tape.trainable_params
     86     )

File ~/Research-Code/nequa/.venv/lib/python3.10/site-packages/pennylane/transforms/convert_to_numpy_parameters.py:30, in _convert_op_to_numpy_data(op)
     28     return op
     29 # Use operator method to change parameters when it become available
---> 30 return qml.ops.functions.bind_new_parameters(op, math.unwrap(op.data))

File ~/Research-Code/nequa/.venv/lib/python3.10/site-packages/pennylane/math/multi_dispatch.py:814, in unwrap(values, max_depth)
    811     return new_val.tolist() if isinstance(new_val, ndarray) and not new_val.shape else new_val
    813 if isinstance(values, (tuple, list)):
--> 814     return type(values)(convert(val) for val in values)
    815 return (
    816     np.to_numpy(values, max_depth=max_depth)
    817     if isinstance(values, ArrayBox)
    818     else np.to_numpy(values)
    819 )

File ~/Research-Code/nequa/.venv/lib/python3.10/site-packages/pennylane/math/multi_dispatch.py:814, in <genexpr>(.0)
    811     return new_val.tolist() if isinstance(new_val, ndarray) and not new_val.shape else new_val
    813 if isinstance(values, (tuple, list)):
--> 814     return type(values)(convert(val) for val in values)
    815 return (
    816     np.to_numpy(values, max_depth=max_depth)
    817     if isinstance(values, ArrayBox)
    818     else np.to_numpy(values)
    819 )

File ~/Research-Code/nequa/.venv/lib/python3.10/site-packages/pennylane/math/multi_dispatch.py:809, in unwrap.<locals>.convert(val)
    806 if isinstance(val, (tuple, list)):
    807     return unwrap(val)
    808 new_val = (
--> 809     np.to_numpy(val, max_depth=max_depth) if isinstance(val, ArrayBox) else np.to_numpy(val)
    810 )
    811 return new_val.tolist() if isinstance(new_val, ndarray) and not new_val.shape else new_val

File ~/Research-Code/nequa/.venv/lib/python3.10/site-packages/autoray/autoray.py:81, in do(fn, like, *args, **kwargs)
     79 backend = _choose_backend(fn, args, kwargs, like=like)
     80 func = get_lib_fn(backend, fn)
---> 81 return func(*args, **kwargs)

File ~/Research-Code/nequa/.venv/lib/python3.10/site-packages/pennylane/math/single_dispatch.py:879, in _to_numpy_jax(x)
    877     return np.array(x)
    878 except (ConcretizationTypeError, TracerArrayConversionError) as e:
--> 879     raise ValueError(
    880         "Converting a JAX array to a NumPy array not supported when using the JAX JIT."
    881     ) from e

ValueError: Converting a JAX array to a NumPy array not supported when using the JAX JIT.

I am not sure where this error originates from, as I am sure I am passing jnp arrays to the circuit.

Note: I tried to run the code with lightning.qubit and got the same error.

Thank you for your help!

Hi @ashutiw2k , welcome to the Forum!

It looks like an issue with batching.
I’m not sure if this will work but you could try using @qml.transforms.broadcast_expand right above your qnode.

Let us know if this solves the issue!

Hi,
Thank you for your reply, but it looks like qml.broadcast and any associated decorators have been deprecated. (Deprecations — PennyLane 0.42.1 documentation).

---------------------------------------------------------------------------
TransformError                            Traceback (most recent call last)
Cell In[1], line 10
      6 jax.config.update("jax_enable_x64", True)
      8 dev = qml.device("lightning.gpu", wires=2)
---> 10 @qml.transforms.broadcast_expand()
     11 @qml.qnode(dev, interface="jax")
     12 def circuit(param):
     13     # These two gates represent our QML model.
     14     qml.RX(param, wires=0)
     15     qml.CNOT(wires=[0, 1])

File ~/Research-Code/nequa/.venv/lib/python3.10/site-packages/pennylane/transforms/core/transform_dispatcher.py:288, in TransformDispatcher.__call__(self, *targs, **tkwargs)
    281     return self._batch_transform(obj, targs, tkwargs)
    283 # Input is not a QNode nor a quantum tape nor a device.
    284 # Assume Python decorator syntax:
    285 #
    286 # result = some_transform(*transform_args)(qnode)(*qnode_args)
--> 288 raise TransformError(
    289     "Decorating a QNode with @transform_fn(**transform_kwargs) has been "
    290     "removed. Please decorate with @functools.partial(transform_fn, **transform_kwargs) "
    291     "instead, or call the transform directly using qnode = transform_fn(qnode, "
    292     "**transform_kwargs). Visit the deprecations page for more details: "
    293     "https://docs.pennylane.ai/en/stable/development/deprecations.html#completed-deprecation-cycles",
    294 )

TransformError: Decorating a QNode with @transform_fn(**transform_kwargs) has been removed. Please decorate with @functools.partial(transform_fn, **transform_kwargs) instead, or call the transform directly using qnode = transform_fn(qnode, **transform_kwargs). Visit the deprecations page for more details: https://docs.pennylane.ai/en/stable/development/deprecations.html#completed-deprecation-cycles

@CatalinaAlbornoz do you have any other suggestions on how to achieve batching and utilizing the entire GPU to simulate circuits?

Hi @ashutiw2k , thanks for pointing it out.

Let me check what other options we have and get back to you.

1 Like

Hi @CatalinaAlbornoz ,

So looks like the solution is to add the “@qml.qjit” decorator to my function?

My code now looks like

import jax
import jax.numpy as jnp
import pennylane as qml

# Added to silence some warnings.
jax.config.update("jax_enable_x64", True)

dev = qml.device("lightning.gpu", wires=2)

@qml.qjit(autograph=True) # Added this now
@qml.qnode(dev, interface="jax")
def circuit(param):
    # These two gates represent our QML model.
    qml.RX(param, wires=0)
    qml.CNOT(wires=[0, 1])

    # The expval here will be the "cost function" we try to minimize.
    # Usually, this would be defined by the problem we want to solve,
    # but for this example we'll just use a single PauliZ.
    return qml.expval(qml.PauliZ(0))

print("\n\nBatching and Evolutionary Strategies")
print("------------------------------------")

# Create a vectorized version of our original circuit.
vcircuit = jax.vmap(circuit)

# Now, we call the ``vcircuit`` with multiple parameters at once and get back a
# batch of expectations.
# This examples runs 3 quantum circuits in parallel.
batch_params = jnp.array([1.02, 0.123, -0.571])

batched_results = vcircuit(batch_params)
print(f"Batched result: {batched_results}")

And I get the expected results

Batching and Evolutionary Strategies
------------------------------------
Batched result: [0.52336595 0.99244503 0.84136092]

I can’t figure out why this would work, any insight is helpful!

Oh that’s interesting @ashutiw2k .

Using qjit means you’re using our Catalyst compiler. It’s based on JAX but I guess our new additions to it are making it more powerful for things like batching. We have a lot of new features in it so I guess this is one too! Catalyst is also open-source and part of the PennyLane ecosystem so it’s a good thing to start using it!

I’m glad you thought of using this and solved the issue!

1 Like