Hi! I’m experiencing a deadlock when using the lightning.gpu backend in a JAX training setup.
The program hangs indefinitely (GPU at 100% utilization, no output) — but runs fine with lightning.qubit.
This minimal script defines a 14-qubit circuit with 1000 observables — the large observable count is only for demonstration, not real training.
Still, this setup consistently causes the GPU backend to freeze.
Can anyone help me? Thanks
import pennylane as qml
import jax, jax.numpy as jnp, numpy as np, os
# --- GPU env setup ---
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
# --- circuit config ---
N_QUBITS = 14
N_OBS = 1000
# --- device ---
dev = qml.device("lightning.gpu", wires=N_QUBITS)
# --- dummy observable set (for test only) ---
ALL_OBSERVABLES = [qml.PauliZ(i % N_QUBITS) for i in range(N_OBS)]
@qml.qnode(dev, interface="jax")
def circuit(params):
for i in range(N_QUBITS):
qml.RX(params[i], wires=i)
qml.RY(params[i]*0.5, wires=i)
for i in range(N_QUBITS-1):
qml.CNOT(wires=[i, i+1])
return [qml.expval(obs) for obs in ALL_OBSERVABLES]
params = jax.random.uniform(jax.random.PRNGKey(0), (N_QUBITS,), 0, 3.14)
print("Running 14-qubit circuit with 1000 observables (demo only)...")
out = circuit(params)
print("Output shape:", jnp.shape(out))