Installing jax-cuda, pennylane-catalyst with Pytorch?

Hello,
I am trying to use a Pytorch model along with my quantum circuit accelerated by catalyst.

Initially, I ran the command:
pip install pennylane pennylane-lightning-gpu pennylane-catalyst pennylane-qiskit torch
This installed the latest Pennylane and Catalyst versions, along with torch 2.4.1

>>> qml.about()
Name: pennylane
Version: 0.42.1
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page:
Author:
Author-email:
License-Expression: Apache-2.0
Location: /home/ashutosh/Test-Code/python/testvenv/lib/python3.10/site-packages
Requires: appdirs, autograd, autoray, cachetools, diastatic-malt, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, tomlkit, typing_extensions
Required-by: PennyLane-qiskit, pennylane_catalyst, pennylane_lightning, pennylane_lightning_gpu

Platform info:           Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.35
Python version:          3.10.12
Numpy version:           2.2.6
Scipy version:           1.15.3
Installed devices:
- lightning.gpu (pennylane_lightning_gpu-0.42.0)
- qiskit.aer (PennyLane-qiskit-0.42.0)
- qiskit.basicaer (PennyLane-qiskit-0.42.0)
- qiskit.basicsim (PennyLane-qiskit-0.42.0)
- qiskit.remote (PennyLane-qiskit-0.42.0)
- nvidia.custatevec (pennylane_catalyst-0.12.0)
- nvidia.cutensornet (pennylane_catalyst-0.12.0)
- oqc.cloud (pennylane_catalyst-0.12.0)
- softwareq.qpp (pennylane_catalyst-0.12.0)
- lightning.qubit (pennylane_lightning-0.42.0)
- default.clifford (pennylane-0.42.1)
- default.gaussian (pennylane-0.42.1)
- default.mixed (pennylane-0.42.1)
- default.qubit (pennylane-0.42.1)
- default.qutrit (pennylane-0.42.1)
- default.qutrit.mixed (pennylane-0.42.1)
- default.tensor (pennylane-0.42.1)
- null.qubit (pennylane-0.42.1)
- reference.qubit (pennylane-0.42.1)
>>>
>>> torch.__version__
'2.4.1+cu121'

However, when trying to execute my JAX circuits I got the following warning:
WARNING:2025-07-31 20:58:21,824:jax._src.xla_bridge:909: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

Upon looking at other issues with the same warning, I realized that I was missing the JAX-cuda libraries.
I tried to install these libraries, but I see this error:

>$ pip install jax[cuda12]
...
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torch 2.4.1 requires nvidia-cudnn-cu12==9.1.0.70; platform_system == "Linux" and platform_machine == "x86_64", but you have nvidia-cudnn-cu12 9.11.0.98 which is incompatible.
Successfully installed jax-cuda12-pjrt-0.6.0 jax-cuda12-plugin-0.6.0 nvidia-cuda-nvcc-cu12-12.9.86 nvidia-cudnn-cu12-9.11.0.98

Is there a way that I can work with both Pytorch and JAX/Catalyst?

Hi @ashutiw2k, thank you for your interest in Catalyst!

I think there might be two different potential issues here:

  • Are you trying to use Pytorch with/inside of Catalyst programs?

    If so Pytorch is not a supported frontend for Catalyst, only JAX is.

  • Are you just trying to use Pytorch in the same environment as Catalyst?

    If so, this should work in general although it appears that the GPU dependencies for JAX might be conflicting with the with ones for Torch. I’m not quite sure what to do about this. Each Catalyst release generally depends on one specific, recent release of JAX, and it appears that both JAX and Torch might hard-pin a version of the same cuda package.

    Have tried running your program anyways, are there any errors? Sometimes even though pip identifies a conflict packages will still work regardless.

    If this is in fact breaking the Torch or JAX package, my next question is are you making use of the GPU backends for both Torch and JAX? Note that it is perfectly acceptable to run Catalyst with JAX’s CPU backend, that is in fact the default. This shouldn’t prevent you from still using something like lightning-gpu or lightning-kokkos to run on GPU-accelerated simulators.
    The only problem I can see is if you were planning on using catalyst.accelerate functionality to run a classical piece of code inside a qjit compiled function on a GPU. In this case sticking to the JAX CPU backend would limit you to running that piece of code on CPU.
    But other than that you shouldn’t have any issues running JAX in CPU mode. You can also silence the warning if it bothers you.

Let me know if this is helpful and if you still run into any issues after following the recommendations above.

Cheers,
David

2 Likes

Hi David, thank you for your detailed and informative response!

Yes, I think this describes my current setup. More accurately, I wish to use Catalyst to accelerate a Qnode inside a PyTorch model.

Basically, I am trying to run/learn parameters of a PennyLane quantum circuit inside a Pytorch model’s forward method. I am working on a learning algorithm that requires I can dynamically insert gates/operations in between the circuit. To this, I am storing my base_circuit as a list of Pennylane operations, and looping over this list every time I want to execute the circuit.

My forward method looks something similar to this:

def forward(self, args)
    
    @catalyst.qjit # Perhaps this will help with GPU execution?
    @qnode(self.qdevice, interface='torch', self.diff)
    def circuit():

        for i, ops in enumerate(self.base_circuit):
            qml.apply(ops)
            
             if i < 5: 
             # Some other operations with args
             if i < 10: 
             # Some other operations with args, etc

        return qml.state()
   return circuit() 

This works pretty well on the CPU (without the Catalyst decorator), but due to the looping and conditionals, its GPU execution is extremely slow (10x - 25x slower).

I can work with the CPU execution for up to 10 qubits, but to go beyond that, and even above 20-qubit circuits, I need to optimize the forward pass to run well on the GPU. Which is why I would like to use Catalyst’s qjit decorator to speed up the GPU execution of my model.

Does this make sense? I am extremely new to Pennylane and Catalyst and trying to figure my way around it. I would appreciate any insight into whether I am overcomplicating simple tasks.

No, not yet. But yes I can try to see if it works or not.

I think this is exactly what I am trying to do.

Thank you for your reply!

Hi @ashutiw2k ,

Thank you for sharing your code.
I think this counts as trying to use PyTorch with/inside of Catalyst programs.
As David mentioned this is not supported so you would need to port your code to JAX.

On the other hand, note that Catalyst probably won’t give you big speedups in this context since the compilation time may be way slower than the actual execution time.

The best thing you can try (and easier than porting everything from Torch to JAX) is to optimize your current program. Use the lightning.qubit device (running on CPU), remove self.diff and let it be selected automatically, and remove as many for loops as you can in your code.

I’m assuming you need differentiation, but in case you don’t need it then setting diff_method=None can speed things up too.

Let us know if you have any further questions, and let us know if these tips help!

Thanks @CatalinaAlbornoz for your response. Maybe I can clarify a few things below:

Oh got it! Does that involve running the reverse pass over the qnode as well? If not, your use case should actually work just fine (as long you resolve the package conflicts, e.g. by removing the jax cuda package).
If you do need the reverse pass, and are able to register a custom reverse pass with pytorch, the approach may still work as well.

This works pretty well on the CPU (without the Catalyst decorator), but due to the looping and conditionals, its GPU execution is extremely slow (10x - 25x slower).

Could you elaborate on what you mean when you say you are comparing GPU against CPU (without catalyst)? By GPU do you mean you ran the same code with the catalyst decorator or something else?
The reason I ask is that there are few different concepts coming together here, and I’d like to make sure there is no confusion. Catalyst accelerates program by compiling to binary code (and doing parametric compilation), but not necessarily by running on GPU. The latter is achieved by using a GPU-based simulator like lightning.gpu, which you can use with or without catalyst.
The catalyst.accelerate decorator I mentioned is only used to put classical code on the GPU, not the circuit simulation. In the code outline you shared I don’t actually see any classical computation inside the qnode, so it might not need to worry about this at all.

If you are interested in speeding up the qnode with catalyst, there are a few things I recommend:

  • Does the circuit you run change each iteration?

    Catalyst is most effect when you can compile once and reuse the compiled program many times (like in an optimization problem). That doesn’t mean you are running the same gates every time though!
    Catalyst supports many dynamic programming features for circuits (in particular control flow operations), so as long as you are able to express your changing circuit with these, you will benefit from greatly reduced processing times each iteration.

    For this to work, you’ll want to add arguments to your quantum function, so that different things can happen in your circuit based on those argument values. You will also need to move the definition of the circuit function out from the forward function to the global scope, so you can call the same function that was compiled once repeatedly inside the forward pass.

    Of course there is a chance that the circuits you are running each iteration have absolutely no relation to each other, and so you cannot express them as a parametrized catalyst function. If that’s the case, catalyst may not be able to speed up your code.

  • I already mentioned control flow, but any large loops in your program should be using catalyst control flow for efficiency (this drastically reduces the size of the program catalyst has to compile). You don’t actually need to use the control flow functions directly for this, the autograph feature will transform python loops into catalyst ones for you.

  • You mention CPU execution bottlenecks at around 10 qubits for you. This is surprising because typically the PL simulator should be able to handle double as many qubits easily. Do you have an idea what is bottlenecking at 10 qubits? How long does it take above 10 qubits? Do you know how big your circuit is in terms of number of gates?

I hope this helps you achieve what you’re looking for!