Parameter broadcasting problem with torch node

Hello @ludmilaaasb ,

I am reposting the entire code that I am using. The updated pennylane version is behaving very weirdly sometimes.

import pennylane as qml
from pennylane import numpy as np
import time as time
import torch
import torch.nn as nn

dev = qml.device("default.qubit", wires=1)
@qml.qnode(dev, interface = 'torch')
def simple_qubit_circuit(inputs, theta):
    qml.RX(inputs, wires=0)
    qml.RY(theta, wires=0)
    return qml.expval(qml.PauliZ(0))
class QNet(nn.Module):
    def __init__(self):
        super().__init__()
        quantum_weights = np.random.normal(0, np.pi)
        self.quantum_weights = nn.parameter.Parameter(torch.tensor(quantum_weights,\
                                    dtype=torch.float32,requires_grad=True))
        shapes = {
            "theta": 1
        }
        self.q = qml.qnn.TorchLayer(simple_qubit_circuit, shapes)
    
    def forward(self, input_value):
        return self.q(input_value)

# x_train = np.array([0.2, 0.1, 0.2, 0.14, 0.11, 0.41, 0.55, 0.3, 0.31, 0.6])
# x_train = torch.tensor(x_train).reshape(10,1)

x_train = torch.rand(10)
x_train = torch.atan(x_train)
model = QNet()
t1 = time.time()
out = model(x_train)
print("time taken for batch operations: ", time.time()-t1)
out2 = []
t2 = time.time()
for x in x_train:
    out2.append(model(x).item())
print("time taken for sequential operations: ", time.time()-t2)

print(out)
print(out2)

I am getting the following error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
~\AppData\Local\Temp\ipykernel_22300\3277396942.py in <cell line: 9>()
      7 model = QNet()
      8 t1 = time.time()
----> 9 out = model(x_train)
     10 print("time taken for batch operations: ", time.time()-t1)
     11 out2 = []

~\Miniconda3\envs\qns\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
   1192         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194             return forward_call(*input, **kwargs)
   1195         # Do not call functions when jit is used
   1196         full_backward_hooks, non_full_backward_hooks = [], []

~\AppData\Local\Temp\ipykernel_22300\1881420835.py in forward(self, input_value)
     23 
     24     def forward(self, input_value):
---> 25         return self.q(input_value)

~\Miniconda3\envs\qns\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
   1192         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194             return forward_call(*input, **kwargs)
   1195         # Do not call functions when jit is used
   1196         full_backward_hooks, non_full_backward_hooks = [], []

~\AppData\Roaming\Python\Python38\site-packages\pennylane\qnn\torch.py in forward(self, inputs)
    406         else:
    407             # calculate the forward pass as usual
--> 408             results = self._evaluate_qnode(inputs)
    409 
    410         # reshape to the correct number of batch dims

~\AppData\Roaming\Python\Python38\site-packages\pennylane\qnn\torch.py in _evaluate_qnode(self, x)
    427             **{arg: weight.to(x) for arg, weight in self.qnode_weights.items()},
    428         }
--> 429         res = self.qnode(**kwargs)
    430 
    431         if isinstance(res, torch.Tensor):

~\AppData\Roaming\Python\Python38\site-packages\pennylane\qnode.py in __call__(self, *args, **kwargs)
    948                 self.execute_kwargs.pop("mode")
    949             # pylint: disable=unexpected-keyword-arg
--> 950             res = qml.execute(
    951                 [self.tape],
    952                 device=self.device,

~\AppData\Roaming\Python\Python38\site-packages\pennylane\interfaces\execution.py in execute(tapes, device, gradient_fn, interface, grad_on_execution, gradient_kwargs, cache, cachesize, max_diff, override_shots, expand_fn, max_expansion, device_batch_transform)
    640             _execute = _get_jax_execute_fn(interface, tapes)
    641 
--> 642         res = _execute(
    643             tapes, device, execute_fn, gradient_fn, gradient_kwargs, _n=1, max_diff=max_diff
    644         )

~\AppData\Roaming\Python\Python38\site-packages\pennylane\interfaces\torch.py in execute(tapes, device, execute_fn, gradient_fn, gradient_kwargs, _n, max_diff)
    496     }
    497 
