Parallel vectorized circuit execution with vmap

Hi,

My use case requires me to execute the same simple parameterized circuit over few qubits (~9) over a batch of inputs (~10,000 unique inputs to the circuit for a minimal POC).

Naturally, I have tried to achieve this as follows, by wrapping the qnode in a vmap.

from pennylane import qnode, device, RX, CRZ, PauliZ, expval
from catalyst import vmap, qjit
import jax
import jax.numpy as jnp

@qjit(keep_intermediate=True, verbose=True) # inspect lowering process
@vmap(in_axes=(0, None), out_axes=0)    # vectorize over 'batch' dimension
@qnode(device("lightning.qubit", wires=5), interface="jax")
def simple_quantum_model(inputs, weights):
    for i in range(5): RX(inputs[i], wires=i)       # encoding
    for i in range(1, 5): CRZ(weights[i], (i, 0))   # ansatz
    return expval(PauliZ(wires=0))                  # measurement

test_inputs = jax.random.normal(jax.random.key(0), shape=(10_000, 5), dtype=jnp.float32)
test_weights = jax.random.normal(jax.random.key(0), shape=(5,), dtype=jnp.float32)

print(simple_quantum_model(test_inputs, test_weights).shape)

The above was tried with simulators: “lightning.qubit”, “lightning.gpu”, “lightning.kokkos” (kokkos built with CUDA support). Regardless of the device, the result of the 4th pass “BufferizationPass” (before the final “MLIRToLLVMDialect”) looks like the following:

func.func public @jit_vmap.simple_quantum_model(%arg0: memref<10000x5xf32>, %arg1: memref<5xf32>) -> memref<10000xf64> attributes {llvm.copy_memref, llvm.emit_c_interface} {
  ... // allocations, copy
  %5 = scf.for %arg2 = %c0 to %c10000 step %c1 iter_args(%arg3 = %alloc_20) -> (memref<10000xf64>) {
    ... // indexing, selection
    %15 = func.call @simple_quantum_model_0(%alloc_17, %arg1) : (memref<5xf32>, memref<5xf32>) -> memref<f64>
    ... // more indexing, copy, selection
    scf.yield %alloc_22 : memref<10000xf64> // final yield
  }
  ...// deallocation, copy
  return %10 : memref<10000xf64>
}

Which is the equivalent of a for loop on the circuit.
In other words, the mapped axis (batch) isn’t pushed down into primitive operations, as one would expect from vmap.

Sidenote: The above was also tried with jax.jit and jax.vmap, which also results in no parallelism and is also orders of magnitude slower to execute.

Would this be intentional? If so, could this be raised as a feature request?
Could you suggest some workarounds for parallel circuit execution on GPU?

qml.about()

Name: PennyLane
Version: 0.40.0
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page: [link]
Author:
Author-email:
License: Apache License 2.0
Location: /home/[user]/.local/lib/python3.11/site-packages
Requires: appdirs, autograd, autoray, cachetools, diastatic-malt, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, tomlkit, typing-extensions
Required-by: amazon-braket-pennylane-plugin, PennyLane-Catalyst, pennylane-qrack, pennylane-qulacs, PennyLane_Lightning, PennyLane_Lightning_GPU, PennyLane_Lightning_Kokkos, PennyLane_Lightning_Tensor

