Hi there! I was reading about mixed precision training in Pytorch. Basically, Pytorch can perform some operations with lower precision (eg automatically cast to 16 instead of 32 bit) to speed up computation, and it’s been applied successfully to classical neural networks.
I have tested some Pennylane circuits with Pytorch and have found that they do train faster with mixed precision. The caveat is that mixed-precision is not recommended when numerical stability is required due to the smaller dynamic range of lower precision floats.
Here are my questions:
- Which of the quantum operations implemented in Pennylane and other device backends are numerically (un)stable?
- How would numeric instability in Pennylane impact the correctness of such operations?
- Would you recommend mixed precision training with Pennylane?
1 Like
Hey @schance995 , that’s definitely an interesting question – thanks for popping it in here. 
I can’t give you great answers right away, but I’m checking in with the team and we’ll get back to you soon!
Hi @schance995 , as you can imagine, the answers to your questions can really depend on what you’re trying to do, but here are some general considerations that might make it simpler, as shared by our developers. 
PennyLane uses FP64 as default, and if you try to downcast something to a lower precision, something else through the stack gets upcast in turn. In this sense, you don’t have to be too worried about numerical instability in PennyLane. For general-purpose deep circuits, and unless you have absolutely massive circuits, floating point rounding errors shouldn’t be an issue.
Of course, this comes with some fairly normal caveats: are you using very small angles or some more exotic calculations with density matrices, entropies? In cases where you’re dealing with standard gate operations, you should be good to go.
Does this help? As you can imagine, it would be a bit of a larger undertaking to figure out the effects of specific lower-precision formats and we don’t have a speedy answer for you.
But we would love to know about your conclusions as you play around with PennyLane, so please let us know how it goes. 
3 Likes
Here’s a brief follow-up on this. It is not currently possible to get significant savings from mixed precision. As an example, here’s the initializer for default.qubit.jax
:
def __init__(self, wires, *, shots=None, prng_key=None, analytic=None):
if jax.config.read("jax_enable_x64"):
c_dtype = jnp.complex128
r_dtype = jnp.float64
else:
c_dtype = jnp.complex64
r_dtype = jnp.float32
Although r_dtype
could be cast to float16, c_dtype
cannot be cast into anything smaller. Numpy and Jax don’t currently support complex32. So mixed precision with Pennylane will have to wait.
Hey @schance995,
default.qubit.jax
is using our old device API. If you use default.qubit
, you’ll access the new changes and won’t have to worry about specifying an interface at any point in your workflow
. default.qubit
knows what to do! Does that change anything for you?
1 Like
Same story applies for the new device, see: pennylane/pennylane/workflow/interfaces/jax_jit.py at 1043c55ccf1ae9001a1ee7c851451f1133da4869 · PennyLaneAI/pennylane · GitHub
def _jax_dtype(m_type):
if m_type == int:
return jnp.int64 if jax.config.jax_enable_x64 else jnp.int32
if m_type == float:
return jnp.float64 if jax.config.jax_enable_x64 else jnp.float32
if m_type == complex:
return jnp.complex128 if jax.config.jax_enable_x64 else jnp.complex64
return jnp.dtype(m_type)
Avoiding 64-bit operations with jax is the best way to minimize memory usage for now, at the cost of very high precision.
Hi @schance995 ,
Thanks for your post! We’re taking a deeper look into this. We’ll be back with more info soon.
Hi @schance995 ,
My colleague Vincent helped me find these answers to your questions.
-
Question: Which of the quantum operations implemented in PennyLane and other device backends are numerically (un)stable?
-
Answer: As complex32
isn’t supported by linear algebra frameworks, with good reasons and no plan for support, I think any circuit using complex gates won’t be able to benefit. PennyLane also uses complex casts in several places, so even trying to keep it real might not work either. That said, mostly everything PennyLane does is matrix multiplication deep down, so it should be stable and maintain the precision of the chosen binary representation as expected.
-
Question: How would numeric instability in PennyLane impact the correctness of such operations?
-
Answer: I would expect round-off error accumulation proportional to the total number of operations, but no instability running a circuit. Anything coming on top of that, such as an iterative solver, could indeed lead to instability but that is out of our control.
-
Question: Would you recommend mixed precision training with PennyLane?
-
Answer: This is not supported at the moment.
I hope this helps you.