Hello! I’m trying to get a QML program running with JAX on a GPU. My issue is that it will hang on execution with 0% utilization on the GPU (despite allocating GPU memory). There is no error, but the computation doesn’t even start. I’m working on a AWS p3 instance with pre-installed cuquantum. I have gotten basic JAX commands to work with the GPU, showing high utilization. I have installed pennylane and jax with the following commands

```
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install pennylane pennylane-lightning-gpu custatevec-cu11
```

My code is copying from a pennylane tutorial. I’ve only changed the layout of the quantum circuit and set jax to use the GPU instead of CPU

```
import pennylane as qml
import jax
from jax import numpy as jnp
import jaxopt
import numpy as np
import math
# Use GPU
jax.config.update("jax_platform_name", "gpu")
print('Devices available:', jax.devices())
n_qubits = 15
n_layers = 3
# Random dataset to train one
n_samples = 100
n_features = 30
X_train = np.random.random(size=(n_samples,n_features))
y_train = np.random.choice(a=[0,1], size=n_samples, replace=True)
# Lightning gpu simulator
dev = qml.device("lightning.gpu", wires=n_qubits)
@qml.qnode(dev)
def circuit(data, weights):
"""Quantum circuit ansatz"""
# Starting locations for entangling gates
a = math.floor((n_qubits-.001)/2)
b = a + 1
# Use inputs to parameterize a "input" circuit
for i in range(data.shape[1]):
tqubit = i % n_qubits
if i % 3 == 0:
qml.RX(data[:,i], wires=i % n_qubits) # all inputs at index i
elif i % 3 == 1:
qml.RY(data[:,i], wires=i % n_qubits)
else:
qml.RZ(data[:,i], wires=i % n_qubits)
# Do entangling every n gates
if i != 0 and i % n_qubits == 0:
for j in range(a-1,-1,-1): # Backwards from a
qml.CNOT(wires=[j,j+1])
for j in range(b,n_qubits-1): # Forwards from b
qml.CNOT(wires=[j,j+1])
# Parameterized Rx, Ry, and Rz gates
for L in range(n_layers):
# Entangling CNOT gates
qml.CNOT(wires=[a,b])
for i in range(a-1,-1,-1): # Backwards from a
qml.CNOT(wires=[i,i+1])
for i in range(b,n_qubits-1): # Forwards from b
qml.CNOT(wires=[i,i+1])
for i in range(n_qubits):
qml.RX(weights[i*3*n_layers + 3*L+0], wires=i)
qml.RY(weights[i*3*n_layers + 3*L+1],wires=i)
qml.RZ(weights[i*3*n_layers + 3*L+2],wires=i)
# we use a sum of local Z's as an observable since a
# local Z would only be affected by params on that qubit.
return qml.expval(qml.sum(*[qml.PauliZ(i) for i in range(n_qubits)]))
# Setup copied from tutorial
def my_model(data, weights, bias):
return circuit(data, weights) + bias
@jax.jit
def loss_fn(params, data, targets):
predictions = my_model(data, params["weights"], params["bias"])
loss = jnp.sum((targets - predictions) ** 2 / len(data))
return loss
# Initial parameters
weights = jnp.ones([n_qubits*n_layers*3])
bias = jnp.array(0.)
params = {"weights": weights, "bias": bias}
def loss_and_grad(params, data, targets, print_training, i):
loss_val, grad_val = jax.value_and_grad(loss_fn)(params, data, targets)
def print_fn():
jax.debug.print("Step: {i} Loss: {loss_val}", i=i, loss_val=loss_val)
# if print_training=True, print the loss every 5 steps
jax.lax.cond((jnp.mod(i, 5) == 0) & print_training, print_fn, lambda: None)
return loss_val, grad_val
@jax.jit
def optimization_jit(params, data, targets, print_training=True):
opt = jaxopt.GradientDescent(loss_and_grad, stepsize=0.3, value_and_grad=True)
opt_state = opt.init_state(params)
def update(i, args):
params, opt_state = opt.update(*args, i)
return (params, opt_state, *args[2:])
args = (params, opt_state, data, targets, print_training)
(params, opt_state, _, _, _) = jax.lax.fori_loop(0, 100, update, args)
return params
# Run training
print('Running experiment with %d qubits' % n_qubits)
params = {"weights": weights, "bias": bias}
trained_params = optimization_jit(params, X_train, y_train, print_training=True)
```

The program will run and print

```
Devices available: [cuda(id=0)]
Running experiment with 15 qubits
```

But does not make any progress. The GPU memory is consumed but has 0% utilization (shown in attached figure). This is true even when running with 3 qubits. I’ve used an identical script using the CPU where it finishes in ~10 seconds.

qml.about prints

```
Name: PennyLane
Version: 0.34.0
Summary: PennyLane is a Python quantum machine learning library by Xanadu Inc.
Home-page: https://github.com/PennyLaneAI/pennylane
Author:
Author-email:
License: Apache License 2.0
Location: /usr/local/lib/python3.10/dist-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, pennylane-lightning, requests, rustworkx, scipy, semantic-version, toml, typing-extensions
Required-by: PennyLane-Lightning, PennyLane-Lightning-GPU, PennyLane-SF
Platform info: Linux-5.19.0-1025-aws-x86_64-with-glibc2.35
Python version: 3.10.12
Numpy version: 1.23.5
Scipy version: 1.12.0
Installed devices:
- lightning.gpu (PennyLane-Lightning-GPU-0.34.0)
- lightning.qubit (PennyLane-Lightning-0.34.0)
- default.gaussian (PennyLane-0.34.0)
- default.mixed (PennyLane-0.34.0)
- default.qubit (PennyLane-0.34.0)
- default.qubit.autograd (PennyLane-0.34.0)
- default.qubit.jax (PennyLane-0.34.0)
- default.qubit.legacy (PennyLane-0.34.0)
- default.qubit.tf (PennyLane-0.34.0)
- default.qubit.torch (PennyLane-0.34.0)
- default.qutrit (PennyLane-0.34.0)
- null.qubit (PennyLane-0.34.0)
- strawberryfields.fock (PennyLane-SF-0.29.1)
- strawberryfields.gaussian (PennyLane-SF-0.29.1)
- strawberryfields.gbs (PennyLane-SF-0.29.1)
- strawberryfields.remote (PennyLane-SF-0.29.1)
- strawberryfields.tf (PennyLane-SF-0.29.1)
```

Thank you!