QGAN application on Dask

Hello! If applicable, put your complete code example down below. Make sure that your code:

  • is 100% self-contained — someone can copy-paste exactly what is here and run it to
    reproduce the behaviour you are observing
  • includes comments
# Put code here
# Library imports
import math
import os
import random
import socket

import distributed
import numpy as np
import pandas as pd
import pennylane as qml
# Pytorch imports
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset


class DigitsDataset(Dataset):
    """Pytorch dataloader for the Optical Recognition of Handwritten Digits Data Set"""

    def __init__(self, csv_file, label=0, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.csv_file = csv_file
        self.transform = transform
        self.df = self.filter_by_label(label)

    def filter_by_label(self, label):
        # Use pandas to return a dataframe of only zeros
        df = pd.read_csv(self.csv_file)
        df = df.loc[df.iloc[:, -1] == label]
        return df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        image = self.df.iloc[idx, :-1] / 16
        image = np.array(image)
        image = image.astype(np.float32).reshape(8, 8)

        if self.transform:
            image = self.transform(image)

        # Return image and label
        return image, 0


# Set the random seed for reproducibility
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

image_size = 8  # Height / width of the square images
batch_size = 1

transform = transforms.Compose([transforms.ToTensor()])
dataset = DigitsDataset(csv_file="quantum_gans/optdigits.tra", transform=transform)
dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=batch_size, shuffle=True, drop_last=True
)


