Integrating Catalyst with PyTorch

Hi all,

I am working on a quantum machine learning project written in Pytorch and Pennylane, but I would like to try to use Pennylane-Catalyst to speed up the training process.

My specs are as follows:

Name: PennyLane
Version: 0.35.0
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page: https://github.com/PennyLaneAI/pennylane
Author: 
Author-email: 
License: Apache License 2.0
Location: /usr/local/lib/python3.11/dist-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, pennylane-lightning, requests, rustworkx, scipy, semantic-version, toml, typing-extensions
Required-by: pennylane-qulacs, PennyLane_Lightning, PennyLane_Lightning_GPU

Platform info:           Linux-6.5.0-25-generic-x86_64-with-glibc2.35
Python version:          3.11.0
Numpy version:           1.26.3
Scipy version:           1.12.0
Installed devices:
- lightning.qubit (PennyLane_Lightning-0.35.1)
- qulacs.simulator (pennylane-qulacs-0.32.0)
- default.clifford (PennyLane-0.35.0)
- default.gaussian (PennyLane-0.35.0)
- default.mixed (PennyLane-0.35.0)
- default.qubit (PennyLane-0.35.0)
- default.qubit.autograd (PennyLane-0.35.0)
- default.qubit.jax (PennyLane-0.35.0)
- default.qubit.legacy (PennyLane-0.35.0)
- default.qubit.tf (PennyLane-0.35.0)
- default.qubit.torch (PennyLane-0.35.0)
- default.qutrit (PennyLane-0.35.0)
- null.qubit (PennyLane-0.35.0)
- lightning.gpu (PennyLane_Lightning_GPU-0.35.1)

I understand that Catalyst normally only supports Jax, but I’m wondering if there is a special technique that I might be able to use to integrate Catalyst with PyTorch without having to overhaul my code. Thank you very much for your help.

Hi @justin6626,

Unfortunately I’m not aware of any way you could compile PyTorch code with Catalyst right now. But it is valuable feedback that there is interest in it!

Out of curiosity, do have some ideas about what part of your code is slow / you would like to speed up? What sort of structure does your program follow, is it mainly PyTorch code with some PennyLane functions embedded, or mostly PennyLane code with a few PyTorch calls for classical pre- and post-processing?

If you are curious about just in time compilation I would definitely encourage you to try and write the model in JAX & Catalyst. If you do be sure to let us know how it goes :slight_smile:

Since lightning.qubit is CPU only, unless you run Pytorch on CPU, there’ll be some overhead with passing data between CPU and GPU. But one idea is to define a custom Pytorch module that calls to Catalyst and computes gradients that are attached to the Pytorch graph.

Hi @schance995,

It’s true that lightning.qubit is CPU only and switching between CPU and GPU can lead to significant overheads.

About your second point I’m not sure whether that would work. Let me check with the team because I know Catalyst will be getting some nice upgrades soon and maybe they can help in this case.

Edit: actually unfortunately Catalyst requires things that PyTorch doesn’t have so the best is to use Jax together with Catalyst instead of using PyTorch.

1 Like

Hiya
Is there any possibility in future updates for pytorch and catalyst to work with another even if currently that functionality isnt possible?

Hi Aaron, it’s definitely something we want to do, but there are some technical challenges which make this difficult and will take time to resolve. Hopefully we can share more with you on this front in the future.

1 Like