Memory leak in when using lighning.kokkos device

Hi all, I am implementing a quantum machine learning model, using the lightning.kokkos device, but I have encountered a memory leak similar to this issue from several months ago:
Memory not releasing after circuit run

My code is shown here:

import pennylane as qml
from pennylane import numpy as np
import torch
from matplotlib import pyplot as plt
from memory_profiler import profile

# we can use the dataset hosted on PennyLane
[pm] = qml.data.load('other', name='plus-minus')

X_train = pm.img_train  
X_test = pm.img_test  
Y_train = pm.labels_train 
Y_test = pm.labels_test  


x_vis = [
    (X_train[Y_train == 0])[0],
    (X_train[Y_train == 1])[0],
    (X_train[Y_train == 2])[0],
    (X_train[Y_train == 3])[0],
]
y_vis = [0, 1, 2, 3]



def visualize_data(x, y, pred=None):
    n_img = len(x)
    labels_list = ["\u2212", "\u002b", "\ua714", "\u02e7"]
    fig, axes = plt.subplots(1, 4, figsize=(8, 2))
    for i in range(n_img):
        axes[i].imshow(x[i], cmap="gray")
        if pred is None:
            axes[i].set_title("Label: {}".format(labels_list[y[i]]))
        else:
            axes[i].set_title("Label: {}, Pred: {}".format(labels_list[y[i]], labels_list[pred[i]]))
    plt.tight_layout(w_pad=2)
    # plt.show()


visualize_data(x_vis, y_vis)


input_dim = 256
num_classes = 4
num_layers = 32
num_qubits = 8
num_reup = 3

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


class QML_classifier(torch.nn.Module):
   
    def __init__(self, input_dim, output_dim, num_qubits, num_layers):
        super().__init__()
        torch.manual_seed(1337) 
        self.num_qubits = num_qubits
        self.output_dim = output_dim
        self.num_layers = num_layers
        self.device = qml.device("lightning.kokkos", wires=self.num_qubits)
        self.weights_shape = qml.StronglyEntanglingLayers.shape(
            n_layers=self.num_layers, n_wires=self.num_qubits
        )

        @qml.qnode(self.device)
        def circuit(inputs, weights, bias):
            inputs = torch.reshape(inputs, self.weights_shape)
            qml.StronglyEntanglingLayers(
                weights=weights * inputs + bias, wires=range(self.num_qubits)
            )
            return [qml.expval(qml.PauliZ(i)) for i in range(self.output_dim)]

        param_shapes = {"weights": self.weights_shape, "bias": self.weights_shape}
        init_vals = {
            "weights": 0.1 * torch.rand(self.weights_shape),
            "bias": 0.1 * torch.rand(self.weights_shape),
        }


        self.qcircuit = qml.qnn.TorchLayer(
            qnode=circuit, weight_shapes=param_shapes, init_method=init_vals
        )

    @profile
    def forward(self, x):
        inputs_stack = torch.hstack([x] * num_reup)
        results = self.qcircuit(inputs_stack)
        return results

learning_rate = 0.1
epochs = 5
batch_size = 50

feats_train = torch.from_numpy(X_train[:200]).reshape(200, -1).to(device)
feats_test = torch.from_numpy(X_test[:50]).reshape(50, -1).to(device)
labels_train = torch.from_numpy(Y_train[:200]).to(device)
labels_test = torch.from_numpy(Y_test[:50]).to(device)
num_train = feats_train.shape[0]

# initialize the model, loss function and optimization algorithm (Adam optimizer)
qml_model = QML_classifier(input_dim, num_classes, num_qubits, num_layers)
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(qml_model.parameters(), lr=learning_rate)
num_batches = feats_train.shape[0] // batch_size

def accuracy(labels, predictions):
    acc = 0
    for l, p in zip(labels, predictions):
        if torch.argmax(p) == l:
            acc += 1
    acc = acc / len(labels)
    return acc


# generate randomly permutated batches to speed up training
def gen_batches(num_samples, num_batches):
    assert num_samples % num_batches == 0
    perm_ind = torch.reshape(torch.randperm(num_samples), (num_batches, -1))
    return perm_ind


