JAX and PennyLane optimizers

Hello! I’m trying to extend the code shown in Using JAX with PennyLane tutorial by substituting the gradient flow rule, which works perfectly as shown in the tutorial, with the optimization step of some optimizer included within PennyLane.

However, the code:

import jax
import pennylane as qml
from pennylane.optimize import AdamOptimizer

dev = qml.device("default.qubit", wires=2)

@jax.jit
@qml.qnode(dev, interface="jax")
def circuit(param):
    qml.RX(param, wires=0)
    qml.CNOT(wires=[0, 1])
    return qml.expval(qml.PauliZ(0))

grad_circuit = jax.grad(circuit)
optimizer = AdamOptimizer()
params = 0.123

for i in range(4):
    params = optimizer.step(circuit, params, grad_fn=grad_circuit)

throws the following exception:

  File "/home/theuser/.local/lib/python3.9/site-packages/pennylane/optimize/gradient_descent.py", line 130, in step
    new_args = self.apply_grad(g, args)
  File "/home/theuser/.local/lib/python3.9/site-packages/pennylane/optimize/adam.py", line 93, in apply_grad
    grad_flat = list(_flatten(grad[trained_index]))
  File "/home/theuser/.local/lib/python3.9/site-packages/pennylane/utils.py", line 198, in _flatten
    yield from _flatten(item)
  File "/home/theuser/.local/lib/python3.9/site-packages/pennylane/utils.py", line 197, in _flatten
    for item in x:
  File "/home/theuser/.local/lib/python3.9/site-packages/jax/_src/device_array.py", line 249, in __iter__
    raise TypeError("iteration over a 0-d array")  # same as numpy error
TypeError: iteration over a 0-d array

The absence of ‘@jax.jit’ does not resolve the issue. Can you kindly tell me if I’m missing something?

OS version: 20.04 LTS (Focal Fossa)
Python version: 3.9.11
jax: 0.3.4
jaxlib: 0.3.2
PennyLane: 0.20.0
PennyLane-Lightning: 0.22.0

Hi @incud, welcome to the forum!
Do you get the same issue if you upgrade to PennyLane v 0.22.1? Or is there any reason why you need to use a previous version?

Hi @CatalinaAlbornoz, thank you very much for your quick response!
Upgrading Pennylane is a great suggestion. Unfortunately, after the upgrade (Pennylane v0.22.1) the error still shows up at the same place.

Hi @incud! Good catch, this seems to be because the PennyLane optimizers assume that they will only be used with Autograd (not JAX), and so fail to properly manipulate JAX device arrays.

If you apply the following diff, this should fix this issue:

diff --git a/pennylane/optimize/gradient_descent.py b/pennylane/optimize/gradient_descent.py
index f21d9fce6..91d79933b 100644
--- a/pennylane/optimize/gradient_descent.py
+++ b/pennylane/optimize/gradient_descent.py
@@ -185,8 +185,8 @@ class GradientDescentOptimizer:
         trained_index = 0
         for index, arg in enumerate(args):
             if getattr(arg, "requires_grad", True):
-                x_flat = _flatten(arg)
-                grad_flat = _flatten(grad[trained_index])
+                x_flat = list(_flatten(arg))
+                grad_flat = list(_flatten(grad[trained_index]))
                 trained_index += 1
 
                 x_new_flat = [e - self.stepsize * g for g, e in zip(grad_flat, x_flat)]
diff --git a/pennylane/utils.py b/pennylane/utils.py
index 2eba075d5..6a089f5b7 100644
--- a/pennylane/utils.py
+++ b/pennylane/utils.py
@@ -196,6 +196,8 @@ def _flatten(x):
         # Since Wires are always flat, just yield.
         for item in x:
             yield item
+    elif getattr(x, "ndim", None) == 0:
+        yield x
     elif isinstance(x, Iterable) and not isinstance(x, (str, bytes)):
         for item in x:
             yield from _flatten(item)
@@ -222,9 +224,10 @@ def _unflatten(flat, model):
     if isinstance(model, (numbers.Number, str)):
         return flat[0], flat[1:]
 
-    if isinstance(model, np.ndarray):
+    if hasattr(model, "ndim"):
         idx = model.size
-        res = np.array(flat)[:idx].reshape(model.shape)
+        res = qml.math.convert_like(flat, model)[:idx]
+        res = qml.math.reshape(flat, qml.math.shape(model))
         return res, flat[idx:]
 
     if isinstance(model, Iterable)

However, rather than using the PennyLane optimizers, I recommend using optimizers designed for JAX, such as optax:

import jax
from jax import numpy as jnp
import pennylane as qml
import optax

dev = qml.device("default.qubit", wires=2)

@jax.jit
@qml.qnode(dev, interface="jax")
def circuit(param):
    qml.RX(param, wires=0)
    qml.CNOT(wires=[0, 1])
    return qml.expval(qml.PauliZ(0))

optimizer = optax.adam(learning_rate=0.1)
params = jnp.array(0.123)
opt_state = optimizer.init(params)

for i in range(20):
    cost, grad_circuit = jax.value_and_grad(circuit)(params)
    updates, opt_state = optimizer.update(grad_circuit, opt_state)
    params = optax.apply_updates(params, updates)
    print(f"step {i}, cost {cost}")

@josh Thank you very much :slight_smile:
Both methods are working perfectly!

@josh, will the implementation of the optimizers be changed in an upcoming release? Ran into the same issue, good to find a fix, but not familiar with applying diffs

Hey @somearthling! At this stage there are no plans I can update you on upgrading the optimizers to support JAX. Unfortunately this was always somewhat of a side-effect, and the optimizers were not written with JAX in mind. As a result, more work would be required beyond the small hotfix I posted above.

I highly recommend using an optimizer library designed for JAX, such as optax above :slightly_smiling_face:

Were there any specific optimizers in PennyLane you were looking to make JAX compatible?

Hey Josh, yep, I went ahead and used optax, works perfectly. Thanks for clarifying.