Parallelization of circuit executions

I’m interested in quantum supervised machine learnings such as QSVM.
The QSVM takes a lot of time for training because N*(N-1)/2 individual circuits exectutions are required.
Kernel-based training of quantum models with scikit-learn — PennyLane
Assume the number of traning data is 1,000, the number of circuit execution is as large as 500,000.
Therefore, the training step takes huge time even if I use Pennylane.

My question is, can I parallelize circuit executions?
In case of SV1 remote simulator in AWS, I confirmed that parallerization is possible. Because SV1 natively supports parallelization.

But I want to parallelize the circuit executions for local simulators.
I have tried QNodeCollection class with dask, however, no acceleration was obsereved as follows,

pennylane is v0.20.0

import pennylane as qml
from pennylane import numpy as np
n_qubits = 20
dev = qml.device('lightning.qubit', wires=n_qubits)
@qml.qnode(dev)
def circuit():
    for i in range(n_qubits):
        qml.Hadamard(wires=i)
    return qml.expval(qml.PauliZ(0))

# one circuit execution
%time circuit() 

Wall time: 220 ms

Num_of_circuits = 2
# construct qnodes for two circuit executions
circuits = qml.QNodeCollection([circuit]*Num_of_circuits)
# two circuit executions
%time circuits(parallel=True)

Wall time: 592 ms

The above time is almost two times larger than that of one circuit exectuion…

Is there a technique for parallelization?
This may be python-help rather than pennylane-help…

Hey @Kuma-quant, you can use JAX and the various functionalities it provides to speed up your simulations significantly. This demo might help.

Hi @Kuma-quant thanks for the question on this.
When you say parallelization here, I assume you mean parallelization over data — ie different inputs to the same circuits. If you intend on doing model parallelism (ie distributed parts of the same computation over multiple cores), it may be best to explore the use of the PyTorch or TensorFlow interfaces. Also, given the example has many iterative steps, it will not in general be possible to parallelize between iterations. That being said, if parallel circuit evaluations are needed, we can explore this as follows:

  1. Firstly, the QNodeCollection Dask support remains experimental. The existing Dask support offloads operations as threads to the chosen device. This works fine in the provided example from https://pennylane.readthedocs.io/en/stable/code/api/pennylane.QNodeCollection.html as the chosen devices exist on separate processes (Rigetti QVMs as listed). However, since we are using the device within the same process here, we may hit the usual threading problem that Python users tend to hit: the GIL. It may be possible to direct the Dask backend to use processes rather than threads, though the recommended approach from the Dask docs is to favour the distributed backend for these problems (https://docs.dask.org/en/latest/how-to/deploy-dask/single-machine.html#use-the-distributed-scheduler). We are currently working on a native Dask-distributed capable backend that will allow parallelization across multiple machines. However, this is still under active development.
  2. It may be possible to offload tasks using the Python multiprocessing module, and I think this may be the preferred way to go here if you wish to carry out multiple executions on your local machine. I have adapted your supplied script to examine a few cases with some heavier circuits: (i) a single execution of a parametric circuit, (ii) serial execution using the same QNodeCollection you have examined, and (iii) offloading the operations to a multiprocessing Pool.
import pennylane as qml
from pennylane import numpy as np
n_qubits = 20
from multiprocessing import Pool, Lock
from timeit import default_timer as timer

dev = qml.device('lightning.qubit', wires=n_qubits)
x_params, y_params = np.random.rand(n_qubits), np.random.rand(n_qubits)
Num_of_circuits = 20

# timing results data
timings = {}

@qml.qnode(dev)
def circuit1(x, y):
    for i in range(n_qubits):
        qml.Hadamard(wires=i)
        qml.RX(x[i], wires=i)
        qml.RY(y[i], wires=i)
        qml.CNOT(wires=[i, (i+1)%n_qubits])
    return [qml.var(qml.PauliZ(i)) for i in range(n_qubits)]

# Used for async offloading as decorators do not play nicely with multiproc async
def process_f(x,y):
    return circuit1(x,y)

# Callback from async operation to add results to array
def collect_result(result):
    lock.acquire()
    try:
        results_mp.append(result)
    finally:
        lock.release()

# Data for Multiprocessing env Pool and Lock must be created below required functions
results_mp = []
pool = Pool(4) # assumes 4 available physical cores
lock = Lock()

# one circuit execution
try:
    start_1 = timer()
    circuit1(x_params, y_params) 
    end_1 = timer()
    timings["single"] = end_1 - start_1
except Exception as e:
    print(f"Failed to run single execution: {e}")

# construct qnodes for Num_of_circuits circuit executions
circuits = qml.QNodeCollection([circuit1]*Num_of_circuits)

# Explicitly time serial execution
try:
    start_s = timer()
    results_s = circuits(x_params, y_params)
    end_s = timer()
    timings["serial"] = end_s - start_s
