JAX implementation in VQC

Hi Team. I’m trying to implement Jax on VQC. But i cannot update the weights and the bias simultaneously in optax.apply_update(). The code is same as the codebook, just want to include JAX over there.

Hi @roysuman088 ,
Could you provide a minimal example of what you are trying to do that shows the incompatibility with JAX? Some pointers without having seen your code could be:

  • Are you updating the interface of used QNode objects, using the interface keyword argument?
  • Are you passing JAX arrays as trainable inputs to the relevant functions?
  • Are you using JAX’s differentiation methods, like jax.jacobian or jax.grad to evaluate involved derivatives?
  • Are you marking the relevant function arguments as trainable, using the argnums keyword argument of e.g. jax.jacobian or jax.grad, or the argnum keyword argument of the corresponding gradient transform (in case you are using that).

Maybe this already helps, otherwise happy to provide more feedback on any code example you have. Happy coding!