Issue using lightning device with JAX

Hello, I was wanting to run the following code with the lightning device from one of the demos but its running into the following error. It works okay with the default.qubit device, but I would like to work with more qubits and get a speedup with the lightning device if possible

from typing import Callable

import jax
import pennylane as qml
from flax import linen as nn
from jax import numpy as jnp
from jax import random as jrand

def make_circuit(dev, n_qubits):
    @qml.qnode(dev)
    def circuit(x, params):
        for i in range(n_qubits):
            qml.RY(x[i], wires=i)
        qml.BasicEntanglerLayers(params, wires=list(range(n_qubits)))
        return [qml.expval(qml.PauliZ(wires=i)) for i in range(n_qubits)]
    return jax.vmap(circuit, in_axes=(0, None))

class QuantumCircuit(nn.Module):
    num_qubits: int
    num_layers: int
    circuit: Callable

    @nn.compact
    def __call__(self, x):
        circuit_weights = self.param(
            'circuit_weights',
            nn.initializers.normal(),
            (self.num_layers, self.num_qubits),
        )
        x = self.circuit(x, circuit_weights)
        return x

num_qubits = 16
num_layers = 4
dev = qml.device('lightning.qubit', wires=num_qubits)
circuit = make_circuit(dev, num_qubits)

dqc = QuantumCircuit(
    circuit=circuit,
    num_qubits=num_qubits,
    num_layers=num_layers,
)
zero_image = jnp.empty((8, 4))
key = jrand.PRNGKey(42)
params = dqc.init(key, zero_image)

If you want help with diagnosing an error, please put the full error message below:

ValueError: Converting a JAX array to a NumPy array not supported when using the JAX JIT.

Versions:

Version: 0.42.1
lightning.qubit (pennylane_lightning-0.42.0)

Hi @apple_jay , welcome to the Forum!

Your code has a lot of added complexity. Could you please post a minimal reproducible example and the full error traceback please? This means reducing the example to the minimal components that cause the error. And for the traceback including all of the lines of the error, since this indicates the source of the problem.

By creating a minimal example you will probably also be able to debug your issue.

Let us know if you have any further questions.

Thank you, the issue seems to be with qml.RY(x[i], wires=i) since commenting that out removes the issue.

Code:
from typing import Callable

import jax
import pennylane as qml
from flax import linen as nn
from jax import numpy as jnp
from jax import random as jrand

def make_circuit(dev, n_qubits):
    @qml.qnode(dev)
    def circuit(x):
        for i in range(n_qubits):
            qml.RY(x[i], wires=i)
        return qml.expval(qml.PauliZ(wires=0))
    return jax.vmap(circuit, in_axes=0)

class QuantumCircuit(nn.Module):
    num_qubits: int
    circuit: Callable

    @nn.compact
    def __call__(self, x):
        x = self.circuit(x)
        return x

dev = qml.device('lightning.qubit', wires=16)
circuit = make_circuit(dev, num_qubits)

qc = QuantumCircuit(
    circuit=circuit,
    num_qubits=num_qubits,
)
zero_x = jnp.empty((8, 4))
key = jrand.PRNGKey(42)
params = qc.init(key, zero_x)

Here’s the full stack trace:

Stack Trace
ConcretizationTypeError                   Traceback (most recent call last)
File /opt/anaconda3/envs/model/lib/python3.12/site-packages/pennylane/math/single_dispatch.py:876, in _to_numpy_jax(x)
875 try:
--> 876     x = concrete_or_error(None, x)
877     return np.array(x)

File /opt/anaconda3/envs/model/lib/python3.12/site-packages/jax/_src/core.py:1603, in concrete_or_error(force, val, context)
1602 if maybe_concrete is None:
-> 1603   raise ConcretizationTypeError(val, context)
1604 else:

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape float32[]

This BatchTracer with object id 14633270640 was created on line:
    /var/folders/lh/byq6_6bx3_n7tg89dvcqsc040000gn/T/ipykernel_32164/2605667695.py:24:12 (QuantumCircuit.__call__)

