The lightning.gpu device does not support gradient computation for circuits that return values including Mutual Information?

I want to use mutual information as the loss function to optimize a quantum neural network. Initially, I used default.qubit for this, but it took too much time to run. Therefore, I decided to switch to lightning.gpu, but it doesn’t work. Why is that? Below is my source code.

import torch
import pennylane as qml

import numpy as np

from torchvision import datasets, transforms
from torch.utils.data import DataLoader,Subset
import torch.nn.functional as F
import torch.nn as nn
from torch.nn import (
    Module,
    Conv2d,
    Linear,
    Dropout2d,
    NLLLoss,
    MaxPool2d,
    Flatten,
    Sequential,
    ReLU,
    Softmax

)
from tqdm import tqdm
import time
import csv
import datetime
import torch.autograd.profiler as profiler
from torch.optim.lr_scheduler import  MultiStepLR


train_num = 125
test_num = 50
batch_size = 25
epochnum = 100
classnum = 10

transform = transforms.Compose([transforms.ToTensor()])
train_dateset = datasets.FashionMNIST(root="./data", train=True, download=True, transform=transform)
test_dateset = datasets.FashionMNIST(root="./data", train=False , download=True, transform=transform)
# train_dateset = datasets.CIFAR10(root="./data", train=True, download=False, transform=transform)
# test_dateset = datasets.CIFAR10(root="./data", train=False, download=False, transform=transform)


train_targets = np.array(train_dateset.targets)
test_targets = np.array(test_dateset.targets)

indices = []
for i in range(classnum):
    idx = np.where(train_targets == i)[0][:train_num]
    indices.extend(idx)
selected_trainsubet = Subset(train_dateset,indices)

indices = []
for i in range(classnum):
    idx = np.where(test_targets == i)[0][:test_num]
    indices.extend(idx)
selected_testsubet = Subset(test_dateset,indices)

train_loader = DataLoader(selected_trainsubet, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(selected_testsubet, batch_size=batch_size, shuffle=True)




print(len(train_loader))
print(len(test_loader))

n_qubits = 4
n_layers = 2

dev = qml.device("lightning.gpu", n_qubits)


@qml.qnode(dev, interface='torch')


def circuit(inputs, weights):
    # inputs = inputs.reshape(4)

    inputs = inputs * np.pi
    # print(weights)

    for qub in range(n_qubits):
        qml.Hadamard(wires=qub)
        qml.RY(inputs[qub], wires=qub)
        # qml.RY(inputs[qub], wires=qub)

    for layer in range(n_layers):
        for i in range(n_qubits):
            qml.CRZ(weights[layer, i], wires=[i, (i + 1) % n_qubits])
        for j in range(n_qubits, 2 * n_qubits):
            qml.RY(weights[layer, j], wires=j % n_qubits)

    # return qml.probs(wires=[0, 1,2,3])
    a = []
    for i in range(n_qubits - 1):
        for j in range(i + 1, n_qubits):
            a.append(qml.mutual_info(wires0=[i], wires1=[j]))

    for i in range(n_qubits):
        a.append(qml.expval(qml.PauliZ(i)))


    return a
    # return [qml.expval(qml.PauliZ(wires=i)) for i in range(n_qubits)]





class Quanvolution(nn.Module):
    def __init__(self, channel_in, channel_out, kernel_size, stride, padding):
        super(Quanvolution, self).__init__()
        weight_shapes = {"weights": (n_layers, 2 * n_qubits)}
        self.channel_in = channel_in
        self.channel_out = channel_out
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        # qnode = qml.QNode(circuit, dev, interface='torch', diff_method="best")
        self.qnode = qml.QNode(circuit, dev, interface='torch', diff_method="best")
        self.ql1 = qml.qnn.TorchLayer(self.qnode, weight_shapes)

    def forward(self, inputs):  
        Filter_out = self.channel_out  
        HH = self.kernel_size  
        p = self.padding  #
        stride = self.stride
        H_new = 1 + (H + 2 * p - HH) // stride
        W_new = 1 + (W + 2 * p - WW) // stride
        inputs_padded = F.pad(inputs, pad=(p, p, p, p), mode='constant', value=0)
        s = stride
        out = torch.zeros((N, Filter_out, H_new, W_new))  
        start_time = time.time()
        # q_results = torch.zeros((9))
        # q_results = 0
        # for param_name, param in model.named_parameters():
        #
        #     if param_name == "qc.ql1.weights":
        #         a = param
        # print(a)

        # loss = self.qnode(torch.tensor([0, 0, 0, 0]),a)
        # MIL,_=spilt_ob(loss)

        loss = []


        for i in range(N):  # ith image
            # for f in range(Filter_out):   # fth filter
            for j in range(H_new):
                for k in range(W_new):
                    for t in range(Filter_in):
                        q_result = self.ql1(torch.flatten(inputs_padded[i,t, j * s:HH + j * s, k * s:WW + k * s]))
                        # print(q_result)
                        loss.append(q_result[0:int(n_qubits * (n_qubits - 1) / 2)])

                        for c in range(Filter_out):
                            out[i, c, j, k] = q_result[c + int(n_qubits * (n_qubits - 1) / 2)]

        # loss = self.ql1(torch.tensor([0, 0, 0, 0]))

        loss = torch.stack(loss, dim=0)
        loss = torch.mean(loss)

        return out, loss


class HybridModel(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.qc = Quanvolution(channel_in=1, channel_out=4, kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(196, 20)
        self.fc2 = nn.Linear(20, 10)
        # self.weight = nn.Parameter(torch.ones(1, 4), requires_grad=True)
        # self.pooling = nn.MaxPool2d(2)

    def forward(self, x):
        x = F.avg_pool2d(x, kernel_size=2)

        x, MIL = self.qc(x)
        # print(x)
        x = x.view(x.shape[0], -1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        # for param_name, param in model.named_parameters():
        #
        #     if param_name == "qc.ql1.weights":
        #         print(param)
        # print(MIL)

        return x, MIL


device = torch.device("cpu")
model = HybridModel()
model = model.to(device)
learning_rate = 0.001

opt = torch.optim.Adam(model.parameters(), lr=0.001)

loss_func = torch.nn.CrossEntropyLoss()
loss_func = loss_func.to(device)

b = 0.05


def train(model, opt, train_loader, test_loader):
    current_date = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    current_date = current_date.replace(":", "_")
    filename = f"./MI_ADAM(0.3)_FashionMNIST10_{n_layers}_{current_date}.csv"

    F1 = open(filename, "w")

    header = ["Epoch", "Train Loss","Train Accuracy", "Test Loss", "Test Accuracy"]
    with open(filename, mode='w', newline='') as csv_file:
        csv_writer = csv.writer(csv_file)
        csv_writer.writerow(header)

    best_validation_accuracy = 0.0 
    best_model_state = None

    print("start training...")
    model.train()
    for epoch in range(epochnum):

        start_time = time.time()

        train_losses = 0
        correct_train = 0

        for batch_id, (data, label) in enumerate(tqdm(train_loader)):
            data, target = data.to(device), label.to(device)
         
            outputs, MIL = model(data)
            loss = loss_func(outputs, target)
        
            loss = loss - 1*MIL
         
            train_losses += loss.item()
            pred = F.softmax(outputs, dim=1)
            pred = torch.argmax(pred, dim=1)

            opt.zero_grad(set_to_none=True)
            loss.backward()
            opt.step()
            correct_train += pred.eq(target).sum().item()

        train_loss_avg = train_losses / len(train_loader)
        accuracy_train = correct_train / len(train_loader) / batch_size
    

        model.eval()
        losses = 0
        correct = 0

        for batch_id, (data, label) in enumerate(tqdm(test_loader)):
            data, target = data.to(device), label.to(device)
        
            outputs, MIL = model(data)
            pred = F.softmax(outputs, dim=1)
            pred = torch.argmax(pred, dim=1)
            loss = loss_func(outputs, target)
       
            loss = loss - 1* MIL
         

        
            correct += pred.eq(target).sum().item()

            losses += loss.item()


        accuracy = correct / len(test_loader) / batch_size
        loss_avg = losses / len(test_loader)

        print("[train] epoch/accuracy/loss: {:d}/{:.4f}/{:.4f}".format(epoch + 1, accuracy_train, train_loss_avg))

        print("[validation] epoch/accuracy/loss: {:d}/{:.4f}/{:.4f}".format(epoch + 1, accuracy, loss_avg))

        if accuracy > best_validation_accuracy:
            best_validation_accuracy = accuracy
            best_model_state = model.state_dict()

        model.train()
        end_time = time.time()
        time1 = end_time - start_time
  
        row = (epoch + 1, train_loss_avg, accuracy_train, loss_avg, accuracy)

        with open(filename, mode='a', newline='') as csv_file:
            csv_writer = csv.writer(csv_file)
            csv_writer.writerow(row)
    
    torch.save(best_model_state, 'best_model45{:.4f}_layer_{:d}.pth'.format(best_validation_accuracy, n_layers))
    print("the best validation accuracy: {:.4f}".format(best_validation_accuracy))


train(model, opt, train_loader, test_loader)














the full error message below:

alueError: Computing the gradient of circuits that return the state with the parameter-shift rule gradient transform is not supported, as it is a hardware-compatible method. 

Hi @HotFrog , that’s right.

In the documentation for qml.mutual_info you will see following note:

Note

Calculating the derivative of mutual_info() is currently supported when using the classical backpropagation differentiation method (diff_method="backprop") with a compatible device and finite differences (diff_method="finite-diff").

This means that you need to set a differentiation method that’s both compatible with the device and the transform if you want to use it for training. For lightning.gpu I don’t think that diff_method="backprop" will work but you can try. Alternatively you can also try diff_method="finite-diff"which is slower and less accurate but it may allow you to un on lightning gpu.

I hope this helps!

Let me know if you have any further questions.