Hello! I’d like to use quantum natural gradient for my optimization problem with multiple starting points. I know it’s possible to use optimizers from JAX to deal with QNodes, but as far as I’ve understood there’s no NatGrad in JAX. So is there any way to combine qml.QNGOptimizer with jax.jit and jax.vmap interface?

I tried this

```
import pennylane as qml
import jax
jax.config.update("jax_enable_x64", True)
from pennylane import numpy as np
num_qubits = 2
learning_rate = 0.005
num_iterations = 10
dev = qml.device("default.qubit", wires=num_qubits)
@jax.jit
@qml.qnode(dev, interface='jax-jit')
def circuit(param_vector):
qml.RX(param_vector[0], wires=0)
qml.RX(param_vector[1], wires=1)
qml.CNOT([0, 1])
qml.RY(param_vector[2], wires=0)
qml.RY(param_vector[3], wires=1)
return qml.expval(qml.PauliZ(0))
opt = qml.QNGOptimizer(stepsize=learning_rate)
np.random.seed(10)
x0 = np.random.uniform(size=4)
theta = x0
for _ in range(num_iterations):
theta = opt.step(circuit, theta)
```

but got the error:

```
ValueError: The objective function must either be encoded as a single QNode or an ExpvalCost object for the natural gradient to be automatically computed. Otherwise, metric_tensor_fn must be explicitly provided to the optimizer.
```

After that I also tried to provide metric tensor to optimizer’s step function in the following way

```
theta = opt.step(circuit, theta, metric_tensor_fn=qml.metric_tensor(circuit, approx='diag'))
```

and received the error again

```
TransformError: Impossible to dispatch your transform on quantum function, because more than one tape is returned
```