except Exception as e:
    print(f"Failed to run serial code: {e}")

# time parallel execution using multiprocessing pool
try:
    start_mp = timer()
    futures = [pool.apply_async(process_f, args=(x_params,y_params), callback=collect_result) for x in range(Num_of_circuits)]
    # Synchronize when finished
    pool.close()
    pool.join()
    end_mp = timer()
    timings["multiproc"] = end_mp - start_mp
except Exception as e:
    print(f"Failed to run multiprocessing code: {e}")

print(f"Elapsed times: {timings}")

On my machine, I have >4 cores available, so my execution gives:
Elapsed times: {'single': 0.6598089630133472, 'serial': 12.97552589897532, 'multiproc': 4.972487455990631}
which is roughly a 2.5x speedup, over the serial execution (I made sure to attempt timing the overheads involved also). However, there can be many sharp-edges with this type of workflow, and you may have limited success (see https://bugs.python.org/issue25053 for an example of a race-condition bug in the Pool). Though, it may help you with your example.

Feel free to follow-up if the above does not help with your problem, or if it does, we’d be happy to know.

Thanks, @ankit27kh and @mlxd !!

I’m trying two methods.

  1. multiprocessing module
  2. jax and the jax.jit

Rewriting my code for multiprocessing seems a tough work for me…

On the other hand, Rewriting my code for jax was easy for me.

Here is a part of sample code.
Then I got siginificant speed up by jax!


X_train_jnp = device_put(X_train)
y_train_jnp = device_put(y_train)

@qml.qnode(dev, interface="jax")
def kernel(x1, x2):
    """The quantum kernel."""
    #S(x=x1)
    for i in range(n_qubits):
        qml.RX(x1[i%n_dim], wires=[i]) 
    for i in range(n_qubits):
        qml.CNOT(wires=[i,(i+1)%n_qubits])
    
    #S^dagger(x=x2)
    for i in range(n_qubits-1,-1,-1):
        qml.CNOT(wires=[i,(i+1)%n_qubits])
    for i in range(n_qubits-1,-1,-1):
        qml.RX(-1*x2[i%n_dim], wires=[i]) 
    return qml.expval(qml.Hermitian(projector, wires=range(n_qubits))) 

jit_kernel = jax.jit(kernel)

def kernel_matrix(A, B):
    """Compute the matrix whose entries are the kernel
       evaluated on pairwise data from sets A and B."""
    return np.array([[jit_kernel(a, b) for b in B] for a in A])

%time svm = SVC(kernel=kernel_matrix).fit(X_train_jnp, y_train_jnp)

The processing time w/o jit was 41 sec.
The processing time w/ jit was 21.8 msec.

Amazing!

Moreover,
The processing time w/o jit but w/ lightning.qubit was ~4 sec.

Therefore, jax with jax.jit was the fastest method!

Unfortunately, the acceleration I have observed would be just “cache”…

Up to 10 qubits, I can’t find improvement by jax compared with lightning.qubit although jax.jit and jax.vmap was also used.
In case of >12 qubits, jax consumed almost all of RAM in my PC.
I beleieve that acceleration by jax will be realized if I use more resources such as multiple-GPU with jax.pmap.

I want to test multiprocessing.

Hi @Kuma-quant thanks for the feedback. Mapping code to run in parallel can always be fraught with some challenges. In this case, JAX is fantastic when we can reuse the JIT compiled code (as an example, with calculating gradients the circuit can be re-executed multiple times, with different provided parameters). However, when opting for a single execution, it may not work as well, since the JIT process only compiles with the first run of the code (so you factor in compilation AND runtime).

Though, I am surprised to hear the JAX engine is eating all available RAM. I wonder if this is the compilation process of JAX, or something on our side. Can you provide an example of your code that caused this to happen? Also, any details on your working machine (OS, CPU, available RAM) would be great.

If you have access to a CUDA-capable GPU, you can always make use of the TF or Torch GPU support to run your circuit. While the use here is more for backpropagation gradient calculations of the circuit, it may be useful to try it out and see if it helps (https://pennylane.readthedocs.io/en/stable/introduction/interfaces/torch.html#gpu-and-cuda-support).

As for mapping the problem to multiprocessing, we are currently putting together a device that should allow this to be more seamless.

@mlxd
I found a bug in my code and fixed it.
Now jax shows siginificant speed up!

I have evaluated 10^6 circuits with 10 qubits.

pennylane w/lightning.qubit -> 20~30 sec.  
pennylane w/jax.jit -> ~10 sec.  
pennylane w/jax.jit and jax.vmap -> ~2 sec.  
qiskit_machine_learning -> ~2 sec.

Siginificant speed up was achieved by jax.jit and jax.vmap.
I also found that qiskit_machine_learning is also very fast.
(I have used “Quantum Kernel” class in qiskit_machine_learning.)
I don’t know the reason…

Thanks all!!

Hi @Kuma-quant, I’m glad you could fix your problem!

It’s interesting to see the speed comparison that you made.

Thanks for sharing it!

We will take a look and hopefully we can understand why this is happening.

Hi @Kuma-quant.

We’re looking into why this is happening and it would be very helpful for us if you could share your full code. You should be able to share it here as a Python file.

Please let me know if you have any questions or issues sharing your code.

Hi @CatalinaAlbornoz -san,

Here are python codes to compare pennylane+lightning, pennylane+jax.jit, pennylane+jax.jit+jax.vmap, and qiskit-ml.
qiskit_machine_learning should be installed in advance.

The assumed use-case is QSVM with quantum kernel.
The number of circuits for kernel evaluation is 102x102 = 10^4.

[Main results]
pennylane+lightning : 26.5 sec.
pennylane+jax.jit : 12.4 sec.
pennylane+jax.jit+jax.vmap: 2.13 sec.
qiskit-ml : 1.34sec.

Sorry. My codes would not be beautiful.

(1) Pennylane+lightning

import pennylane as qml
from pennylane import numpy as np
from sklearn.svm import SVC
from sklearn.datasets import make_classification
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

np.random.seed(1)

X, y = make_classification(n_features=2, n_redundant=0, n_informative=1, n_clusters_per_class=1, n_samples=1024)

# scaling the inputs is important since the embedding we use is periodic
scaler = StandardScaler().fit(X)
X_scaled = scaler.transform(X)

# scaling the labels to -1, 1 is important for the SVM and the
# definition of a hinge loss
y_scaled = 2 * (y - 0.5)
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y_scaled, train_size=0.1)

n_qubits = 10
dev = qml.device("lightning.qubit", wires=n_qubits)
n_dim = len(X_train[0])

@qml.qnode(dev)
def kernel(x1, x2):
    """The quantum kernel."""
    #S(x=x1)
    for i in range(n_qubits):
        qml.Hadamard(wires=[i])
        qml.RZ(x1[i%n_dim], wires=[i])     
    #S^dagger(x=x2)
    for i in range(n_qubits-1,-1,-1):
        qml.RZ(-1*x2[i%n_dim], wires=[i]) 
        qml.Hadamard(wires=[i])
    return qml.probs(wires=range(n_qubits))

def kernel_matrix(A, B):
    """Compute the matrix whose entries are the kernel
       evaluated on pairwise data from sets A and B."""
    return np.array([[kernel(a, b)[0] for b in B] for a in A])

print(np.shape(X_train)[0],'x',np.shape(X_train)[0],'kernel_matrix','with', n_qubits,'qubits circuits')
%time kernel_matrix(X_train,X_train)

102 x 102 kernel_matrix with 10 qubits circuits CPU times: user 26.4 s, sys: 78.1 ms, total: 26.5 s Wall time: 26.5 s

(2) Pennylane+jax.jit
Note: jax might not work well on Windows.

# Added to silence some warnings.
from jax.config import config
config.update("jax_enable_x64", True)

import jax
from jax import device_put, jit
import pennylane as qml
from jax import numpy as np

from sklearn.svm import SVC
import numpy as vnp
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_classification
import matplotlib.pyplot as plt

vnp.random.seed(1)

X, y = make_classification(n_features=2, n_redundant=0, n_informative=1, n_clusters_per_class=1, n_samples=1024)

# scaling the inputs is important since the embedding we use is periodic
scaler = StandardScaler().fit(X)
X_scaled = scaler.transform(X)

# scaling the labels to -1, 1 is important for the SVM and the
# definition of a hinge loss
y_scaled = 2 * (y - 0.5)
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y_scaled, train_size=0.1)



