Default.qubit.jax not working with newer Jax version?

Dear Pennylane Team,

I’am currentlx trying to combine pennylane with Jax to speedup my computions.
I have Pennylane==0.33., Jax==0.4.34 installed.
When trying to initialize a device:

dev_qnpu_state = qml.device(“default.qubit.jax”, wires=n_wires_qnpu_state, shots=None)

I get the error:
ImportError: default.qubit.jax device requires installing jax>0.3.20
which I clearly have. Trying older versions of jax also didn’t resolve the issue for me. I tested that the correct version is loaded in the code with print(jax.version).
Do you know this issue, or could you reccomend a jax version that works for sure?

Thanks and best regards,
Pia

Hi @Pia !

To me it looks like an environment issue. I would recommend creating a new environment if possible.

On the other hand, is there a reason why you need to use PennyLane 0.33? I generally recommend to stay up to date with the latest version, which is v0.38 at the moment.

Below are the steps you can follow to create a new environment with the latest version of PennyLane. You will need to adapt it to install Jax too, in addition to the specific version of PennyLane in case you prefer not to use the latest one.

Let me know if this works for you. If it doesn’t please post the output of qml.about()
I hope this helps!


You can create a virtual environment with Conda and install PennyLane as follows:

  1. Install Miniconda following the instructions here.
  2. Open your terminal (mac) or command line (Windows).
  3. Create a new Conda environment with: conda create --name <name_of_your_environment> python=3.10
  4. Activate the environment with: conda activate <name_of_your_environment>
  5. Install PennyLane with: python -m pip install pennylane
  6. Install other useful packages with: python -m pip install jupyter matplotlib

Note that you will be installing 3 packages here: PennyLane, Jupyter, and Matplotlib. Also, note that where it says <name_of_your_environment> you can choose any name that you want.