# display accuracy and loss after each epoch (to speed up runtime, only do this for first 100 samples)
def print_acc(epoch, max_ep=5):
    predictions_train = [qml_model(f) for f in feats_train[:50]]
    predictions_test = [qml_model(f) for f in feats_test]
    cost_approx_train = loss(torch.stack(predictions_train), labels_train[:50])
    cost_approx_test = loss(torch.stack(predictions_test), labels_test)
    acc_approx_train = accuracy(labels_train[:50], predictions_train)
    acc_approx_test = accuracy(labels_test, predictions_test)
    print(
        f"Epoch {epoch}/{max_ep} | Approx Cost (train): {cost_approx_train:0.7f} | Cost (val): {cost_approx_test:0.7f} |"
        f" Approx Acc train: {acc_approx_train:0.7f} | Acc val: {acc_approx_test:0.7f}"
    )


print(
    f"Starting training loop for quantum variational classifier ({num_qubits} qubits, {num_layers} layers)..."
)

for ep in range(0, epochs):
    batch_ind = gen_batches(num_train, num_batches)
    print_acc(epoch=ep)

    for it in range(num_batches):
        optimizer.zero_grad()
        feats_train_batch = feats_train[batch_ind[it]]
        labels_train_batch = labels_train[batch_ind[it]]

        outputs = [qml_model(f) for f in feats_train_batch]
        batch_loss = loss(torch.stack(outputs), labels_train_batch)
        # if REG:
        #    loss = loss + lipschitz_regularizer(regularization_rate, model.qcircuit.weights)
        batch_loss.backward()
        optimizer.step()

print_acc(epochs)


# show accuracy
x_vis_torch = torch.from_numpy(np.array(x_vis).reshape(4, -1))
y_vis_torch = torch.from_numpy(np.array(y_vis))
benign_preds = [qml_model(f) for f in x_vis_torch]

benign_class_output = [torch.argmax(p) for p in benign_preds]
visualize_data(x_vis, y_vis, benign_class_output)



Running this version of the code requires the installation of the python memory-profiler package:
https://pypi.org/project/memory-profiler/

The output that I get from memory-profiler is not exactly an error message, but it shows that more memory is allocated with certain to the node, and for some reason, most of the memory is not getting released after the forward call is complete. This output is from the first epoch of training:

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    93    689.2 MiB    689.2 MiB           1       @profile
    94                                             def forward(self, x):
    95    693.5 MiB      4.2 MiB           1           inputs_stack = torch.hstack([x] * num_reup)
    96    738.0 MiB     44.5 MiB           1           results = self.qcircuit(inputs_stack)
    97    738.0 MiB      0.0 MiB           1           return results


Filename: /home/owner/quantum_adv_demo.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    93    738.0 MiB    738.0 MiB           1       @profile
    94                                             def forward(self, x):
    95    738.0 MiB      0.0 MiB           1           inputs_stack = torch.hstack([x] * num_reup)
    96    740.7 MiB      2.8 MiB           1           results = self.qcircuit(inputs_stack)
    97    740.7 MiB      0.0 MiB           1           return results


Filename: /home/owner/quantum_adv_demo.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    93    740.7 MiB    740.7 MiB           1       @profile
    94                                             def forward(self, x):
    95    740.7 MiB      0.0 MiB           1           inputs_stack = torch.hstack([x] * num_reup)
    96    743.5 MiB      2.8 MiB           1           results = self.qcircuit(inputs_stack)
    97    743.5 MiB      0.0 MiB           1           return results


Filename: /home/owner/quantum_adv_demo.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    93    743.5 MiB    743.5 MiB           1       @profile
    94                                             def forward(self, x):
    95    743.5 MiB      0.0 MiB           1           inputs_stack = torch.hstack([x] * num_reup)
    96    746.0 MiB      2.5 MiB           1           results = self.qcircuit(inputs_stack)
    97    746.0 MiB      0.0 MiB           1           return results


Filename: /home/owner/quantum_adv_demo.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    93    746.0 MiB    746.0 MiB           1       @profile
    94                                             def forward(self, x):
    95    746.0 MiB      0.0 MiB           1           inputs_stack = torch.hstack([x] * num_reup)
    96    748.7 MiB      2.8 MiB           1           results = self.qcircuit(inputs_stack)
    97    748.7 MiB      0.0 MiB           1           return results