n_qubits = 10
dev = qml.device("default.qubit", wires=n_qubits)
n_dim = len(X_train[0])

X_train = device_put(X_train)
y_train = device_put(y_train)

@qml.qnode(dev, interface="jax")
def kernel(x1, x2):
    """The quantum kernel."""
    #S(x=x1)
    for i in range(n_qubits):
        qml.Hadamard(wires=[i])
        qml.RZ(x1[i%n_dim], wires=[i]) 
    
    #S^dagger(x=x2)
    for i in range(n_qubits-1,-1,-1):
        qml.RZ(-1*x2[i%n_dim], wires=[i]) 
        qml.Hadamard(wires=[i])
    return qml.probs(wires=range(n_qubits))

jit_kernel = jax.jit(kernel)

def kernel_matrix(A, B):
    """Compute the matrix whose entries are the kernel
       evaluated on pairwise data from sets A and B."""
    return np.array([[jit_kernel(a, b).block_until_ready()[0] for b in B] for a in A])

print(np.shape(X_train)[0],'x',np.shape(X_train)[0],'kernel_matrix','with', n_qubits,'qubits circuits')
%time kernel_matrix(X_train,X_train)

102 x 102 kernel_matrix with 10 qubits circuits CPU times: user 12 s, sys: 219 ms, total: 12.3 s Wall time: 12.4 s