--> 498     return ExecuteTapes.apply(kwargs, *parameters)
    499 
    500 

~\AppData\Roaming\Python\Python38\site-packages\pennylane\interfaces\torch.py in new_apply(*inp)
    260         # Inputs already flat
    261         out_struct_holder = []
--> 262         flat_out = orig_apply(out_struct_holder, *inp)
    263         return pytree.tree_unflatten(flat_out, out_struct_holder[0])
    264 

~\AppData\Roaming\Python\Python38\site-packages\pennylane\interfaces\torch.py in new_forward(ctx, out_struct_holder, *inp)
    264 
    265     def new_forward(ctx, out_struct_holder, *inp):
--> 266         out = orig_fw(ctx, *inp)
    267         flat_out, out_struct = pytree.tree_flatten(out)
    268         ctx._out_struct = out_struct

~\AppData\Roaming\Python\Python38\site-packages\pennylane\interfaces\torch.py in forward(ctx, kwargs, *parameters)
    341 
    342         unwrapped_tapes = tuple(convert_to_numpy_parameters(t) for t in ctx.tapes)
--> 343         res, ctx.jacs = ctx.execute_fn(unwrapped_tapes, **ctx.gradient_kwargs)
    344 
    345         # if any input tensor uses the GPU, the output should as well

~\AppData\Roaming\Python\Python38\site-packages\pennylane\interfaces\execution.py in wrapper(tapes, **kwargs)
    285             # execute all unique tapes that do not exist in the cache
    286             # convert to list as new device interface returns a tuple
--> 287             res = list(fn(execution_tapes.values(), **kwargs))
    288 
    289         final_res = []

~\AppData\Roaming\Python\Python38\site-packages\pennylane\interfaces\execution.py in fn(tapes, **kwargs)
    208         def fn(tapes: Sequence[QuantumTape], **kwargs):  # pylint: disable=function-redefined
    209             tapes = [expand_fn(tape) for tape in tapes]
--> 210             return original_fn(tapes, **kwargs)
    211 
    212     @wraps(fn)

~\Miniconda3\envs\qns\lib\contextlib.py in inner(*args, **kwds)
     73         def inner(*args, **kwds):
     74             with self._recreate_cm():
---> 75                 return func(*args, **kwds)
     76         return inner
     77 

~\AppData\Roaming\Python\Python38\site-packages\pennylane\_qubit_device.py in batch_execute(self, circuits)
    601             self.reset()
    602 
--> 603             res = self.execute(circuit)
    604             results.append(res)
    605 

~\AppData\Roaming\Python\Python38\site-packages\pennylane\_qubit_device.py in execute(self, circuit, **kwargs)
    322         # generate computational basis samples
    323         if self.shots is not None or circuit.is_sampled:
--> 324             self._samples = self.generate_samples()
    325 
    326         # compute the required statistics

~\AppData\Roaming\Python\Python38\site-packages\pennylane\_qubit_device.py in generate_samples(self)
   1158         rotated_prob = self.analytic_probability()
   1159 
-> 1160         samples = self.sample_basis_states(number_of_states, rotated_prob)
   1161         return self.states_to_binary(samples, self.num_wires)
   1162 

~\AppData\Roaming\Python\Python38\site-packages\pennylane\_qubit_device.py in sample_basis_states(self, number_of_states, state_probability)
   1186             # np.random.choice does not support broadcasting as needed here.
   1187             return np.array(
-> 1188                 [np.random.choice(basis_states, shots, p=prob) for prob in state_probability]
   1189             )
   1190 

~\AppData\Roaming\Python\Python38\site-packages\pennylane\_qubit_device.py in <listcomp>(.0)
   1186             # np.random.choice does not support broadcasting as needed here.
   1187             return np.array(
-> 1188                 [np.random.choice(basis_states, shots, p=prob) for prob in state_probability]
   1189             )
   1190 

mtrand.pyx in numpy.random.mtrand.RandomState.choice()

ValueError: probabilities do not sum to 1

It looks like that this is a bug with Pennylane which @Maria_Schuld had pointed in this post.