How to handle Batch dimension of classical layers in Hybrid quantum ML model through torchlayer

how do I go with the given example in the section ‘Creating non-sequential models’ form the demo Turning quantum nodes into Torch Layers | PennyLane Demos when batch dimension is also present as input from the classical layer.

class HybridModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.clayer_1 = torch.nn.Linear(2, 4)
        self.qlayer_1 = qml.qnn.TorchLayer(qnode, weight_shapes)
        self.qlayer_2 = qml.qnn.TorchLayer(qnode, weight_shapes)
        self.clayer_2 = torch.nn.Linear(4, 2)
        self.softmax = torch.nn.Softmax(dim=1)

    def forward(self, x):
        x = self.clayer_1(x)
        x_1, x_2 = torch.split(x, 2, dim=1)
        x_1 = self.qlayer_1(x_1)
        x_2 = self.qlayer_2(x_2)
        x = torch.cat([x_1, x_2], axis=1)
        x = self.clayer_2(x)
        return self.softmax(x)

model = HybridModel()

Also, consider weight-shapes = {“weights”: (wire*2,)}

Put code here

If you want help with diagnosing an error, please put the full error message below:

# Put full error message here

And, finally, make sure to include the versions of your packages. Specifically, show us the output of qml.about().

Hi @mukulgupta ,

The code in that demo is a bit complex so I’ll answer here with a simpler example.

Let’s say you want to create a model where each quantum TorchLayer has two layers of rotations that encode your inputs and two BasicEntanglerLayers that embed your weights.


If you expand the BasicEntanglerLayers it looks like this.

If you don’t have a batch dimension for your inputs, you can simply flatten your inputs and feed them to your circuit.
If you do have a batch dimension for your inputs, you can make use of PennyLane’s functionality and run your gates for the full batch at once. For example if you have 3 batches, you can send an input of size 3 to each RX gate instead of sending an input of size 1.

In the code below you’ll see a QuantumCircuit class where I define my quantum function. The encoding of the inputs is different depending on whether or not I have a batch dimension or not. The key is that the inputs passed to RX gates have a single dimension. If you try to pass multiple dimensions into RX you’ll get an error.

At the end I test this for two sets of inputs, with and without a batch dimension.

class QuantumCircuit:
    # Initialize your class with all of the attributes you'll need later
    def __init__(self, n_qubits, n_layers):
        self.n_qubits = n_qubits
        self.n_layers = n_layers
        self.weight_shapes = {"weights": (self.n_layers, self.n_qubits)}
        self.dev = qml.device('default.qubit', wires=self.n_qubits)

    # Define your quantum circuit with inputs and weights
    def quantum_circuit(self, inputs, weights):
        n_inputs = self.n_qubits * self.n_layers # number of inputs per batch
        
        # Encode your inputs
        if len(inputs.shape) > 1: # use this if you have a batch dimension
            for i in range(n_inputs):
                qml.RX(inputs[:, i], wires=(i % self.n_qubits)) # the key is in this line
        else: # use this if you don't have a batch dimension
            for i in range(n_inputs):
                qml.RX(inputs[i], wires=(i % self.n_qubits))

        # Add your trainable layers
        qml.BasicEntanglerLayers(weights, wires=range(n_qubits))

        # Return an output with a size that's compatible with your following classical layer 
        return [qml.expval(qml.PauliZ(i)) for i in range(self.n_qubits)]

# Create your quantum Torchlayer
n_qubits=4
n_layers=2
QuantumCircuit1 = QuantumCircuit(n_qubits, n_layers)
qnode = qml.QNode(QuantumCircuit1.quantum_circuit, QuantumCircuit1.dev, diff_method='best')
qnn_layer = qml.qnn.TorchLayer(qnode, QuantumCircuit1.weight_shapes)

# Test that it works with and without batches

# Single input data
inputs = torch.randn(n_layers, n_qubits)
inputs = inputs.reshape(-1)
print('inputs 1: ',inputs)

output = qnn_layer(inputs)
print('output 1: ',output)

# Batch input data
n_batches = 3
inputs = torch.randn(n_batches, n_layers, n_qubits)
inputs = inputs.reshape((n_batches, -1)) # Flatten the inputs
print('inputs 2: ',inputs)

output = qnn_layer(inputs)
print('output 2: ',output)

With this in mind you should be able to adapt any code, including the demo you mentioned. It may require some work to fully understand the dimensions of your inputs and how to handle them, so I recommend using print statements, debugging tools, or starting with a smaller/easier example to get the hang of it.

I hope this helps!