Parallel vectorized circuit execution with vmap

Hi,

My use case requires me to execute the same simple parameterized circuit over few qubits (~9) over a batch of inputs (~10,000 unique inputs to the circuit for a minimal POC).

Naturally, I have tried to achieve this as follows, by wrapping the qnode in a vmap.

from pennylane import qnode, device, RX, CRZ, PauliZ, expval
from catalyst import vmap, qjit
import jax
import jax.numpy as jnp

@qjit(keep_intermediate=True, verbose=True) # inspect lowering process
@vmap(in_axes=(0, None), out_axes=0)    # vectorize over 'batch' dimension
@qnode(device("lightning.qubit", wires=5), interface="jax")
def simple_quantum_model(inputs, weights):
    for i in range(5): RX(inputs[i], wires=i)       # encoding
    for i in range(1, 5): CRZ(weights[i], (i, 0))   # ansatz
    return expval(PauliZ(wires=0))                  # measurement

test_inputs = jax.random.normal(jax.random.key(0), shape=(10_000, 5), dtype=jnp.float32)
test_weights = jax.random.normal(jax.random.key(0), shape=(5,), dtype=jnp.float32)

print(simple_quantum_model(test_inputs, test_weights).shape)

The above was tried with simulators: “lightning.qubit”, “lightning.gpu”, “lightning.kokkos” (kokkos built with CUDA support). Regardless of the device, the result of the 4th pass “BufferizationPass” (before the final “MLIRToLLVMDialect”) looks like the following:

func.func public @jit_vmap.simple_quantum_model(%arg0: memref<10000x5xf32>, %arg1: memref<5xf32>) -> memref<10000xf64> attributes {llvm.copy_memref, llvm.emit_c_interface} {
  ... // allocations, copy
  %5 = scf.for %arg2 = %c0 to %c10000 step %c1 iter_args(%arg3 = %alloc_20) -> (memref<10000xf64>) {
    ... // indexing, selection
    %15 = func.call @simple_quantum_model_0(%alloc_17, %arg1) : (memref<5xf32>, memref<5xf32>) -> memref<f64>
    ... // more indexing, copy, selection
    scf.yield %alloc_22 : memref<10000xf64> // final yield
  }
  ...// deallocation, copy
  return %10 : memref<10000xf64>
}

Which is the equivalent of a for loop on the circuit.
In other words, the mapped axis (batch) isn’t pushed down into primitive operations, as one would expect from vmap.

Sidenote: The above was also tried with jax.jit and jax.vmap, which also results in no parallelism and is also orders of magnitude slower to execute.

Would this be intentional? If so, could this be raised as a feature request?
Could you suggest some workarounds for parallel circuit execution on GPU?

qml.about()

Name: PennyLane
Version: 0.40.0
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page: [link]
Author:
Author-email:
License: Apache License 2.0
Location: /home/[user]/.local/lib/python3.11/site-packages
Requires: appdirs, autograd, autoray, cachetools, diastatic-malt, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, tomlkit, typing-extensions
Required-by: amazon-braket-pennylane-plugin, PennyLane-Catalyst, pennylane-qrack, pennylane-qulacs, PennyLane_Lightning, PennyLane_Lightning_GPU, PennyLane_Lightning_Kokkos, PennyLane_Lightning_Tensor

