Hi everyone,
I wanted to share a project I recently worked on: a port of the classic Quanvolutional Neural Networks tutorial from TensorFlow/Keras to the JAX ecosystem.
The Motivation My primary motivation came from running into library incompatibilities with the original TensorFlow implementation. I noticed the PennyLane ecosystem seems to be gravitating more towards PyTorch and JAX recently. Since I was already comfortable with PyTorch, I saw this as the perfect opportunity to dive into JAX and see what it offers for QML.
During this process, I tackled the bottleneck of processing image patches. In the standard implementation, iterating over patches sequentially in Python can be slow. Using jax.vmap, I was able to vectorize the quantum circuit execution, allowing us to process all 196 patches of an MNIST image simultaneously in a single compiled call.
The Benchmarks We observed a significant performance jump on the single image processing step:
The Repository The full notebook, including the custom Flax training loop and benchmark code, is available here: GitHub - Spartoons/quanvolution-jax: A modern JAX port of the PennyLane "Quanvolution" tutorial. Demonstrates how to vectorize hybrid quantum-classical networks for massive performance gains.
I hope this is helpful for anyone looking to optimize hybrid QML pipelines or making a similar switch from TensorFlow! I’d love to hear any feedback on the JAX implementation.
1 Like
Hi @Spartoons , this is amazing! Thanks for sharing it here!
Would you like to submit this as a community demo?
Community demos are hosted on your own repo, but shown on the PennyLane community demos page.
Community demos have a few advantages:
- You host them on your own repo so you have plenty of freedom for structuring the demo and including additional datasets or files.
- They’re generally easy to make since you can use a jupyter notebook or Python file if you prefer.
- They generally get published fairly quick (within a couple of weeks).
You will find the submission guidelines for the community demo at the bottom of our demo submission page. This process will require you to open an issue so we’ll get back to you on the issue once you’ve opened it.
Let us know if you have any questions about this!
2 Likes
Hi @CatalinaAlbornoz, done! The submission issue is open here: [issue]
1 Like
Thank you @Spartoons ! Let’s continue the conversation on the issue. We will assign someone from our team for review and get back to you with next steps there.
Note that since it’s the holiday season we may take a bit longer than usual but we’ll do our best to provide a quick review.
1 Like
“Wow, this is really impressive!
I love how you leveraged jax.vmap to vectorize the quantum circuit execution — those speed improvements are massive. It’s inspiring to see a clean JAX implementation for QML, especially with the full Flax training loop included. Definitely bookmarking this for when I dive into hybrid quantum-classical networks. Thanks for sharing your work and benchmarks!”
2 Likes
Welcome to the Forum @Khadija_Shafiq 
Thanks for sharing these words, it’s great to see this positive feedback!
We will be sharing @Spartoons ’ demo on our Community Demos page and PennyLane’s social media over the next few weeks. Stay tuned and feel free to share this with your network!
Hi @Khadija_Shafiq , thank you very much! I’m still new to this space myself, so I’m happy to share my progress and path with the community. Glad you found the JAX/Flax integration useful!