Hardware-compatible differentiation and QML

Hello,
I am currently trying to setup a hybrid quantum machine learning experiment. Using JAX and pennylane. My problem arises when I change the gradient method from Simulation-based differentiation to Hardware-compatible differentiation.
Using backprop works with jax.jit see our paper Hybrid quantum tensor networks for aeroelastic applications for details of the setup. As a next step we wanted to test using Hardware-compatible differentiation methods. But as stated in the docu using the jitted version of nnx.grad(loss_fn) with the loss_fn generating predictions and using optax.losses.huber_loss leads to errors (also just using jax.vmap leads to errors).

Before rewriting my whole code I wanted to ask for a clarification of the following in the documentation of Quantum gradients using JAX

The output of vector-valued QNodes can, however, be used in the definition of scalar-valued cost functions whose gradients can be computed.

Is this not the case in the mentioned setup, i.e. using the probs for a scalar loss?

The main problem currently is in how to batch the input since vamp seems to lead to errors. Furthermore would it be possible to use Hardware-compatible differentiation methods within a PyTorch setup or using catalyst and jax.vmap?

Any thoughts on this would be very much appreciated. Thanks!

Hi @LautaroH , welcome to the Forum!

I can confirm that hardware-compatible methods can be used with JAX and PennyLane. This includes the parameter-shift method and finite-differences. However, some restrictions apply to the measurements in the QNode:

  • Sample and probability measurements cannot be mixed with other measurement types in QNodes;
  • Multiple probability measurements need to have the same number of wires specified;

On the other hand, both jax.vmap and catalyst.vmap should work, however PyTorch is not compatible with Catalyst.

I’d love to understand your issues better. Do you have some code or pseudocode showing the “ideal scenario” that you would like to have?

From what I understand so far you would like to have just-in-time compilation, hardware-compatible differentiation, and PyTorch integration. Unfortunately JIT compilation and PyTorch are not compatible with each other so which one is most critical for you?

I hope this helps!