Aiming to use the “default.qubit” simulator, pennylane, and pytorch to create a basic quantum version of a recurrent neural network for sequential data. However, I’m having trouble understanding how classical and quantum computing relate to one another!
Here is my basis code for the classical version, which I’m attempting to follow.
# Simple RNN model for binary classification
class RNNModel(nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(RNNModel, self).__init__()
self.hidden_size = hidden_size
self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
self.fc = nn.Linear(hidden_size, num_classes)
def forward(self, x):
h0 = torch.zeros(1, x.size(0), self.hidden_size) # Initial hidden state
out, _ = self.rnn(x.unsqueeze(1), h0) # Add a batch dimension (seq_len=1)
out = self.fc(out[:, -1, :]) # Get the output of the last time step
return out
and here is my code for quantum version :
# Define the quantum circuit
def quantum_circuit(inputs, weights):
qml.templates.AngleEmbedding(inputs,rotation='X', wires=range(len(inputs)))
qml.templates.BasicEntanglerLayers(weights, wires=range(len(inputs)))
return [qml.expval(qml.PauliZ(i)) for i in range(len(inputs))]
class QuantumRNNCell(nn.Module):
def __init__(self, input_size, embedding_size, hidden_size, n_qubits, n_qlayers):
super(QuantumRNNCell, self).__init__()
self.hidden_size = hidden_size
self.concat_size = hidden_size + embedding_size # Concatenate input with hidden state
self.device = qml.device('default.qubit', wires=n_qubits)
# Define the quantum node
self.qlayer = qml.QNode(quantum_circuit, self.device, interface="torch")
self.weight_shapes = {"weights": (n_qlayers, n_qubits)}
# Classical layers for processing input and output
self.clayer_in = torch.nn.Linear(self.concat_size, n_qubits) # Linear layer before VQC
self.clayer_out = torch.nn.Linear(n_qubits, hidden_size) # Linear layer after VQC
# Quantum layer (variational quantum circuit)
self.VQC = qml.qnn.TorchLayer(self.qlayer, self.weight_shapes)
def forward(self, x, h_t):
batch_size, seq_length, _ = x.size() # Get batch, sequence, and feature size
h_t = h_t if h_t is not None else torch.zeros(batch_size, self.hidden_size).to(x.device) # Initialize hidden state
for i in range(seq_length):
# x=x.unsqueezed(1)
print("x", x.shape)
print("h_t", h_t.shape)
x_t = x[:, i, :] # Get the i-th time step (word)
v_t = torch.cat((h_t, x_t), dim=1) # Concatenate hidden state and input
y_t = self.clayer_in(v_t) # Apply linear layer before quantum circuit
q_output = self.VQC(y_t) # Apply quantum layer
h_t = torch.tanh(self.clayer_out(q_output)) # Apply linear layer after quantum circuit
return h_t, h_t # Return hidden state and output (same for QRNN)
# Define the main QRNN model
class QRNN(nn.Module):
def __init__(self, input_size, embedding_size, hidden_size, n_qubits, n_qlayers):
super(QRNN, self).__init__()
self.embedding_layer = nn.Embedding(input_size, embedding_size) # Embedding layer for text data
self.qrnn_layer = QuantumRNNCell(input_size, embedding_size, hidden_size, n_qubits, n_qlayers)
self.dropout = nn.Dropout(0.2) # Dropout layer
self.fc = nn.Linear(hidden_size, 1) # Final fully connected layer for binary classification
def forward(self, x, hidden=None):
embedded_input = self.embedding_layer(x)
print("embedde input shape ", embedded_input.shape())
# Convert input indices to embeddings
qrnn_output, hidden = self.qrnn_layer(embedded_input, hidden) # Apply QRNN layer
qrnn_output = self.dropout(qrnn_output) # Apply dropout
output = self.fc(qrnn_output) # Final output layer
return torch.sigmoid(output) # Apply sigmoid for binary classification
My questions are :
Is each neuron replaced by an independent Variational quantum circuit !! if yes , am i using 8 VQCs if i define my hidden_size = 8?
Does the number of qubits depends on the input_size, embedding_size or both !! How many qubits will i need if my input_size =12 and embedding_size=4 (Using Angle and BaciseltagledLayer in VQC design)
What should I take into account for my architecture to handle these samples effectively if I use a batch_size=3 (processing 3 samples at once)?