I’m using pennylane with the Jax backend.
I’m quite experienced with Jax and have often noticed large speedups in other projects, so I wanted to test it in Pennylane as well.
The simulations I’m running are VQE type simulations of about 16 qubits.
This works like a charm when I’m doing the exact calculations, and it’s super fast.
Now I’d like to use the same code to add shots.
There is a very brief example on how to pass jax.random.PRNGKeys at each step in the VQE, which requires to jit also the device creation.
I didn’t get this to work yet (I received a bunch of problems with the array conversion of some np.array(indices) line somewhere, if I pass a Hamiltonian in the expval).
This last point is not the problem. The main issue is that my code runs extremely slow with shots (I’m just jitting the cost function, not the device creation).
I’m talking multiple orders of magnitude slowdown. I’m expecting of course a slow down, but not to this magnitude that I’m no longer able to carry out my VQE.
I’ve read in Speeding up grad computation that this is indeed the current behavior.
So, is it pointless to use Jax with VQE with shots?
Would it be possible to add an example of VQE of a Hamiltonian with shots to see how it should be done properly, in case it should run fast?
Thanks in advance,