QNN in JAX slow eventhough QNN alone is fast

I found out that my QNN is fast, but the backpropagation for that QNN using JAX is slow.

Given this code

import pennylane as qml
from pennylane import numpy as np
import time
from pennylane import X, Y, Z

def run_circuit(n):
    dev_lightning = qml.device(
        "lightning.qubit", wires=n,
    )
    dev = qml.device(
        "default.qubit", wires=n,
    )

    def circuit(input_data):
        qml.AmplitudeEmbedding(input_data, wires=range(n), normalize=True)
        coeffs = [0.5]
        obs = [X(0) @ X(1) @ X(2) @ Z(3) @ Y(5) @ X(4) @ X(6) @ Y(7) @ Z(8) @ Z(9) @ X(10) @ Y(11)]
        hamiltonian = qml.Hamiltonian(coeffs, obs)
        print(f"Hamiltonian {n} qubits: {hamiltonian}")
        qml.exp(hamiltonian)
        return qml.state()


    input_data = np.random.rand(2**n)
    start_time = time.time()
    state = qml.QNode(circuit, dev)(input_data)
    # state = qml.QNode(circuit, dev_lightning)(input_data) # << THIS WILL FAIL
    end_time = time.time()

    execution_time = end_time - start_time
    return execution_time, state


n = 12
execution_time, state = run_circuit(n)
print(f"Execution time for n={n}: {execution_time:.4f} seconds")

It works by itself (~ completed after 20s), but when putting into a JAX setup with loss function, back probs etc it gets very very slow. Scaling n to smaller number helps. Therefore I am guessing it has something to do with JAX and not Pennylane

I think that it has to be done with JAX, then I tried jit and qujit, only to knows that it doesn’t support SparseHamiltonian

So I don’t jit the network, but switch to lightning.qubit, but it doesn’t support that operation.

So far I face 3 issues:

  1. lightning.qubit doesn’t support the operator I am using
  2. Cannot jit a QNN involving a sparse hamiltonian
  3. Backprop seems to be slow with large qubits.

Do you have any thoughts to speed that up? My first guess is I may just switch to GPU without using lightning.qubit, and hope the non-pennylane part (backprob) get speed up. Not sure though

Hi @mchau ,

qml.AmplitudeEmbedding uses qml.StatePrep under the hood.

If the StatePrep operation is not supported natively on the target device, PennyLane will attempt to decompose the operation using the method developed by Möttönen et al. (Quantum Info. Comput., 2005).

This method is inefficient, slow, and in general, not differentiable.

default.qubit works nicely because it supports StatePrep. However the lightning suite doesn’t support it so it will be slower (GPUs won’t fix the problem, they may make it worse).

My recommendation would be to keep using default.qubit, or switch to using a different embedding if possible.

Regarding your issues:

  1. Which is the operator you say doesn’t work with lightning.qubit?
  2. As far as I know SparseHamiltonian has no issues with jitting
  3. It’s not backprop that’s slow, it’s finding the gradients over your embedding (which is not generally differentiable)

I hope this helps.

1 Like

hey Catalina, thanks for the answer.

  1. lightning.qubit doesn’t seem to support decomposition of [X(0) @ X(1) @ X(2) @ Z(3) @ Y(5) @ X(4) @ X(6) @ Y(7) @ Z(8) @ Z(9) @ X(10) @ Y(11)]
  2. This line pennylane/pennylane/ops/qubit/observables.py at 7c677222e39e8b83a0167bff796658f071698ab2 · PennyLaneAI/pennylane · GitHub makes it fail. With jitting the SparseMatrix is wrapped inside JAX’s jit object.
  3. Is there a way to ask Pennylane or JAX to not find the gradients over the embedding? It should make the QNN less expressive, like having one less layer, but I think it is fine if we have a big boost in speed

Hi @mchau ,

The code you shared takes around 20s to run because you have 2^{12} input datapoints, which you’re adding to your embedding. This is 4096 numbers in input_data. Quantum computers (and simulators) are known to be bad at handling large amounts of data, so your algorithm will inevitably run slowly, unless you compile it in an efficient way, which would be a whole different research project.

On the other hand, lightning seemed to be having issues with qml.exp(hamiltonian), not with your operator itself. You may want to try using qml.evolve() instead.

Based on the code you shared before, here’s how you can use evolve:

import pennylane as qml
from pennylane import numpy as np
import time
from pennylane import X, Y, Z
np.random.seed(42)


def run_circuit(n, input_data, dev_name):
    dev = qml.device(
        dev_name, wires=n,
    )

    def circuit(input_data):
        qml.AmplitudeEmbedding(input_data, wires=range(n), normalize=True)
        coeffs = [0.5]
        obs = [X(0) @ X(1) @ X(2) @ Z(3) @ Y(5) @ X(4) @ X(6) @ Y(7) @ Z(8) @ Z(9) @ X(10) @ Y(11)]
        hamiltonian = qml.Hamiltonian(coeffs, obs)
        print(f"Hamiltonian {n} qubits: {hamiltonian}")
        #qml.exp(hamiltonian)
        qml.evolve(hamiltonian,1)
        return qml.state()


    
    start_time = time.time()
    state = qml.QNode(circuit, dev)(input_data)
    end_time = time.time()

    execution_time = end_time - start_time
    return execution_time, state


n = 12
input_data = np.random.rand(2**n)

dev_lightning = "lightning.qubit"
dev_default = "default.qubit"

# Choose the device to use
execution_time, state = run_circuit(n,input_data,dev_lightning)
print(f"Execution time for n={n}: {execution_time:.4f} seconds")

Regarding your questions about JAX, JIT and SparseMatrix it’s hard to help you without seeing your code and your error message. If you can make a minimal example I can try to see if I can help you.

I hope this helps.

Thank you for your message. Your code work :slight_smile:

What about this point?

Hi @mchau,

If you’re using JAX you can use argnums to specify the trainable arguments. See for example the docs for the JAX interface.

If you’re using a different interface then you would use argnum.

In this specific case since your only argument to the circuit is input_data you could also use diff_method=None when you create your QNode. See the QNode docs for more details or let me know if you have any specific issues with this.

I hope this helps!