See https://docs.jax.dev/en/latest/errors.html#jax.errors.ConcretizationTypeError

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
Cell In[34], line 36
34 zero_x = jnp.empty((8, 4))
35 key = jrand.PRNGKey(42)
---> 36 params = qc.init(key, zero_x)

[... skipping hidden 9 frame]

Cell In[34], line 24, in QuantumCircuit.__call__(self, x)
22 @nn.compact
23 def __call__(self, x):
---> 24     x = self.circuit(x)
25     return x

[... skipping hidden 7 frame]

File /opt/anaconda3/envs/model/lib/python3.12/site-packages/pennylane/workflow/qnode.py:922, in QNode.__call__(self, *args, **kwargs)
919     from ._capture_qnode import capture_qnode  # pylint: disable=import-outside-toplevel
921     return capture_qnode(self, *args, **kwargs)
--> 922 return self._impl_call(*args, **kwargs)

File /opt/anaconda3/envs/model/lib/python3.12/site-packages/pennylane/workflow/qnode.py:895, in QNode._impl_call(self, *args, **kwargs)
892 # Calculate the classical jacobians if necessary
893 self._transform_program.set_classical_component(self, args, kwargs)
--> 895 res = execute(
    896     (tape,),
    897     device=self.device,
898     diff_method=self.diff_method,
899     interface=self.interface,
900     transform_program=self._transform_program,
901     gradient_kwargs=self.gradient_kwargs,
902     **self.execute_kwargs,
903 )
904 res = res[0]
906 # convert result to the interface in case the qfunc has no parameters

File /opt/anaconda3/envs/model/lib/python3.12/site-packages/pennylane/workflow/execution.py:233, in execute(tapes, device, diff_method, interface, grad_on_execution, cache, cachesize, max_diff, device_vjp, postselect_mode, mcm_method, gradient_kwargs, transform_program, executor_backend)
229 tapes, outer_post_processing = outer_transform(tapes)
231 assert not outer_transform.is_informative, "should only contain device preprocessing"
--> 233 results = run(tapes, device, config, inner_transform)
234 return user_post_processing(outer_post_processing(results))

File /opt/anaconda3/envs/model/lib/python3.12/site-packages/pennylane/workflow/run.py:338, in run(tapes, device, config, inner_transform_program)
335         params = tape.get_parameters(trainable_only=False)
336         tape.trainable_params = qml.math.get_trainable_indices(params)
--> 338 results = ml_execute(tapes, execute_fn, jpc, device=device)
339 return results

File /opt/anaconda3/envs/model/lib/python3.12/site-packages/pennylane/workflow/interfaces/jax.py:230, in jax_jvp_execute(tapes, execute_fn, jpc, device)
226     logger.debug("Entry with (tapes=%s, execute_fn=%s, jpc=%s)", tapes, execute_fn, jpc)
228 parameters = tuple(tuple(t.get_parameters()) for t in tapes)
--> 230 return _execute_jvp(parameters, _NonPytreeWrapper(tuple(tapes)), execute_fn, jpc)

[... skipping hidden 13 frame]

File /opt/anaconda3/envs/model/lib/python3.12/site-packages/pennylane/workflow/interfaces/jax.py:195, in _execute_wrapper(params, tapes, execute_fn, jpc)
193 """Executes ``tapes`` with ``params`` via ``execute_fn``"""
194 new_tapes = set_parameters_on_copy_and_unwrap(tapes.vals, params, unwrap=False)
--> 195 return _to_jax(execute_fn(new_tapes))

File /opt/anaconda3/envs/model/lib/python3.12/site-packages/pennylane/workflow/jacobian_products.py:487, in DeviceDerivatives.execute_and_cache_jacobian(self, tapes)
485 if logger.isEnabledFor(logging.DEBUG):  # pragma: no cover
    486     logger.debug("Forward pass called with %s", tapes)
