AttributeError: module 'jax.core' has no attribute 'ConcreteArray' - Problems using jax.grad(circuit)

Hello dear Pennylane team :slight_smile:
I tried to follow the tutorial about JAX and Pennylane (Using JAX with PennyLane | PennyLane Demos) because I heard that JAX speeds up the optimization of a VQC a lot and I want to look into it.

But there seems to be a problem with getting the value of the gradient at a certain value.
I also tried googling the problem and just found to scale down the jax version (0.4.26), which helped, but isn’t a clean solution :/.
Could you help me there?
I have constructed a small example where I think it should work, but it doesn’t.

import jax
import jax.numpy as jnp
import pennylane as qml

dev = qml.device('default.qubit', wires=1)

@qml.qnode(dev, interface='jax')
def circuit(x):
    qml.Hadamard(wires=0)
    qml.RX(x, wires=0)
    return qml.expval(qml.PauliZ(0))

x = jnp.array(0.5)
grad_circuit = jax.grad(circuit)

# the following line produces the error
grad_value = grad_circuit(x)

print(grad_value)

The full error message is as following:

Traceback (most recent call last):
  File "C:\Users\berg_j7\Documents\Codebooks\3_IntroductionVQAs\minimalExample.py", line 17, in <module>
    grad_value = grad_circuit(x)
                 ^^^^^^^^^^^^^^^
  File "C:\Users\berg_j7\Documents\Codebooks\3_IntroductionVQAs\.venv\Lib\site-packages\pennylane\workflow\qnode.py", line 905, in __call__
    return self._impl_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\berg_j7\Documents\Codebooks\3_IntroductionVQAs\.venv\Lib\site-packages\pennylane\workflow\qnode.py", line 881, in _impl_call
    res = qml.execute(
          ^^^^^^^^^^^^
  File "C:\Users\berg_j7\Documents\Codebooks\3_IntroductionVQAs\.venv\Lib\site-packages\pennylane\workflow\execution.py", line 195, in execute
    interface = _resolve_interface(interface, tapes)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\berg_j7\Documents\Codebooks\3_IntroductionVQAs\.venv\Lib\site-packages\pennylane\workflow\resolution.py", line 124, in _resolve_interface
    interface = _get_jax_interface_name(tapes)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\berg_j7\Documents\Codebooks\3_IntroductionVQAs\.venv\Lib\site-packages\pennylane\workflow\resolution.py", line 68, in _get_jax_interface_name
    if any(qml.math.is_abstract(param) for param in op.data):
       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\berg_j7\Documents\Codebooks\3_IntroductionVQAs\.venv\Lib\site-packages\pennylane\workflow\resolution.py", line 68, in <genexpr>
    if any(qml.math.is_abstract(param) for param in op.data):
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\berg_j7\Documents\Codebooks\3_IntroductionVQAs\.venv\Lib\site-packages\pennylane\math\utils.py", line 282, in is_abstract
    return not isinstance(tensor.aval, jax.core.ConcreteArray)
                                       ^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\berg_j7\Documents\Codebooks\3_IntroductionVQAs\.venv\Lib\site-packages\jax\_src\deprecations.py", line 57, in getattr
    raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax.core' has no attribute 'ConcreteArray'
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The qml.about() prints:

Python version: 3.12.2
Numpy version: 2.0.2
Scipy version: 1.15.1
Installed devices:

  • default.clifford (PennyLane-0.40.0)
  • default.gaussian (PennyLane-0.40.0)
  • default.mixed (PennyLane-0.40.0)
  • default.qubit (PennyLane-0.40.0)
  • default.qutrit (PennyLane-0.40.0)
  • default.qutrit.mixed (PennyLane-0.40.0)
  • default.tensor (PennyLane-0.40.0)
  • null.qubit (PennyLane-0.40.0)
  • reference.qubit (PennyLane-0.40.0)
  • lightning.qubit (PennyLane_Lightning-0.40.0)

Thanks in advance for any help <3

Hi @johananasBeere ,

I just ran your code example with JAX v0.4.33 and it ran with no issues.
The demo also runs well with the same version of JAX.

Are you able to run the demo with JAX v0.4.33?

Hej @CatalinaAlbornoz,

yes, just not with the newest JAX version, which is the version 0.5.1. (the version 0.5.0 also does not work)
Is that an issue just from JAX which should be reported there?

Greetings Johanna :slight_smile:

Oh I see.

I don’t think there’s an issue with JAX, it’s just that we don’t support version 0.5.1 yet. We do have plans for adding support for newer versions of JAX though. This will take a few months so in the meantime I’d recommend using a lower version of JAX. For the demo and code you shared v0.4.33 should work. Let me know if this works for you Johanna!

1 Like

Yes it works :sunny: Thankss

1 Like

AttributeError: module ‘jax.core’ has no attribute ‘ConcreteArray’

I ran the code successfully abiut a week ten days back on colab. Now that JAX has been upgraded, its not working.
I know i have to downgrade JAX to 4.33. Can you let me know where do i give the downgrade commands as its perpetually going into reinitialisation of page

Error being faced is
Never expected Jax to disable stuff like ConcreteArray


AttributeError Traceback (most recent call last)
in <cell line: 0>()
22 # run training for multiple sizes
23 train_sizes = [2, 5, 10, 20, 40, 80, 160, 320, 500]
—> 24 results_df = run_iterations(n_train=2)
25 for n_train in train_sizes[1:]:
26 results_df = pd.concat([results_df, run_iterations(n_train=n_train)])

15 frames
[… skipping hidden 29 frame]

[... skipping hidden 13 frame]

[... skipping hidden 7 frame]

/usr/local/lib/python3.11/dist-packages/jax/_src/deprecations.py in getattr(name)
55 warnings.warn(message, DeprecationWarning, stacklevel=2)
56 return fn
—> 57 raise AttributeError(f"module {module!r} has no attribute {name!r}")
58
59 return getattr

AttributeError: module ‘jax.core’ has no attribute ‘ConcreteArray’

Hi @Kiran_Kumar_M , welcome to the Forum!

For now I’d recommend installing JAX 0.4.33 on Google Colab. For example the line below installs PennyLane and JAX v0.4.33.

!pip install pennylane jax==0.4.33

If you already have a session running you may need to restart it after running the install (in the “Runtime” menu).

Let us know if this works for you!