Hi,
My use case requires me to execute the same simple parameterized circuit over few qubits (~9) over a batch of inputs (~10,000 unique inputs to the circuit for a minimal POC).
Naturally, I have tried to achieve this as follows, by wrapping the qnode in a vmap.
from pennylane import qnode, device, RX, CRZ, PauliZ, expval
from catalyst import vmap, qjit
import jax
import jax.numpy as jnp
@qjit(keep_intermediate=True, verbose=True) # inspect lowering process
@vmap(in_axes=(0, None), out_axes=0) # vectorize over 'batch' dimension
@qnode(device("lightning.qubit", wires=5), interface="jax")
def simple_quantum_model(inputs, weights):
for i in range(5): RX(inputs[i], wires=i) # encoding
for i in range(1, 5): CRZ(weights[i], (i, 0)) # ansatz
return expval(PauliZ(wires=0)) # measurement
test_inputs = jax.random.normal(jax.random.key(0), shape=(10_000, 5), dtype=jnp.float32)
test_weights = jax.random.normal(jax.random.key(0), shape=(5,), dtype=jnp.float32)
print(simple_quantum_model(test_inputs, test_weights).shape)
The above was tried with simulators: “lightning.qubit”, “lightning.gpu”, “lightning.kokkos” (kokkos built with CUDA support). Regardless of the device, the result of the 4th pass “BufferizationPass” (before the final “MLIRToLLVMDialect”) looks like the following:
func.func public @jit_vmap.simple_quantum_model(%arg0: memref<10000x5xf32>, %arg1: memref<5xf32>) -> memref<10000xf64> attributes {llvm.copy_memref, llvm.emit_c_interface} {
... // allocations, copy
%5 = scf.for %arg2 = %c0 to %c10000 step %c1 iter_args(%arg3 = %alloc_20) -> (memref<10000xf64>) {
... // indexing, selection
%15 = func.call @simple_quantum_model_0(%alloc_17, %arg1) : (memref<5xf32>, memref<5xf32>) -> memref<f64>
... // more indexing, copy, selection
scf.yield %alloc_22 : memref<10000xf64> // final yield
}
...// deallocation, copy
return %10 : memref<10000xf64>
}
Which is the equivalent of a for loop on the circuit.
In other words, the mapped axis (batch) isn’t pushed down into primitive operations, as one would expect from vmap.
Sidenote: The above was also tried with jax.jit and jax.vmap, which also results in no parallelism and is also orders of magnitude slower to execute.
Would this be intentional? If so, could this be raised as a feature request?
Could you suggest some workarounds for parallel circuit execution on GPU?
qml.about()
Name: PennyLane
Version: 0.40.0
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page: [link]
Author:
Author-email:
License: Apache License 2.0
Location: /home/[user]/.local/lib/python3.11/site-packages
Requires: appdirs, autograd, autoray, cachetools, diastatic-malt, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, tomlkit, typing-extensions
Required-by: amazon-braket-pennylane-plugin, PennyLane-Catalyst, pennylane-qrack, pennylane-qulacs, PennyLane_Lightning, PennyLane_Lightning_GPU, PennyLane_Lightning_Kokkos, PennyLane_Lightning_Tensor
Platform info: Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.35
Python version: 3.11.11
Numpy version: 1.26.4
Scipy version: 1.13.0
Installed devices:
- default.clifford (PennyLane-0.40.0)
- default.gaussian (PennyLane-0.40.0)
- default.mixed (PennyLane-0.40.0)
- default.qubit (PennyLane-0.40.0)
- default.qutrit (PennyLane-0.40.0)
- default.qutrit.mixed (PennyLane-0.40.0)
- default.tensor (PennyLane-0.40.0)
- null.qubit (PennyLane-0.40.0)
- reference.qubit (PennyLane-0.40.0)
- braket.aws.ahs (amazon-braket-pennylane-plugin-1.31.2)
- braket.aws.qubit (amazon-braket-pennylane-plugin-1.31.2)
- braket.local.ahs (amazon-braket-pennylane-plugin-1.31.2)
- braket.local.qubit (amazon-braket-pennylane-plugin-1.31.2)
- nvidia.custatevec (PennyLane-Catalyst-0.11.0.dev15)
- nvidia.cutensornet (PennyLane-Catalyst-0.11.0.dev15)
- oqc.cloud (PennyLane-Catalyst-0.11.0.dev15)
- softwareq.qpp (PennyLane-Catalyst-0.11.0.dev15)
- lightning.gpu (PennyLane_Lightning_GPU-0.40.0)
- lightning.tensor (PennyLane_Lightning_Tensor-0.40.0)
- lightning.qubit (PennyLane_Lightning-0.41.0.dev7)
- lightning.kokkos (PennyLane_Lightning_Kokkos-0.41.0.dev7)
- qulacs.simulator (pennylane-qulacs-0.40.0)
- qrack.simulator (pennylane-qrack-0.12.0)
PS: I’m aware there are similar threads, however I could not find one for vmap vectorization.