Memory leak upon switching from jax.jit to catalyst.qjit

Hi there!

I’m trying to use Pennylane-Catalyst to speedup training of my pennylane-jax based quantum-classical QLSTM for time-series prediction.

With jax’s jit decorator, the code doesn’t have any issue, however it’s quite slow to train.
Upon switching the decorator used on the quantum circuit method from jax.jit to catalyst.qjit, the code runs about 6 times faster, however the problem is that memory usage starts accumulating, and an Out Of Memory error eventually occurs.
I’ve tried using python’s garbage collecting methods to solve this problem, but this has remained unsuccessful.

##CODE BELOW##


import pandas as pd
import numpy as np
import jax
import jax.numpy as jnp
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from jax import random, lax
import optax
import time
import threading
import psutil
import GPUtil
from tqdm import tqdm
import gc
import pennylane as qml
import catalyst


def monitor_resources(interval=5):
    while True:
        try:
            memory = psutil.virtual_memory()
            print(f"Memory Usage: {memory.percent}% | Available Memory: {memory.available / (1024 ** 3):.2f} GB")
            gpus = GPUtil.getGPUs()
            for gpu in gpus:
                print(f"GPU: {gpu.name}, Memory Used: {gpu.memoryUsed}MB / {gpu.memoryTotal}MB")
            time.sleep(interval)
        except Exception as e:
            print(f"Monitoring Error: {e}")


threading.Thread(target=monitor_resources, daemon=True).start()

# Load and preprocess data
df = pd.read_csv('electricity_prediction_1_96step.csv').drop(columns=['index', 'Unnamed: 0']).iloc[1900:2400].reset_index(drop=True)
df = df.iloc[:, :-72]  # prediction length 24

# Feature and target separation
target_columns = [col for col in df.columns if col.startswith('OT_lead')]
feature_columns = [col for col in df.columns if col not in target_columns]

# Split data
train_size = int(len(df) * 0.6)
df_train, df_val = df[:train_size].copy(), df[train_size:].copy()

# Scale data
scaler = StandardScaler()
train_scaled = pd.DataFrame(scaler.fit_transform(df_train), columns=df.columns)
val_scaled = pd.DataFrame(scaler.transform(df_val), columns=df.columns)

# PCA for dimensionality reduction
scaler_pca = StandardScaler()
train_features_scaled = scaler_pca.fit_transform(df_train[feature_columns])
pca = PCA(n_components='mle')
train_features_pca = pca.fit_transform(train_features_scaled)
important_features = sorted(set(df.columns[np.abs(pca.components_).argmax(axis=1)]))

# Prepare final datasets
train_inputs = jnp.array(train_scaled[important_features].values)
val_inputs = jnp.array(val_scaled[important_features].values)
train_targets = jnp.array(train_scaled[target_columns].values)
val_targets = jnp.array(val_scaled[target_columns].values)

# Function to create time-series sequences
def create_sequences(data, targets, sequence_length):
    inputs = [data[i:i + sequence_length] for i in range(len(data) - sequence_length)]
    target_values = [targets[i + sequence_length] for i in range(len(data) - sequence_length)]
    return jnp.array(inputs), jnp.array(target_values)

sequence_length = 24
train_seq_inputs, train_seq_labels = create_sequences(train_inputs, train_targets, sequence_length)
val_seq_inputs, val_seq_labels = create_sequences(val_inputs, val_targets, sequence_length)

# Initialize parameters
def initialize_params(shape, key):
    return random.normal(key, shape)

key = random.PRNGKey(0)
hidden_units, n_qubits, qlayers = 16, 4, 3
params = {
    'input': initialize_params((len(important_features) + hidden_units, n_qubits), key),
    'W_f': initialize_params((qlayers, 3), key),
    'W_i': initialize_params((qlayers, 3), key),
    'W_u': initialize_params((qlayers, 3), key),
    'W_o': initialize_params((qlayers, 3), key),
    'output': initialize_params((n_qubits, hidden_units), key),
    'final': initialize_params((hidden_units, 1), key)
}

learning_rate = 0.00004
optimizer = optax.adam(learning_rate)