--> 487 results, jac = self._dev_execute_and_compute_derivatives(tapes)
488 self._results_cache[tapes] = results
489 self._jacs_cache[tapes] = jac

File /opt/anaconda3/envs/model/lib/python3.12/site-packages/pennylane/workflow/jacobian_products.py:451, in DeviceDerivatives._dev_execute_and_compute_derivatives(self, tapes)
445 def _dev_execute_and_compute_derivatives(self, tapes: QuantumScriptBatch):
    446     """
    447     Converts tapes to numpy before computing the the results and derivatives on the device.
    448
    449     Dispatches between the two different device interfaces.
    450     """
--> 451     numpy_tapes, _ = qml.transforms.convert_to_numpy_parameters(tapes)
452     return self._device.execute_and_compute_derivatives(numpy_tapes, self._execution_config)

File /opt/anaconda3/envs/model/lib/python3.12/site-packages/pennylane/transforms/core/transform_dispatcher.py:281, in TransformDispatcher.__call__(self, *targs, **tkwargs)
278     return self._qfunc_transform(obj, targs, tkwargs)
280 if isinstance(obj, Sequence) and all(isinstance(q, qml.tape.QuantumScript) for q in obj):
--> 281     return self._batch_transform(obj, targs, tkwargs)
283 # Input is not a QNode nor a quantum tape nor a device.
284 # Assume Python decorator syntax:
285 #
286 # result = some_transform(*transform_args)(qnode)(*qnode_args)
288 raise TransformError(
    289     "Decorating a QNode with @transform_fn(**transform_kwargs) has been "
290     "removed. Please decorate with @functools.partial(transform_fn, **transform_kwargs) "
(...)    293     "https://docs.pennylane.ai/en/stable/development/deprecations.html#completed-deprecation-cycles",
294 )

File /opt/anaconda3/envs/model/lib/python3.12/site-packages/pennylane/transforms/core/transform_dispatcher.py:489, in TransformDispatcher._batch_transform(self, original_batch, targs, tkwargs)
483 tape_counts = []
485 for t in original_batch:
    486     # Preprocess the tapes by applying transforms
487     # to each tape, and storing corresponding tapes
488     # for execution, processing functions, and list of tape lengths.
--> 489     new_tapes, fn = self(t, *targs, **tkwargs)
490     execution_tapes.extend(new_tapes)
491     batch_fns.append(fn)

File /opt/anaconda3/envs/model/lib/python3.12/site-packages/pennylane/transforms/core/transform_dispatcher.py:253, in TransformDispatcher.__call__(self, *targs, **tkwargs)
250         return expand_processing(processed_results)
252 else:
--> 253     transformed_tapes, processing_fn = self._transform(obj, *targs, **tkwargs)
255 if self.is_informative:
    256     return processing_fn(transformed_tapes)

File /opt/anaconda3/envs/model/lib/python3.12/site-packages/pennylane/transforms/convert_to_numpy_parameters.py:84, in convert_to_numpy_parameters(tape)
82 new_ops = (_convert_op_to_numpy_data(op) for op in tape.operations)
83 new_measurements = (_convert_measurement_to_numpy_data(m) for m in tape.measurements)
---> 84 new_circuit = tape.__class__(
    85     new_ops, new_measurements, shots=tape.shots, trainable_params=tape.trainable_params
86 )
88 def null_postprocessing(results):
    89     """A postprocesing function returned by a transform that only converts the batch of results
     90     into a result for a single ``QuantumTape``.
     91     """

File /opt/anaconda3/envs/model/lib/python3.12/site-packages/pennylane/tape/qscript.py:194, in QuantumScript.__init__(self, ops, measurements, shots, trainable_params)
187 def __init__(
        188     self,
        189     ops: Optional[Iterable[Operator]] = None,
        (...)    192     trainable_params: Optional[Sequence[int]] = None,
193 ):
--> 194     self._ops = [] if ops is None else list(ops)
195     self._measurements = [] if measurements is None else list(measurements)
196     self._shots = Shots(shots)

File /opt/anaconda3/envs/model/lib/python3.12/site-packages/pennylane/transforms/convert_to_numpy_parameters.py:82, in <genexpr>(.0)
47 @transform
48 def convert_to_numpy_parameters(tape: QuantumScript) -> tuple[QuantumScriptBatch, PostprocessingFn]:
    49     """Transforms a circuit to one with purely numpy parameters.
     50
     51     Args:
   (...)     80
     81     """
---> 82     new_ops = (_convert_op_to_numpy_data(op) for op in tape.operations)
83     new_measurements = (_convert_measurement_to_numpy_data(m) for m in tape.measurements)
84     new_circuit = tape.__class__(
    85         new_ops, new_measurements, shots=tape.shots, trainable_params=tape.trainable_params
86     )

File /opt/anaconda3/envs/model/lib/python3.12/site-packages/pennylane/transforms/convert_to_numpy_parameters.py:30, in _convert_op_to_numpy_data(op)
28     return op
29 # Use operator method to change parameters when it become available
---> 30 return qml.ops.functions.bind_new_parameters(op, math.unwrap(op.data))

File /opt/anaconda3/envs/model/lib/python3.12/site-packages/pennylane/math/multi_dispatch.py:814, in unwrap(values, max_depth)
811     return new_val.tolist() if isinstance(new_val, ndarray) and not new_val.shape else new_val
813 if isinstance(values, (tuple, list)):
--> 814     return type(values)(convert(val) for val in values)
815 return (
    816     np.to_numpy(values, max_depth=max_depth)
817     if isinstance(values, ArrayBox)
818     else np.to_numpy(values)
819 )

File /opt/anaconda3/envs/model/lib/python3.12/site-packages/pennylane/math/multi_dispatch.py:814, in <genexpr>(.0)
811     return new_val.tolist() if isinstance(new_val, ndarray) and not new_val.shape else new_val
813 if isinstance(values, (tuple, list)):
--> 814     return type(values)(convert(val) for val in values)
815 return (
    816     np.to_numpy(values, max_depth=max_depth)
817     if isinstance(values, ArrayBox)
818     else np.to_numpy(values)
819 )

File /opt/anaconda3/envs/model/lib/python3.12/site-packages/pennylane/math/multi_dispatch.py:809, in unwrap.<locals>.convert(val)
806 if isinstance(val, (tuple, list)):
    807     return unwrap(val)
808 new_val = (
    --> 809     np.to_numpy(val, max_depth=max_depth) if isinstance(val, ArrayBox) else np.to_numpy(val)
810 )
811 return new_val.tolist() if isinstance(new_val, ndarray) and not new_val.shape else new_val

File /opt/anaconda3/envs/model/lib/python3.12/site-packages/autoray/autoray.py:81, in do(fn, like, *args, **kwargs)
79 backend = _choose_backend(fn, args, kwargs, like=like)
80 func = get_lib_fn(backend, fn)
---> 81 return func(*args, **kwargs)

File /opt/anaconda3/envs/model/lib/python3.12/site-packages/pennylane/math/single_dispatch.py:879, in _to_numpy_jax(x)
877     return np.array(x)
878 except (ConcretizationTypeError, TracerArrayConversionError) as e:
--> 879     raise ValueError(
    880         "Converting a JAX array to a NumPy array not supported when using the JAX JIT."
881     ) from e

ValueError: Converting a JAX array to a NumPy array not supported when using the JAX JIT.

Hi!
I was able to run your code without errors by adding the @qml.qjit decorator on top of the qml.node.

def make_circuit(dev, n_qubits):
    @qml.qjit
    @qml.qnode(dev)
    def circuit(x):
        for i in range(n_qubits):
            qml.RX(x[i], wires=i)
        return qml.expval(qml.PauliZ(wires=0))
    return jax.vmap(circuit, in_axes=0)

You can see this post where a similar issue was discussed.

thank you for the answer!