Hello!
I’m trying to correctly implement CVNeuralNetLayers with TorchLayer. For some reason I just can’t get it right.
Below is some example code pulled from the PennyLane website:
import torch
from sklearn.datasets import make_moons
X, y = make_moons(n_samples=200, noise=0.1)
y_ = torch.unsqueeze(torch.tensor(y), 1)
y_hot = torch.scatter(torch.zeros((200, 2)), 1, y_, 1)
import pennylane as qml
dev = qml.device("default.qubit", wires=2)
@qml.qnode(dev)
def qnode(inputs, weights):
qml.AngleEmbedding(inputs, wires=range(2))
qml.CVNeuralNetLayers(*weights, wires=range(2))
return [qml.expval(qml.PauliZ(wires=i)) for i in range(2)]
shapes = qml.CVNeuralNetLayers.shape(n_layers=2, n_wires=2)
weight_shapes = {"weights" : [shapes]}
qlayer = qml.qnn.TorchLayer(qnode, weight_shapes)
clayer_1 = torch.nn.Linear(2, 2)
clayer_2 = torch.nn.Linear(2, 2)
softmax = torch.nn.Softmax(dim=1)
model = torch.nn.Sequential(clayer_1, qlayer, clayer_2, softmax)
opt = torch.optim.SGD(model.parameters(), lr=0.2)
loss = torch.nn.L1Loss()
X = torch.tensor(X, requires_grad=True).float()
y_hot = y_hot.float()
batch_size = 5
batches = 200 // batch_size
data_loader = torch.utils.data.DataLoader(
list(zip(X, y_hot)), batch_size=5, shuffle=True, drop_last=True
)
for epoch in range(2):
running_loss = 0
for xs, ys in data_loader:
opt.zero_grad()
loss_evaluated = loss(model(xs), ys)
loss_evaluated.backward()
opt.step()
The error that I get back returns as its last error:
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
...
/usr/local/lib/python3.7/dist-packages/pennylane/templates/layers/cv_neural_net.py in <listcomp>(.0)
116
117 # check second dimensions
--> 118 second_dims = [s[1] for s in shapes]
119 expected = [n_if] * 2 + [n_wires] * 3 + [n_if] * 2 + [n_wires] * 4
120 if not all(e == d for e, d in zip(expected, second_dims)):
IndexError: tuple index out of range
I’m trying to figure out how to make this work nicely with TorchLayer but I’m all out ideas. Any thoughts on how to correctly implement this?
Any help or advice would be very much appreciated.