# Unitary Ansatze for Convolutional Layer
def U_TTN(params, wires):  # 1 params
    qml.RY(params, wires=wires[0])
    qml.RY(params, wires=wires[1])
    qml.CNOT(wires=[wires[0], wires[1]])

def ansatz(params, wireset):
    U_TTN(params[0], wires=[wireset[0], wireset[1]])
    U_TTN(params[0], wires=[wireset[2], wireset[3]])
    U_TTN(params[1], wires=[wireset[0], wireset[2]])
    U_TTN(params[1], wires=[wireset[1], wireset[3]])
    U_TTN(params[2], wires=[wireset[1], wireset[2]])
    U_TTN(params[2], wires=[wireset[0], wireset[3]])

def encode_ry(i, ry_params):
    qml.Hadamard(wires=i)
    qml.RY(ry_params[i], wires=i)

# Quantum LSTM functions
# Here using @catalyst.qjit instead of @jax.jit results in an OOM issue

@jax.jit
@qml.qnode(qml.device("lightning.qubit", wires=n_qubits), interface="jax", diff_method="adjoint")
def quantum_circuit(params, inputs):
    ry_params = jnp.arctan(inputs)
    for i in range(n_qubits):
        qml.Hadamard(wires=i)
        qml.RY(ry_params[i], wires=i)
    qml.layer(ansatz, qlayers, params, wireset=range(n_qubits))
    return tuple(qml.expval(qml.PauliZ(wires=i)) for i in range(n_qubits))

@jax.jit
def apply_gates(params, v_t):

    f_t = jax.nn.sigmoid(jnp.dot(jnp.array(quantum_circuit(params["W_f"], v_t)), params["output"]))
    i_t = jax.nn.sigmoid(jnp.dot(jnp.array(quantum_circuit(params["W_i"], v_t)), params["output"]))
    u_t = jax.nn.tanh(jnp.dot(jnp.array(quantum_circuit(params["W_u"], v_t)), params["output"]))
    o_t = jax.nn.sigmoid(jnp.dot(jnp.array(quantum_circuit(params["W_o"], v_t)), params["output"]))

    return f_t, i_t, u_t, o_t

@jax.jit
def mse_loss_qlstm(params, inputs, targets):
    hidden_state, cell_state = jnp.zeros(hidden_units), jnp.zeros(hidden_units)
    predicted_values = jnp.zeros(inputs.shape[0])

    def lstm_step(time, carry):
        hidden, cell, predictions = carry
        x_t = inputs[time]
        v_t = jnp.dot(jnp.concatenate([x_t, hidden], axis=-1), params["input"])

        f_t, i_t, u_t, o_t = apply_gates(params, v_t)
        cell = f_t * cell + i_t * u_t
        hidden = o_t * jax.nn.tanh(cell)

        y = jnp.squeeze(jnp.dot(hidden, params["final"]))
        predictions = predictions.at[time].set(y)
        return hidden, cell, predictions

    hidden_state, cell_state, predicted_values = lax.fori_loop(0, inputs.shape[0], lstm_step, (hidden_state, cell_state, predicted_values))
    loss = jnp.mean((predicted_values - targets) ** 2)
    return loss


def train_epoch(params, train_inputs, train_labels, optimizer, opt_state, pbar):
    train_loss = 0.0

    for i in range(train_inputs.shape[0]):
        inputs = train_inputs[i]
        labels = train_labels[i]

        try:
            loss_value, grads = jax.block_until_ready(
                jax.value_and_grad(mse_loss_qlstm)(
                    params, inputs, labels
                )
            )
            updates, opt_state = optimizer.update(grads, opt_state)
            params = optax.apply_updates(params, updates)
            train_loss += loss_value

            pbar.set_postfix({"Train Loss": loss_value.item()})
            pbar.update(1)

        finally:
            del loss_value, grads, updates

    return params, opt_state, train_loss

def validate_epoch(params, val_inputs, val_labels, pbar):
    val_loss = 0.0

    for j in range(val_inputs.shape[0]):
        inputs = val_inputs[j]
        labels = val_labels[j]

        loss_value = jax.block_until_ready(mse_loss_qlstm(params, inputs, labels))
        val_loss += loss_value

        pbar.set_postfix({"Validation Loss": loss_value.item()})
        pbar.update(1)

        del loss_value

    return val_loss

# Train function
def train_qlstm(params, train_inputs, train_labels, val_inputs, val_labels, epochs=50):
    opt_state = optimizer.init(params)
    for epoch in range(epochs):
        with tqdm(desc=f"Epoch {epoch+1}/{epochs}", unit="batch") as pbar:
            params, opt_state, train_loss = train_epoch(params, train_inputs, train_labels, optimizer, opt_state, pbar)
            avg_train_loss = train_loss / train_inputs.shape[0]
            val_loss = validate_epoch(params, val_inputs, val_labels, pbar)
            avg_val_loss = val_loss / val_inputs.shape[0]
            print(f"Epoch {epoch+1} - Train Loss: {avg_train_loss:.4f}, Validation Loss: {avg_val_loss:.4f}")

            gc.collect()
            jax.clear_caches()

    return params

# Start training
trained_params = train_qlstm(params, train_seq_inputs, train_seq_labels, val_seq_inputs, val_seq_labels)

##OUTPUT WHEN USING @jax.jit## (Memory usage does not accumulate)

Memory Usage: 2.7% | Available Memory: 122.43 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9320.0MB / 12288.0MB
Epoch 1/50: 1batch [00:08, 8.79s/batch, Train Loss=1.83]Memory Usage: 2.8% | Available Memory: 122.30 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9328.0MB / 12288.0MB
Epoch 1/50: 3batch [00:16, 5.04s/batch, Train Loss=1.79]Memory Usage: 2.8% | Available Memory: 122.30 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9328.0MB / 12288.0MB
Epoch 1/50: 4batch [00:20, 4.60s/batch, Train Loss=1.82]Memory Usage: 2.8% | Available Memory: 122.30 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9328.0MB / 12288.0MB
Epoch 1/50: 5batch [00:24, 4.30s/batch, Train Loss=1.85]Memory Usage: 2.8% | Available Memory: 122.30 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9328.0MB / 12288.0MB
Epoch 1/50: 7batch [00:32, 4.03s/batch, Train Loss=1.96]Memory Usage: 2.8% | Available Memory: 122.30 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9328.0MB / 12288.0MB
Epoch 1/50: 8batch [00:36, 3.99s/batch, Train Loss=1.89]Memory Usage: 2.8% | Available Memory: 122.30 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9328.0MB / 12288.0MB
Epoch 1/50: 9batch [00:40, 4.17s/batch, Train Loss=1.94]Memory Usage: 2.8% | Available Memory: 122.31 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9328.0MB / 12288.0MB
Epoch 1/50: 10batch [00:44, 4.14s/batch, Train Loss=2.01]Memory Usage: 2.8% | Available Memory: 122.31 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9328.0MB / 12288.0MB
Epoch 1/50: 12batch [00:52, 4.05s/batch, Train Loss=1.89]Memory Usage: 2.8% | Available Memory: 122.31 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9328.0MB / 12288.0MB
Epoch 1/50: 13batch [00:56, 4.01s/batch, Train Loss=1.78]Memory Usage: 2.8% | Available Memory: 122.33 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9328.0MB / 12288.0MB
Epoch 1/50: 14batch [01:00, 4.02s/batch, Train Loss=1.7]Memory Usage: 2.8% | Available Memory: 122.32 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9328.0MB / 12288.0MB
Epoch 1/50: 16batch [01:08, 3.95s/batch, Train Loss=1.52]Memory Usage: 2.8% | Available Memory: 122.33 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9328.0MB / 12288.0MB
Epoch 1/50: 17batch [01:12, 3.91s/batch, Train Loss=1.39]Memory Usage: 2.8% | Available Memory: 122.32 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9328.0MB / 12288.0MB
Epoch 1/50: 18batch [01:16, 3.91s/batch, Train Loss=1.34]Memory Usage: 2.8% | Available Memory: 122.32 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9328.0MB / 12288.0MB

##OUTPUT WHEN USING @catalyst.qjit## (Memory usage accumulates)

Epoch 1/50: 8batch [00:13, 1.20batch/s, Train Loss=2.36]Memory Usage: 3.4% | Available Memory: 121.50 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9324.0MB / 12288.0MB
Epoch 1/50: 16batch [00:18, 1.65batch/s, Train Loss=2.13]Memory Usage: 3.4% | Available Memory: 121.48 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9324.0MB / 12288.0MB
Epoch 1/50: 25batch [00:24, 1.67batch/s, Train Loss=1.82]Memory Usage: 3.4% | Available Memory: 121.60 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9324.0MB / 12288.0MB
Epoch 1/50: 33batch [00:29, 1.68batch/s, Train Loss=1.69]Memory Usage: 3.4% | Available Memory: 121.59 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9324.0MB / 12288.0MB
Epoch 1/50: 41batch [00:34, 1.59batch/s, Train Loss=1.9]Memory Usage: 3.4% | Available Memory: 121.57 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9324.0MB / 12288.0MB
Epoch 1/50: 50batch [00:39, 1.70batch/s, Train Loss=3.87]Memory Usage: 3.4% | Available Memory: 121.55 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9324.0MB / 12288.0MB
Epoch 1/50: 58batch [00:44, 1.65batch/s, Train Loss=3.8]Memory Usage: 3.4% | Available Memory: 121.52 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9324.0MB / 12288.0MB
Epoch 1/50: 66batch [00:49, 1.70batch/s, Train Loss=5.02]Memory Usage: 3.4% | Available Memory: 121.50 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9324.0MB / 12288.0MB
Epoch 1/50: 74batch [00:54, 1.54batch/s, Train Loss=3.94]Memory Usage: 3.4% | Available Memory: 121.48 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9324.0MB / 12288.0MB
Epoch 1/50: 83batch [00:59, 1.69batch/s, Train Loss=3.04]Memory Usage: 3.4% | Available Memory: 121.48 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9324.0MB / 12288.0MB
Epoch 1/50: 91batch [01:04, 1.70batch/s, Train Loss=2.88]Memory Usage: 3.5% | Available Memory: 121.45 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9324.0MB / 12288.0MB
Epoch 1/50: 99batch [01:09, 1.62batch/s, Train Loss=3.11]Memory Usage: 3.5% | Available Memory: 121.42 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9324.0MB / 12288.0MB
Epoch 1/50: 108batch [01:15, 1.69batch/s, Train Loss=2.83]Memory Usage: 3.5% | Available Memory: 121.40 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9324.0MB / 12288.0MB
Epoch 1/50: 116batch [01:20, 1.69batch/s, Train Loss=0.857]Memory Usage: 3.5% | Available Memory: 121.38 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9324.0MB / 12288.0MB
Epoch 1/50: 124batch [01:25, 1.47batch/s, Train Loss=1.14]Memory Usage: 3.5% | Available Memory: 121.36 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9324.0MB / 12288.0MB
Epoch 1/50: 132batch [01:30, 1.68batch/s, Train Loss=1.13]Memory Usage: 3.6% | Available Memory: 121.34 GB
GPU: NVIDIA TITAN Xp, Memory Used: 9324.0MB / 12288.0MB
Epoch 1/50: 136batch [01:32, 1.66batch/s, Train Loss=1.63]

heres my qml.about():
Home-page: GitHub - PennyLaneAI/pennylane: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Author:
Author-email:
License: Apache License 2.0
Location: /home/sachin/miniconda3/envs/quantum_lstm_env/lib/python3.9/site-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, toml, typing-extensions
Required-by: PennyLane-Catalyst, PennyLane_Lightning

Platform info: Linux-6.8.0-48-generic-x86_64-with-glibc2.39
Python version: 3.9.20
Numpy version: 1.26.4
Scipy version: 1.12.0
Installed devices:

  • default.clifford (PennyLane-0.38.0)
  • default.gaussian (PennyLane-0.38.0)
  • default.mixed (PennyLane-0.38.0)
  • default.qubit (PennyLane-0.38.0)
  • default.qubit.autograd (PennyLane-0.38.0)
  • default.qubit.jax (PennyLane-0.38.0)
  • default.qubit.legacy (PennyLane-0.38.0)
  • default.qubit.tf (PennyLane-0.38.0)
  • default.qubit.torch (PennyLane-0.38.0)
  • default.qutrit (PennyLane-0.38.0)
  • default.qutrit.mixed (PennyLane-0.38.0)
  • default.tensor (PennyLane-0.38.0)
  • null.qubit (PennyLane-0.38.0)
  • nvidia.custatevec (PennyLane-Catalyst-0.8.1)
  • nvidia.cutensornet (PennyLane-Catalyst-0.8.1)
  • oqc.cloud (PennyLane-Catalyst-0.8.1)
  • softwareq.qpp (PennyLane-Catalyst-0.8.1)
  • lightning.qubit (PennyLane-Lightning-0.38.0)

Heres a drive link to the csv file with the dataset:

https://drive.google.com/file/d/1cKA1rKkaINpCEqkTbCHIW1KHM9QNaG0g/view?usp=sharing

Hi @Sachin_Reddy , welcome to the Forum!

It’s great to see that you’re using Catalyst! Let me check with the team to see what can be happening.

Can you please post the output of qml.about()?
And are you able to share a minimal reproducible version of your code? This can help our team find the issue faster.
Thanks!

heres my qml.about():
repro.py (3.4 KB)

Author:
Author-email:
License: Apache License 2.0
Location: /home/sachin/miniconda3/envs/quantum_lstm_env/lib/python3.9/site-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, toml, typing-extensions
Required-by: PennyLane-Catalyst, PennyLane_Lightning

Platform info: Linux-6.8.0-48-generic-x86_64-with-glibc2.39
Python version: 3.9.20
Numpy version: 1.26.4
Scipy version: 1.12.0
Installed devices:

  • default.clifford (PennyLane-0.38.0)
  • default.gaussian (PennyLane-0.38.0)
  • default.mixed (PennyLane-0.38.0)
  • default.qubit (PennyLane-0.38.0)
  • default.qubit.autograd (PennyLane-0.38.0)
  • default.qubit.jax (PennyLane-0.38.0)
  • default.qubit.legacy (PennyLane-0.38.0)
  • default.qubit.tf (PennyLane-0.38.0)
  • default.qubit.torch (PennyLane-0.38.0)
  • default.qutrit (PennyLane-0.38.0)
  • default.qutrit.mixed (PennyLane-0.38.0)
  • default.tensor (PennyLane-0.38.0)
  • null.qubit (PennyLane-0.38.0)
  • nvidia.custatevec (PennyLane-Catalyst-0.8.1)
  • nvidia.cutensornet (PennyLane-Catalyst-0.8.1)
  • oqc.cloud (PennyLane-Catalyst-0.8.1)
  • softwareq.qpp (PennyLane-Catalyst-0.8.1)
  • lightning.qubit (PennyLane-Lightning-0.38.0)

and attached is a more minimal version of the code that has the same memory leakage issue.

Hi @Sachin_Reddy , thanks for sharing this info.

The team confirmed that this does look like a memory leak. We’ll look into it over the next few weeks. We’ll let you know when we have additional information.

Thanks again for pointing this out and helping improve Catalyst! Let us know if you have other feedback or ideas for improvement.

1 Like

Hi @Sachin_Reddy , thanks for reporting the issue!

I took a look and experimented a bit, and noticed that without applying any jit decorators at all, either jax or catalyst, the memory leak still persists. This is true regardless of whether the device is default.qubit or lightning.qubit.

Can you reproduce this behavior? Just so I can confirm it’s not a peculiarity of my machine. Thanks!

Hi @Sachin_Reddy , the team has been investigating this issue. We are unable to confirm if this is a memory leak (unfreed, inaccessible allocations) or some form of internal caching that is not cleared with JAX (unfreed, accessible allocations). To alleviate the memory usage, we suggest two potential workarounds:
(a) Use jax.clear_caches() after calling the jitted entry point function to clear the compiled objects. To retain some performance, you might want to clear caches once every few calls instead of every call.
(b) Split the computation into multiple scripts, so the all the used memory is forcefully cleaned after a script finishes.

If none of the above helps, it might also be helpful to investigate with jax’s memory profiler: Profiling device memory — JAX documentation

Please don’t hesitate to reach out if there’s any other issues.

1 Like