Inputs dimension mix with batch dimension in qml.qnn.TorchLayer

Ah! Thank you. This clears things up a bit :+1:

There are a couple of different ways that you can make a custom embedding procedure and have it work with broadcasting. You can assume that the input has a leading (batch) dimension like this:

def weird_embedding(x):
    qml.CRX(x[:, 0], wires=[0, 1])
    qml.T(wires=0)
    qml.CRY(x[:, 1], wires=[0, 1])
    qml.S(wires=1)
    qml.CRZ(x[:, 2], wires=[0, 1])

You can then call it like this:

dev = qml.device("default.qubit", wires=2)

@qml.qnode(dev)
def circuit(x):
    weird_embedding(x)
    return [qml.expval(qml.PauliZ(i)) for i in range(2)]

x = np.random.uniform(0, np.pi, size=(50,3))

circuit(x)
[tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1.], requires_grad=True),
 tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1.], requires_grad=True)]

But, in the case that x is just one data point (not a batch), then this way doesn’t generalize like it would with three separate variables:

def weird_embedding(a, b, c):
    # x is supposed to be a single data, not a batch of data
    qml.CRX(a, wires=[0, 1])
    qml.T(wires=0)
    qml.CRY(b, wires=[0, 1])
    qml.S(wires=1)
    qml.CRZ(c, wires=[0, 1])

dev = qml.device("default.qubit", wires=2)

@qml.qnode(dev)
def circuit(a, b, c):
    weird_embedding(a, b, c)
    return [qml.expval(qml.PauliZ(i)) for i in range(2)]

a = np.random.uniform(0, np.pi, size=(50,))
b = np.random.uniform(0, np.pi, size=(50,))
c = np.random.uniform(0, np.pi, size=(50,))

circuit(a, b, c)

When this is in the context of a torch layer, you’re stuck with just having inputs — you can’t split it up into input1, input2, etc., for example. So you’ll have to write some logic at the start of your quantum functions to discern if `inputs is a batch or not. E.g.,

def weird_embedding(x):
    if len(x.shape) > 1:
        # is a batch of data
        a, b, c = x[:,0], x[:, 1], x[:, 2]
    else:
        # is a single data point
        a, b, c = x[0], x[1], x[2]

    qml.CRX(a, wires=[0, 1])
    qml.T(wires=0)
    qml.CRY(b, wires=[0, 1])
    qml.S(wires=1)
    qml.CRZ(c, wires=[0, 1])

You can find more info here: Templates — PennyLane 0.33.0 documentation

Let me know if this helps!

1 Like