Ah! Thank you. This clears things up a bit
There are a couple of different ways that you can make a custom embedding procedure and have it work with broadcasting. You can assume that the input has a leading (batch) dimension like this:
def weird_embedding(x):
qml.CRX(x[:, 0], wires=[0, 1])
qml.T(wires=0)
qml.CRY(x[:, 1], wires=[0, 1])
qml.S(wires=1)
qml.CRZ(x[:, 2], wires=[0, 1])
You can then call it like this:
dev = qml.device("default.qubit", wires=2)
@qml.qnode(dev)
def circuit(x):
weird_embedding(x)
return [qml.expval(qml.PauliZ(i)) for i in range(2)]
x = np.random.uniform(0, np.pi, size=(50,3))
circuit(x)
[tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1.], requires_grad=True),
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1.], requires_grad=True)]
But, in the case that x
is just one data point (not a batch), then this way doesn’t generalize like it would with three separate variables:
def weird_embedding(a, b, c):
# x is supposed to be a single data, not a batch of data
qml.CRX(a, wires=[0, 1])
qml.T(wires=0)
qml.CRY(b, wires=[0, 1])
qml.S(wires=1)
qml.CRZ(c, wires=[0, 1])
dev = qml.device("default.qubit", wires=2)
@qml.qnode(dev)
def circuit(a, b, c):
weird_embedding(a, b, c)
return [qml.expval(qml.PauliZ(i)) for i in range(2)]
a = np.random.uniform(0, np.pi, size=(50,))
b = np.random.uniform(0, np.pi, size=(50,))
c = np.random.uniform(0, np.pi, size=(50,))
circuit(a, b, c)
When this is in the context of a torch layer, you’re stuck with just having inputs
— you can’t split it up into input1
, input2
, etc., for example. So you’ll have to write some logic at the start of your quantum functions to discern if `inputs is a batch or not. E.g.,
def weird_embedding(x):
if len(x.shape) > 1:
# is a batch of data
a, b, c = x[:,0], x[:, 1], x[:, 2]
else:
# is a single data point
a, b, c = x[0], x[1], x[2]
qml.CRX(a, wires=[0, 1])
qml.T(wires=0)
qml.CRY(b, wires=[0, 1])
qml.S(wires=1)
qml.CRZ(c, wires=[0, 1])
You can find more info here: Templates — PennyLane 0.33.0 documentation
Let me know if this helps!