Here is the codes, thanks.
import random
import time
import numpy as np
import torch.optim as optim
import numpy
import torch
import json
import torch
import pennylane as qml
import torch.nn as nn
from PIL import Image
from pennylane.templates import AmplitudeEmbedding
iterations = 100
batch_size =32
N_QUBITS = 11
total_epochs = 3
learning_rate = 0.01
K = 320
dim=11
N_LAYERS = 2
m = 0.99
T = 0.07
# ===================================================================
def train_data(anchor, net, device, criterion, opt):
# anchor = anchor.to(device)
logits, labels, Queue = net(anchor)
loss = criterion(logits / T, labels)
loss.backward()
opt.step()
train_loss = loss.data
return train_loss, Queue
def train_network(net=None, TRAIN_DATASET=None, device=None, epochs=total_epochs, optimizer=None,
criterion=None, idx=None, n_gene=None):
net = net.to(device)
train_losses = []
Queue = None
net.train()
for epoch in range(epochs):
for i in range(iterations):
t2 = time.time()
anchor = random.sample(TRAIN_DATASET, batch_size)
anchor = torch.tensor(numpy.array(anchor), dtype=torch.float32, requires_grad=True)
train_loss, Queue = train_data(anchor, net, device, criterion, optimizer)
net.Queue = Queue
optimizer.zero_grad()
train_losses.append(train_loss)
print('batch({}):{:.4f}| train_loss:{:.4f} '.format(
i, time.time() - t2, train_loss))
if i % 10 == 0:
torch.save(net.state_dict(), f"result/model_train_parameters_{n_gene}_{idx}.pth")
return Queue
dev = qml.device("lightning.kokkos", wires=N_QUBITS)
@qml.qnode(dev)
def qnode_q(inputs, weights):
global clayer
if torch.all(inputs == 0) or len(inputs) > 2 ** N_QUBITS:
print('something error', torch.array(inputs).shape, inputs)
AmplitudeEmbedding(features=inputs, wires=range(N_QUBITS), normalize=True, pad_with=0.01)
count = 0
for W in weights:
layer(W, count)
count += 1
result = [qml.expval(qml.PauliZ(i)) for i in range(N_QUBITS)]
return result
def layer(W, count):
clayer = [[[0, 10], [1, 10], [2, 10], [8, 10], [10, 7]], [[5, 10], [6, 10], [7, 10], [8, 10], [7, 4], [9, 8]]]
for i in range(N_QUBITS):
qml.Rot(W[i, 0], W[i, 1], W[i, 2], wires=i)
qml.Barrier(only_visual=False)
for element in clayer[count]:
wire, ctrl = element
qml.CNOT(wires=[wire, ctrl])
qml.Barrier(only_visual=False)
weight_shapes = {"weights": (N_LAYERS, N_QUBITS, 3)}
class Net(torch.nn.Module):
def __init__(self, Queue):
super(Net, self).__init__()
self.encoder_q = qml.qnn.TorchLayer(qnode_q, weight_shapes)
self.encoder_k = qml.qnn.TorchLayer(qnode_q, weight_shapes)
self.K = K
self.m = m
self.T = T
self.Queue = Queue
for params_q, params_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
params_k.data.copy_(params_q.data)
params_k.requires_grad = False
self.Queue = torch.nn.functional.normalize(self.Queue, dim=0)
self.queue_ptr = torch.zeros(1, dtype=torch.long)
@torch.no_grad()
def _momentum_update_key_encoder(self):
for param_q, param_k in zip(
self.encoder_q.parameters(), self.encoder_k.parameters()
):
param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m)
param_k.data.require_grad = False
@torch.no_grad()
def _dequeue_and_enqueue(self, keys):
batch_size = keys.shape[0]
ptr = int(self.queue_ptr)
assert self.K % batch_size == 0
self.Queue[:, ptr: ptr + batch_size] = keys.T
ptr = (ptr + batch_size) % self.K
self.queue_ptr[0] = ptr
def forward(self, inputs):
im_q = [a[0] for a in inputs]
im_k = [a[1] for a in inputs]
q = None
for idx, anchor in enumerate(im_q):
q_exp = torch.Tensor(self.encoder_q(anchor)).flatten().unsqueeze(0)
if q is None:
q = q_exp
else:
q = torch.cat((q, q_exp))
q = torch.nn.functional.normalize(q, dim=0) # # queries: NxC [64, 11]
self._momentum_update_key_encoder()
k = None
for key in im_k:
k_exp = torch.Tensor(self.encoder_k(key)).flatten().unsqueeze(0)
if k is None:
k = k_exp
else:
k = torch.cat((k, k_exp))
k = torch.nn.functional.normalize(k, dim=0).detach()
l_pos = torch.einsum("nc,nc->n", [q, k]).unsqueeze(-1)
que_k = self.Queue.clone().detach()
que_k = que_k.to(device)
l_neg = torch.einsum("nc,ck->nk", [q, que_k])
logits = torch.cat([l_pos, l_neg], dim=1)
logits /= self.T
labels = torch.zeros(logits.shape[0], dtype=torch.long)
self._dequeue_and_enqueue(k)
return logits, labels, self.Queue
device = torch.device("cuda:0")
torch.cuda.manual_seed(42)
Queue = torch.randn(dim, K)
net = Net(Queue)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.encoder_q.parameters(), lr=learning_rate)
fig = np.random.randint(0, 256, (26, 26, 3), dtype=np.uint8)
TRAIN_DATASET = []
for i in range(100):
TRAIN_DATASET.append((np.array(fig).flatten(), np.array(fig).flatten()))
Queue = train_network(net, TRAIN_DATASET, device, total_epochs,
optimizer, criterion)