Hi all, I am implementing a quantum machine learning model, using the lightning.kokkos
device, but I have encountered a memory leak similar to this issue from several months ago:
Memory not releasing after circuit run
My code is shown here:
import pennylane as qml
from pennylane import numpy as np
import torch
from matplotlib import pyplot as plt
from memory_profiler import profile
# we can use the dataset hosted on PennyLane
[pm] = qml.data.load('other', name='plus-minus')
X_train = pm.img_train
X_test = pm.img_test
Y_train = pm.labels_train
Y_test = pm.labels_test
x_vis = [
(X_train[Y_train == 0])[0],
(X_train[Y_train == 1])[0],
(X_train[Y_train == 2])[0],
(X_train[Y_train == 3])[0],
]
y_vis = [0, 1, 2, 3]
def visualize_data(x, y, pred=None):
n_img = len(x)
labels_list = ["\u2212", "\u002b", "\ua714", "\u02e7"]
fig, axes = plt.subplots(1, 4, figsize=(8, 2))
for i in range(n_img):
axes[i].imshow(x[i], cmap="gray")
if pred is None:
axes[i].set_title("Label: {}".format(labels_list[y[i]]))
else:
axes[i].set_title("Label: {}, Pred: {}".format(labels_list[y[i]], labels_list[pred[i]]))
plt.tight_layout(w_pad=2)
# plt.show()
visualize_data(x_vis, y_vis)
input_dim = 256
num_classes = 4
num_layers = 32
num_qubits = 8
num_reup = 3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
class QML_classifier(torch.nn.Module):
def __init__(self, input_dim, output_dim, num_qubits, num_layers):
super().__init__()
torch.manual_seed(1337)
self.num_qubits = num_qubits
self.output_dim = output_dim
self.num_layers = num_layers
self.device = qml.device("lightning.kokkos", wires=self.num_qubits)
self.weights_shape = qml.StronglyEntanglingLayers.shape(
n_layers=self.num_layers, n_wires=self.num_qubits
)
@qml.qnode(self.device)
def circuit(inputs, weights, bias):
inputs = torch.reshape(inputs, self.weights_shape)
qml.StronglyEntanglingLayers(
weights=weights * inputs + bias, wires=range(self.num_qubits)
)
return [qml.expval(qml.PauliZ(i)) for i in range(self.output_dim)]
param_shapes = {"weights": self.weights_shape, "bias": self.weights_shape}
init_vals = {
"weights": 0.1 * torch.rand(self.weights_shape),
"bias": 0.1 * torch.rand(self.weights_shape),
}
self.qcircuit = qml.qnn.TorchLayer(
qnode=circuit, weight_shapes=param_shapes, init_method=init_vals
)
@profile
def forward(self, x):
inputs_stack = torch.hstack([x] * num_reup)
results = self.qcircuit(inputs_stack)
return results
learning_rate = 0.1
epochs = 5
batch_size = 50
feats_train = torch.from_numpy(X_train[:200]).reshape(200, -1).to(device)
feats_test = torch.from_numpy(X_test[:50]).reshape(50, -1).to(device)
labels_train = torch.from_numpy(Y_train[:200]).to(device)
labels_test = torch.from_numpy(Y_test[:50]).to(device)
num_train = feats_train.shape[0]
# initialize the model, loss function and optimization algorithm (Adam optimizer)
qml_model = QML_classifier(input_dim, num_classes, num_qubits, num_layers)
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(qml_model.parameters(), lr=learning_rate)
num_batches = feats_train.shape[0] // batch_size
def accuracy(labels, predictions):
acc = 0
for l, p in zip(labels, predictions):
if torch.argmax(p) == l:
acc += 1
acc = acc / len(labels)
return acc
# generate randomly permutated batches to speed up training
def gen_batches(num_samples, num_batches):
assert num_samples % num_batches == 0
perm_ind = torch.reshape(torch.randperm(num_samples), (num_batches, -1))
return perm_ind
# display accuracy and loss after each epoch (to speed up runtime, only do this for first 100 samples)
def print_acc(epoch, max_ep=5):
predictions_train = [qml_model(f) for f in feats_train[:50]]
predictions_test = [qml_model(f) for f in feats_test]
cost_approx_train = loss(torch.stack(predictions_train), labels_train[:50])
cost_approx_test = loss(torch.stack(predictions_test), labels_test)
acc_approx_train = accuracy(labels_train[:50], predictions_train)
acc_approx_test = accuracy(labels_test, predictions_test)
print(
f"Epoch {epoch}/{max_ep} | Approx Cost (train): {cost_approx_train:0.7f} | Cost (val): {cost_approx_test:0.7f} |"
f" Approx Acc train: {acc_approx_train:0.7f} | Acc val: {acc_approx_test:0.7f}"
)
print(
f"Starting training loop for quantum variational classifier ({num_qubits} qubits, {num_layers} layers)..."
)
for ep in range(0, epochs):
batch_ind = gen_batches(num_train, num_batches)
print_acc(epoch=ep)
for it in range(num_batches):
optimizer.zero_grad()
feats_train_batch = feats_train[batch_ind[it]]
labels_train_batch = labels_train[batch_ind[it]]
outputs = [qml_model(f) for f in feats_train_batch]
batch_loss = loss(torch.stack(outputs), labels_train_batch)
# if REG:
# loss = loss + lipschitz_regularizer(regularization_rate, model.qcircuit.weights)
batch_loss.backward()
optimizer.step()
print_acc(epochs)
# show accuracy
x_vis_torch = torch.from_numpy(np.array(x_vis).reshape(4, -1))
y_vis_torch = torch.from_numpy(np.array(y_vis))
benign_preds = [qml_model(f) for f in x_vis_torch]
benign_class_output = [torch.argmax(p) for p in benign_preds]
visualize_data(x_vis, y_vis, benign_class_output)
Running this version of the code requires the installation of the python memory-profiler package:
https://pypi.org/project/memory-profiler/
The output that I get from memory-profiler is not exactly an error message, but it shows that more memory is allocated with certain to the node, and for some reason, most of the memory is not getting released after the forward call is complete. This output is from the first epoch of training:
Line # Mem usage Increment Occurrences Line Contents
=============================================================
93 689.2 MiB 689.2 MiB 1 @profile
94 def forward(self, x):
95 693.5 MiB 4.2 MiB 1 inputs_stack = torch.hstack([x] * num_reup)
96 738.0 MiB 44.5 MiB 1 results = self.qcircuit(inputs_stack)
97 738.0 MiB 0.0 MiB 1 return results
Filename: /home/owner/quantum_adv_demo.py
Line # Mem usage Increment Occurrences Line Contents
=============================================================
93 738.0 MiB 738.0 MiB 1 @profile
94 def forward(self, x):
95 738.0 MiB 0.0 MiB 1 inputs_stack = torch.hstack([x] * num_reup)
96 740.7 MiB 2.8 MiB 1 results = self.qcircuit(inputs_stack)
97 740.7 MiB 0.0 MiB 1 return results
Filename: /home/owner/quantum_adv_demo.py
Line # Mem usage Increment Occurrences Line Contents
=============================================================
93 740.7 MiB 740.7 MiB 1 @profile
94 def forward(self, x):
95 740.7 MiB 0.0 MiB 1 inputs_stack = torch.hstack([x] * num_reup)
96 743.5 MiB 2.8 MiB 1 results = self.qcircuit(inputs_stack)
97 743.5 MiB 0.0 MiB 1 return results
Filename: /home/owner/quantum_adv_demo.py
Line # Mem usage Increment Occurrences Line Contents
=============================================================
93 743.5 MiB 743.5 MiB 1 @profile
94 def forward(self, x):
95 743.5 MiB 0.0 MiB 1 inputs_stack = torch.hstack([x] * num_reup)
96 746.0 MiB 2.5 MiB 1 results = self.qcircuit(inputs_stack)
97 746.0 MiB 0.0 MiB 1 return results
Filename: /home/owner/quantum_adv_demo.py
Line # Mem usage Increment Occurrences Line Contents
=============================================================
93 746.0 MiB 746.0 MiB 1 @profile
94 def forward(self, x):
95 746.0 MiB 0.0 MiB 1 inputs_stack = torch.hstack([x] * num_reup)
96 748.7 MiB 2.8 MiB 1 results = self.qcircuit(inputs_stack)
97 748.7 MiB 0.0 MiB 1 return results
Filename: /home/owner/quantum_adv_demo.py
Line # Mem usage Increment Occurrences Line Contents
=============================================================
93 748.7 MiB 748.7 MiB 1 @profile
94 def forward(self, x):
95 748.7 MiB 0.0 MiB 1 inputs_stack = torch.hstack([x] * num_reup)
96 751.5 MiB 2.8 MiB 1 results = self.qcircuit(inputs_stack)
97 751.5 MiB 0.0 MiB 1 return results
Filename: /home/owner/quantum_adv_demo.py
Line # Mem usage Increment Occurrences Line Contents
=============================================================
93 751.5 MiB 751.5 MiB 1 @profile
94 def forward(self, x):
95 751.5 MiB 0.0 MiB 1 inputs_stack = torch.hstack([x] * num_reup)
96 754.0 MiB 2.5 MiB 1 results = self.qcircuit(inputs_stack)
97 754.0 MiB 0.0 MiB 1 return results
Filename: /home/owner/quantum_adv_demo.py
Line # Mem usage Increment Occurrences Line Contents
=============================================================
93 754.0 MiB 754.0 MiB 1 @profile
94 def forward(self, x):
95 754.0 MiB 0.0 MiB 1 inputs_stack = torch.hstack([x] * num_reup)
96 756.7 MiB 2.8 MiB 1 results = self.qcircuit(inputs_stack)
97 756.7 MiB 0.0 MiB 1 return results
Filename: /home/owner/quantum_adv_demo.py
Line # Mem usage Increment Occurrences Line Contents
=============================================================
93 756.7 MiB 756.7 MiB 1 @profile
94 def forward(self, x):
95 756.7 MiB 0.0 MiB 1 inputs_stack = torch.hstack([x] * num_reup)
96 759.5 MiB 2.8 MiB 1 results = self.qcircuit(inputs_stack)
97 759.5 MiB 0.0 MiB 1 return results
Filename: /home/owner/quantum_adv_demo.py
Line # Mem usage Increment Occurrences Line Contents
=============================================================
93 759.5 MiB 759.5 MiB 1 @profile
94 def forward(self, x):
95 759.5 MiB 0.0 MiB 1 inputs_stack = torch.hstack([x] * num_reup)
96 762.0 MiB 2.5 MiB 1 results = self.qcircuit(inputs_stack)
97 762.0 MiB 0.0 MiB 1 return results
Filename: /home/owner/quantum_adv_demo.py
Line # Mem usage Increment Occurrences Line Contents
=============================================================
93 762.0 MiB 762.0 MiB 1 @profile
94 def forward(self, x):
95 762.0 MiB 0.0 MiB 1 inputs_stack = torch.hstack([x] * num_reup)
96 764.7 MiB 2.8 MiB 1 results = self.qcircuit(inputs_stack)
97 764.7 MiB 0.0 MiB 1 return results
Filename: /home/owner/quantum_adv_demo.py
Line # Mem usage Increment Occurrences Line Contents
=============================================================
93 764.7 MiB 764.7 MiB 1 @profile
94 def forward(self, x):
95 764.7 MiB 0.0 MiB 1 inputs_stack = torch.hstack([x] * num_reup)
96 767.2 MiB 2.5 MiB 1 results = self.qcircuit(inputs_stack)
97 767.2 MiB 0.0 MiB 1 return results
Filename: /home/owner/quantum_adv_demo.py
Line # Mem usage Increment Occurrences Line Contents
=============================================================
93 767.2 MiB 767.2 MiB 1 @profile
94 def forward(self, x):
95 767.2 MiB 0.0 MiB 1 inputs_stack = torch.hstack([x] * num_reup)
96 769.7 MiB 2.5 MiB 1 results = self.qcircuit(inputs_stack)
97 769.7 MiB 0.0 MiB 1 return results
Filename: /home/owner/quantum_adv_demo.py
Line # Mem usage Increment Occurrences Line Contents
=============================================================
93 769.7 MiB 769.7 MiB 1 @profile
94 def forward(self, x):
95 769.7 MiB 0.0 MiB 1 inputs_stack = torch.hstack([x] * num_reup)
96 772.2 MiB 2.5 MiB 1 results = self.qcircuit(inputs_stack)
97 772.2 MiB 0.0 MiB 1 return results
Filename: /home/owner/quantum_adv_demo.py
Line # Mem usage Increment Occurrences Line Contents
=============================================================
93 772.2 MiB 772.2 MiB 1 @profile
94 def forward(self, x):
95 772.2 MiB 0.0 MiB 1 inputs_stack = torch.hstack([x] * num_reup)
96 775.0 MiB 2.8 MiB 1 results = self.qcircuit(inputs_stack)
97 775.0 MiB 0.0 MiB 1 return results
Filename: /home/owner/quantum_adv_demo.py
Line # Mem usage Increment Occurrences Line Contents
=============================================================
93 775.0 MiB 775.0 MiB 1 @profile
94 def forward(self, x):
95 775.0 MiB 0.0 MiB 1 inputs_stack = torch.hstack([x] * num_reup)
96 777.7 MiB 2.8 MiB 1 results = self.qcircuit(inputs_stack)
97 777.7 MiB 0.0 MiB 1 return results
Filename: /home/owner/quantum_adv_demo.py
Line # Mem usage Increment Occurrences Line Contents
=============================================================
93 777.7 MiB 777.7 MiB 1 @profile
94 def forward(self, x):
95 777.7 MiB 0.0 MiB 1 inputs_stack = torch.hstack([x] * num_reup)
96 780.5 MiB 2.8 MiB 1 results = self.qcircuit(inputs_stack)
97 780.5 MiB 0.0 MiB 1 return results
Filename: /home/owner/quantum_adv_demo.py
Line # Mem usage Increment Occurrences Line Contents
=============================================================
93 780.5 MiB 780.5 MiB 1 @profile
94 def forward(self, x):
95 780.5 MiB 0.0 MiB 1 inputs_stack = torch.hstack([x] * num_reup)
96 783.0 MiB 2.5 MiB 1 results = self.qcircuit(inputs_stack)
97 783.0 MiB 0.0 MiB 1 return results
Filename: /home/owner/quantum_adv_demo.py
Line # Mem usage Increment Occurrences Line Contents
=============================================================
93 783.0 MiB 783.0 MiB 1 @profile
94 def forward(self, x):
95 783.0 MiB 0.0 MiB 1 inputs_stack = torch.hstack([x] * num_reup)
96 785.7 MiB 2.8 MiB 1 results = self.qcircuit(inputs_stack)
97 785.7 MiB 0.0 MiB 1 return results
My specs are here:
Name: PennyLane
Version: 0.38.0
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page: https://github.com/PennyLaneAI/pennylane
Author:
Author-email:
License: Apache License 2.0
Location: /usr/local/lib/python3.11/dist-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, toml, typing-extensions
Required-by: PennyLane-qiskit, pennylane-qulacs, PennyLane_Lightning, PennyLane_Lightning_GPU, PennyLane_Lightning_Kokkos
Platform info: Linux-6.8.0-40-generic-x86_64-with-glibc2.35
Python version: 3.11.0
Numpy version: 1.26.3
Scipy version: 1.12.0
Installed devices:
- lightning.kokkos (PennyLane_Lightning_Kokkos-0.38.0)
- qiskit.aer (PennyLane-qiskit-0.37.0)
- qiskit.basicaer (PennyLane-qiskit-0.37.0)
- qiskit.basicsim (PennyLane-qiskit-0.37.0)
- qiskit.ibmq (PennyLane-qiskit-0.37.0)
- qiskit.ibmq.circuit_runner (PennyLane-qiskit-0.37.0)
- qiskit.ibmq.sampler (PennyLane-qiskit-0.37.0)
- qiskit.remote (PennyLane-qiskit-0.37.0)
- default.clifford (PennyLane-0.38.0)
- default.gaussian (PennyLane-0.38.0)
- default.mixed (PennyLane-0.38.0)
- default.qubit (PennyLane-0.38.0)
- default.qubit.autograd (PennyLane-0.38.0)
- default.qubit.jax (PennyLane-0.38.0)
- default.qubit.legacy (PennyLane-0.38.0)
- default.qubit.tf (PennyLane-0.38.0)
- default.qubit.torch (PennyLane-0.38.0)
- default.qutrit (PennyLane-0.38.0)
- default.qutrit.mixed (PennyLane-0.38.0)
- default.tensor (PennyLane-0.38.0)
- null.qubit (PennyLane-0.38.0)
- lightning.qubit (PennyLane_Lightning-0.38.0)
- lightning.gpu (PennyLane_Lightning_GPU-0.35.1)
- qulacs.simulator (pennylane-qulacs-0.36.0)
What I need to do is find a way to free the memory that is being allocated by the calls to the node once it is no longer needed. I followed the instructions provided in the earlier issue mentioned in the link, and I also tried initializing a new device and node in each forward call. However, none of these methods worked, and I am not sure where the memory is actually allocated. Any guidance on how to address this issue would be greatly appreciated.