Qnode layer's weights cannot working on GPU when using Pytorch and Pennylane together

Hi, we are using Pennylane, Pennylane-Lighting and Pytorch on Ubuntu 22.04, we set the device as follows:

dev = qml.device("lightning.kokkos", wires=N_QUBITS)

and we create a torchlayer:

class Net(torch.nn.Module):
    def __init__(self, Queue):
        super(Net, self).__init__()
        self.encoder_q = qml.qnn.TorchLayer(qnode_q, weight_shapes)

the qnode_q is

@qml.qnode(dev)
def qnode_q(inputs, weights):
    if torch.all(inputs == 0) or len(inputs) > 2 ** N_QUBITS:
        print('something error', torch.array(inputs).shape, inputs)
    AmplitudeEmbedding(features=inputs, wires=range(N_QUBITS), normalize=True, pad_with=0.01)
    count = 0
    for W in weights:
        layer(W, count)
        count += 1
    result = [qml.expval(qml.PauliZ(i)) for i in range(N_QUBITS)]
    return result
  • But the codes works slowly and tend to stuck in “backward()”, to find the reason, we set a breakpoint after
    self.encoder_k = qml.qnn.TorchLayer(qnode_q, weight_shapes)
    We check the weights in self.encoder_q, which show that is_cpu=True and is_cuda=False, i wonder if that means the code does not work on GPU correctly?

  • Another problem is when we use lightning.gpu, sometimes the code can work properly, but sometimes the process stopped and returns Process finished with exit code 139 (interrupted by signal 11:SIGSEGV)
    We are curious about the reason and want to know which one is better for use, .gpu or .kokkos?

I hope to receive a prompt response to the question, thanks.

Hey @StillWaterJ! Welcome to the forum :sunglasses:

PennyLane Lightning GPU / Kokkos doesn’t integrate with how Pytorch dispatches to GPUs. They’re separate entities at the moment. While they might still work together in some cases, it doesn’t surprise me that you’re getting some errors.

Do you mind posting a full & complete example so that I can copy-paste your code and try to replicate what’s going on? If there’s no immediate solution, I can try to help provide a workaround :slight_smile:

Here is the codes, thanks.

import random
import time
import numpy as np
import torch.optim as optim
import numpy
import torch
import json
import torch
import pennylane as qml
import torch.nn as nn
from PIL import Image
from pennylane.templates import AmplitudeEmbedding
iterations = 100
batch_size =32
N_QUBITS = 11
total_epochs = 3
learning_rate = 0.01
K = 320
dim=11
N_LAYERS = 2
m = 0.99
T = 0.07
# ===================================================================

def train_data(anchor, net, device, criterion, opt):
    # anchor = anchor.to(device)
    logits, labels, Queue = net(anchor)
    loss = criterion(logits / T, labels)
    loss.backward()
    opt.step()
    train_loss = loss.data
    return train_loss, Queue

def train_network(net=None, TRAIN_DATASET=None, device=None, epochs=total_epochs, optimizer=None,
                  criterion=None, idx=None, n_gene=None):
    net = net.to(device)
    train_losses = []
    Queue = None
    net.train()
    for epoch in range(epochs):
        for i in range(iterations):
            t2 = time.time()
            anchor = random.sample(TRAIN_DATASET, batch_size)
            anchor = torch.tensor(numpy.array(anchor), dtype=torch.float32, requires_grad=True)
            train_loss, Queue = train_data(anchor, net, device, criterion, optimizer)
            net.Queue = Queue
            optimizer.zero_grad()
            train_losses.append(train_loss)
            print('batch({}):{:.4f}| train_loss:{:.4f} '.format(
                i, time.time() - t2, train_loss))

            if i % 10 == 0:
                torch.save(net.state_dict(), f"result/model_train_parameters_{n_gene}_{idx}.pth")
    return Queue

dev = qml.device("lightning.kokkos", wires=N_QUBITS)
@qml.qnode(dev)
def qnode_q(inputs, weights):
    global clayer
    if torch.all(inputs == 0) or len(inputs) > 2 ** N_QUBITS:
        print('something error', torch.array(inputs).shape, inputs)
    AmplitudeEmbedding(features=inputs, wires=range(N_QUBITS), normalize=True, pad_with=0.01)
    count = 0
    for W in weights:
        layer(W, count)
        count += 1
    result = [qml.expval(qml.PauliZ(i)) for i in range(N_QUBITS)]
    return result

