Parameter broadcasting problem with torch node

Problem 1: If I use a tensor created using torch.rand() method, I get the following error. However, this error is not there when I used a tensor created with a numpy vector. I am not sure why this happens.

So, I did a minor modification on the code and it works fine for me! :slightly_smiling_face:

import pennylane as qml
from pennylane import numpy as np
import time as time
import torch
import torch.nn as nn

dev = qml.device("default.qubit", wires=1)
@qml.qnode(dev, interface = 'torch')
def simple_qubit_circuit(inputs, theta):
    qml.RX(inputs, wires=0)
    qml.RY(theta, wires=0)
    return qml.expval(qml.PauliZ(0))
class QNet(nn.Module):
    def __init__(self):
        super().__init__()
        shapes = {
            "theta": ()
        }
        self.q = qml.qnn.TorchLayer(simple_qubit_circuit, shapes)
    
    def forward(self, input_value):
        return self.q(input_value)

x_train = torch.rand(10)
x_train = torch.atan(x_train)

model = QNet()
t1 = time.time()
out = model(x_train)
print("time taken for batch operations: ", time.time()-t1)
out2 = []
t2 = time.time()
for x in x_train:
    out2.append(model(x).item())
print("time taken for sequential operations: ", time.time()-t2)

print(out)
print(out2)

And the output:

time taken for batch operations:  0.001728057861328125
time taken for sequential operations:  0.00826883316040039
tensor([0.3852, 0.4101, 0.4733, 0.4276, 0.5313, 0.5150, 0.4444, 0.4483, 0.5314,
        0.5327], grad_fn=<ToCopyBackward0>)
[0.3851933479309082, 0.41005653142929077, 0.4732654094696045, 0.42755088210105896, 0.5312792658805847, 0.5150039792060852, 0.4443522095680237, 0.4482826292514801, 0.5314198732376099, 0.5326558351516724]

Problem 3:
The output tensors calculated using batch operation vs sequential operation for the same input tensor and for the same weight is different which should not be the case. Again not sure what’s wrong here

They look the same for me. So I think the solution is just basically updating for the lastest version and it should be fine! :grin: