Hello dear Pennylane team
I tried to follow the tutorial about JAX and Pennylane (Using JAX with PennyLane | PennyLane Demos) because I heard that JAX speeds up the optimization of a VQC a lot and I want to look into it.
But there seems to be a problem with getting the value of the gradient at a certain value.
I also tried googling the problem and just found to scale down the jax version (0.4.26), which helped, but isn’t a clean solution :/.
Could you help me there?
I have constructed a small example where I think it should work, but it doesn’t.
import jax
import jax.numpy as jnp
import pennylane as qml
dev = qml.device('default.qubit', wires=1)
@qml.qnode(dev, interface='jax')
def circuit(x):
qml.Hadamard(wires=0)
qml.RX(x, wires=0)
return qml.expval(qml.PauliZ(0))
x = jnp.array(0.5)
grad_circuit = jax.grad(circuit)
# the following line produces the error
grad_value = grad_circuit(x)
print(grad_value)
The full error message is as following:
Traceback (most recent call last):
File "C:\Users\berg_j7\Documents\Codebooks\3_IntroductionVQAs\minimalExample.py", line 17, in <module>
grad_value = grad_circuit(x)
^^^^^^^^^^^^^^^
File "C:\Users\berg_j7\Documents\Codebooks\3_IntroductionVQAs\.venv\Lib\site-packages\pennylane\workflow\qnode.py", line 905, in __call__
return self._impl_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\berg_j7\Documents\Codebooks\3_IntroductionVQAs\.venv\Lib\site-packages\pennylane\workflow\qnode.py", line 881, in _impl_call
res = qml.execute(
^^^^^^^^^^^^
File "C:\Users\berg_j7\Documents\Codebooks\3_IntroductionVQAs\.venv\Lib\site-packages\pennylane\workflow\execution.py", line 195, in execute
interface = _resolve_interface(interface, tapes)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\berg_j7\Documents\Codebooks\3_IntroductionVQAs\.venv\Lib\site-packages\pennylane\workflow\resolution.py", line 124, in _resolve_interface
interface = _get_jax_interface_name(tapes)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\berg_j7\Documents\Codebooks\3_IntroductionVQAs\.venv\Lib\site-packages\pennylane\workflow\resolution.py", line 68, in _get_jax_interface_name
if any(qml.math.is_abstract(param) for param in op.data):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\berg_j7\Documents\Codebooks\3_IntroductionVQAs\.venv\Lib\site-packages\pennylane\workflow\resolution.py", line 68, in <genexpr>
if any(qml.math.is_abstract(param) for param in op.data):
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\berg_j7\Documents\Codebooks\3_IntroductionVQAs\.venv\Lib\site-packages\pennylane\math\utils.py", line 282, in is_abstract
return not isinstance(tensor.aval, jax.core.ConcreteArray)
^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\berg_j7\Documents\Codebooks\3_IntroductionVQAs\.venv\Lib\site-packages\jax\_src\deprecations.py", line 57, in getattr
raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax.core' has no attribute 'ConcreteArray'
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The qml.about() prints:
Python version: 3.12.2
Numpy version: 2.0.2
Scipy version: 1.15.1
Installed devices:
- default.clifford (PennyLane-0.40.0)
- default.gaussian (PennyLane-0.40.0)
- default.mixed (PennyLane-0.40.0)
- default.qubit (PennyLane-0.40.0)
- default.qutrit (PennyLane-0.40.0)
- default.qutrit.mixed (PennyLane-0.40.0)
- default.tensor (PennyLane-0.40.0)
- null.qubit (PennyLane-0.40.0)
- reference.qubit (PennyLane-0.40.0)
- lightning.qubit (PennyLane_Lightning-0.40.0)
Thanks in advance for any help <3