Filename: /home/owner/quantum_adv_demo.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    93    748.7 MiB    748.7 MiB           1       @profile
    94                                             def forward(self, x):
    95    748.7 MiB      0.0 MiB           1           inputs_stack = torch.hstack([x] * num_reup)
    96    751.5 MiB      2.8 MiB           1           results = self.qcircuit(inputs_stack)
    97    751.5 MiB      0.0 MiB           1           return results


Filename: /home/owner/quantum_adv_demo.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    93    751.5 MiB    751.5 MiB           1       @profile
    94                                             def forward(self, x):
    95    751.5 MiB      0.0 MiB           1           inputs_stack = torch.hstack([x] * num_reup)
    96    754.0 MiB      2.5 MiB           1           results = self.qcircuit(inputs_stack)
    97    754.0 MiB      0.0 MiB           1           return results


Filename: /home/owner/quantum_adv_demo.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    93    754.0 MiB    754.0 MiB           1       @profile
    94                                             def forward(self, x):
    95    754.0 MiB      0.0 MiB           1           inputs_stack = torch.hstack([x] * num_reup)
    96    756.7 MiB      2.8 MiB           1           results = self.qcircuit(inputs_stack)
    97    756.7 MiB      0.0 MiB           1           return results


Filename: /home/owner/quantum_adv_demo.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    93    756.7 MiB    756.7 MiB           1       @profile
    94                                             def forward(self, x):
    95    756.7 MiB      0.0 MiB           1           inputs_stack = torch.hstack([x] * num_reup)
    96    759.5 MiB      2.8 MiB           1           results = self.qcircuit(inputs_stack)
    97    759.5 MiB      0.0 MiB           1           return results


Filename: /home/owner/quantum_adv_demo.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    93    759.5 MiB    759.5 MiB           1       @profile
    94                                             def forward(self, x):
    95    759.5 MiB      0.0 MiB           1           inputs_stack = torch.hstack([x] * num_reup)
    96    762.0 MiB      2.5 MiB           1           results = self.qcircuit(inputs_stack)
    97    762.0 MiB      0.0 MiB           1           return results


Filename: /home/owner/quantum_adv_demo.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    93    762.0 MiB    762.0 MiB           1       @profile
    94                                             def forward(self, x):
    95    762.0 MiB      0.0 MiB           1           inputs_stack = torch.hstack([x] * num_reup)
    96    764.7 MiB      2.8 MiB           1           results = self.qcircuit(inputs_stack)
    97    764.7 MiB      0.0 MiB           1           return results


Filename: /home/owner/quantum_adv_demo.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    93    764.7 MiB    764.7 MiB           1       @profile
    94                                             def forward(self, x):
    95    764.7 MiB      0.0 MiB           1           inputs_stack = torch.hstack([x] * num_reup)
    96    767.2 MiB      2.5 MiB           1           results = self.qcircuit(inputs_stack)
    97    767.2 MiB      0.0 MiB           1           return results


Filename: /home/owner/quantum_adv_demo.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    93    767.2 MiB    767.2 MiB           1       @profile
    94                                             def forward(self, x):
    95    767.2 MiB      0.0 MiB           1           inputs_stack = torch.hstack([x] * num_reup)
    96    769.7 MiB      2.5 MiB           1           results = self.qcircuit(inputs_stack)
    97    769.7 MiB      0.0 MiB           1           return results


Filename: /home/owner/quantum_adv_demo.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    93    769.7 MiB    769.7 MiB           1       @profile
    94                                             def forward(self, x):
    95    769.7 MiB      0.0 MiB           1           inputs_stack = torch.hstack([x] * num_reup)
    96    772.2 MiB      2.5 MiB           1           results = self.qcircuit(inputs_stack)
    97    772.2 MiB      0.0 MiB           1           return results


Filename: /home/owner/quantum_adv_demo.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    93    772.2 MiB    772.2 MiB           1       @profile
    94                                             def forward(self, x):
    95    772.2 MiB      0.0 MiB           1           inputs_stack = torch.hstack([x] * num_reup)
    96    775.0 MiB      2.8 MiB           1           results = self.qcircuit(inputs_stack)
    97    775.0 MiB      0.0 MiB           1           return results


Filename: /home/owner/quantum_adv_demo.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    93    775.0 MiB    775.0 MiB           1       @profile
    94                                             def forward(self, x):
    95    775.0 MiB      0.0 MiB           1           inputs_stack = torch.hstack([x] * num_reup)
    96    777.7 MiB      2.8 MiB           1           results = self.qcircuit(inputs_stack)
    97    777.7 MiB      0.0 MiB           1           return results


Filename: /home/owner/quantum_adv_demo.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    93    777.7 MiB    777.7 MiB           1       @profile
    94                                             def forward(self, x):
    95    777.7 MiB      0.0 MiB           1           inputs_stack = torch.hstack([x] * num_reup)
    96    780.5 MiB      2.8 MiB           1           results = self.qcircuit(inputs_stack)
    97    780.5 MiB      0.0 MiB           1           return results


Filename: /home/owner/quantum_adv_demo.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    93    780.5 MiB    780.5 MiB           1       @profile
    94                                             def forward(self, x):
    95    780.5 MiB      0.0 MiB           1           inputs_stack = torch.hstack([x] * num_reup)
    96    783.0 MiB      2.5 MiB           1           results = self.qcircuit(inputs_stack)
    97    783.0 MiB      0.0 MiB           1           return results


Filename: /home/owner/quantum_adv_demo.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    93    783.0 MiB    783.0 MiB           1       @profile
    94                                             def forward(self, x):
    95    783.0 MiB      0.0 MiB           1           inputs_stack = torch.hstack([x] * num_reup)
    96    785.7 MiB      2.8 MiB           1           results = self.qcircuit(inputs_stack)
    97    785.7 MiB      0.0 MiB           1           return results


My specs are here:

Name: PennyLane
Version: 0.38.0
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page: https://github.com/PennyLaneAI/pennylane
Author: 
Author-email: 
License: Apache License 2.0
Location: /usr/local/lib/python3.11/dist-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, toml, typing-extensions
Required-by: PennyLane-qiskit, pennylane-qulacs, PennyLane_Lightning, PennyLane_Lightning_GPU, PennyLane_Lightning_Kokkos

Platform info:           Linux-6.8.0-40-generic-x86_64-with-glibc2.35
Python version:          3.11.0
Numpy version:           1.26.3
Scipy version:           1.12.0
Installed devices:
- lightning.kokkos (PennyLane_Lightning_Kokkos-0.38.0)
- qiskit.aer (PennyLane-qiskit-0.37.0)
- qiskit.basicaer (PennyLane-qiskit-0.37.0)
- qiskit.basicsim (PennyLane-qiskit-0.37.0)
- qiskit.ibmq (PennyLane-qiskit-0.37.0)
- qiskit.ibmq.circuit_runner (PennyLane-qiskit-0.37.0)
- qiskit.ibmq.sampler (PennyLane-qiskit-0.37.0)
- qiskit.remote (PennyLane-qiskit-0.37.0)
- default.clifford (PennyLane-0.38.0)
- default.gaussian (PennyLane-0.38.0)
- default.mixed (PennyLane-0.38.0)
- default.qubit (PennyLane-0.38.0)
- default.qubit.autograd (PennyLane-0.38.0)
- default.qubit.jax (PennyLane-0.38.0)
- default.qubit.legacy (PennyLane-0.38.0)
- default.qubit.tf (PennyLane-0.38.0)
- default.qubit.torch (PennyLane-0.38.0)
- default.qutrit (PennyLane-0.38.0)
- default.qutrit.mixed (PennyLane-0.38.0)
- default.tensor (PennyLane-0.38.0)
- null.qubit (PennyLane-0.38.0)
- lightning.qubit (PennyLane_Lightning-0.38.0)
- lightning.gpu (PennyLane_Lightning_GPU-0.35.1)
- qulacs.simulator (pennylane-qulacs-0.36.0)

What I need to do is find a way to free the memory that is being allocated by the calls to the node once it is no longer needed. I followed the instructions provided in the earlier issue mentioned in the link, and I also tried initializing a new device and node in each forward call. However, none of these methods worked, and I am not sure where the memory is actually allocated. Any guidance on how to address this issue would be greatly appreciated.

Hi @DarthMalloc, thanks for reporting this. Do you see the same issue with PennyLane v0.38.1?

Thank you very much for getting back to me! I got another response on the Pennylane github, so I am first checking to see if the problem happens with default.qubit, but after that I will check it with PennyLane v0.38.1.

No problem. Yes I noticed the issue there. Let’s keep the conversation there for now :slightly_smiling_face: