Greetings PennyLane team, I’m trying to compute the gradient of a circuit using the Jax interface but my kernel dies. Could you help me why this happens?
Hi @georgy! I’m not sure why this is happening but I did manage to make it work by removing diff_method=“finite-diff” and removing (phi) in the last line. I will look deeper into why this is happening but for now you could work with the following code:
---------------
import jax
from jax import numpy as jnp
import pennylane as qml
import numpy as np
Oh thanks for taking a look! I tried your suggestion, but something weird is going on. It indeed executed fine and wanted to compute with the result, but then getting another error:
jax.numpy.sum(res)
~/anaconda3/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py in _check_arraylike(fun_name, *args)
306 if not _arraylike(arg))
307 msg = "{} requires ndarray or scalar arguments, got {} at position {}."
--> 308 raise TypeError(msg.format(fun_name, type(arg), pos))
309
310 def _check_no_float0s(fun_name, *args):
TypeError: sum requires ndarray or scalar arguments, got <class 'function'> at position 0.
Hi @georgy! You’re getting this error because res is in fact a function. It’s the function that gives the gradient of your circuit! You can check this by using print(res).
In order to get the result you expect you need to evaluate the res function on a point. It’s here that you include the phi.
If you try print(f"res(phi): {res(phi):0.3f}")
you will get the result expected.
You can also use numpy to do operations on the result the way you did, but always remembering to evaluate the function first.
In your case your code would look like this: print(jax.numpy.sum(res(phi)))
Please let me know if this answered your question!
Also, regarding the initial question of why the finite-difference method didn’t work, I will ask someone from the team to take a deeper look at this.
Unfortunately, when using the jax interface, we’re experiencing issues for the parameter-shift and finite-diff differentiation methods. After some investigation, it looks that the stopped kernel is due to the lack of ability to specify jax arrays as trainable. We’re still looking for a fitting solution.
In the meantime, the solution that @CatalinaAlbornoz suggested should work great. That will use the backprop differentiation method with the default.qubit.jax device.