2D batching of TorchLayer input

Hello! I’m trying to write a TorchLayer capable of handling multidimensional batches. I’d like to split the batch in the second dimension and perform some separate computations on the two halves. A toy example of a similar computation in PyTorch looks like this:

import torch

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(2, 2)

    def forward(self, x):
        # Split the batch along second dimension
        x1 = x[:, 0]
        x2 = x[:, 1]
        # Do some computation separately
        y1 = self.linear(x1)
        y2 = self.linear(x2)
        return y1 + y2

mm = MyModule()
x = torch.zeros((8, 2, 2))

output = mm(x)

When I try to pass batched input to a TorchLayer, it gets flattened so that I end up having only two dimensions to work with.

import pennylane as qml
import torch

dev = qml.device("default.qubit", wires=2)

@qml.qnode(dev, interface="torch")
def qnode(inputs, weights_0, weight_1):
    print(f'This is {inputs.shape} but should be (8, 2, 2)')

    # Some computation
    qml.RX(inputs[:, 0], wires=0)
    qml.RX(inputs[:, 1], wires=1)
    qml.Rot(*weights_0, wires=0)
    qml.RY(weight_1, wires=1)
    qml.CNOT(wires=[0, 1])
    return qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliZ(1))

weight_shapes = {"weights_0": 3, "weight_1": 1}
x = torch.zeros((8, 2, 2))

qlayer = qml.qnn.TorchLayer(qnode, weight_shapes)
output = qlayer(x)

# Output:
# This is torch.Size([16, 2]) but should be (8, 2, 2)

My question therefore is: how can I have the circuit not flatten the first two dimensions of my input? I’d like to avoid workarounds like reshaping the tensor inside qnode.

Thanks a lot!

Hey @akatief! Welcome to the forum :smiley:

I tried running your quantum circuit, and I got

This is torch.Size([2]) but should be (8, 2, 2)

I’m using the most up-to-date versions of PennyLane and PyTorch :slight_smile: (at the time of writing, v0.31 for PennyLane and 2.0.1 for Torch).

That said, inputs[:, 0] has shape (8, 2), which RX can’t broadcast over. It needs to see a one-dimensional vector of parameters. E.g., qml.RX([0.1, 0.2, 0.3], wires=0). You may just need to cleverly access the elements of inputs that you want to broadcast over :slight_smile:.

Let me know if this helps!

Thanks a lot for your answer!

Quite worringly, I get a different output than yours despite having the same version of PyTorch and PennyLane, I wonder why that is.

I understand your point about broadcasting though, I also found some lines in TorchLayer.forward that explicitly reshape the input I give.
I guess a workaround could be to reshape the input vector outside the qnode and access specific elements inside of it. I wonder, is that the intended way?

I also found some lines in TorchLayer.forward that explicitly reshape the input I give.

Hmm :thinking:… If you’re creating a hybrid model that inherits from nn.Module, then forward is user-defined. In other words, most likely there’s something that you’re doing that is causing things to be reshaped.

I guess a workaround could be to reshape the input vector outside the qnode and access specific elements inside of it. I wonder, is that the intended way?

It really depends on your application. Sometimes the forward pass of a model can be non-trivial and may involve some clever intermediary data processing between layers. Do you have a small code example that replicates what you’re seeing / having to do?

I’m still talking about the toy example of before, same code. Using debug I can see that output = qlayer(x) calls TorchLayer.forward inside the file torch.py and runs line 398:

# inputs is our (8,2,2) tensor
inputs = torch.reshape(inputs, (-1, inputs.shape[-1]))

From this, I understand that PennyLane really doesn’t want inputs with more than one batch dimension. Is this the same behavior you get?

Again, to make sure we’re running the same version I restarted the kernel, pip freez-ed and got

torch==2.0.1
torchsummary==1.5.1
torchvision==0.15.2
PennyLane==0.31.0
PennyLane-Lightning==0.31.0
PennyLane-qiskit==0.31.0

I feel like my complete use case has some other issues that may require a different thread, for now I’d just be happy with sorting this one out :slight_smile:

Ah! Apologies. You’re right — if you’re passing in a batch of inputs, it will be flattened such that len(inputs.shape) == 2 based on here.

most likely there’s something that you’re doing that is causing things to be reshaped.

Just clarifying that this is misleading at best :sweat_smile:.

So, in your case, if you must have inputs that have len(inputs.shape) > 2, you should be aware of how TorchLayer will reshape it, and adjust things if desired.