def layer(W, count):
    clayer = [[[0, 10], [1, 10], [2, 10], [8, 10], [10, 7]], [[5, 10], [6, 10], [7, 10], [8, 10], [7, 4], [9, 8]]]
    for i in range(N_QUBITS):
        qml.Rot(W[i, 0], W[i, 1], W[i, 2], wires=i)
    qml.Barrier(only_visual=False)
    for element in clayer[count]:
        wire, ctrl = element
        qml.CNOT(wires=[wire, ctrl])
    qml.Barrier(only_visual=False)

weight_shapes = {"weights": (N_LAYERS, N_QUBITS, 3)}

class Net(torch.nn.Module):
    def __init__(self, Queue):
        super(Net, self).__init__()
        self.encoder_q = qml.qnn.TorchLayer(qnode_q, weight_shapes)
        self.encoder_k = qml.qnn.TorchLayer(qnode_q, weight_shapes)

        self.K = K
        self.m = m
        self.T = T
        self.Queue = Queue
        for params_q, params_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            params_k.data.copy_(params_q.data)
            params_k.requires_grad = False

        self.Queue = torch.nn.functional.normalize(self.Queue, dim=0)
        self.queue_ptr = torch.zeros(1, dtype=torch.long)

    @torch.no_grad()
    def _momentum_update_key_encoder(self):
        for param_q, param_k in zip(
                self.encoder_q.parameters(), self.encoder_k.parameters()
        ):
            param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m)
            param_k.data.require_grad = False

    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        batch_size = keys.shape[0]
        ptr = int(self.queue_ptr)
        assert self.K % batch_size == 0
        self.Queue[:, ptr: ptr + batch_size] = keys.T
        ptr = (ptr + batch_size) % self.K
        self.queue_ptr[0] = ptr

    def forward(self, inputs):
        im_q = [a[0] for a in inputs]
        im_k = [a[1] for a in inputs]
        q = None
        for idx, anchor in enumerate(im_q):
            q_exp = torch.Tensor(self.encoder_q(anchor)).flatten().unsqueeze(0)
            if q is None:
                q = q_exp
            else:
                q = torch.cat((q, q_exp))
        q = torch.nn.functional.normalize(q, dim=0)  # # queries: NxC  [64, 11]

        self._momentum_update_key_encoder()
        k = None
        for key in im_k:
            k_exp = torch.Tensor(self.encoder_k(key)).flatten().unsqueeze(0)
            if k is None:
                k = k_exp
            else:
                k = torch.cat((k, k_exp))
        k = torch.nn.functional.normalize(k, dim=0).detach()
        l_pos = torch.einsum("nc,nc->n", [q, k]).unsqueeze(-1)
        que_k = self.Queue.clone().detach()
        que_k = que_k.to(device)
        l_neg = torch.einsum("nc,ck->nk", [q, que_k])
        logits = torch.cat([l_pos, l_neg], dim=1)
        logits /= self.T
        labels = torch.zeros(logits.shape[0], dtype=torch.long)
        self._dequeue_and_enqueue(k)
        return logits, labels, self.Queue

device = torch.device("cuda:0")
torch.cuda.manual_seed(42)

Queue = torch.randn(dim, K)
net = Net(Queue)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.encoder_q.parameters(), lr=learning_rate)
fig = np.random.randint(0, 256, (26, 26, 3), dtype=np.uint8)
TRAIN_DATASET = []
for i in range(100):
    TRAIN_DATASET.append((np.array(fig).flatten(), np.array(fig).flatten()))

Queue = train_network(net, TRAIN_DATASET, device, total_epochs,
                      optimizer, criterion)

Thank you! I’ll get back to you shortly with a response.

Hey @StillWaterJ,

I added torch.set_default_tensor_type('torch.cuda.FloatTensor') to the top of your code, and when I run it I get hung up on

Queue = train_network(net, TRAIN_DATASET, device, total_epochs,
                      optimizer, criterion)

specifically

            train_loss, Queue = train_data(anchor, net, device, criterion, optimizer)

Try adding torch.set_default_tensor_type('torch.cuda.FloatTensor') right after your imports to see if that fixes anything. And, just to make sure we’re running the same versions, I’m using PennyLane v0.36 and Torch 2.3.0+cu121.

If that isn’t helping, can you reduce your code down to something less computationally intensive? That will help the debugging process :slight_smile:. Hope this helps!

2 Likes

Thank you for your answer!

1 Like

My pleasure! Let us know if you have any further questions :slight_smile: