Hi there!
I’m trying to use Pennylane-Catalyst to speedup training of my pennylane-jax based quantum-classical QLSTM for time-series prediction.
With jax’s jit decorator, the code doesn’t have any issue, however it’s quite slow to train.
Upon switching the decorator used on the quantum circuit method from jax.jit to catalyst.qjit, the code runs about 6 times faster, however the problem is that memory usage starts accumulating, and an Out Of Memory error eventually occurs.
I’ve tried using python’s garbage collecting methods to solve this problem, but this has remained unsuccessful.
##CODE BELOW##
import pandas as pd
import numpy as np
import jax
import jax.numpy as jnp
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from jax import random, lax
import optax
import time
import threading
import psutil
import GPUtil
from tqdm import tqdm
import gc
import pennylane as qml
import catalyst
def monitor_resources(interval=5):
while True:
try:
memory = psutil.virtual_memory()
print(f"Memory Usage: {memory.percent}% | Available Memory: {memory.available / (1024 ** 3):.2f} GB")
gpus = GPUtil.getGPUs()
for gpu in gpus:
print(f"GPU: {gpu.name}, Memory Used: {gpu.memoryUsed}MB / {gpu.memoryTotal}MB")
time.sleep(interval)
except Exception as e:
print(f"Monitoring Error: {e}")
threading.Thread(target=monitor_resources, daemon=True).start()
# Load and preprocess data
df = pd.read_csv('electricity_prediction_1_96step.csv').drop(columns=['index', 'Unnamed: 0']).iloc[1900:2400].reset_index(drop=True)
df = df.iloc[:, :-72] # prediction length 24
# Feature and target separation
target_columns = [col for col in df.columns if col.startswith('OT_lead')]
feature_columns = [col for col in df.columns if col not in target_columns]
# Split data
train_size = int(len(df) * 0.6)
df_train, df_val = df[:train_size].copy(), df[train_size:].copy()
# Scale data
scaler = StandardScaler()
train_scaled = pd.DataFrame(scaler.fit_transform(df_train), columns=df.columns)
val_scaled = pd.DataFrame(scaler.transform(df_val), columns=df.columns)
# PCA for dimensionality reduction
scaler_pca = StandardScaler()
train_features_scaled = scaler_pca.fit_transform(df_train[feature_columns])
pca = PCA(n_components='mle')
train_features_pca = pca.fit_transform(train_features_scaled)
important_features = sorted(set(df.columns[np.abs(pca.components_).argmax(axis=1)]))
# Prepare final datasets
train_inputs = jnp.array(train_scaled[important_features].values)
val_inputs = jnp.array(val_scaled[important_features].values)
train_targets = jnp.array(train_scaled[target_columns].values)
val_targets = jnp.array(val_scaled[target_columns].values)
# Function to create time-series sequences
def create_sequences(data, targets, sequence_length):
inputs = [data[i:i + sequence_length] for i in range(len(data) - sequence_length)]
target_values = [targets[i + sequence_length] for i in range(len(data) - sequence_length)]
return jnp.array(inputs), jnp.array(target_values)
sequence_length = 24
train_seq_inputs, train_seq_labels = create_sequences(train_inputs, train_targets, sequence_length)
val_seq_inputs, val_seq_labels = create_sequences(val_inputs, val_targets, sequence_length)
# Initialize parameters
def initialize_params(shape, key):
return random.normal(key, shape)
key = random.PRNGKey(0)
hidden_units, n_qubits, qlayers = 16, 4, 3
params = {
'input': initialize_params((len(important_features) + hidden_units, n_qubits), key),
'W_f': initialize_params((qlayers, 3), key),
'W_i': initialize_params((qlayers, 3), key),
'W_u': initialize_params((qlayers, 3), key),
'W_o': initialize_params((qlayers, 3), key),
'output': initialize_params((n_qubits, hidden_units), key),
'final': initialize_params((hidden_units, 1), key)
}
learning_rate = 0.00004
optimizer = optax.adam(learning_rate)
# Unitary Ansatze for Convolutional Layer
def U_TTN(params, wires): # 1 params
qml.RY(params, wires=wires[0])
qml.RY(params, wires=wires[1])
qml.CNOT(wires=[wires[0], wires[1]])
def ansatz(params, wireset):
U_TTN(params[0], wires=[wireset[0], wireset[1]])
U_TTN(params[0], wires=[wireset[2], wireset[3]])
U_TTN(params[1], wires=[wireset[0], wireset[2]])
U_TTN(params[1], wires=[wireset[1], wireset[3]])
U_TTN(params[2], wires=[wireset[1], wireset[2]])
U_TTN(params[2], wires=[wireset[0], wireset[3]])
def encode_ry(i, ry_params):
qml.Hadamard(wires=i)
qml.RY(ry_params[i], wires=i)
# Quantum LSTM functions
# Here using @catalyst.qjit instead of @jax.jit results in an OOM issue
@jax.jit
@qml.qnode(qml.device("lightning.qubit", wires=n_qubits), interface="jax", diff_method="adjoint")
def quantum_circuit(params, inputs):
ry_params = jnp.arctan(inputs)
for i in range(n_qubits):
qml.Hadamard(wires=i)
qml.RY(ry_params[i], wires=i)
qml.layer(ansatz, qlayers, params, wireset=range(n_qubits))
return tuple(qml.expval(qml.PauliZ(wires=i)) for i in range(n_qubits))
@jax.jit
def apply_gates(params, v_t):
f_t = jax.nn.sigmoid(jnp.dot(jnp.array(quantum_circuit(params["W_f"], v_t)), params["output"]))
i_t = jax.nn.sigmoid(jnp.dot(jnp.array(quantum_circuit(params["W_i"], v_t)), params["output"]))
u_t = jax.nn.tanh(jnp.dot(jnp.array(quantum_circuit(params["W_u"], v_t)), params["output"]))
o_t = jax.nn.sigmoid(jnp.dot(jnp.array(quantum_circuit(params["W_o"], v_t)), params["output"]))
return f_t, i_t, u_t, o_t
@jax.jit
def mse_loss_qlstm(params, inputs, targets):
hidden_state, cell_state = jnp.zeros(hidden_units), jnp.zeros(hidden_units)
predicted_values = jnp.zeros(inputs.shape[0])
def lstm_step(time, carry):
hidden, cell, predictions = carry
x_t = inputs[time]
v_t = jnp.dot(jnp.concatenate([x_t, hidden], axis=-1), params["input"])
f_t, i_t, u_t, o_t = apply_gates(params, v_t)
cell = f_t * cell + i_t * u_t
hidden = o_t * jax.nn.tanh(cell)
y = jnp.squeeze(jnp.dot(hidden, params["final"]))
predictions = predictions.at[time].set(y)
return hidden, cell, predictions
hidden_state, cell_state, predicted_values = lax.fori_loop(0, inputs.shape[0], lstm_step, (hidden_state, cell_state, predicted_values))
loss = jnp.mean((predicted_values - targets) ** 2)
return loss
def train_epoch(params, train_inputs, train_labels, optimizer, opt_state, pbar):
train_loss = 0.0
for i in range(train_inputs.shape[0]):
inputs = train_inputs[i]
labels = train_labels[i]
try:
loss_value, grads = jax.block_until_ready(
jax.value_and_grad(mse_loss_qlstm)(
params, inputs, labels
)
)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
train_loss += loss_value
pbar.set_postfix({"Train Loss": loss_value.item()})
pbar.update(1)
finally:
del loss_value, grads, updates
return params, opt_state, train_loss
def validate_epoch(params, val_inputs, val_labels, pbar):
val_loss = 0.0
for j in range(val_inputs.shape[0]):
inputs = val_inputs[j]
labels = val_labels[j]
loss_value = jax.block_until_ready(mse_loss_qlstm(params, inputs, labels))
val_loss += loss_value
pbar.set_postfix({"Validation Loss": loss_value.item()})
pbar.update(1)
del loss_value
return val_loss
# Train function
def train_qlstm(params, train_inputs, train_labels, val_inputs, val_labels, epochs=50):
opt_state = optimizer.init(params)
for epoch in range(epochs):
with tqdm(desc=f"Epoch {epoch+1}/{epochs}", unit="batch") as pbar:
params, opt_state, train_loss = train_epoch(params, train_inputs, train_labels, optimizer, opt_state, pbar)
avg_train_loss = train_loss / train_inputs.shape[0]
val_loss = validate_epoch(params, val_inputs, val_labels, pbar)
avg_val_loss = val_loss / val_inputs.shape[0]
print(f"Epoch {epoch+1} - Train Loss: {avg_train_loss:.4f}, Validation Loss: {avg_val_loss:.4f}")
gc.collect()
jax.clear_caches()
return params
# Start training
trained_params = train_qlstm(params, train_seq_inputs, train_seq_labels, val_seq_inputs, val_seq_labels)
##OUTPUT WHEN USING @jax.jit## (Memory usage does not accumulate)
…
Memory Usage: 2.7% | Available Memory: 122.43 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9320.0MB / 12288.0MB
Epoch 1/50: 1batch [00:08, 8.79s/batch, Train Loss=1.83]Memory Usage: 2.8% | Available Memory: 122.30 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9328.0MB / 12288.0MB
Epoch 1/50: 3batch [00:16, 5.04s/batch, Train Loss=1.79]Memory Usage: 2.8% | Available Memory: 122.30 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9328.0MB / 12288.0MB
Epoch 1/50: 4batch [00:20, 4.60s/batch, Train Loss=1.82]Memory Usage: 2.8% | Available Memory: 122.30 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9328.0MB / 12288.0MB
Epoch 1/50: 5batch [00:24, 4.30s/batch, Train Loss=1.85]Memory Usage: 2.8% | Available Memory: 122.30 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9328.0MB / 12288.0MB
Epoch 1/50: 7batch [00:32, 4.03s/batch, Train Loss=1.96]Memory Usage: 2.8% | Available Memory: 122.30 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9328.0MB / 12288.0MB
Epoch 1/50: 8batch [00:36, 3.99s/batch, Train Loss=1.89]Memory Usage: 2.8% | Available Memory: 122.30 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9328.0MB / 12288.0MB
Epoch 1/50: 9batch [00:40, 4.17s/batch, Train Loss=1.94]Memory Usage: 2.8% | Available Memory: 122.31 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9328.0MB / 12288.0MB
Epoch 1/50: 10batch [00:44, 4.14s/batch, Train Loss=2.01]Memory Usage: 2.8% | Available Memory: 122.31 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9328.0MB / 12288.0MB
Epoch 1/50: 12batch [00:52, 4.05s/batch, Train Loss=1.89]Memory Usage: 2.8% | Available Memory: 122.31 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9328.0MB / 12288.0MB
Epoch 1/50: 13batch [00:56, 4.01s/batch, Train Loss=1.78]Memory Usage: 2.8% | Available Memory: 122.33 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9328.0MB / 12288.0MB
Epoch 1/50: 14batch [01:00, 4.02s/batch, Train Loss=1.7]Memory Usage: 2.8% | Available Memory: 122.32 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9328.0MB / 12288.0MB
Epoch 1/50: 16batch [01:08, 3.95s/batch, Train Loss=1.52]Memory Usage: 2.8% | Available Memory: 122.33 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9328.0MB / 12288.0MB
Epoch 1/50: 17batch [01:12, 3.91s/batch, Train Loss=1.39]Memory Usage: 2.8% | Available Memory: 122.32 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9328.0MB / 12288.0MB
Epoch 1/50: 18batch [01:16, 3.91s/batch, Train Loss=1.34]Memory Usage: 2.8% | Available Memory: 122.32 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9328.0MB / 12288.0MB
…
##OUTPUT WHEN USING @catalyst.qjit## (Memory usage accumulates)
…
Epoch 1/50: 8batch [00:13, 1.20batch/s, Train Loss=2.36]Memory Usage: 3.4% | Available Memory: 121.50 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9324.0MB / 12288.0MB
Epoch 1/50: 16batch [00:18, 1.65batch/s, Train Loss=2.13]Memory Usage: 3.4% | Available Memory: 121.48 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9324.0MB / 12288.0MB
Epoch 1/50: 25batch [00:24, 1.67batch/s, Train Loss=1.82]Memory Usage: 3.4% | Available Memory: 121.60 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9324.0MB / 12288.0MB
Epoch 1/50: 33batch [00:29, 1.68batch/s, Train Loss=1.69]Memory Usage: 3.4% | Available Memory: 121.59 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9324.0MB / 12288.0MB
Epoch 1/50: 41batch [00:34, 1.59batch/s, Train Loss=1.9]Memory Usage: 3.4% | Available Memory: 121.57 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9324.0MB / 12288.0MB
Epoch 1/50: 50batch [00:39, 1.70batch/s, Train Loss=3.87]Memory Usage: 3.4% | Available Memory: 121.55 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9324.0MB / 12288.0MB
Epoch 1/50: 58batch [00:44, 1.65batch/s, Train Loss=3.8]Memory Usage: 3.4% | Available Memory: 121.52 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9324.0MB / 12288.0MB
Epoch 1/50: 66batch [00:49, 1.70batch/s, Train Loss=5.02]Memory Usage: 3.4% | Available Memory: 121.50 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9324.0MB / 12288.0MB
Epoch 1/50: 74batch [00:54, 1.54batch/s, Train Loss=3.94]Memory Usage: 3.4% | Available Memory: 121.48 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9324.0MB / 12288.0MB
Epoch 1/50: 83batch [00:59, 1.69batch/s, Train Loss=3.04]Memory Usage: 3.4% | Available Memory: 121.48 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9324.0MB / 12288.0MB
Epoch 1/50: 91batch [01:04, 1.70batch/s, Train Loss=2.88]Memory Usage: 3.5% | Available Memory: 121.45 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9324.0MB / 12288.0MB
Epoch 1/50: 99batch [01:09, 1.62batch/s, Train Loss=3.11]Memory Usage: 3.5% | Available Memory: 121.42 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9324.0MB / 12288.0MB
Epoch 1/50: 108batch [01:15, 1.69batch/s, Train Loss=2.83]Memory Usage: 3.5% | Available Memory: 121.40 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9324.0MB / 12288.0MB
Epoch 1/50: 116batch [01:20, 1.69batch/s, Train Loss=0.857]Memory Usage: 3.5% | Available Memory: 121.38 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9324.0MB / 12288.0MB
Epoch 1/50: 124batch [01:25, 1.47batch/s, Train Loss=1.14]Memory Usage: 3.5% | Available Memory: 121.36 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9324.0MB / 12288.0MB
Epoch 1/50: 132batch [01:30, 1.68batch/s, Train Loss=1.13]Memory Usage: 3.6% | Available Memory: 121.34 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9324.0MB / 12288.0MB
Epoch 1/50: 136batch [01:32, 1.66batch/s, Train Loss=1.63]
…
heres my qml.about():
Home-page: GitHub - PennyLaneAI/pennylane: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Author:
Author-email:
License: Apache License 2.0
Location: /home/sachin/miniconda3/envs/quantum_lstm_env/lib/python3.9/site-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, toml, typing-extensions
Required-by: PennyLane-Catalyst, PennyLane_Lightning
Platform info: Linux-6.8.0-48-generic-x86_64-with-glibc2.39
Python version: 3.9.20
Numpy version: 1.26.4
Scipy version: 1.12.0
Installed devices:
- default.clifford (PennyLane-0.38.0)
- default.gaussian (PennyLane-0.38.0)
- default.mixed (PennyLane-0.38.0)
- default.qubit (PennyLane-0.38.0)
- default.qubit.autograd (PennyLane-0.38.0)
- default.qubit.jax (PennyLane-0.38.0)
- default.qubit.legacy (PennyLane-0.38.0)
- default.qubit.tf (PennyLane-0.38.0)
- default.qubit.torch (PennyLane-0.38.0)
- default.qutrit (PennyLane-0.38.0)
- default.qutrit.mixed (PennyLane-0.38.0)
- default.tensor (PennyLane-0.38.0)
- null.qubit (PennyLane-0.38.0)
- nvidia.custatevec (PennyLane-Catalyst-0.8.1)
- nvidia.cutensornet (PennyLane-Catalyst-0.8.1)
- oqc.cloud (PennyLane-Catalyst-0.8.1)
- softwareq.qpp (PennyLane-Catalyst-0.8.1)
- lightning.qubit (PennyLane-Lightning-0.38.0)