Hello! I’m trying to get a QML program running with JAX on a GPU. My issue is that it will hang on execution with 0% utilization on the GPU (despite allocating GPU memory). There is no error, but the computation doesn’t even start. I’m working on a AWS p3 instance with pre-installed cuquantum. I have gotten basic JAX commands to work with the GPU, showing high utilization. I have installed pennylane and jax with the following commands
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install pennylane pennylane-lightning-gpu custatevec-cu11
My code is copying from a pennylane tutorial. I’ve only changed the layout of the quantum circuit and set jax to use the GPU instead of CPU
import pennylane as qml
import jax
from jax import numpy as jnp
import jaxopt
import numpy as np
import math
# Use GPU
jax.config.update("jax_platform_name", "gpu")
print('Devices available:', jax.devices())
n_qubits = 15
n_layers = 3
# Random dataset to train one
n_samples = 100
n_features = 30
X_train = np.random.random(size=(n_samples,n_features))
y_train = np.random.choice(a=[0,1], size=n_samples, replace=True)
# Lightning gpu simulator
dev = qml.device("lightning.gpu", wires=n_qubits)
@qml.qnode(dev)
def circuit(data, weights):
"""Quantum circuit ansatz"""
# Starting locations for entangling gates
a = math.floor((n_qubits-.001)/2)
b = a + 1
# Use inputs to parameterize a "input" circuit
for i in range(data.shape[1]):
tqubit = i % n_qubits
if i % 3 == 0:
qml.RX(data[:,i], wires=i % n_qubits) # all inputs at index i
elif i % 3 == 1:
qml.RY(data[:,i], wires=i % n_qubits)
else:
qml.RZ(data[:,i], wires=i % n_qubits)
# Do entangling every n gates
if i != 0 and i % n_qubits == 0:
for j in range(a-1,-1,-1): # Backwards from a
qml.CNOT(wires=[j,j+1])
for j in range(b,n_qubits-1): # Forwards from b
qml.CNOT(wires=[j,j+1])
# Parameterized Rx, Ry, and Rz gates
for L in range(n_layers):
# Entangling CNOT gates
qml.CNOT(wires=[a,b])
for i in range(a-1,-1,-1): # Backwards from a
qml.CNOT(wires=[i,i+1])
for i in range(b,n_qubits-1): # Forwards from b
qml.CNOT(wires=[i,i+1])
for i in range(n_qubits):
qml.RX(weights[i*3*n_layers + 3*L+0], wires=i)
qml.RY(weights[i*3*n_layers + 3*L+1],wires=i)
qml.RZ(weights[i*3*n_layers + 3*L+2],wires=i)
# we use a sum of local Z's as an observable since a
# local Z would only be affected by params on that qubit.
return qml.expval(qml.sum(*[qml.PauliZ(i) for i in range(n_qubits)]))
# Setup copied from tutorial
def my_model(data, weights, bias):
return circuit(data, weights) + bias
@jax.jit
def loss_fn(params, data, targets):
predictions = my_model(data, params["weights"], params["bias"])
loss = jnp.sum((targets - predictions) ** 2 / len(data))
return loss
# Initial parameters
weights = jnp.ones([n_qubits*n_layers*3])
bias = jnp.array(0.)
params = {"weights": weights, "bias": bias}
def loss_and_grad(params, data, targets, print_training, i):
loss_val, grad_val = jax.value_and_grad(loss_fn)(params, data, targets)
def print_fn():
jax.debug.print("Step: {i} Loss: {loss_val}", i=i, loss_val=loss_val)
# if print_training=True, print the loss every 5 steps
jax.lax.cond((jnp.mod(i, 5) == 0) & print_training, print_fn, lambda: None)
return loss_val, grad_val
@jax.jit
def optimization_jit(params, data, targets, print_training=True):
opt = jaxopt.GradientDescent(loss_and_grad, stepsize=0.3, value_and_grad=True)
opt_state = opt.init_state(params)
def update(i, args):
params, opt_state = opt.update(*args, i)
return (params, opt_state, *args[2:])
args = (params, opt_state, data, targets, print_training)
(params, opt_state, _, _, _) = jax.lax.fori_loop(0, 100, update, args)
return params
# Run training
print('Running experiment with %d qubits' % n_qubits)
params = {"weights": weights, "bias": bias}
trained_params = optimization_jit(params, X_train, y_train, print_training=True)
The program will run and print
Devices available: [cuda(id=0)]
Running experiment with 15 qubits
But does not make any progress. The GPU memory is consumed but has 0% utilization (shown in attached figure). This is true even when running with 3 qubits. I’ve used an identical script using the CPU where it finishes in ~10 seconds.
qml.about prints
Name: PennyLane
Version: 0.34.0
Summary: PennyLane is a Python quantum machine learning library by Xanadu Inc.
Home-page: https://github.com/PennyLaneAI/pennylane
Author:
Author-email:
License: Apache License 2.0
Location: /usr/local/lib/python3.10/dist-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, pennylane-lightning, requests, rustworkx, scipy, semantic-version, toml, typing-extensions
Required-by: PennyLane-Lightning, PennyLane-Lightning-GPU, PennyLane-SF
Platform info: Linux-5.19.0-1025-aws-x86_64-with-glibc2.35
Python version: 3.10.12
Numpy version: 1.23.5
Scipy version: 1.12.0
Installed devices:
- lightning.gpu (PennyLane-Lightning-GPU-0.34.0)
- lightning.qubit (PennyLane-Lightning-0.34.0)
- default.gaussian (PennyLane-0.34.0)
- default.mixed (PennyLane-0.34.0)
- default.qubit (PennyLane-0.34.0)
- default.qubit.autograd (PennyLane-0.34.0)
- default.qubit.jax (PennyLane-0.34.0)
- default.qubit.legacy (PennyLane-0.34.0)
- default.qubit.tf (PennyLane-0.34.0)
- default.qubit.torch (PennyLane-0.34.0)
- default.qutrit (PennyLane-0.34.0)
- null.qubit (PennyLane-0.34.0)
- strawberryfields.fock (PennyLane-SF-0.29.1)
- strawberryfields.gaussian (PennyLane-SF-0.29.1)
- strawberryfields.gbs (PennyLane-SF-0.29.1)
- strawberryfields.remote (PennyLane-SF-0.29.1)
- strawberryfields.tf (PennyLane-SF-0.29.1)
Thank you!