Facing Issues with JAX jitting the Optimization loop

Hey @G_Akash,

Apologies for the wait :sweat_smile:.

… slow compilation with increasing qubits, and no amount of GPU is helping in the reduction of the compilation time, on the contrary, the compilation time increases with an increase in GPU performance.

One thing that might be causing this is the lack of jax control flow operations (see here: Writing TPU kernels with Pallas β€” JAX documentation). I would try using the native control flow operations and see how that goes! The other thing I can recommend is to use Catalyst. Might be worth revisiting the post you made here?

I would like to know how to modify my circuit architecture to make it less computationally intensive while simultaneously increasing its ability to understand the complex features in the dataset.

That’s the million dollar question! Unfortunately I don’t have a one-size-fits-all answer.