Greetings PennyLane team, I’m trying to compute the gradient of a circuit using the Jax interface but my kernel dies. Could you help me why this happens?

This is an example that produces the issue:

```
import jax
from jax import numpy as jnp
import pennylane as qml
import numpy as np
dev = qml.device("default.qubit", wires=2)
amplitude_0 = 1
amplitude_1 = 2
amplitude_2 = 3
amplitude_3 = 4
initial_state = jnp.array([amplitude_0, amplitude_1, amplitude_2, amplitude_3])
norm = jnp.linalg.norm(initial_state)
initial_state = initial_state / norm
@qml.qnode(dev, diff_method="finite-diff", interface="jax")
def circuit(phi):
qml.QubitStateVector(initial_state, wires=[0, 1])
qml.IsingZZ(phi, wires=[0, 1])
return qml.expval(qml.PauliZ(0))
phi = jnp.array(0.1)
res = jax.grad(circuit, argnums=0)(phi)
```

Thank you!