Hello! I’m trying to extend the code shown in Using JAX with PennyLane tutorial by substituting the gradient flow rule, which works perfectly as shown in the tutorial, with the optimization step of some optimizer included within PennyLane.
However, the code:
import jax
import pennylane as qml
from pennylane.optimize import AdamOptimizer
dev = qml.device("default.qubit", wires=2)
@jax.jit
@qml.qnode(dev, interface="jax")
def circuit(param):
qml.RX(param, wires=0)
qml.CNOT(wires=[0, 1])
return qml.expval(qml.PauliZ(0))
grad_circuit = jax.grad(circuit)
optimizer = AdamOptimizer()
params = 0.123
for i in range(4):
params = optimizer.step(circuit, params, grad_fn=grad_circuit)
throws the following exception:
File "/home/theuser/.local/lib/python3.9/site-packages/pennylane/optimize/gradient_descent.py", line 130, in step
new_args = self.apply_grad(g, args)
File "/home/theuser/.local/lib/python3.9/site-packages/pennylane/optimize/adam.py", line 93, in apply_grad
grad_flat = list(_flatten(grad[trained_index]))
File "/home/theuser/.local/lib/python3.9/site-packages/pennylane/utils.py", line 198, in _flatten
yield from _flatten(item)
File "/home/theuser/.local/lib/python3.9/site-packages/pennylane/utils.py", line 197, in _flatten
for item in x:
File "/home/theuser/.local/lib/python3.9/site-packages/jax/_src/device_array.py", line 249, in __iter__
raise TypeError("iteration over a 0-d array") # same as numpy error
TypeError: iteration over a 0-d array
The absence of ‘@jax.jit’ does not resolve the issue. Can you kindly tell me if I’m missing something?
OS version: 20.04 LTS (Focal Fossa)
Python version: 3.9.11
jax: 0.3.4
jaxlib: 0.3.2
PennyLane: 0.20.0
PennyLane-Lightning: 0.22.0