Hi Team. I’m trying to implement Jax on VQC. But i cannot update the weights and the bias simultaneously in optax.apply_update(). The code is same as the codebook, just want to include JAX over there.
Hi @roysuman088 ,
Could you provide a minimal example of what you are trying to do that shows the incompatibility with JAX? Some pointers without having seen your code could be:
- Are you updating the interface of used
QNodeobjects, using theinterfacekeyword argument? - Are you passing JAX arrays as trainable inputs to the relevant functions?
- Are you using JAX’s differentiation methods, like
jax.jacobianorjax.gradto evaluate involved derivatives? - Are you marking the relevant function arguments as trainable, using the
argnumskeyword argument of e.g.jax.jacobianorjax.grad, or theargnumkeyword argument of the corresponding gradient transform (in case you are using that).
Maybe this already helps, otherwise happy to provide more feedback on any code example you have. Happy coding!