Platform info: Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.35
Python version: 3.11.11
Numpy version: 1.26.4
Scipy version: 1.13.0
Installed devices:

  • default.clifford (PennyLane-0.40.0)
  • default.gaussian (PennyLane-0.40.0)
  • default.mixed (PennyLane-0.40.0)
  • default.qubit (PennyLane-0.40.0)
  • default.qutrit (PennyLane-0.40.0)
  • default.qutrit.mixed (PennyLane-0.40.0)
  • default.tensor (PennyLane-0.40.0)
  • null.qubit (PennyLane-0.40.0)
  • reference.qubit (PennyLane-0.40.0)
  • braket.aws.ahs (amazon-braket-pennylane-plugin-1.31.2)
  • braket.aws.qubit (amazon-braket-pennylane-plugin-1.31.2)
  • braket.local.ahs (amazon-braket-pennylane-plugin-1.31.2)
  • braket.local.qubit (amazon-braket-pennylane-plugin-1.31.2)
  • nvidia.custatevec (PennyLane-Catalyst-0.11.0.dev15)
  • nvidia.cutensornet (PennyLane-Catalyst-0.11.0.dev15)
  • oqc.cloud (PennyLane-Catalyst-0.11.0.dev15)
  • softwareq.qpp (PennyLane-Catalyst-0.11.0.dev15)
  • lightning.gpu (PennyLane_Lightning_GPU-0.40.0)
  • lightning.tensor (PennyLane_Lightning_Tensor-0.40.0)
  • lightning.qubit (PennyLane_Lightning-0.41.0.dev7)
  • lightning.kokkos (PennyLane_Lightning_Kokkos-0.41.0.dev7)
  • qulacs.simulator (pennylane-qulacs-0.40.0)
  • qrack.simulator (pennylane-qrack-0.12.0)

PS: I’m aware there are similar threads, however I could not find one for vmap vectorization.

Hi @LasradoRohan, thank you for sharing your use case!

You’re right, catalyst.vmap does not push the batch dimensions into primitives, because the quantum device interface was not designed with this in mind (for instance, how would a QPU vectorize a single gate, with no input/output arrays to speak of?). On the other hand, classical primitives from the NumPy api are designed to handle additional input/output dimensions on arrays directly in their kernels).
Instead, we created catalyst.vmap with the goal to enable easy batching of functions, using the familiar api from jax.vmap. While we don’t vectorize primitive kernels, the compiled for loop should still be faster than launching those 10,000 function executions from Python. If this is not the case, please let us know and we’ll look into it!

Sidenote: The above was also tried with jax.jit and jax.vmap , which also results in no parallelism

You mention parallelism, but vmap is not a parallelizing operation, it is a batching operation (explanation from the JAX developers). Maybe what you are looking for is jax.pmap? Unfortunately, pmap is currently unsupported with Catalyst, although you should be able to observe parallelism with it on pure JAX programs.

and is also orders of magnitude slower to execute.

What are you comparing against?

Would this be intentional? If so, could this be raised as a feature request?
Could you suggest some workarounds for parallel circuit execution on GPU?

Let me get some additional info on this and get back to you :slight_smile:

Hi @David_Ittah, thanks for the answer.

how would a QPU vectorize a single gate

Although it makes sense that a QPU would not be able to execute gates over multiple states in parallel, I believe state vector simulators that work directly with matrices could benefit from such a batching operation. e.g. the compiled for loop (which launches one kernel per batch index) leads to very low GPU utilization in my case.

but vmap is not a parallelizing operation

Thanks for clarifying this about vmap. Although, it doesn’t guarantee parallelization, but in an ML use case where fusing matrix operations is possible, it is widely used to achieve parallelism over a batch.

Maybe what you are looking for is jax.pmap?

jax.pmap distributes each batch index across multiple XLA devices. It requires that that batch dimension be smaller than or equal to the number of XLA devices. However, my issue is with underutilization of a single device.

Currently, for example, the only apparent option to increase device utilization is to double (triple, etc.) the number of qubits and run the two sections of the same circuit in parallel, obviously wasting computation for the sake of parallelism.

What are you comparing against?

Just a sidenote that executing qjit(vmap(qnode))(inputs) is much faster than jax.jit(jax.vmap(qnode))(inputs), probably due to the compiled circuit.

Let me get some additional info on this and get back to you :slight_smile:

Yes, more information on what I could do to achieve this parallelism would be game changing for me.
Thanks :slight_smile:

That’s true, I believe that is what would happen if you use JAX directly with the default.qubit device, because it is written in Python using JAX instructions, it should be able to vectorize the simulation of each individual gate. If you haven’t tried this configuration yet this might be what you are looking for :slight_smile:

Unfortunately, because device internals (living in pre-compiled libraries) are opaque to Catalyst , we cannot vectorize operations like that even if they are simulators.

However, my issue is with underutilization of a single device.

Got it, thanks for clarifying the use case!