Error when using device.state for qnode in a torch.nn.Module childclass

Dear Pennylane Team,
I stumbled over an error, which I don’t really understand so far.

What I want to do:
I want to be able to check the current quantum state at different stages of the process. This works perfectly fine if I do not work with classes:

import pennylane

dev = qml.device("default.qubit", wires = 2, shots = None)

inp = [0.,1.,0.,0.]

@qml.qnode(device=dev)
def qcFkt(inputs):
    qml.QubitStateVector(inputs, wires=range(2))
    return qml.state()

I can now check on the current state, before really applying the circuit:

dev.state

which yields as expected: array([1.+0.j, 0.+0.j, 0.+0.j, 0.+0.j])

I can apply the circuit:

qcFkt(inp)

And check the state again:

dev.state

which yields: tensor([0.+0.j, 1.+0.j, 0.+0.j, 0.+0.j], requires_grad=True)

So far so easy, but I currently work with classes and the torch.interface:
Here is a minimal example of the Class:

class QNet(torch.nn.Module):
    def __init__(self,  nqubits,
                 device: str = "default.qubit", shots = None, **weights):
        
        super().__init__()
        self.device = qml.device(device, wires=nqubits, shots=shots)
        
        @qml.qnode(device=self.device, interface='torch', diff_method="best")
        def qnet(inputs, **weights):
            qml.QubitStateVector(inputs, wires=range(nqubits))
            return qml.expval(qml.PauliZ(0))
        
        self.quantumCircuit = qnet
        self.qlayer = qml.qnn.TorchLayer(self.quantumCircuit, weights)
        self.qnn = torch.nn.Sequential(self.qlayer)
        
    def forward(self, x):
        r"""
        Propagate input x forward through the defined QDNN.
        This method is required by pytorch to subclass torch.nn.Module

        :param x: network input
        :type x: pytorch tensor
        :return: network output
        :rtype: pytorch tensor
        """
        pred = self.qnn(x)
        return pred

Defining now an objet qnet:

myqnet = QNet(nqubits=2)

I can chek the current state with

myqnet.device.state

and get the output, as expected : array([1.+0.j, 0.+0.j, 0.+0.j, 0.+0.j])

But: If I now call the object with the input:

myqnet(inp)

and want to get the state afterwards,

myqnet.device.state

I get the error:


---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
~\AppData\Local\Temp/ipykernel_5000/1001346351.py in <module>
----> 1 qnet.device.state

~\...\smarty-project\lib\site-packages\pennylane\devices\default_qubit.py in state(self)
    656     def state(self):
    657         dim = 2**self.num_wires
--> 658         batch_size = self._get_batch_size(self._pre_rotated_state, (2,) * self.num_wires, dim)
    659         # Do not flatten the state completely but leave the broadcasting dimension if there is one
    660         shape = (batch_size, dim) if batch_size is not None else (dim,)

~\...\lib\site-packages\pennylane\devices\default_qubit.py in _get_batch_size(self, tensor, expected_shape, expected_size)
    223         compared to an expected_shape."""
    224         size = self._size(tensor)
--> 225         if self._ndim(tensor) > len(expected_shape) or size > expected_size:
    226             return size // expected_size
    227 

TypeError: '>' not supported between instances of 'builtin_function_or_method' and 'int'

The error appears also with any other way of input feeding into the quantum circuit (e.g. calling myqnet.qnet(inp) and then checking myqnet.device)

Strangely enough, this error is only thrown if

shots=None

for shots=1024 (or any other integer), it works perfectly fine!

Please let me know if you have any Idea, why this is the case or whether this is a bug and I should raise an issue.

Best regards
Pia

Hi @Pia ,

I’m unable to reproduce your error because I get an error with your forward function. I’m guessing I have a different version than you do. what is your output for qml.about()?

The error you mention doesn’t seem to be caused by using a class, but by using the Torch interface. You need to specify the device as “default.qubit.torch”. Please let me know if this fixes your issue!

Also, I’m thinking that maybe you can benefit from using our mid-circuit-snapshots feature. I haven’t tried it with your code but maybe you can give it a try and let me know if it works for what you want to do.

Please let me know how it goes with these suggestions!

Dear Catalina,
thank you for the fast reply.
Using “default.qubit.torch” does indeed solve the struggle :+1:.
And thank you for mentioning the mid-circuit snapshots. I will have a closer look at them :slight_smile:
Best regards,
Pia

I’m glad this solved your struggle @Pia !

Enjoy using PennyLane!