# 2D batching of TorchLayer input

Hello! I’m trying to write a TorchLayer capable of handling multidimensional batches. I’d like to split the batch in the second dimension and perform some separate computations on the two halves. A toy example of a similar computation in PyTorch looks like this:

``````import torch

class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(2, 2)

def forward(self, x):
# Split the batch along second dimension
x1 = x[:, 0]
x2 = x[:, 1]
# Do some computation separately
y1 = self.linear(x1)
y2 = self.linear(x2)
return y1 + y2

mm = MyModule()
x = torch.zeros((8, 2, 2))

output = mm(x)
``````

When I try to pass batched input to a TorchLayer, it gets flattened so that I end up having only two dimensions to work with.

``````import pennylane as qml
import torch

dev = qml.device("default.qubit", wires=2)

@qml.qnode(dev, interface="torch")
def qnode(inputs, weights_0, weight_1):
print(f'This is {inputs.shape} but should be (8, 2, 2)')

# Some computation
qml.RX(inputs[:, 0], wires=0)
qml.RX(inputs[:, 1], wires=1)
qml.Rot(*weights_0, wires=0)
qml.RY(weight_1, wires=1)
qml.CNOT(wires=[0, 1])
return qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliZ(1))

weight_shapes = {"weights_0": 3, "weight_1": 1}
x = torch.zeros((8, 2, 2))

qlayer = qml.qnn.TorchLayer(qnode, weight_shapes)
output = qlayer(x)

# Output:
# This is torch.Size([16, 2]) but should be (8, 2, 2)
``````

My question therefore is: how can I have the circuit not flatten the first two dimensions of my input? I’d like to avoid workarounds like reshaping the tensor inside qnode.

Thanks a lot!

Hey @akatief! Welcome to the forum

I tried running your quantum circuit, and I got

``````This is torch.Size([2]) but should be (8, 2, 2)
``````

I’m using the most up-to-date versions of PennyLane and PyTorch (at the time of writing, v0.31 for PennyLane and 2.0.1 for Torch).

That said, `inputs[:, 0]` has shape `(8, 2)`, which `RX` can’t broadcast over. It needs to see a one-dimensional vector of parameters. E.g., `qml.RX([0.1, 0.2, 0.3], wires=0)`. You may just need to cleverly access the elements of `inputs` that you want to broadcast over .

Let me know if this helps!

Quite worringly, I get a different output than yours despite having the same version of PyTorch and PennyLane, I wonder why that is.

I understand your point about broadcasting though, I also found some lines in TorchLayer.forward that explicitly reshape the input I give.
I guess a workaround could be to reshape the input vector outside the qnode and access specific elements inside of it. I wonder, is that the intended way?

I also found some lines in TorchLayer.forward that explicitly reshape the input I give.

Hmm … If you’re creating a hybrid model that inherits from `nn.Module`, then `forward` is user-defined. In other words, most likely there’s something that you’re doing that is causing things to be reshaped.

I guess a workaround could be to reshape the input vector outside the qnode and access specific elements inside of it. I wonder, is that the intended way?

It really depends on your application. Sometimes the forward pass of a model can be non-trivial and may involve some clever intermediary data processing between layers. Do you have a small code example that replicates what you’re seeing / having to do?

I’m still talking about the toy example of before, same code. Using debug I can see that output = qlayer(x) calls TorchLayer.forward inside the file torch.py and runs line 398:

``````# inputs is our (8,2,2) tensor
inputs = torch.reshape(inputs, (-1, inputs.shape[-1]))
``````

From this, I understand that PennyLane really doesn’t want inputs with more than one batch dimension. Is this the same behavior you get?

Again, to make sure we’re running the same version I restarted the kernel, pip freez-ed and got

``````torch==2.0.1
torchsummary==1.5.1
torchvision==0.15.2
PennyLane==0.31.0
PennyLane-Lightning==0.31.0
PennyLane-qiskit==0.31.0
``````

I feel like my complete use case has some other issues that may require a different thread, for now I’d just be happy with sorting this one out

Ah! Apologies. You’re right — if you’re passing in a batch of inputs, it will be flattened such that `len(inputs.shape) == 2` based on here.

most likely there’s something that you’re doing that is causing things to be reshaped.

Just clarifying that this is misleading at best .

So, in your case, if you must have inputs that have `len(inputs.shape) > 2`, you should be aware of how `TorchLayer` will reshape it, and adjust things if desired.