Problems while using Pennylane with Jax

I’ve been trying to use Jax to speed up my code running locally. I have a functioncircuit that returns the probabilities of each measurement across ~10 qubits, and then this line:

    probs = jax.vmap(circuit, in_axes=(0, None))(x, p)

But I get the following error:

    C:\Users\samyk\Documents\IISc\Thesis\Sims 2\venv\lib\site-packages\jax\_src\numpy\lax_numpy.py:5154: UserWarning: Explicitly requested dtype <class 'jax._src.numpy.lax_numpy.complex128'> requested in astype is not available, and will be truncated to dtype complex64. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  lax_internal._check_user_dtype_supported(dtype, "astype")
Traceback (most recent call last):
  File "Sims 2/main.py", line 146, in <module>
    a = gradient(x, y, params)
  File "Sims 2/main.py", line 132, in gradient
    probs = jax.vmap(circuit, in_axes=(0, None))(x, p)
  File "Sims 2\venv\lib\site-packages\jax\_src\traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "Sims 2\venv\lib\site-packages\jax\_src\api.py", line 1438, in vmap_f
    out_flat = batching.batch(
  File "Sims 2\venv\lib\site-packages\jax\linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "Sims 2\venv\lib\site-packages\pennylane\qnode.py", line 578, in __call__
    res = qml.execute(
  File "Sims 2\venv\lib\site-packages\pennylane\interfaces\batch\__init__.py", line 342, in execute
    cache_execute(batch_execute, cache, return_tuple=False, expand_fn=expand_fn)(tapes)
  File "Sims 2\venv\lib\site-packages\pennylane\interfaces\batch\__init__.py", line 173, in wrapper
    res = fn(execution_tapes.values(), **kwargs)
  File "Sims 2\venv\lib\site-packages\pennylane\interfaces\batch\__init__.py", line 125, in fn
    return original_fn(tapes, **kwargs)
  File "AppData\Local\Programs\Python\Python39\lib\contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "Sims 2\venv\lib\site-packages\pennylane\_qubit_device.py", line 289, in batch_execute
    res = self.execute(circuit)
  File "Sims 2\venv\lib\site-packages\pennylane\_qubit_device.py", line 201, in execute
    self.apply(circuit.operations, rotations=circuit.diagonalizing_gates, **kwargs)
  File "Sims 2\venv\lib\site-packages\pennylane\devices\default_qubit.py", line 215, in apply
    self._apply_state_vector(operation.parameters[0], operation.wires)
  File "Sims 2\venv\lib\site-packages\pennylane\devices\default_qubit.py", line 639, in _apply_state_vector
    state = self._asarray(state, dtype=self.C_DTYPE)
  File "Sims 2\venv\lib\site-packages\pennylane\devices\default_qubit_autograd.py", line 102, in _asarray
    res = np.asarray(array, dtype=dtype)
  File "Sims 2\venv\lib\site-packages\pennylane\numpy\wrapper.py", line 117, in _wrapped
    res = obj(*args, **kwargs)
  File "Sims 2\venv\lib\site-packages\autograd\tracer.py", line 48, in f_wrapped
    return f_raw(*args, **kwargs)
  File "Sims 2\venv\lib\site-packages\pennylane\numpy\tensor.py", line 36, in asarray
    return _np.array(vals, *args, **kwargs)
  File "Sims 2\venv\lib\site-packages\autograd\numpy\numpy_wrapper.py", line 60, in array
    return _array_from_scalar_or_array(args, kwargs, A)
  File "Sims 2\venv\lib\site-packages\autograd\tracer.py", line 48, in f_wrapped
    return f_raw(*args, **kwargs)
  File "Sims 2\venv\lib\site-packages\autograd\numpy\numpy_wrapper.py", line 73, in _array_from_scalar_or_array
    return _np.array(scalar, *array_args, **array_kwargs)
  File "Sims 2\venv\lib\site-packages\jax\core.py", line 470, in __array__
    raise TracerArrayConversionError(self)
jax._src.traceback_util.UnfilteredStackTrace: jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(complex64[1024])>with<BatchTrace(level=1/0)> with
  val = DeviceArray([[0.+0.j, 0.+0.j, 0.+0.j, ..., 0.+0.j, 0.+0.j, 0.+0.j],
             [0.+0.j, 0.+0.j, 0.+0.j, ..., 0.+0.j, 0.+0.j, 0.+0.j],
             [0.+0.j, 0.+0.j, 0.+0.j, ..., 0.+0.j, 0.+0.j, 0.+0.j],
             ...,
             [0.+0.j, 0.+0.j, 0.+0.j, ..., 0.+0.j, 0.+0.j, 0.+0.j],
             [0.+0.j, 0.+0.j, 0.+0.j, ..., 0.+0.j, 0.+0.j, 0.+0.j],
             [0.+0.j, 0.+0.j, 0.+0.j, ..., 0.+0.j, 0.+0.j, 0.+0.j]],            dtype=complex64)
  batch_dim = 0
This Tracer was created on line Sims 2\venv\lib\site-packages\autoray\autoray.py:240 (astype)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
  File "Sims 2\venv\lib\site-packages\pennylane\qnode.py", line 578, in __call__
    res = qml.execute(
  File "Sims 2\venv\lib\site-packages\pennylane\interfaces\batch\__init__.py", line 342, in execute
    cache_execute(batch_execute, cache, return_tuple=False, expand_fn=expand_fn)(tapes)
  File "Sims 2\venv\lib\site-packages\pennylane\interfaces\batch\__init__.py", line 173, in wrapper
    res = fn(execution_tapes.values(), **kwargs)
  File "Sims 2\venv\lib\site-packages\pennylane\interfaces\batch\__init__.py", line 125, in fn
    return original_fn(tapes, **kwargs)
  File "AppData\Local\Programs\Python\Python39\lib\contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "Sims 2\venv\lib\site-packages\pennylane\_qubit_device.py", line 289, in batch_execute
    res = self.execute(circuit)
  File "Sims 2\venv\lib\site-packages\pennylane\_qubit_device.py", line 201, in execute
    self.apply(circuit.operations, rotations=circuit.diagonalizing_gates, **kwargs)
  File "Sims 2\venv\lib\site-packages\pennylane\devices\default_qubit.py", line 215, in apply
    self._apply_state_vector(operation.parameters[0], operation.wires)
  File "Sims 2\venv\lib\site-packages\pennylane\devices\default_qubit.py", line 639, in _apply_state_vector
    state = self._asarray(state, dtype=self.C_DTYPE)
  File "Sims 2\venv\lib\site-packages\pennylane\devices\default_qubit_autograd.py", line 102, in _asarray
    res = np.asarray(array, dtype=dtype)
  File "Sims 2\venv\lib\site-packages\pennylane\numpy\wrapper.py", line 117, in _wrapped
    res = obj(*args, **kwargs)
  File "Sims 2\venv\lib\site-packages\autograd\tracer.py", line 48, in f_wrapped
    return f_raw(*args, **kwargs)
  File "Sims 2\venv\lib\site-packages\pennylane\numpy\tensor.py", line 36, in asarray
    return _np.array(vals, *args, **kwargs)
  File "Sims 2\venv\lib\site-packages\autograd\numpy\numpy_wrapper.py", line 60, in array
    return _array_from_scalar_or_array(args, kwargs, A)
  File "Sims 2\venv\lib\site-packages\autograd\tracer.py", line 48, in f_wrapped
    return f_raw(*args, **kwargs)
  File "Sims 2\venv\lib\site-packages\autograd\numpy\numpy_wrapper.py", line 73, in _array_from_scalar_or_array
    return _np.array(scalar, *array_args, **array_kwargs)
jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(complex64[1024])>with<BatchTrace(level=1/0)> with
  val = DeviceArray([[0.+0.j, 0.+0.j, 0.+0.j, ..., 0.+0.j, 0.+0.j, 0.+0.j],
             [0.+0.j, 0.+0.j, 0.+0.j, ..., 0.+0.j, 0.+0.j, 0.+0.j],
             [0.+0.j, 0.+0.j, 0.+0.j, ..., 0.+0.j, 0.+0.j, 0.+0.j],
             ...,
             [0.+0.j, 0.+0.j, 0.+0.j, ..., 0.+0.j, 0.+0.j, 0.+0.j],
             [0.+0.j, 0.+0.j, 0.+0.j, ..., 0.+0.j, 0.+0.j, 0.+0.j],
             [0.+0.j, 0.+0.j, 0.+0.j, ..., 0.+0.j, 0.+0.j, 0.+0.j]],            dtype=complex64)
  batch_dim = 0
This Tracer was created on line Sims 2\venv\lib\site-packages\autoray\autoray.py:240 (astype)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

It looks like when Jax tries to run the default_qubit backend, that leads to an autograd function that uses numpy instead of jax.numpy. I tried replacing the version used in the autograd file but it doesn’t seem to work. Any ideas if I’m doing something wrong, or can modify something to go forward?

Hi @somearthling! would you be able to post a minimal version of your code, including your QNode and device? This would help with the debugging :slight_smile:

Hi, looks like it was a rather silly mistake - did not specify the interface as "jax" in the QNode wrapper

I’m glad you found the cause @somearthling!

Enjoy using PennyLane!