Optimization over batches in Pennylane

Hello, I have this instantiation of the nn.Module class that creates a QNode during init() and then uses this quantum circuit to generate the logits during forward():

class QRLAgentDQN(nn.Module):
    def __init__(self,envs,config):
        super().__init__()
        self.num_features = np.array(envs.single_observation_space.shape).prod()
        self.num_actions = envs.single_action_space.n
        self.num_qubits = config["num_qubits"]
        self.num_layers = config["num_layers"]
        self.ansatz = config["ansatz"]
        self.dynamic_meas = config["dynamic_meas"]
        self.init_method = config["init_method"]
        self.observables = config["observables"]
        if self.observables == "global":
            self.S = self.num_qubits
        elif self.observables == "local":
            self.S = self.num_qubits // 2
        self.wires = range(self.num_qubits)

        if self.ansatz == "hwe":
            # Input and Output Scaling weights are always initialized as 1s        
            self.register_parameter(name="input_scaling_actor", param = nn.Parameter(torch.ones(self.num_layers,self.num_qubits), requires_grad=True))
            self.register_parameter(name="output_scaling_actor", param = nn.Parameter(torch.ones(self.num_actions), requires_grad=True))

            # The variational weights are initialized differently according to the config file
            if self.init_method == "uniform":
                self.register_parameter(name="variational_actor", param = nn.Parameter(torch.rand(self.num_layers,self.num_qubits * 2) * 2 * torch.pi - torch.pi, requires_grad=True))
           
        dev = qml.device("lightning.qubit", wires = self.wires)
        if self.ansatz == "hwe":
           self.qc = qml.QNode(ansatz_hwe, dev, diff_method = "adjoint", interface = "torch")
        
    def forward(self, x):
        x = x.repeat(1, len(self.wires) // len(x[0]) + 1)[:, :len(self.wires)]
        logits = self.qc(x, self._parameters["input_scaling_actor"], self._parameters["variational_actor"], self.wires, self.num_layers, "actor", self.dynamic_meas, self.measured_qubits, self.observables)
        if x.shape[0] == 1:
            logits = logits.reshape(x.shape[0], logits.shape[0])
        logits_scaled = logits * self._parameters["output_scaling_actor"]
        return logits_scaled

My question is the following. Is this optimized over batches, assuming the input to the forward pass x will have shape [num_batches, num_features]? Or will each batch be processed sequentially? I am wondering because it seems to me like it’s being done sequentially but I’m not entirely sure. And if that’s the case, is there any way to optimize over the batches?

Hi @rc17782 ,

I don’t have all of the info to test your code locally so I’m not sure how the optimization is happening.

If it helps I made this code example for someone else, where I broke down the batches into individual inputs within the forward pass, but then the optimization was done over batches.

If this doesn’t answer your question, would you be able to share a minimal reproducible example of your code?

I hope this helps!