Hi @kente , welcome to the Forum!
Yes, lately qml.qnn.TorchLayer has shown some issues for people when using batches. One alternative is to create a class that inherits from torch.nn.Module
, where you break up the batch into elements within the forward pass. This replaces the need for using TorchLayer. The code below shows how to do this for your code example. Note that it will work with inputs = torch.randn(8,2,4)
and inputs = torch.randn(1,2,4)
but it will throw an error if you try to do inputs = torch.randn(2,4)
.
# Create a model that inherits from torch.nn.Module
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
# Initialize the parameters
self.n_qubits = QuantumCircuit1.n_qubits
self.ch_in = QuantumCircuit1.ch_in
self.q_params = torch.nn.Parameter(torch.randn(self.n_qubits,self.ch_in))
def forward(self, inputs):
# initialize q_out with the dimension of the output of the circuit
q_out = torch.Tensor(0, self.n_qubits)
# Apply the quantum circuit to each element of the batch and append to q_out
for elem in inputs:
q_out_elem = torch.hstack(qnode(elem, self.q_params)).float().unsqueeze(0)
q_out = torch.cat((q_out, q_out_elem))
# return the concatenated output
return q_out
# Create an instance of your model
my_model = MyModel()
# Run your model
inputs = torch.randn(8,2,4)
print('Output: \n',my_model(inputs))
You can find a similar example in this forum thread.
I hope this helps you!