Greetings PennyLane team, I’m trying to compute the gradient of a circuit using the Jax interface but my kernel dies. Could you help me why this happens?

This is an example that produces the issue:

``````import jax
from jax import numpy as jnp
import pennylane as qml
import numpy as np

dev = qml.device("default.qubit", wires=2)

amplitude_0 = 1
amplitude_1 = 2
amplitude_2 = 3
amplitude_3 = 4

initial_state = jnp.array([amplitude_0, amplitude_1, amplitude_2, amplitude_3])
norm = jnp.linalg.norm(initial_state)
initial_state = initial_state / norm

@qml.qnode(dev, diff_method="finite-diff", interface="jax")
def circuit(phi):
qml.QubitStateVector(initial_state, wires=[0, 1])
qml.IsingZZ(phi, wires=[0, 1])
return qml.expval(qml.PauliZ(0))

phi = jnp.array(0.1)

``````

Thank you!

Hi @georgy! I’m not sure why this is happening but I did manage to make it work by removing diff_method=“finite-diff” and removing (phi) in the last line. I will look deeper into why this is happening but for now you could work with the following code:

# ---------------

import jax
from jax import numpy as jnp
import pennylane as qml
import numpy as np

dev = qml.device(“default.qubit”, wires=2)

amplitude_0 = 1
amplitude_1 = 2
amplitude_2 = 3
amplitude_3 = 4

initial_state = jnp.array([amplitude_0, amplitude_1, amplitude_2, amplitude_3])
norm = jnp.linalg.norm(initial_state)
initial_state = initial_state / norm

@qml.qnode(dev, interface=“jax”)
def circuit(phi):
qml.QubitStateVector(initial_state, wires=[0, 1])
qml.IsingZZ(phi, wires=[0, 1])
return qml.expval(qml.PauliZ(0))

phi = jnp.array(0.1)

# -----------------

Please let me know if this helps and I’ll let you know when I figure out the cause for this problem!

Hello @CatalinaAlbornoz,

Oh thanks for taking a look! I tried your suggestion, but something weird is going on. It indeed executed fine and wanted to compute with the result, but then getting another error:

``````jax.numpy.sum(res)
``````
``````~/anaconda3/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py in _check_arraylike(fun_name, *args)
306                     if not _arraylike(arg))
307     msg = "{} requires ndarray or scalar arguments, got {} at position {}."
--> 308     raise TypeError(msg.format(fun_name, type(arg), pos))
309
310 def _check_no_float0s(fun_name, *args):

TypeError: sum requires ndarray or scalar arguments, got <class 'function'> at position 0.
``````

What’s wrong here?

Hi @georgy! You’re getting this error because res is in fact a function. It’s the function that gives the gradient of your circuit! You can check this by using print(res).
In order to get the result you expect you need to evaluate the res function on a point. It’s here that you include the phi.
If you try
print(f"res(phi): {res(phi):0.3f}")
you will get the result expected.

You can also use numpy to do operations on the result the way you did, but always remembering to evaluate the function first.
print(jax.numpy.sum(res(phi)))

Also, regarding the initial question of why the finite-difference method didn’t work, I will ask someone from the team to take a deeper look at this.

Hi @georgy,

Thanks for the question!

Unfortunately, when using the `jax` interface, we’re experiencing issues for the `parameter-shift` and `finite-diff` differentiation methods. After some investigation, it looks that the stopped kernel is due to the lack of ability to specify `jax` arrays as trainable. We’re still looking for a fitting solution.

In the meantime, the solution that @CatalinaAlbornoz suggested should work great. That will use the `backprop` differentiation method with the `default.qubit.jax` device.

Thank you for the answers! Oh, that makes sense. I’ll then be using backpropagation for the time being. Thanks again!