Hello,
I am currently trying to setup a hybrid quantum machine learning experiment. Using JAX and pennylane. My problem arises when I change the gradient method from Simulation-based differentiation to Hardware-compatible differentiation.
Using backprop works with jax.jit see our paper Hybrid quantum tensor networks for aeroelastic applications for details of the setup. As a next step we wanted to test using Hardware-compatible differentiation methods. But as stated in the docu using the jitted version of nnx.grad(loss_fn) with the loss_fn generating predictions and using optax.losses.huber_loss leads to errors (also just using jax.vmap leads to errors).
Before rewriting my whole code I wanted to ask for a clarification of the following in the documentation of Quantum gradients using JAX
The output of vector-valued QNodes can, however, be used in the definition of scalar-valued cost functions whose gradients can be computed.
Is this not the case in the mentioned setup, i.e. using the probs for a scalar loss?
The main problem currently is in how to batch the input since vamp seems to lead to errors. Furthermore would it be possible to use Hardware-compatible differentiation methods within a PyTorch setup or using catalyst and jax.vmap?
Any thoughts on this would be very much appreciated. Thanks!