Platform info: Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.35
Python version: 3.11.11
Numpy version: 1.26.4
Scipy version: 1.13.0
Installed devices:

  • default.clifford (PennyLane-0.40.0)
  • default.gaussian (PennyLane-0.40.0)
  • default.mixed (PennyLane-0.40.0)
  • default.qubit (PennyLane-0.40.0)
  • default.qutrit (PennyLane-0.40.0)
  • default.qutrit.mixed (PennyLane-0.40.0)
  • default.tensor (PennyLane-0.40.0)
  • null.qubit (PennyLane-0.40.0)
  • reference.qubit (PennyLane-0.40.0)
  • braket.aws.ahs (amazon-braket-pennylane-plugin-1.31.2)
  • braket.aws.qubit (amazon-braket-pennylane-plugin-1.31.2)
  • braket.local.ahs (amazon-braket-pennylane-plugin-1.31.2)
  • braket.local.qubit (amazon-braket-pennylane-plugin-1.31.2)
  • nvidia.custatevec (PennyLane-Catalyst-0.11.0.dev15)
  • nvidia.cutensornet (PennyLane-Catalyst-0.11.0.dev15)
  • oqc.cloud (PennyLane-Catalyst-0.11.0.dev15)
  • softwareq.qpp (PennyLane-Catalyst-0.11.0.dev15)
  • lightning.gpu (PennyLane_Lightning_GPU-0.40.0)
  • lightning.tensor (PennyLane_Lightning_Tensor-0.40.0)
  • lightning.qubit (PennyLane_Lightning-0.41.0.dev7)
  • lightning.kokkos (PennyLane_Lightning_Kokkos-0.41.0.dev7)
  • qulacs.simulator (pennylane-qulacs-0.40.0)
  • qrack.simulator (pennylane-qrack-0.12.0)

PS: I’m aware there are similar threads, however I could not find one for vmap vectorization.

Hi @LasradoRohan, thank you for sharing your use case!

You’re right, catalyst.vmap does not push the batch dimensions into primitives, because the quantum device interface was not designed with this in mind (for instance, how would a QPU vectorize a single gate, with no input/output arrays to speak of?). On the other hand, classical primitives from the NumPy api are designed to handle additional input/output dimensions on arrays directly in their kernels).
Instead, we created catalyst.vmap with the goal to enable easy batching of functions, using the familiar api from jax.vmap. While we don’t vectorize primitive kernels, the compiled for loop should still be faster than launching those 10,000 function executions from Python. If this is not the case, please let us know and we’ll look into it!

Sidenote: The above was also tried with jax.jit and jax.vmap , which also results in no parallelism

You mention parallelism, but vmap is not a parallelizing operation, it is a batching operation (explanation from the JAX developers). Maybe what you are looking for is jax.pmap? Unfortunately, pmap is currently unsupported with Catalyst, although you should be able to observe parallelism with it on pure JAX programs.

and is also orders of magnitude slower to execute.

What are you comparing against?

Would this be intentional? If so, could this be raised as a feature request?
Could you suggest some workarounds for parallel circuit execution on GPU?

Let me get some additional info on this and get back to you :slight_smile:

Hi @David_Ittah, thanks for the answer.

how would a QPU vectorize a single gate

Although it makes sense that a QPU would not be able to execute gates over multiple states in parallel, I believe state vector simulators that work directly with matrices could benefit from such a batching operation. e.g. the compiled for loop (which launches one kernel per batch index) leads to very low GPU utilization in my case.

but vmap is not a parallelizing operation

Thanks for clarifying this about vmap. Although, it doesn’t guarantee parallelization, but in an ML use case where fusing matrix operations is possible, it is widely used to achieve parallelism over a batch.

Maybe what you are looking for is jax.pmap?

jax.pmap distributes each batch index across multiple XLA devices. It requires that that batch dimension be smaller than or equal to the number of XLA devices. However, my issue is with underutilization of a single device.

Currently, for example, the only apparent option to increase device utilization is to double (triple, etc.) the number of qubits and run the two sections of the same circuit in parallel, obviously wasting computation for the sake of parallelism.

What are you comparing against?

Just a sidenote that executing qjit(vmap(qnode))(inputs) is much faster than jax.jit(jax.vmap(qnode))(inputs), probably due to the compiled circuit.

Let me get some additional info on this and get back to you :slight_smile:

Yes, more information on what I could do to achieve this parallelism would be game changing for me.
Thanks :slight_smile:

That’s true, I believe that is what would happen if you use JAX directly with the default.qubit device, because it is written in Python using JAX instructions, it should be able to vectorize the simulation of each individual gate. If you haven’t tried this configuration yet this might be what you are looking for :slight_smile:

Unfortunately, because device internals (living in pre-compiled libraries) are opaque to Catalyst , we cannot vectorize operations like that even if they are simulators.

However, my issue is with underutilization of a single device.

Got it, thanks for clarifying the use case!

Hello @David_Ittah ,

Thank you for your detailed answer. I have a similar issue and would like to understand it better. As I have tested and is reported here, the torch, jax and tf based qnodes with the default.qubit can use GPUs for parallel optimization. Is this also possible with a more optimized devices e.g. with a special build from source or a special simulator?
I am interested in executing the same hybrid model (500-1000 times) with low qubit numbers (8-16) for different input data. So far all configurations (apart from the those with the default.qubit) where executed serial. Here are some benchmark results for an optimization step:

  1. Case → 0.2 sec (default.qubit, torch interface, 12 Qubits, batch size 1024)
  2. Case → 19.5 sec (JIT lightning.kokkos, jax interface, 12 Qubits, batch size 1024).

The 2. Case delivers the best performance apart from the default.qubit. The limiting factor is the serial execution which is quite frustrating.
Do you have any recommendation for parallel optimization of hybrid ML models on GPUs (apart from using the 1. Case: GPU + default.qubit )?

Best reagrds

minn-bj

Thanks for sharing @minn-bj. I think for your use case you might indeed get the best performance by running the default.qubit simulator in your ML library of choice. Is there anything you are missing when using this configuration?

Hey @David_Ittah ,

Thank you for your fast reply. In principle I just hoped that there could be a way to compile the circuit for pafallel execution or something. I am surprised that there is no feature which enables parallel execution of smaller circuits as this is one of the basic features in classical machine learning.

In addition, the scalability with additional hardware is something that I miss. In This scenario one is limited to use the default.qubit for a single model with a single CPU or GPU. The scalability and the batch parallelization are key features for real world investigations.

Do you know any other way to speed up my calculations?

Best regards

minn-bj

Hi @minn-bj ,

Thank you for your feedback. The points you raise can be inputs to future roadmaps for the development of Catalyst and PennyLane.

In your particular case, how deep are your circuits? For 12-qubit circuits lightning.qubit should work well, and if you’re using JAX then you can compile your circuit with Catalyst to speed up the computation. In the numbers that you mentioned before, were you using Catalyst?

Note that even if you don’t use Catalyst you can indeed parallelize the computation of gradients when using the lightning suite, however using Catalyst is maybe an easier option.

Finally, note that using GPUs for the quantum computation probably won’t help given that you have a lot of back and forth from CPU and GPU and the overheads will just make it slower.

Finally, note that some basic features in classical machine learning can be hard or impossible to implement in the case of quantum computing. Hence the need to understand the core of the field and the problems that quantum computers are most likely to solve well.

So if you can give us more details I can recommend things that may help speed up your program but sometimes the best speedups come from rewriting the program itself, removing control flow, etc. This is an active field of research so there’s still a lot to discover!

Hey @CatalinaAlbornoz ,

thank you for your detailed answer. I did a benchmark with reasonably shallow circuit ( the usecase is for reasonable deep circuits so I messed up a bit) with all lightning devices for 8,12 and 16 Qubits using Tensorflow, PyTorch and JAX with and without catalyst. Unfortunately all CPU (apart from the default qubit) based approaches do not batch parallelize and are therefore very slow. The best combination seams to be the default.qubit with torch on an GPU. If there is another way to batch parallelize an optimization step on a GPU I would be interested in that, but it seams not to be possible as @David_Ittah mentioned.

Best reagards

minn-bj

Hi @minn-bj ,

Given the number of qubits you’re working with it makes sense for default.qubit to give you the best performance.