HI @mass_of_15, I modified the demo quantum gan from pennylane in order to adapt the code to my dataset that is composed by 300 RGB images. I resized images to 64x64 and batch size=32, but the training cell is still running for infinite time or I get a memory crash for the total utilization of RAM both on google colab and locally. So, I tried to reduce the dimension of the dataset from 300 items to 30 and the batchsize from 32 to 1 and the training completed, but after 200 epochs the output images are black. Can I ask you if there are errors in the code or in the visualization of images?? Do you have an example of implementation of quantum gan with RGB images?
for k in range (len(test_images)):
fig,axs = plt.subplots(1, 1, sharey=False, tight_layout=True, figsize=(2,2), facecolor=‘white’)
#axs.matshow(np.squeeze(test_images[k].permute(1,2,0)))
axs.matshow(test_images[k].T)
!pip install pennylane custatevec-cu11 pennylane-lightning-gpu
# Library imports
import math
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import pennylane as qml
# Pytorch imports
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
# Set the random seed for reproducibility
seed = 999
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
# Number of workers for dataloader
workers = 2
# Batch size during training
batch_size = 32
# Spatial size of training images. All images will be resized to this
# size using a transformer.
image_size = 64
# Number of channels in the training images. For color images this is 3
nc = 3
# Size of z latent vector (i.e. size of generator input)
nz = 13
# Size of feature maps in generator
ngf = 64
# Size of feature maps in discriminator
ndf = 64
# Number of training epochs
num_epochs = 100
# Learning rate for optimizers
lr = 0.0002
# Beta1 hyperparam for Adam optimizers
beta1 = 0.5
# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1
from google.colab import drive
drive.mount('/content/drive')
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
import torchvision.datasets as dset
ImgLocation='/content/drive/MyDrive/Colab Notebooks/Subset_Dil_Bos/'
dataset = dset.ImageFolder(root=ImgLocation,
transform=transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,shuffle=True, num_workers=workers)
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
class Discriminator(nn.Module):
"""Fully connected classical discriminator"""
def __init__(self,ngpu):
super().__init__()
self.ngpu = ngpu
self.model = nn.Sequential(
# Inputs to first hidden layer (num_input_features -> 64)
nn.Linear(3*image_size * image_size, 64,device=device),
nn.ReLU(),
nn.Linear(64, 64,device=device),
nn.ReLU(),
# First hidden layer (64 -> 16)
nn.Linear(64, 16,device=device),
nn.ReLU(),
# Second hidden layer (16 -> output)
nn.Linear(16, 1,device=device),
nn.Sigmoid(),
)
def forward(self, x):
return self.model(x)
discriminator = Discriminator(ngpu).to(device)
# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
discriminator = nn.DataParallel(discriminator, list(range(ngpu)))
# Apply the weights_init function to randomly initialize all weights
# to mean=0, stdev=0.2.
discriminator.apply(weights_init)
# Print the model
print(discriminator)
# Quantum variables
n_qubits = 13 # Total number of qubits / N
n_a_qubits = 2 # Number of ancillary qubits / N_A
q_depth = 6 # Depth of the parameterised quantum circuit / D
n_generators = 6 # Number of subgenerators for the patch method / N_G
dev = qml.device("lightning.gpu", wires=n_qubits)
@qml.qnode(dev, interface="torch") #, diff_method="parameter-shift"
def quantum_circuit(noise, weights):
weights = weights.reshape(q_depth, n_qubits)
# Initialise latent vectors
for i in range(n_qubits):
qml.RY(noise[i], wires=i)
# Repeated layer
for i in range(q_depth):
# Parameterised layer
for y in range(n_qubits):
qml.RY(weights[i][y], wires=y)
# Control Z gates
for y in range(n_qubits - 1):
qml.CZ(wires=[y, y + 1])
return qml.probs(wires=list(range(n_qubits)))
def partial_measure(noise, weights):
# Non-linear Transform
probs = quantum_circuit(noise, weights)
probsgiven0 = probs[: (2 ** (n_qubits - n_a_qubits))]
probsgiven0 /= torch.sum(probs)
# Post-Processing
probsgiven = probsgiven0 / torch.max(probsgiven0)
return probsgiven
class PatchQuantumGenerator(nn.Module):
"""Quantum generator class for the patch method"""
def __init__(self, n_generators, q_delta=1):
"""
Args:
n_generators (int): Number of sub-generators to be used in the patch method.
q_delta (float, optional): Spread of the random distribution for parameter initialisation.
"""
super().__init__()
self.q_params = nn.ParameterList(
[
nn.Parameter(q_delta * torch.rand(q_depth * n_qubits), requires_grad=True)
for _ in range(n_generators)
]
)
self.n_generators = n_generators
def forward(self, x):
# Size of each sub-generator output
patch_size = 2 ** (n_qubits - n_a_qubits)
# Create a Tensor to 'catch' a batch of images from the for loop. x.size(0) is the batch size.
images = torch.Tensor(x.size(0), 0).to(device)
# Iterate over all sub-generators
for params in self.q_params:
# Create a Tensor to 'catch' a batch of the patches from a single sub-generator
patches = torch.Tensor(0, patch_size).to(device)
for elem in x:
q_out = partial_measure(elem, params).float().unsqueeze(0)
patches = torch.cat((patches, q_out))
# Each batch of patches is concatenated with each other to create a batch of images
images = torch.cat((images, patches), 1)
return images
lrG = 0.3 # Learning rate for the generator
lrD = 0.01 # Learning rate for the discriminator
num_iter = 100# Number of training iterations
discriminator = Discriminator(ngpu).to(device)
generator = PatchQuantumGenerator(n_generators).to(device)
# Binary cross entropy
criterion = nn.BCELoss()
# Optimisers
optD = optim.SGD(discriminator.parameters(), lr=lrD)
optG = optim.SGD(generator.parameters(), lr=lrG)
real_labels = torch.full((batch_size,), 1.0, dtype=torch.float, device=device)
fake_labels = torch.full((batch_size,), 0.0, dtype=torch.float, device=device)
# Fixed noise allows us to visually track the generated images throughout training
fixed_noise = torch.rand(image_size, n_qubits, device=device) * math.pi / 2
# Iteration counter
counter = 0
# Collect images for plotting later
results = []
while True:
for i, (data, _) in enumerate(dataloader):
# Data for training the discriminator
data = data.reshape( -1,image_size * image_size*3) #data=(32,12288)
real_data = data.to(device) #real_data=(32,12288)
# Training the discriminator
# Noise follwing a uniform distribution in range [0,pi/2)
noise = torch.rand(batch_size, n_qubits, device=device) * math.pi / 2 #noise=(32,13)
fake_data = generator(noise) #fake_data=(32,12288)
discriminator.zero_grad()
outD_real = discriminator(real_data).view(-1) #(outD_real=32)
errD_real = criterion(outD_real, real_labels) #(criterion(32, 32))
outD_fake = discriminator(fake_data.detach()).view(-1) #(outD_fake=32)
errD_fake = criterion(outD_fake, fake_labels) #(criterion(32,32))
# Propagate gradients
errD_real.backward()
errD_fake.backward()
errD = errD_real + errD_fake
optD.step()
# Training the generator
generator.zero_grad()
outD_fake = discriminator(fake_data).view(-1) #outD_fake=32
errG = criterion(outD_fake, real_labels) #criterion(32,32)
errG.backward()
optG.step()
counter += 1
# Show loss values
if counter % 10 == 0:
print(f'Iteration: {counter}, Discriminator Loss: {errD:0.3f}, Generator Loss: {errG:0.3f}')
test_images = generator(fixed_noise).view(image_size,3,image_size,image_size).cpu().detach()
# Save images every 50 iterations
if counter % 50 == 0:
results.append(test_images)
if counter == num_iter:
break
if counter == num_iter:
break