VQE with shots with jax


I’m using pennylane with the Jax backend.
I’m quite experienced with Jax and have often noticed large speedups in other projects, so I wanted to test it in Pennylane as well.
The simulations I’m running are VQE type simulations of about 16 qubits.
This works like a charm when I’m doing the exact calculations, and it’s super fast.
Now I’d like to use the same code to add shots.
There is a very brief example on how to pass jax.random.PRNGKeys at each step in the VQE, which requires to jit also the device creation.
I didn’t get this to work yet (I received a bunch of problems with the array conversion of some np.array(indices) line somewhere, if I pass a Hamiltonian in the expval).
This last point is not the problem. The main issue is that my code runs extremely slow with shots (I’m just jitting the cost function, not the device creation).
I’m talking multiple orders of magnitude slowdown. I’m expecting of course a slow down, but not to this magnitude that I’m no longer able to carry out my VQE.
I’ve read in Speeding up grad computation that this is indeed the current behavior.
So, is it pointless to use Jax with VQE with shots?
Would it be possible to add an example of VQE of a Hamiltonian with shots to see how it should be done properly, in case it should run fast?

Thanks in advance,

Hi @Jannes_Nys,

The basic problem is that we should use the parameter-shift rule when finite-shot is on.

I first note that statement “estimating gradient using a finite-shot” is not well-defined without specifying which circuit we use for a gradient. In PennLane, we assume an experimental scenario that which the parameter-shift rule is used, i.e. we create two circuits with different parameters (or more time depending on the gate) and add shot-noise to their results. We then post-process the results to obtain the proper gradient with sampling errors. As we need to evaluate at least 2 circuits for each parameter. This is where the slow-down comes.

It is possible (theoretically) that we mimic this shot-noise in this scenario when the second derivative is available, but we did not implement this functionality yet (as this information is not always available in all simulation devices).

If you want to have further discussion, please let us know.


Hi @Jannes_Nys,

Thanks for posting your issue. If it is what I think it is, it is indeed something we are actively trying to improve. As @CY_Park says, currently we have to fall back to a “parameter-shift” method to estimate gradients with finite shots, which has a linear scaling with number of parameters. If we were able to make things compatible with standard backpropagation, or its memory-efficient variant the adjoint method, then you would see the kinds of speeds you expect.

This is an issue we are aware of, but unfortunately it requires us to do some research in order to resolve the problem (finding an efficient way to sample the distribution of a quantum circuit’s gradient), it’s not just a coding/implementation detail. I’m hopeful that we will be able to crack it, but it may take us some time to find a way :pray: