DefaultQubitJax class cannot be build

Hello people,

I want to use the jax interface to speed up my code with jax.jit and jax.vmap. However I am not even able to make a circuit instance with the jax interface. Following the jax tutorial here

from jax.config import config
config.update("jax_enable_x64", True)

import jax
import jax.numpy as jnp
import pennylane as qml

@qml.qnode(dev, interface="jax")
def circuit(param):
    qml.RX(param, wires=0)
    qml.CNOT(wires=[0, 1])
    return qml.expval(qml.PauliZ(0))

I get the error message:

TypeError: Can’t instantiate abstract class DefaultQubitJax with abstract methods _abs, _conj, _dot

I have Installed jax, jaxlib and pennylane and also made sure that they are on the newest version.

Any help would be greatly appriciated.

Best regards!

Hi @huhnja, welcome to the forum!

I think you forgot to create your device.

You can use something like:
dev = qml.device("default.qubit.jax", wires=2)

This should make your code run smoothly :sunglasses:

Let me know if this fixes your issue!

Hi @CatalinaAlbornoz,

thank you for your quick reply. I tried your sugesstion, but I get the same error. I am running this code on a cloud backend with a jupyter notebook as frontend. Could this explain the error I am getting?

Hi @huhnja ! There is an issue with Jax 0.4.4, you should downgrade to 0.4.3. The issue will be solved with 0.4.5.

Let me know if there is anything else.

Thank you @Romain_Moyard,

This solves the problem! And for other people reading this, you also need to downgrade jaxlib to version 0.4.3 :slight_smile:

1 Like