class Discriminator(nn.Module):
    """Fully connected classical discriminator"""

    def __init__(self):
        super().__init__()

        self.model = nn.Sequential(
            # Inputs to first hidden layer (num_input_features -> 64)
            nn.Linear(image_size * image_size, 64),
            nn.ReLU(),
            # First hidden layer (64 -> 16)
            nn.Linear(64, 16),
            nn.ReLU(),
            # Second hidden layer (16 -> output)
            nn.Linear(16, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.model(x)


# Given Ng sub-generators that make up the entire generator, N qubits, and Na ancillary qubits, the size of
# each sub-generator’s output (patch_size) is pow(2, N−Na). So, when all of the sub-generators are recombined /
# concatenated together, that cumulative output needs to be 8x8.
#
# In the tutorial, we have Ng=4,N=5, and Na=1. Each sub-generator will create a feature whose length is pow(2, 5−1)=16.
# 4 groups of 16 gives us 64 — the dimensionality we need!
#
# However, we could also easily do Ng=4,N=6, and Na=2 and the math still checks out (you can try this). You could even
# do Ng=16, N=5 and Na=3.. Here is the magic formula you can use: pow(2, n-na) = 64/Ng. If that equation holds, then
# those values of Ng, Na, N will work.


# Quantum variables
n_qubits = 5  # Total number of qubits / N
n_a_qubits = 1  # Number of ancillary qubits / N_A
q_depth = 6  # Depth of the parameterised quantum circuit / D
n_generators = 4  # Number of subgenerators for the patch method / N_G


class PatchQuantumGenerator(nn.Module):
    """Quantum generator class for the patch method"""

    def __init__(self, n_generators, q_delta=1):
        """
        Args:
            n_generators (int): Number of sub-generators to be used in the patch method.
            q_delta (float, optional): Spread of the random distribution for parameter initialisation.
        """

        super().__init__()

        self.q_params = nn.ParameterList(
            [
                nn.Parameter(q_delta * torch.rand(q_depth * n_qubits), requires_grad=True)
                for _ in range(n_generators)
            ]
        )
        self.n_generators = n_generators

    def forward(self, x):
        # For further info on how the non-linear transform is implemented in Pennylane
        # https://discuss.pennylane.ai/t/ancillary-subsystem-measurement-then-trace-out/1532
        def partial_measure(noise, weights):
            # Non-linear Transform
            probs = quantum_circuit(noise, weights)
            probsgiven0 = probs[: (2 ** (n_qubits - n_a_qubits))]
            probsgiven0 /= torch.sum(probs)

            # Post-Processing
            probsgiven = probsgiven0 / torch.max(probsgiven0)
            return probsgiven

        # Quantum simulator
        dev = qml.device("lightning.gpu", wires=n_qubits)

        @qml.qnode(dev, interface="torch", diff_method="parameter-shift")
        def quantum_circuit(noise, weights):
            weights = weights.reshape(q_depth, n_qubits)

            # Initialise latent vectors
            for i in range(n_qubits):
                qml.RY(noise[i], wires=i)

            # Repeated layer
            for i in range(q_depth):
                # Parameterised layer
                for y in range(n_qubits):
                    qml.RY(weights[i][y], wires=y)

                # Control Z gates
                for y in range(n_qubits - 1):
                    qml.CZ(wires=[y, y + 1])

            return qml.probs(wires=list(range(n_qubits)))

        # Size of each sub-generator output
        patch_size = 2 ** (n_qubits - n_a_qubits)

        # Create a Tensor to 'catch' a batch of images from the for loop. x.size(0) is the batch size.
        # Enable CUDA device if available
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        images = torch.Tensor(x.size(0), 0).to(device)

        # Iterate over all sub-generators
        for params in self.q_params:

            # Create a Tensor to 'catch' a batch of the patches from a single sub-generator
            patches = torch.Tensor(0, patch_size).to(device)
            for elem in x:
                q_out = partial_measure(elem, params).float().unsqueeze(0)
                patches = torch.cat((patches, q_out))

            # Each batch of patches is concatenated with each other to create a batch of images
            images = torch.cat((images, patches), 1)

        return images


if __name__ == "__main__":
    dask_client = None
    try:
        dask_client = distributed.Client()
        dask_client.scheduler_info()

        print(dask_client.gather(dask_client.map(lambda a: a * a, range(10))))
        print(dask_client.gather(dask_client.map(lambda a: socket.gethostname(), range(10))))


        def run_training():

            # Enable CUDA device if available
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

            lrG = 0.3  # Learning rate for the generator
            lrD = 0.01  # Learning rate for the discriminator
            num_iter = 500  # Number of training iterations

            num_gpus = torch.cuda.device_count()
            print("Number of GPUS %s on the host %s" % (num_gpus, socket.gethostname()))

            # # Wrap the models with DataParallel
            discriminator = nn.DataParallel(Discriminator().to(device), device_ids=range(num_gpus))
            generator = nn.DataParallel(PatchQuantumGenerator(n_generators).to(device), device_ids=range(num_gpus))

            # Non parallel invocation of the model
            # discriminator = Discriminator().to(device)
            # generator = PatchQuantumGenerator(n_generators).to(device)

            # Binary cross entropy
            criterion = nn.BCELoss()

            # Optimisers
            optD = optim.SGD(discriminator.parameters(), lr=lrD)
            optG = optim.SGD(generator.parameters(), lr=lrG)

            real_labels = torch.full((batch_size,), 1.0, dtype=torch.float, device=device)
            fake_labels = torch.full((batch_size,), 0.0, dtype=torch.float, device=device)

            # Fixed noise allows us to visually track the generated images throughout training
            fixed_noise = torch.rand(8, n_qubits, device=device) * math.pi / 2

            # Iteration counter
            counter = 0

            # Collect images for plotting later
            results = []

            while True:
                for i, (data, _) in enumerate(dataloader):

                    # Data for training the discriminator
                    data = data.reshape(-1, image_size * image_size)
                    real_data = data.to(device)

                    # Noise follwing a uniform distribution in range [0,pi/2)
                    noise = torch.rand(batch_size, n_qubits, device=device) * math.pi / 2
                    fake_data = generator(noise)

                    # Training the discriminator
                    discriminator.zero_grad()
                    outD_real = discriminator(real_data).view(-1)
                    outD_fake = discriminator(fake_data.detach()).view(-1)

                    errD_real = criterion(outD_real, real_labels)
                    errD_fake = criterion(outD_fake, fake_labels)
                    # Propagate gradients
                    errD_real.backward()
                    errD_fake.backward()

                    errD = errD_real + errD_fake
                    optD.step()

                    # Training the generator
                    generator.zero_grad()
                    outD_fake = discriminator(fake_data).view(-1)
                    errG = criterion(outD_fake, real_labels)
                    errG.backward()
                    optG.step()

                    counter += 1

                    # Show loss values
                    if counter % 10 == 0:
                        print(f'Iteration: {counter}, Discriminator Loss: {errD:0.3f}, Generator Loss: {errG:0.3f}')
                        test_images = generator(fixed_noise).view(8, 1, image_size, image_size).cpu().detach()

                        # Save images every 50 iterations
                        if counter % 50 == 0:
                            results.append(test_images)

                    if counter == num_iter:
                        break
                if counter == num_iter:
                    break


        training_future = dask_client.submit(run_training)
        training_future.result()
    finally:
        if dask_client:
            dask_client.cancel()

If you want help with diagnosing an error, please put the full error message below:

# Put full error message here
[0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
['nid008688', 'nid008688', 'nid008688', 'nid008688', 'nid008688', 'nid008688', 'nid008688', 'nid008688', 'nid008688', 'nid008688']
Number of GPUS 4 on the host nid008688
Iteration: 10, Discriminator Loss: 1.361, Generator Loss: 0.596
2023-07-31 19:03:49,676 - distributed.worker - WARNING - Compute Failed
Key:       run_training-95aa27f38c0648a3855f20fa07c3e5fa
Function:  run_training
args:      ()
kwargs:    {}
Exception: 'RuntimeError(\'Caught RuntimeError in replica 1 on device 1.\\nOriginal Traceback 
(most recent call last):\\n  File "/global/common/software/m4408/prmantha/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py", line 64, in _worker\\n    output = module(*input, **kwargs)\\n             ^^^^^^^^^^^^^^^^^^^^^^^^\\n  
File "/global/common/software/m4408/prmantha/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl\\n    return forward_call(*args, **kwargs)\\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n  File "/global/u1/p/prmantha/pilot-streaming/examples/scripts/nersc/test.py", line 186, in forward\\n    
patches = torch.cat((patches, q_out))\\n              ^^^^^^^^^^^^^^^^^^^^^^^^^^^\\nRuntimeError: 
Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! 
(when checking argument for argument tensors in method wrapper_CUDA_cat)\\n\')'

And, finally, make sure to include the versions of your packages. Specifically, show us the output of qml.about().

Name: PennyLane
Version: 0.30.0
Summary: PennyLane is a Python quantum machine learning library by Xanadu Inc.
Home-page: https://github.com/XanaduAI/pennylane
Author: 
Author-email: 
License: Apache License 2.0
Location: /global/common/software/m4408/prmantha/lib/python3.11/site-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, pennylane-lightning, requests, rustworkx, scipy, semantic-version, toml
Required-by: PennyLane-Lightning, PennyLane-Lightning-GPU

Platform info:           Linux-5.14.21-150400.24.46_12.0.73-cray_shasta_c-x86_64-with-glibc2.31
Python version:          3.11.4
Numpy version:           1.23.5
Scipy version:           1.11.1
Installed devices:
- default.gaussian (PennyLane-0.30.0)
- default.mixed (PennyLane-0.30.0)
- default.qubit (PennyLane-0.30.0)
- default.qubit.autograd (PennyLane-0.30.0)
- default.qubit.jax (PennyLane-0.30.0)
- default.qubit.tf (PennyLane-0.30.0)
- default.qubit.torch (PennyLane-0.30.0)
- default.qutrit (PennyLane-0.30.0)
- null.qubit (PennyLane-0.30.0)
- lightning.qubit (PennyLane-Lightning-0.31.0)
- lightning.gpu (PennyLane-Lightning-GPU-0.31.0)

Hi @QuantumMan, that looks like you’re trying to make a call without having all of your tensors on the same device.

Did you try checking on which devices your tensors live?
For all their benefits, dealing with multiple GPUs is never straightforward. :slight_smile:

Same issue here. But instead of using lightning.qubit I used default.qubit.torch. I guess this is not supported yet.

Hey @Daniel_Wang,

Unfortunately Dask doesn’t really jive well with PennyLane and is hard for us to support. Having said that, if you have access to a GPU you can parallelize things directly in PennyLane (see here: PennyLane v0.31 released | PennyLane Blog). Let me know if that helps!

Hey @Daniel_Wang, one small note to add to my last response:

default qubit now has support for basic parallelization (not in Dask) by setting max_workers when loading the device. Scroll down to the section titled Accelerate calculations with multiprocessing.