PennyLane's QNGOtimizer with JAX interface

Hello! I’d like to use quantum natural gradient for my optimization problem with multiple starting points. I know it’s possible to use optimizers from JAX to deal with QNodes, but as far as I’ve understood there’s no NatGrad in JAX. So is there any way to combine qml.QNGOptimizer with jax.jit and jax.vmap interface?

I tried this

import pennylane as qml
import jax
jax.config.update("jax_enable_x64", True)
from pennylane import numpy as np

num_qubits = 2
learning_rate = 0.005
num_iterations = 10
dev = qml.device("default.qubit", wires=num_qubits)

@qml.qnode(dev, interface='jax-jit')
def circuit(param_vector):
    qml.RX(param_vector[0], wires=0)
    qml.RX(param_vector[1], wires=1)
    qml.CNOT([0, 1])
    qml.RY(param_vector[2], wires=0)
    qml.RY(param_vector[3], wires=1)
    return qml.expval(qml.PauliZ(0))

opt = qml.QNGOptimizer(stepsize=learning_rate)

x0 = np.random.uniform(size=4)
theta = x0

for _ in range(num_iterations):
    theta = opt.step(circuit, theta)

but got the error:

ValueError: The objective function must either be encoded as a single QNode or an ExpvalCost object for the natural gradient to be automatically computed. Otherwise, metric_tensor_fn must be explicitly provided to the optimizer.

After that I also tried to provide metric tensor to optimizer’s step function in the following way

theta = opt.step(circuit, theta, metric_tensor_fn=qml.metric_tensor(circuit, approx='diag'))

and received the error again

TransformError: Impossible to dispatch your transform on quantum function, because more than one tape is returned

Hi @TheMadTiX, welcome to the PennyLane forum!

Unfortunately if you want to use JAX you need to use JAX optimizers :smiling_face_with_tear:
You can use JAXopt or Optax as described here in the PennyLane docs.

I’m sorry we don’t have better news for you.