Parallelization of circuit executions

I’m interested in quantum supervised machine learnings such as QSVM.
The QSVM takes a lot of time for training because N*(N-1)/2 individual circuits exectutions are required.
Kernel-based training of quantum models with scikit-learn — PennyLane
Assume the number of traning data is 1,000, the number of circuit execution is as large as 500,000.
Therefore, the training step takes huge time even if I use Pennylane.

My question is, can I parallelize circuit executions?
In case of SV1 remote simulator in AWS, I confirmed that parallerization is possible. Because SV1 natively supports parallelization.

But I want to parallelize the circuit executions for local simulators.
I have tried QNodeCollection class with dask, however, no acceleration was obsereved as follows,

pennylane is v0.20.0

import pennylane as qml
from pennylane import numpy as np
n_qubits = 20
dev = qml.device('lightning.qubit', wires=n_qubits)
def circuit():
    for i in range(n_qubits):
    return qml.expval(qml.PauliZ(0))

# one circuit execution
%time circuit() 

Wall time: 220 ms

Num_of_circuits = 2
# construct qnodes for two circuit executions
circuits = qml.QNodeCollection([circuit]*Num_of_circuits)
# two circuit executions
%time circuits(parallel=True)

Wall time: 592 ms

The above time is almost two times larger than that of one circuit exectuion…

Is there a technique for parallelization?
This may be python-help rather than pennylane-help…

Hey @Kuma-quant, you can use JAX and the various functionalities it provides to speed up your simulations significantly. This demo might help.

