Why running a hybrid neural network is prohibitively slow?

Hi,

I am trying to train a hybrid neural network with using the Torch interface. when I try to run it using the lightning devices (tried lightning.gpu, lightning.kokkos and compare with lightning.qubit on the CPU which was also significantly slower than default.qubit), it gets prohibitively slow, as it doesn’t finish a single epoch in over an hour…
Is there anything that can be done about it? I’m running pretty big torch models with a small quantum circuit in it as a single layer and this way it has to run on CPUs instead of GPUs.

n_qubits = 5
n_layers = 1

the Qlayer

Qlayer's qnode
@qml.qnode(qdev, interface="torch")
def qnode(inputs, weights):
    qml.AngleEmbedding(inputs, wires=range(self.n_qubits))
    qml.StronglyEntanglingLayers(weights, wires=range(self.n_qubits))
    measurement = [qml.expval(qml.PauliZ(wires=i)) for i in range(self.n_qubits)]
    return measurement
network architecture
self.model = nn.Sequential(
                                nn.Linear(2, 64),
                                nn.Tanh(),
                                nn.Linear(64, 64), 
                                nn.Tanh(),
                                nn.Linear(64, 64),   
                                nn.Tanh(),
                                nn.Linear(64, n_qubits),   
                                nn.Tanh(),
                                QLayer(n_qubits, n_layers),  # Quantum layer
                                nn.Linear(n_qubits, 2)

Thanks in advance!

Hi @ZivChen , unfortunately I don’t think there’s anything you can do.

I would be curious to know though, how much slower is it to run this model, compared to running the purely classical model on the same CPU?

Hi Catalina,

Currently it is infinitely slower, because I didn’t see it finish even one epoch… I’ll running it tonight and update you in the morning.

Best,
Ziv

@ZivChen if it’s infinitely slow then there might be something else going on. Are you able to share your full code to see if I can identify anything happening there?