Jax instead of autograde

can we use Jax instead of autograde for greater speed while training complex circuits?

Hi @kareem_essafty! Not at the moment, but this is on our radar as a feature to add :slightly_smiling_face:

I’d love to get involved please voluntarily

Hi @kareem_essafty,

We welcome community contributions :slight_smile:
Note that we maintain a high standard of code in PennyLane (including thorough documentation & testing). Please see here for our guidelines, in particular the section on Pull Requests.

Feel free to hack away at a jax interface and see whether you can get it working. It’s best to keep lines of communication with the pennylane dev team open if you’ll be working on it.

So sorry for the late reply, I’ll start working on this asap. Besides that, is it okay if I use cython syntax or something like predefined c files that contain the repeated operations?

Hi @kareem_essafty, is there any particular reason you don’t want to write in pure python (like the rest of PL)? Something particular to Jax?

Apart from jax or autograd, based on my humble experience tensor operations or matrix multiplication in general are a bit time consuming. Defining these using cython or in c++ and use any binding method does actually help. In computer vision for example i always encounter some delays and using c++ or opencl with python actually makes it really fast.
Regarding jax, I only want to see pennylane natively supporters GPU, it is a wonderful quantum machine learning library and you sir and the wonderful team have made the qubit simulator actually faster and better than the earlier versions.

2 Likes

Hi @kareem_essafty,

From what I understand, you just want to create a new Jax interface to replace the existing autograd interface. This is just an interface, not a simulator, so it doesn’t contain any numerical code (other than possible trivial reshaping of arrays). There shouldn’t be any need for tensor operations/matrix multiplication

1 Like

I meant the quantum routines themselves can be much more faster using cython for example. besides that, Jax will make the training loops on the gpu

Hi @kareem_essafty,

About your original request to code up a jax interface, I can only repeat what I said above: there aren’t really any real numerical computations or simulations happening at that level of the code (it is just a bridge between an external classical framework and PennyLane).

If you instead want to write a custom simulator, that’s a different story. That would be a new PennyLane plugin/device, and certainly writing it in cython or some equivalent could be an option