Good afternoon Pennylane’s guys, I’m diving into “Generalization from few training data” tutorial to use the QCNN architecture you have made, which looks super!!!
Here, I would like to focus on the functions used to calculate metrics such as accuracy and loss, subsequently I’m going to share the code which display something suspicious: 99% is my mistake, but I can’t see/understand. So, I am here to ask your help, please.
I am not convinced about the loss function and the accuracy used in the tutorial. Usually for classification problems people use a cross-entropy function. Anyway, here is what I have written, again on MNIST downsampled to 256 features (not on sklearn.datasets.load_digits):
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, accuracy_score
import jax;
jax.config.update('jax_platform_name', 'cpu')
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import optax # optimization using jax
import pennylane as qml
import pennylane.numpy as pnp
from keras.datasets import mnist, fashion_mnist
import tensorflow as tf
seed = 0
def convolutional_layer(weights, wires, skip_first_layer=True):
"""Adds a convolutional layer to a circuit.
Args:
weights (np.array): 1D array with 15 weights of the parametrized gates.
wires (list[int]): Wires where the convolutional layer acts on.
skip_first_layer (bool): Skips the first two U3 gates of a layer.
"""
n_wires = len(wires)
assert n_wires >= 2, "this circuit is too small!"
for p in [0, 1, 2]:
for indx, w in enumerate(wires):
if indx % 2 == p and indx < n_wires - 1:
if indx % 2 == 0 and not skip_first_layer:
qml.U3(*weights[:3], wires=[w])
qml.U3(*weights[3:6], wires=[wires[indx + 1]])
qml.IsingXX(weights[6], wires=[w, wires[indx + 1]])
qml.IsingYY(weights[7], wires=[w, wires[indx + 1]])
qml.IsingZZ(weights[8], wires=[w, wires[indx + 1]])
qml.U3(*weights[9:12], wires=[w])
qml.U3(*weights[12:], wires=[wires[indx + 1]])
def pooling_layer(weights, wires):
"""Adds a pooling layer to a circuit.
Args:
weights (np.array): Array with the weights of the conditional U3 gate.
wires (list[int]): List of wires to apply the pooling layer on.
"""
n_wires = len(wires)
assert len(wires) >= 2, "this circuit is too small!"
for indx, w in enumerate(wires):
if indx % 2 == 1 and indx < n_wires:
m_outcome = qml.measure(w)
qml.cond(m_outcome, qml.U3)(*weights, wires=wires[indx - 1])
def conv_and_pooling(kernel_weights, n_wires, skip_first_layer=True):
"""Apply both the convolutional and pooling layer."""
convolutional_layer(kernel_weights[:15], n_wires, skip_first_layer=skip_first_layer)
pooling_layer(kernel_weights[15:], n_wires)
def dense_layer(weights, wires):
"""Apply an arbitrary unitary gate to a specified set of wires."""
qml.ArbitraryUnitary(weights, wires)
num_wires = 8 ## 256 features
device = qml.device("default.qubit", wires=num_wires)
@qml.qnode(device)
def conv_net_LONG(weights, last_layer_weights, features):
"""Define the QCNN circuit
Args:
weights (np.array): Parameters of the convolution and pool layers.
last_layer_weights (np.array): Parameters of the last dense layer.
features (np.array): Input data to be embedded using AmplitudEmbedding."""
layers = weights.shape[1]
wires = list(range(num_wires))
# inputs the state input_state
qml.AmplitudeEmbedding(features=features, wires=wires, pad_with=0.5)
qml.Barrier(wires=wires, only_visual=True)
# adds convolutional and pooling layers
for j in range(layers):
conv_and_pooling(weights[:, j], wires, skip_first_layer=(not j == 0))
wires = wires[::2]
qml.Barrier(wires=wires, only_visual=True)
assert last_layer_weights.size == 4 ** (len(wires)) - 1, (
"The size of the last layer weights vector is incorrect!"
f" \n Expected {4 ** (len(wires)) - 1}, Given {last_layer_weights.size}"
)
dense_layer(last_layer_weights, wires)
#return qml.expval(qml.PauliZ(wires=(0)))
return qml.probs(wires=0)
fig, ax = qml.draw_mpl(conv_net_LONG, style="sketch")(
np.random.rand(18, 3), np.random.rand(4 ** 1 - 1), np.random.rand(2 ** num_wires)
)
plt.show()
@jax.jit
def compute_out(weights, weights_last, features, labels):
"""Computes the output of the corresponding label in the qcnn"""
cost = lambda weights, weights_last, feature, label: conv_net_LONG(weights, weights_last, feature)[
label
]
return jax.vmap(cost, in_axes=(None, None, 0, 0), out_axes=0)(
weights, weights_last, features, labels
)
### The remapping function here is useless since probabilities are already between {0,1}
def binary_crossentropy(weights, weights_last, features, labels):
ytrue = jnp.array(labels)
out = jnp.array(compute_out(weights, weights_last, features, labels))
epsilon = 1e-8
ypred_bounded = jnp.clip(out, epsilon, 1 - epsilon)
loss = ytrue*jnp.log10(ypred_bounded) + (1-ytrue)*jnp.log10(1-ypred_bounded)
return -jnp.mean(loss)
def calculate_accuracy(weights, weights_last, features, labels):
out = np.array(compute_out(weights, weights_last, features, labels))
y_predicted = jnp.where(out <= 0.5, 0, 1)
accuracy = accuracy_score(y_true=labels, y_pred=y_predicted)
return accuracy
n_epochs = 200
weights, weights_last = init_weights()
np.random.seed(7)
# learning rate decay
cosine_decay_scheduler = optax.cosine_decay_schedule(0.1, decay_steps=n_epochs, alpha=0.95)
optimizer = optax.adam(learning_rate=cosine_decay_scheduler)
opt_state = optimizer.init((weights, weights_last))
# data containers
train_cost_epochs, train_acc_epochs = [], []
for step in range(n_epochs):
# Training step with (adam) optimizer
train_cost, grad_circuit = value_and_grad(weights, weights_last, X_train, y_train)
updates, opt_state = optimizer.update(grad_circuit, opt_state)
weights, weights_last = optax.apply_updates((weights, weights_last), updates)
train_cost_epochs.append(train_cost)
#compute accuracy on training data
train_acc = calculate_accuracy(weights, weights_last, X_train, y_train)
train_acc_epochs.append(train_acc)
print(f"Epoch {step}:", "---Train loss:", train_cost, "---Train acc.:", train_acc)
# Save the optimal weights
optimal_weights = weights
optimal_last_weights = weights_last
And here the results, on which I’m not convinced at all…:
Epoch 0: —Train loss: 0.25562302862628694 —Train acc.: 1.0
Epoch 1: —Train loss: 0.10043462121607755 —Train acc.: 1.0
Epoch 2: —Train loss: 0.07562060084377388 —Train acc.: 1.0
Epoch 3: —Train loss: 0.06272983254153912 —Train acc.: 1.0
Epoch 4: —Train loss: 0.047359607727645685 —Train acc.: 1.0
Epoch 5: —Train loss: 0.04271474614977181 —Train acc.: 1.0
Epoch 6: —Train loss: 0.04542359568846337 —Train acc.: 1.0
Epoch 7: —Train loss: 0.047489800861275966 —Train acc.: 1.0
Epoch 8: —Train loss: 0.042356685193502955 —Train acc.: 1.0
Epoch 9: —Train loss: 0.03308038205266399 —Train acc.: 1.0
Epoch 10: —Train loss: 0.026231809437399337 —Train acc.: 1.0
Epoch 11: —Train loss: 0.02460413246078036 —Train acc.: 1.0
Epoch 12: —Train loss: 0.02722732315876212 —Train acc.: 1.0
Epoch 13: —Train loss: 0.029903901220307238 —Train acc.: 1.0
Epoch 14: —Train loss: 0.029678124089959984 —Train acc.: 1.0
Epoch 15: —Train loss: 0.026717622036112207 —Train acc.: 1.0
Epoch 16: —Train loss: 0.022597290301865646 —Train acc.: 1.0
Epoch 17: —Train loss: 0.019179374443462318 —Train acc.: 1.0
Epoch 18: —Train loss: 0.017661101994214733 —Train acc.: 1.0
Epoch 19: —Train loss: 0.01752062683214453 —Train acc.: 1.0
Epoch 20: —Train loss: 0.018047777441789363 —Train acc.: 1.0
Epoch 21: —Train loss: 0.018676961174444638 —Train acc.: 1.0
Epoch 22: —Train loss: 0.018538257030111038 —Train acc.: 1.0
Epoch 23: —Train loss: 0.017574251387971206 —Train acc.: 1.0
Epoch 24: —Train loss: 0.01644080597210095 —Train acc.: 1.0. …I have dropped next 176 epochs.
accuracy_test = test_accuracy(optimal_weights, optimal_last_weights, X_test, y_test)
print("Accuracy on test set: ", accuracy_test)
Accuracy on test set: 1.0
ypredicted = jnp.where((compute_out(optimal_weights, optimal_last_weights, X_test, y_test)) <=0.5, 0, 1)
test_cm = confusion_matrix(y_true=y_test, y_pred=ypredicted)
ConfusionMatrixDisplay(test_cm).plot()
plt.tight_layout()
array([[199, 0],
[ 0, 201]])
I am not convinced because if I compare to a simple CNN written in Torch, it is able to reach 100% of accuracy but the loss function is like 10^{-6} below 0.
Here you can see:
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
"""
Conv params: ((w * h * d) + 1 * ) * k,
where:
w stands for width and is 3 (kernel's size),
h stands for height and is 3 (kernel's size),
d stands for previous' layer filter.
"""
self.layer1 = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=12, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.layer2 = nn.Sequential(
nn.Conv2d(in_channels=12, out_channels=24, kernel_size=5),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.fc1 = nn.Linear(in_features=2*2*24, out_features=60)
self.fc2 = nn.Linear(in_features=60, out_features=20)
self.fc3 = nn.Linear(in_features=20, out_features=2)
def forward(self, x):
out = self.layer1(x) #n° params = 0
out = self.layer2(out) # n° params = (5*5*1 + 1)*6
out = out.view(out.size(0), -1)
out = self.fc1(out)
out = self.fc2(out)
out = self.fc3(out)
return out
cnn_net = CNN()
loss_function = nn.CrossEntropyLoss() ## loss function
optimizer = optim.Adam(cnn_net.parameters(), lr=0.001)
CNN(
(layer1): Sequential(
(0): Conv2d(1, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU()
(2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(layer2): Sequential(
(0): Conv2d(12, 24, kernel_size=(5, 5), stride=(1, 1))
(1): ReLU()
(2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(fc1): Linear(in_features=96, out_features=60, bias=True)
(fc2): Linear(in_features=60, out_features=20, bias=True)
(fc3): Linear(in_features=20, out_features=2, bias=True)
)
### Training
prediction_list = []
accuracy_list = []
loss_list = []
for epoch in range(n_epochs):
running_loss = 0.0
correct_predictions = 0
total_samples = 0
for images, labels in data_loader:
optimizer.zero_grad()
outputs = cnn_net(images)
loss = loss_function(outputs, labels)
loss.backward()
optimizer.step()
_, predicted = torch.max(outputs, 1)
correct_predictions += (predicted == labels).sum().item()
total_samples += labels.size(0)
prediction_list.extend(predicted.tolist())
running_loss += loss.item()
epoch_loss = running_loss / len(data_loader)
epoch_accuracy = correct_predictions / total_samples
# Append loss and accuracy for plotting
loss_list.append(epoch_loss)
accuracy_list.append(epoch_accuracy)
# Print epoch statistics
print(f"Epoch {epoch + 1}: Loss: {epoch_loss}, Accuracy: {epoch_accuracy}")
# After training, you can plot the loss and accuracy lists
Epoch 1: Loss: 0.5580346845090389, Accuracy: 0.76
Epoch 2: Loss: 0.22074300035601482, Accuracy: 0.955
Epoch 3: Loss: 0.10980158725869842, Accuracy: 0.96
Epoch 4: Loss: 0.08428646697138902, Accuracy: 0.97
Epoch 5: Loss: 0.05880510069628144, Accuracy: 0.975
Epoch 6: Loss: 0.05095558602570236, Accuracy: 0.98
Epoch 7: Loss: 0.03941345617458865, Accuracy: 0.985
Epoch 8: Loss: 0.02388183739858505, Accuracy: 0.995
Epoch 9: Loss: 0.04132281154779775, Accuracy: 0.975
Epoch 10: Loss: 0.027125994058746983, Accuracy: 0.99
Epoch 11: Loss: 0.008764596322464513, Accuracy: 1.0
Epoch 12: Loss: 0.005491497802381673, Accuracy: 1.0
Epoch 13: Loss: 0.003788663222833577, Accuracy: 1.0
Epoch 14: Loss: 0.002373977152736728, Accuracy: 1.0
Epoch 15: Loss: 0.002041028591252214, Accuracy: 1.0
Epoch 16: Loss: 0.0024530677900159504, Accuracy: 1.0
Epoch 17: Loss: 0.0011964650286728328, Accuracy: 1.0
Epoch 18: Loss: 0.0009337588966435107, Accuracy: 1.0
Epoch 19: Loss: 0.0006957614569481407, Accuracy: 1.0
Epoch 20: Loss: 0.0005514739908187849, Accuracy: 1.0
Epoch 21: Loss: 0.0004759820359014455, Accuracy: 1.0
Epoch 22: Loss: 0.00040661833953112845, Accuracy: 1.0
Epoch 23: Loss: 0.0003674088685538202, Accuracy: 1.0
Epoch 24: Loss: 0.00030712640227985587, Accuracy: 1.0
Epoch 25: Loss: 0.00027320761861062695, Accuracy: 1.0
...
Epoch 197: Loss: 3.1173123016259296e-07, Accuracy: 1.0
Epoch 198: Loss: 3.051747382798453e-07, Accuracy: 1.0
Epoch 199: Loss: 2.944460034015606e-07, Accuracy: 1.0
Epoch 200: Loss: 2.8967762801812567e-07, Accuracy: 1.0
As you can notice below. So I am not convinced about my results and the functions used in the original code either. Could you please help me?
Thanks in advance!