Change output dtypes with Torch interface

Hi there,

I am trying to do a quantum ml task using the pennylane and PyTorch on a mac device with a quantum circuit that looks like the following:

@qml.qnode(dev, interface="torch", diff_method="parameter-shift")
def quantum_circuit(noise, weights):
    weights = weights.reshape(depth, n_qubits)

    # Initialise latent vectors
    for i in range(n_qubits):
        qml.RY(noise[i], wires=i)

    # Repeated layer
    for i in range(depth):
        # Parameterised layer
        for y in range(n_qubits):
            qml.RY(weights[i][y], wires=y)

        # Control Z gates
        for y in range(n_qubits - 1):
            qml.CZ(wires=[y, y + 1])

    return qml.probs(wires=list(range(n_qubits)))

The datatype of the output of this quantum circuit when provided the input variables is Torch.float64
During backpropagation I get the following error

File "/Users/aaronmkts/miniconda3/lib/python3.11/site-packages/pennylane/math/single_dispatch.py", line 526, in _asarray_torch
    return torch.as_tensor(x, dtype=dtype, **kwargs)

TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

This error, as indicated seems to be an issue with the MPS framework. In order to fix this I need the computation/output of the circuit to use Torch.float32 as indicated, I cannot find how to make this the case in penny lane, the inputs (noise and weights) are both Torch.float32 datatypes. Any insight into how to adjust the output would be greatly appreciated. I tried taking the output of the circuit and casting it to Torch.float32 using .to(torch.float32) however this doesn’t help as it still has to backprop along the part which is the wrong datatype.

Hey @Aaron_Thomas,

I don’t have your full code, so I can’t try to replicate what’s going on. That said, I’d make sure that you’re using up-to-date versions of PL (currently v0.35.1) and torch (currently 2.2.2). I was able to see something similar on my end with v0.34.0 and 2.1.2 on my laptop.

Let me know if that helps!