Hi Team. I’m trying to implement Jax on VQC. But i cannot update the weights and the bias simultaneously in optax.apply_update(). The code is same as the codebook, just want to include JAX over there.
Hi @roysuman088 ,
Could you provide a minimal example of what you are trying to do that shows the incompatibility with JAX? Some pointers without having seen your code could be:
- Are you updating the interface of used
QNode
objects, using theinterface
keyword argument? - Are you passing JAX arrays as trainable inputs to the relevant functions?
- Are you using JAX’s differentiation methods, like
jax.jacobian
orjax.grad
to evaluate involved derivatives? - Are you marking the relevant function arguments as trainable, using the
argnums
keyword argument of e.g.jax.jacobian
orjax.grad
, or theargnum
keyword argument of the corresponding gradient transform (in case you are using that).
Maybe this already helps, otherwise happy to provide more feedback on any code example you have. Happy coding!