I want to use the jax interface to speed up my code with jax.jit and jax.vmap. However I am not even able to make a circuit instance with the jax interface. Following the jax tutorial here
from jax.config import config
config.update("jax_enable_x64", True)
import jax
import jax.numpy as jnp
import pennylane as qml
@qml.qnode(dev, interface="jax")
def circuit(param):
qml.RX(param, wires=0)
qml.CNOT(wires=[0, 1])
return qml.expval(qml.PauliZ(0))
I get the error message:
TypeError: Canβt instantiate abstract class DefaultQubitJax with abstract methods _abs, _conj, _dot
I have Installed jax, jaxlib and pennylane and also made sure that they are on the newest version.
thank you for your quick reply. I tried your sugesstion, but I get the same error. I am running this code on a cloud backend with a jupyter notebook as frontend. Could this explain the error I am getting?