Mixed precision training + numerical stability for Pytorch + Pennylane?

Hi there! I was reading about mixed precision training in Pytorch. Basically, Pytorch can perform some operations with lower precision (eg automatically cast to 16 instead of 32 bit) to speed up computation, and it’s been applied successfully to classical neural networks.

I have tested some Pennylane circuits with Pytorch and have found that they do train faster with mixed precision. The caveat is that mixed-precision is not recommended when numerical stability is required due to the smaller dynamic range of lower precision floats.

Here are my questions:

  1. Which of the quantum operations implemented in Pennylane and other device backends are numerically (un)stable?
  2. How would numeric instability in Pennylane impact the correctness of such operations?
  3. Would you recommend mixed precision training with Pennylane?
1 Like

Hey @schance995 , that’s definitely an interesting question – thanks for popping it in here. :slight_smile:

I can’t give you great answers right away, but I’m checking in with the team and we’ll get back to you soon!

Hi @schance995 , as you can imagine, the answers to your questions can really depend on what you’re trying to do, but here are some general considerations that might make it simpler, as shared by our developers. :slight_smile:

PennyLane uses FP64 as default, and if you try to downcast something to a lower precision, something else through the stack gets upcast in turn. In this sense, you don’t have to be too worried about numerical instability in PennyLane. For general-purpose deep circuits, and unless you have absolutely massive circuits, floating point rounding errors shouldn’t be an issue.

Of course, this comes with some fairly normal caveats: are you using very small angles or some more exotic calculations with density matrices, entropies? In cases where you’re dealing with standard gate operations, you should be good to go.

Does this help? As you can imagine, it would be a bit of a larger undertaking to figure out the effects of specific lower-precision formats and we don’t have a speedy answer for you.
But we would love to know about your conclusions as you play around with PennyLane, so please let us know how it goes. :smiley:

2 Likes

Here’s a brief follow-up on this. It is not currently possible to get significant savings from mixed precision. As an example, here’s the initializer for default.qubit.jax:

def __init__(self, wires, *, shots=None, prng_key=None, analytic=None):
    if jax.config.read("jax_enable_x64"):
        c_dtype = jnp.complex128
        r_dtype = jnp.float64
    else:
        c_dtype = jnp.complex64
        r_dtype = jnp.float32

Although r_dtype could be cast to float16, c_dtype cannot be cast into anything smaller. Numpy and Jax don’t currently support complex32. So mixed precision with Pennylane will have to wait.