Lightning GPU never finishes and has 0% utilization

Hello! I’m trying to get a QML program running with JAX on a GPU. My issue is that it will hang on execution with 0% utilization on the GPU (despite allocating GPU memory). There is no error, but the computation doesn’t even start. I’m working on a AWS p3 instance with pre-installed cuquantum. I have gotten basic JAX commands to work with the GPU, showing high utilization. I have installed pennylane and jax with the following commands

pip install --upgrade "jax[cuda11_pip]" -f
pip install pennylane pennylane-lightning-gpu custatevec-cu11 

My code is copying from a pennylane tutorial. I’ve only changed the layout of the quantum circuit and set jax to use the GPU instead of CPU

import pennylane as qml
import jax
from jax import numpy as jnp
import jaxopt
import numpy as np
import math

# Use GPU
jax.config.update("jax_platform_name", "gpu")
print('Devices available:', jax.devices())

n_qubits = 15
n_layers = 3

# Random dataset to train one
n_samples = 100
n_features = 30
X_train = np.random.random(size=(n_samples,n_features))
y_train = np.random.choice(a=[0,1], size=n_samples, replace=True)

# Lightning gpu simulator
dev = qml.device("lightning.gpu", wires=n_qubits)

def circuit(data, weights):
    """Quantum circuit ansatz"""
    # Starting locations for entangling gates
    a = math.floor((n_qubits-.001)/2)
    b = a + 1
    # Use inputs to parameterize a "input" circuit
    for i in range(data.shape[1]):
        tqubit = i % n_qubits
        if i % 3 == 0:
            qml.RX(data[:,i], wires=i % n_qubits) # all inputs at index i
        elif i % 3 == 1:
            qml.RY(data[:,i], wires=i % n_qubits)
            qml.RZ(data[:,i], wires=i % n_qubits)
        # Do entangling every n gates
        if i != 0 and i % n_qubits == 0:
            for j in range(a-1,-1,-1): # Backwards from a
            for j in range(b,n_qubits-1): # Forwards from b
    # Parameterized Rx, Ry, and Rz gates
    for L in range(n_layers):
        # Entangling CNOT gates  
        for i in range(a-1,-1,-1): # Backwards from a
        for i in range(b,n_qubits-1): # Forwards from b

        for i in range(n_qubits):
            qml.RX(weights[i*3*n_layers + 3*L+0], wires=i)
            qml.RY(weights[i*3*n_layers + 3*L+1],wires=i)
            qml.RZ(weights[i*3*n_layers + 3*L+2],wires=i)

    # we use a sum of local Z's as an observable since a
    # local Z would only be affected by params on that qubit.
    return qml.expval(qml.sum(*[qml.PauliZ(i) for i in range(n_qubits)]))

# Setup copied from tutorial
def my_model(data, weights, bias):
    return circuit(data, weights) + bias

def loss_fn(params, data, targets):
    predictions = my_model(data, params["weights"], params["bias"])
    loss = jnp.sum((targets - predictions) ** 2 / len(data))
    return loss

# Initial parameters
weights = jnp.ones([n_qubits*n_layers*3])
bias = jnp.array(0.)
params = {"weights": weights, "bias": bias}

def loss_and_grad(params, data, targets, print_training, i):
    loss_val, grad_val = jax.value_and_grad(loss_fn)(params, data, targets)

    def print_fn():
        jax.debug.print("Step: {i}  Loss: {loss_val}", i=i, loss_val=loss_val)

    # if print_training=True, print the loss every 5 steps
    jax.lax.cond((jnp.mod(i, 5) == 0) & print_training, print_fn, lambda: None)

    return loss_val, grad_val

def optimization_jit(params, data, targets, print_training=True):
    opt = jaxopt.GradientDescent(loss_and_grad, stepsize=0.3, value_and_grad=True)
    opt_state = opt.init_state(params)

    def update(i, args):
        params, opt_state = opt.update(*args, i)
        return (params, opt_state, *args[2:])

    args = (params, opt_state, data, targets, print_training)
    (params, opt_state, _, _, _) = jax.lax.fori_loop(0, 100, update, args)

    return params

# Run training
print('Running experiment with %d qubits' % n_qubits)
params = {"weights": weights, "bias": bias}
trained_params = optimization_jit(params, X_train, y_train, print_training=True)

The program will run and print

Devices available: [cuda(id=0)]
Running experiment with 15 qubits

But does not make any progress. The GPU memory is consumed but has 0% utilization (shown in attached figure). This is true even when running with 3 qubits. I’ve used an identical script using the CPU where it finishes in ~10 seconds.

qml.about prints

Name: PennyLane
Version: 0.34.0
Summary: PennyLane is a Python quantum machine learning library by Xanadu Inc.
License: Apache License 2.0
Location: /usr/local/lib/python3.10/dist-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, pennylane-lightning, requests, rustworkx, scipy, semantic-version, toml, typing-extensions
Required-by: PennyLane-Lightning, PennyLane-Lightning-GPU, PennyLane-SF

Platform info:           Linux-5.19.0-1025-aws-x86_64-with-glibc2.35
Python version:          3.10.12
Numpy version:           1.23.5
Scipy version:           1.12.0
Installed devices:
- lightning.gpu (PennyLane-Lightning-GPU-0.34.0)
- lightning.qubit (PennyLane-Lightning-0.34.0)
- default.gaussian (PennyLane-0.34.0)
- default.mixed (PennyLane-0.34.0)
- default.qubit (PennyLane-0.34.0)
- default.qubit.autograd (PennyLane-0.34.0)
- default.qubit.jax (PennyLane-0.34.0)
- default.qubit.legacy (PennyLane-0.34.0)
- (PennyLane-0.34.0)
- default.qubit.torch (PennyLane-0.34.0)
- default.qutrit (PennyLane-0.34.0)
- null.qubit (PennyLane-0.34.0)
- strawberryfields.fock (PennyLane-SF-0.29.1)
- strawberryfields.gaussian (PennyLane-SF-0.29.1)
- strawberryfields.gbs (PennyLane-SF-0.29.1)
- strawberryfields.remote (PennyLane-SF-0.29.1)
- (PennyLane-SF-0.29.1)

Thank you!

Hi @Resch92 , welcome to the Forum!

Thank you for your question. I’m thinking the version of CUDA might be the issue here. Our team will take a look at this and get back to you on this next week.

We will also have a new PennyLane release on Tuesday so please let us know if the problem persists or not after updating to the new version that will be available on Tuesday.

Hi @Resch92 , your circuit looks fairly deep (in addition to using a relatively large number of qubits, but that’s good since it is the range where Lightning will start to shine). This usually incurs substantial compilation costs to perform backprop on an entire workflow. I’d first try to evaluate that by timing the first invocation of a jitted function, and then evaluation. For example, at the end of your script (but before optimization_jit)

import time

t0 = time.time()
loss_fn(params, X_train, y_train)
t1 = time.time()
print(f"compilation time loss_fn = {t1-t0}")

t0 = time.time()
t1 = 0
n_iter = 0
while t1 < 10:
    n_iter += 1
    loss_fn(params, X_train, y_train)
    t1 = time.time() - t0
print(f"execution time loss_fn = {t1 / n_iter}")

I have a pretty powerful CPU and it takes more than a minute to compile, and then about half a second to execute. That’s only for loss_fn. I would suggest doing the same going up the stack: loss_and_grad, update and optimization_jit. See if the compilation time diverges.