Hi @Kuma-quant thanks for the question on this.
When you say parallelization here, I assume you mean parallelization over data — ie different inputs to the same circuits. If you intend on doing model parallelism (ie distributed parts of the same computation over multiple cores), it may be best to explore the use of the PyTorch or TensorFlow interfaces. Also, given the example has many iterative steps, it will not in general be possible to parallelize between iterations. That being said, if parallel circuit evaluations are needed, we can explore this as follows:

  1. Firstly, the QNodeCollection Dask support remains experimental. The existing Dask support offloads operations as threads to the chosen device. This works fine in the provided example from as the chosen devices exist on separate processes (Rigetti QVMs as listed). However, since we are using the device within the same process here, we may hit the usual threading problem that Python users tend to hit: the GIL. It may be possible to direct the Dask backend to use processes rather than threads, though the recommended approach from the Dask docs is to favour the distributed backend for these problems ( We are currently working on a native Dask-distributed capable backend that will allow parallelization across multiple machines. However, this is still under active development.
  2. It may be possible to offload tasks using the Python multiprocessing module, and I think this may be the preferred way to go here if you wish to carry out multiple executions on your local machine. I have adapted your supplied script to examine a few cases with some heavier circuits: (i) a single execution of a parametric circuit, (ii) serial execution using the same QNodeCollection you have examined, and (iii) offloading the operations to a multiprocessing Pool.
import pennylane as qml
from pennylane import numpy as np
n_qubits = 20
from multiprocessing import Pool, Lock
from timeit import default_timer as timer

dev = qml.device('lightning.qubit', wires=n_qubits)
x_params, y_params = np.random.rand(n_qubits), np.random.rand(n_qubits)
Num_of_circuits = 20

# timing results data
timings = {}

def circuit1(x, y):
    for i in range(n_qubits):
        qml.RX(x[i], wires=i)
        qml.RY(y[i], wires=i)
        qml.CNOT(wires=[i, (i+1)%n_qubits])
    return [qml.var(qml.PauliZ(i)) for i in range(n_qubits)]

# Used for async offloading as decorators do not play nicely with multiproc async
def process_f(x,y):
    return circuit1(x,y)

# Callback from async operation to add results to array
def collect_result(result):

# Data for Multiprocessing env Pool and Lock must be created below required functions
results_mp = []
pool = Pool(4) # assumes 4 available physical cores
lock = Lock()

# one circuit execution
    start_1 = timer()
    circuit1(x_params, y_params) 
    end_1 = timer()
    timings["single"] = end_1 - start_1
except Exception as e:
    print(f"Failed to run single execution: {e}")

# construct qnodes for Num_of_circuits circuit executions
circuits = qml.QNodeCollection([circuit1]*Num_of_circuits)

# Explicitly time serial execution
    start_s = timer()
    results_s = circuits(x_params, y_params)
    end_s = timer()
    timings["serial"] = end_s - start_s
except Exception as e:
    print(f"Failed to run serial code: {e}")

# time parallel execution using multiprocessing pool
    start_mp = timer()
    futures = [pool.apply_async(process_f, args=(x_params,y_params), callback=collect_result) for x in range(Num_of_circuits)]
    # Synchronize when finished
    end_mp = timer()
    timings["multiproc"] = end_mp - start_mp
except Exception as e:
    print(f"Failed to run multiprocessing code: {e}")

print(f"Elapsed times: {timings}")

On my machine, I have >4 cores available, so my execution gives:
Elapsed times: {'single': 0.6598089630133472, 'serial': 12.97552589897532, 'multiproc': 4.972487455990631}
which is roughly a 2.5x speedup, over the serial execution (I made sure to attempt timing the overheads involved also). However, there can be many sharp-edges with this type of workflow, and you may have limited success (see for an example of a race-condition bug in the Pool). Though, it may help you with your example.

Feel free to follow-up if the above does not help with your problem, or if it does, we’d be happy to know.

Thanks, @ankit27kh and @mlxd !!

I’m trying two methods.

  1. multiprocessing module
  2. jax and the jax.jit

Rewriting my code for multiprocessing seems a tough work for me…

On the other hand, Rewriting my code for jax was easy for me.

Here is a part of sample code.
Then I got siginificant speed up by jax!

X_train_jnp = device_put(X_train)
y_train_jnp = device_put(y_train)

@qml.qnode(dev, interface="jax")
def kernel(x1, x2):
    """The quantum kernel."""
    for i in range(n_qubits):
        qml.RX(x1[i%n_dim], wires=[i]) 
    for i in range(n_qubits):
    for i in range(n_qubits-1,-1,-1):
    for i in range(n_qubits-1,-1,-1):
        qml.RX(-1*x2[i%n_dim], wires=[i]) 
    return qml.expval(qml.Hermitian(projector, wires=range(n_qubits))) 

jit_kernel = jax.jit(kernel)

def kernel_matrix(A, B):
    """Compute the matrix whose entries are the kernel
       evaluated on pairwise data from sets A and B."""
    return np.array([[jit_kernel(a, b) for b in B] for a in A])

%time svm = SVC(kernel=kernel_matrix).fit(X_train_jnp, y_train_jnp)

The processing time w/o jit was 41 sec.
The processing time w/ jit was 21.8 msec.


The processing time w/o jit but w/ lightning.qubit was ~4 sec.

Therefore, jax with jax.jit was the fastest method!

Unfortunately, the acceleration I have observed would be just “cache”…

Up to 10 qubits, I can’t find improvement by jax compared with lightning.qubit although jax.jit and jax.vmap was also used.
In case of >12 qubits, jax consumed almost all of RAM in my PC.
I beleieve that acceleration by jax will be realized if I use more resources such as multiple-GPU with jax.pmap.

I want to test multiprocessing.

Hi @Kuma-quant thanks for the feedback. Mapping code to run in parallel can always be fraught with some challenges. In this case, JAX is fantastic when we can reuse the JIT compiled code (as an example, with calculating gradients the circuit can be re-executed multiple times, with different provided parameters). However, when opting for a single execution, it may not work as well, since the JIT process only compiles with the first run of the code (so you factor in compilation AND runtime).

Though, I am surprised to hear the JAX engine is eating all available RAM. I wonder if this is the compilation process of JAX, or something on our side. Can you provide an example of your code that caused this to happen? Also, any details on your working machine (OS, CPU, available RAM) would be great.

If you have access to a CUDA-capable GPU, you can always make use of the TF or Torch GPU support to run your circuit. While the use here is more for backpropagation gradient calculations of the circuit, it may be useful to try it out and see if it helps (

As for mapping the problem to multiprocessing, we are currently putting together a device that should allow this to be more seamless.