Multiple batched amplitude embedding

Hello! As suggested in another thread, I’m asking for advice on a more complicated use case. I’d like to build a TorchLayer that performs two separate amplitude embeddings on different parts of its input. These inputs are 2D tensors of size (batch_size, embedding_dim).

To make it work, I’m using merge_amplitude_embedding. The issue lies in the fact that merge_amplitude_embedding performs a Kronecker product under the hood, also expanding the batch dimension instead of just the embedding one, resulting in an intermediate tensor of size (batch_size^2, embedding_dim^2).

Is there some workaround to prevent this behavior?

import pennylane as qml
import torch

dev = qml.device("default.qubit", wires=4)

@qml.qnode(dev, interface="torch")
def qnode(inputs, weight):
    inputs1 = inputs[:,:4]
    inputs2 = inputs[:,4:]
    qml.AmplitudeEmbedding(inputs1, wires=[0, 1], normalize=True, pad_with=0.0)
    qml.AmplitudeEmbedding(inputs2, wires=[2, 3], normalize=True, pad_with=0.0)
    return qml.state()

weight_shapes = {'weight' : 1}
x = torch.ones(9*2*4).reshape((9,2, 4))
x = torch.flatten(x, start_dim=1)

qlayer = qml.qnn.TorchLayer(qnode, weight_shapes)

output = qlayer(x) # Will crash with RuntimeError: shape '[9, 16]' is invalid for input of size 1296

Might also be useful to look at this line:

Running on:


Thanks for your help!

Hey @akatief!

This looks like a good candidate for a feature request issue. Maybe einsum might be better here so that we can be more careful? Anyway — let’s see what our dev team says :grin:. Great catch!

Thanks! Opening the request right away :slight_smile:

In the meantime I found a workaround that consists in building the tensor product yourself and feeding it directly to a circuit with a single AmplitudeEmbedding. When performing the tensor product you need to be careful with padding, I think this is the correct way to do it:

def pad_to_next_power_of_two(tensor):
    curr_size = tensor.size(-1)
    next_power_of_two = 2 ** (curr_size - 1).bit_length()
    padding_size = next_power_of_two - curr_size

    padded_tensor = torch.nn.functional.pad(tensor, (0, padding_size))
    return padded_tensor

# a and b are the two 2D tensors you want to encode separately
a = pad_to_next_power_of_two(a)
b = pad_to_next_power_of_two(b)
# Performs tensor product along only one axis
c = torch.einsum('nk,nl->nkl',a,b).reshape(a.shape[0],-1)
circuit(c) # Contains a single amplitude encoding

Hope this is useful to someone!

Interesting! Do you mind linking the feature request that you made to the PL github? :pray: