Qnn with Torchlayer cant process multip-Dim input data

I have a two-dimensional input data and I have created a quantum circuit to encode this type of input. Here is my code:

import pennylane as qml
import torch

class QuantumCircuit:
    def __init__(self, n_wqubits, ch_in):
        self.n_wqubits = n_wqubits
        self.n_qubits = n_wqubits + 1
        self.ch_in = ch_in
        self.weight_shapes = {"weights": (self.n_qubits, self.ch_in)}
        self.dev = qml.device('default.qubit', wires=self.n_qubits)

    def quantum_circuit(self, inputs, weights):
        qml.Hadamard(wires=0)
        input = inputs.reshape(ch_in, n_wqubits)

        for k in range(self.ch_in):
            for i in range(self.n_wqubits):
                qml.RX(input[k][i], wires=i + 1)
        qml.Hadamard(wires=0)
        return [qml.expval(qml.PauliZ(i)) for i in range(self.n_qubits)]

Test code is

QuantumCircuit1 = QuantumCircuit(n_wqubits=4, ch_in=2)
qnode = qml.QNode(QuantumCircuit1.quantum_circuit, QuantumCircuit1.dev, interface="torch", diff_method='best')
qnn_layer = qml.qnn.TorchLayer(qnode, QuantumCircuit1.weight_shapes)

# Single input data
inputs = torch.randn(2, 4)
output = qnn_layer(inputs)
print(output)

# Batch input data
inputs = torch.randn(8, 2, 4)
output = qnn_layer(inputs)
print(output)
# pennylane version 0.38.0
# torch-gpu version 2.5.1

When using single input data (e.g., [inputs = torch.randn(2, 4)]), the code runs fine. However, when using batch input data (e.g., [inputs = torch.randn(8, 2, 4)], it actually be automatically reshaped as torch.randn(16, 4) when using circuits), the code throws an error:
" RuntimeError: shape ‘[2, -1]’ is invalid for input of size 1 "

Hi @kente , welcome to the Forum!

Yes, lately qml.qnn.TorchLayer has shown some issues for people when using batches. One alternative is to create a class that inherits from torch.nn.Module, where you break up the batch into elements within the forward pass. This replaces the need for using TorchLayer. The code below shows how to do this for your code example. Note that it will work with inputs = torch.randn(8,2,4) and inputs = torch.randn(1,2,4) but it will throw an error if you try to do inputs = torch.randn(2,4).

# Create a model that inherits from torch.nn.Module
class MyModel(torch.nn.Module):

    def __init__(self):
        super().__init__()
        # Initialize the parameters
        self.n_qubits = QuantumCircuit1.n_qubits
        self.ch_in = QuantumCircuit1.ch_in
        self.q_params = torch.nn.Parameter(torch.randn(self.n_qubits,self.ch_in))

    def forward(self, inputs):
        # initialize q_out with the dimension of the output of the circuit
        q_out = torch.Tensor(0, self.n_qubits)

        # Apply the quantum circuit to each element of the batch and append to q_out
        for elem in inputs:
            q_out_elem = torch.hstack(qnode(elem, self.q_params)).float().unsqueeze(0)
            q_out = torch.cat((q_out, q_out_elem))

        # return the concatenated output
        return q_out

# Create an instance of your model
my_model = MyModel()

# Run your model
inputs = torch.randn(8,2,4)
print('Output: \n',my_model(inputs))

You can find a similar example in this forum thread.

I hope this helps you!