Hello, I have this instantiation of the nn.Module class that creates a QNode during init() and then uses this quantum circuit to generate the logits during forward():
class QRLAgentDQN(nn.Module):
def __init__(self,envs,config):
super().__init__()
self.num_features = np.array(envs.single_observation_space.shape).prod()
self.num_actions = envs.single_action_space.n
self.num_qubits = config["num_qubits"]
self.num_layers = config["num_layers"]
self.ansatz = config["ansatz"]
self.dynamic_meas = config["dynamic_meas"]
self.init_method = config["init_method"]
self.observables = config["observables"]
if self.observables == "global":
self.S = self.num_qubits
elif self.observables == "local":
self.S = self.num_qubits // 2
self.wires = range(self.num_qubits)
if self.ansatz == "hwe":
# Input and Output Scaling weights are always initialized as 1s
self.register_parameter(name="input_scaling_actor", param = nn.Parameter(torch.ones(self.num_layers,self.num_qubits), requires_grad=True))
self.register_parameter(name="output_scaling_actor", param = nn.Parameter(torch.ones(self.num_actions), requires_grad=True))
# The variational weights are initialized differently according to the config file
if self.init_method == "uniform":
self.register_parameter(name="variational_actor", param = nn.Parameter(torch.rand(self.num_layers,self.num_qubits * 2) * 2 * torch.pi - torch.pi, requires_grad=True))
dev = qml.device("lightning.qubit", wires = self.wires)
if self.ansatz == "hwe":
self.qc = qml.QNode(ansatz_hwe, dev, diff_method = "adjoint", interface = "torch")
def forward(self, x):
x = x.repeat(1, len(self.wires) // len(x[0]) + 1)[:, :len(self.wires)]
logits = self.qc(x, self._parameters["input_scaling_actor"], self._parameters["variational_actor"], self.wires, self.num_layers, "actor", self.dynamic_meas, self.measured_qubits, self.observables)
if x.shape[0] == 1:
logits = logits.reshape(x.shape[0], logits.shape[0])
logits_scaled = logits * self._parameters["output_scaling_actor"]
return logits_scaled
My question is the following. Is this optimized over batches, assuming the input to the forward pass x will have shape [num_batches, num_features]? Or will each batch be processed sequentially? I am wondering because it seems to me like it’s being done sequentially but I’m not entirely sure. And if that’s the case, is there any way to optimize over the batches?