Best practices when combining Pennylane with TensorFlow

Hi,

My use case essentially combines classical TensorFlow networks with quantum networks both in terms of one network with classical and quantum layers and multiple networks of this type working together. I’ve been experimenting with different qml devices and use of @tf.function. Mainly default.qubit, lightning.gpu and lightning.qubit

From my experiments it seems like default.qubit combined with @tf.function produces the best results. My use case involves lots of derivatives sometimes up to third order. On the more complex examples sometimes avoiding @tf.function seems best as tracing the graph can take 2+ hours even on a GPU.

Are there any suggestions for the setup to use here? I have a hard time believing the default.qubit is the best choice just going off some of the comments I’ve seen on other threads about it being slow. I can attach code examples but I’m really asking for general advice.

Thanks

Hi @Bnesh !

You’re touching on some interesting points. Note that using GPUs won’t always make things faster, and using default.qubit won’t always make things slower. In our performance page you’ll see a section on what simulator to use according to your needs.

As a general guidance, consider that moving data from CPU to GPU is inefficient, so GPUs generally won’t provide improvements unless you have large circuits with many qubits and gates that require a lot of matrix multiplications without needing to input and output data.

On the other hand, differentiation methods can make a huge difference and they’re not all compatible with each device. Each device is set to use the differentiation method that is optimal for it, but this can mean that certain workflows work better on default.qubit than other devices. Adjoint differentiation is usually fastest but it’s not compatible with actual quantum hardware for example.

Finally, simplification can bring big advantages. Avoid having nested for loops and complex examples that are hard to debug and optimize.

Note: if you’re able to move away from Tensorflow and use Jax instead this can bring big advantages too.

I hope this helps!

Hi,

Thanks for the response, I have found the CPU can outperform GPU / TPU for the smaller circuit examples so it’s good to know that’s expected behaviour. When it comes to the differentiation method I’ve found that changing “diff_method” to anything other than backprop causes my code to fail with varying errors depending on the diff choice. I figured this was expected behaviour as the TensorFlow interface expects to use backprop on the derivatives. If this isn’t the case I can attach some code examples showing this.

Note I am calculating higher order derivatives I don’t know if I’d still get the errors without this.

When it comes to moving to Jax what sort of speedups could be expected? Obviously just a very rough approximate idea here

Thanks

Hi @Bnesh ,

Changing “diff_method” will not fail because of TensorFlow, usually it will fail because of the combination of workflow, derivatives and device. So it may well be that for your specific case backprop is the best option.

I don’t have a specific number for the JAX speedups. If you do jitting (with jax-jit) then you’ll find speedups when you have “for” loops or elements that maintain structure but vary parameters after every iteration. If you have higher order derivatives I don’t know if this would be compatible though. So if you’re happy with the performance you’re getting then you have no need to move to JAX. You can find some of our demos using JAX here.

I hope this helps!