(3) Pennylane+jax.jit+jax.vmap

# Added to silence some warnings.
from jax.config import config
config.update("jax_enable_x64", True)

import jax
from jax import device_put, jit
import pennylane as qml
from jax import numpy as np

from sklearn.svm import SVC
import numpy as vnp
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_classification
import matplotlib.pyplot as plt

vnp.random.seed(1)

X, y = make_classification(n_features=2, n_redundant=0, n_informative=1, n_clusters_per_class=1, n_samples=1024)

# scaling the inputs is important since the embedding we use is periodic
scaler = StandardScaler().fit(X)
X_scaled = scaler.transform(X)

# scaling the labels to -1, 1 is important for the SVM and the
# definition of a hinge loss
y_scaled = 2 * (y - 0.5)
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y_scaled, train_size=0.1)



n_qubits = 10
dev = qml.device("default.qubit", wires=n_qubits)
n_dim = len(X_train[0])

@qml.qnode(dev, interface="jax")
def kernel(x1, x2):
    """The quantum kernel."""
    #S(x=x1)
    for i in range(n_qubits):
        qml.Hadamard(wires=[i])
        qml.RZ(x1[i%n_dim], wires=[i]) 
    
    #S^dagger(x=x2)
    for i in range(n_qubits-1,-1,-1):
        qml.RZ(-1*x2[i%n_dim], wires=[i]) #S(x)
        qml.Hadamard(wires=[i])
    return qml.probs(wires=range(n_qubits))

vectorized_kernel = jax.vmap(kernel)
jit_vectorized_kernel = jax.jit(vectorized_kernel)

# batching
result_0 = []
result_1 = []
for i in range(np.shape(X_train)[0]):
    for k in range(np.shape(X_train)[0]):
        result_0.append(X_train[i,:])
        result_1.append(X_train[k,:])
x0_batch = np.array(result_0)
x1_batch = np.array(result_1)

print(np.shape(X_train)[0],'x',np.shape(X_train)[0],'kernel_matrix','with', n_qubits,'qubits circuits')
%time my_kernel_matrix = jit_vectorized_kernel(x0_batch,x1_batch).block_until_ready()

102 x 102 kernel_matrix with 10 qubits circuits CPU times: user 4.47 s, sys: 1.14 s, total: 5.61 s Wall time: 2.13 s

(4) qiskit_machine_learning

from sklearn.svm import SVC
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_classification
import matplotlib.pyplot as plt

X, y = make_classification(n_features=2, n_redundant=0, n_informative=1, n_clusters_per_class=1, n_samples=1024)

# scaling the inputs is important since the embedding we use is periodic
scaler = StandardScaler().fit(X)
X_scaled = scaler.transform(X)

# scaling the labels to -1, 1 is important for the SVM and the
# definition of a hinge loss
y_scaled = 2 * (y - 0.5)
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y_scaled, train_size=0.1)



n_qubits = 10
n_dim = len(X_train[0])

from qiskit import BasicAer
from qiskit.circuit.library import ZZFeatureMap, ZFeatureMap, PauliFeatureMap
from qiskit.utils import QuantumInstance, algorithm_globals
from qiskit_machine_learning.algorithms import QSVC
from qiskit_machine_learning.kernels import QuantumKernel

n_reps = 1
input_data_dimension = n_dim
n_duplicate = int(n_qubits/input_data_dimension)
_X_train = np.tile(X_train, n_duplicate)


adhoc_feature_map = ZFeatureMap(feature_dimension=n_qubits, reps=1)
seed = 12345
adhoc_backend = QuantumInstance(BasicAer.get_backend('statevector_simulator'), shots=1, seed_simulator=seed, seed_transpiler=seed)
adhoc_kernel = QuantumKernel(feature_map=adhoc_feature_map, quantum_instance=adhoc_backend)
print(np.shape(X_train)[0],'x',np.shape(X_train)[0],'kernel_matrix','with', n_qubits,'qubits circuits')
%time adhoc_kernel.evaluate(_X_train,_X_train)

102 x 102 kernel_matrix with 10 qubits circuits CPU times: user 672 ms, sys: 266 ms, total: 938 ms Wall time: 1.34 s

I hope that these codes are helpful for pennylane lovers.

Hi @Kuma-quant, thank you very much for sharing your code!

Your code may be helpful for other PennyLane lovers and it might also help us make lightning faster.

Thank you again and keep enjoying PennyLane!