Hi,
My use case requires me to execute the same simple parameterized circuit over few qubits (~9) over a batch of inputs (~10,000 unique inputs to the circuit for a minimal POC).
Naturally, I have tried to achieve this as follows, by wrapping the qnode in a vmap.
from pennylane import qnode, device, RX, CRZ, PauliZ, expval
from catalyst import vmap, qjit
import jax
import jax.numpy as jnp
@qjit(keep_intermediate=True, verbose=True) # inspect lowering process
@vmap(in_axes=(0, None), out_axes=0) # vectorize over 'batch' dimension
@qnode(device("lightning.qubit", wires=5), interface="jax")
def simple_quantum_model(inputs, weights):
for i in range(5): RX(inputs[i], wires=i) # encoding
for i in range(1, 5): CRZ(weights[i], (i, 0)) # ansatz
return expval(PauliZ(wires=0)) # measurement
test_inputs = jax.random.normal(jax.random.key(0), shape=(10_000, 5), dtype=jnp.float32)
test_weights = jax.random.normal(jax.random.key(0), shape=(5,), dtype=jnp.float32)
print(simple_quantum_model(test_inputs, test_weights).shape)
The above was tried with simulators: “lightning.qubit”, “lightning.gpu”, “lightning.kokkos” (kokkos built with CUDA support). Regardless of the device, the result of the 4th pass “BufferizationPass” (before the final “MLIRToLLVMDialect”) looks like the following:
func.func public @jit_vmap.simple_quantum_model(%arg0: memref<10000x5xf32>, %arg1: memref<5xf32>) -> memref<10000xf64> attributes {llvm.copy_memref, llvm.emit_c_interface} {
... // allocations, copy
%5 = scf.for %arg2 = %c0 to %c10000 step %c1 iter_args(%arg3 = %alloc_20) -> (memref<10000xf64>) {
... // indexing, selection
%15 = func.call @simple_quantum_model_0(%alloc_17, %arg1) : (memref<5xf32>, memref<5xf32>) -> memref<f64>
... // more indexing, copy, selection
scf.yield %alloc_22 : memref<10000xf64> // final yield
}
...// deallocation, copy
return %10 : memref<10000xf64>
}
Which is the equivalent of a for loop on the circuit.
In other words, the mapped axis (batch) isn’t pushed down into primitive operations, as one would expect from vmap
.
Sidenote: The above was also tried with jax.jit
and jax.vmap
, which also results in no parallelism and is also orders of magnitude slower to execute.
Would this be intentional? If so, could this be raised as a feature request?
Could you suggest some workarounds for parallel circuit execution on GPU?
qml.about()
Name: PennyLane
Version: 0.40.0
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page: [link]
Author:
Author-email:
License: Apache License 2.0
Location: /home/[user]/.local/lib/python3.11/site-packages
Requires: appdirs, autograd, autoray, cachetools, diastatic-malt, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, tomlkit, typing-extensions
Required-by: amazon-braket-pennylane-plugin, PennyLane-Catalyst, pennylane-qrack, pennylane-qulacs, PennyLane_Lightning, PennyLane_Lightning_GPU, PennyLane_Lightning_Kokkos, PennyLane_Lightning_Tensor
Platform info: Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.35
Python version: 3.11.11
Numpy version: 1.26.4
Scipy version: 1.13.0
Installed devices:
- default.clifford (PennyLane-0.40.0)
- default.gaussian (PennyLane-0.40.0)
- default.mixed (PennyLane-0.40.0)
- default.qubit (PennyLane-0.40.0)
- default.qutrit (PennyLane-0.40.0)
- default.qutrit.mixed (PennyLane-0.40.0)
- default.tensor (PennyLane-0.40.0)
- null.qubit (PennyLane-0.40.0)
- reference.qubit (PennyLane-0.40.0)
- braket.aws.ahs (amazon-braket-pennylane-plugin-1.31.2)
- braket.aws.qubit (amazon-braket-pennylane-plugin-1.31.2)
- braket.local.ahs (amazon-braket-pennylane-plugin-1.31.2)
- braket.local.qubit (amazon-braket-pennylane-plugin-1.31.2)
- nvidia.custatevec (PennyLane-Catalyst-0.11.0.dev15)
- nvidia.cutensornet (PennyLane-Catalyst-0.11.0.dev15)
- oqc.cloud (PennyLane-Catalyst-0.11.0.dev15)
- softwareq.qpp (PennyLane-Catalyst-0.11.0.dev15)
- lightning.gpu (PennyLane_Lightning_GPU-0.40.0)
- lightning.tensor (PennyLane_Lightning_Tensor-0.40.0)
- lightning.qubit (PennyLane_Lightning-0.41.0.dev7)
- lightning.kokkos (PennyLane_Lightning_Kokkos-0.41.0.dev7)
- qulacs.simulator (pennylane-qulacs-0.40.0)
- qrack.simulator (pennylane-qrack-0.12.0)
PS: I’m aware there are similar threads, however I could not find one for vmap
vectorization.