Hello! I’m trying to write a TorchLayer capable of handling multidimensional batches. I’d like to split the batch in the second dimension and perform some separate computations on the two halves. A toy example of a similar computation in PyTorch looks like this:

```
import torch
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(2, 2)
def forward(self, x):
# Split the batch along second dimension
x1 = x[:, 0]
x2 = x[:, 1]
# Do some computation separately
y1 = self.linear(x1)
y2 = self.linear(x2)
return y1 + y2
mm = MyModule()
x = torch.zeros((8, 2, 2))
output = mm(x)
```

When I try to pass batched input to a TorchLayer, it gets flattened so that I end up having only two dimensions to work with.

```
import pennylane as qml
import torch
dev = qml.device("default.qubit", wires=2)
@qml.qnode(dev, interface="torch")
def qnode(inputs, weights_0, weight_1):
print(f'This is {inputs.shape} but should be (8, 2, 2)')
# Some computation
qml.RX(inputs[:, 0], wires=0)
qml.RX(inputs[:, 1], wires=1)
qml.Rot(*weights_0, wires=0)
qml.RY(weight_1, wires=1)
qml.CNOT(wires=[0, 1])
return qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliZ(1))
weight_shapes = {"weights_0": 3, "weight_1": 1}
x = torch.zeros((8, 2, 2))
qlayer = qml.qnn.TorchLayer(qnode, weight_shapes)
output = qlayer(x)
# Output:
# This is torch.Size([16, 2]) but should be (8, 2, 2)
```

My question therefore is: how can I have the circuit not flatten the first two dimensions of my input? I’d like to avoid workarounds like reshaping the tensor inside qnode.

Thanks a lot!