Facing Issues with JAX jitting the Optimization loop

Currently, I am working on a regression problem and after several months of trial and error and research, I was able to perform my task efficiently wrt time using JAX-JIT. However, when I was going through How to optimize a QML model using JAX and Optax I came across jitting the optimization loop and when I ran that notebook in my python environment, it ran seamlessly and the optimization jitting showed further decrease in time. I tried to apply the same method but I ended up facing few errors. Moreover, in that tutorial, they vectorized the data embedding and hence the whole circuit execution is vectorized, which is different from what I am doing (Is that necessary to jit the optimization loop??)

To give an idea about my input and output, my input is of the size (10000,128) which I split into train, validation and test sets. The input is 10000 different realization of some curve (let’s say Gaussian curve) and the output is the parameters characterizing these curves which is of the size (10000,3) i.e., 3 parameters for each curve.

MY MODEL

dev = qml.device("default.qubit", wires=n_qubits)

@qml.qnode(dev,interface="jax",diff_method='backprop')
def circ(weights,inputs):
    qml.AmplitudeEmbedding(features=inputs, wires=range(n_qubits),normalize=True,pad_with=0.5)
    StronglyEntanglingLayers(weights, wires=range(n_qubits),imprimitive=qml.ops.CNOT)
    out =  [qml.expval(qml.PauliZ(i)+qml.PauliZ(i+1)) for i in range(n_qubits-1)[0::2]]
    return out

def qnn(params,inputs):
    weights = params["weights"]
    bias = params["bias"]
    circ_out = circ(weights,inputs)
    out = []
    for i in range(len(circ_out)):
        out.append(circ_out[i] + bias[i])
    return out 

def mse(observed,predictions):
  loss = jnp.sum((observed - predictions) ** 2 / len(observed))
  return jnp.mean(loss)

def predict(params,features):
    preds = qnn(params,features)
    preds_np = jnp.asarray(preds).T
    return preds_np

batched_predict = jax.vmap(predict, in_axes=(None, 0))

def cost(params, features, observed):
    preds = batched_predict(params,features)
    cost = mse(observed, preds)
    return cost

@jax.jit
def jit_cost(params, features, observed):
    preds = batched_predict(params,features)
    cost = mse(observed, preds)
    return cost

APPROACH-1

num_train = len(x_train)
num_val =len(x_val)
@jax.jit
def update_step_jit(i, args):
    key_t,params, opt_state, x_train, y_train, x_val, y_val,opt,batch_size,print_training = args
    key_t, key_v = random.split(key_t)
    idx_train = random.choice(key_t,num_train,shape=(batch_size,))
    x_train_batch = jnp.asarray(x_train[idx_train])
    y_train_batch = jnp.asarray(y_train[idx_train])
    x_val_batch = jnp.asarray(x_val)
    y_val_batch = jnp.asarray(y_val)
    start_time = time.time()
    train_cost, grads = jax.value_and_grad(jit_cost)(params, x_train_batch, y_train_batch)
    updates, opt_state = opt.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    end_time = time.time()
    val_cost = jit_cost(params,x_val_batch,y_val_batch)
    epoch_time = end_time - start_time
    jax.debug.print("Epoch: {:5d} | Training_Loss: {:0.7f} | Val_Loss: {:0.7f} ".format(i+1, train_cost,train_cost))

    train_loss = jnp.append(train_loss,train_cost)
    val_loss = jnp.append(val_loss,val_cost)
    return (params, opt_state,train_loss,train_loss)

@jax.jit
def fit_jit(params, features, observed,features_val, observed_val, batch_size, print_training=False):
    opt = optax.adam(learning_rate=0.01)
    opt_state = opt.init(params)
    key_t = random.PRNGKey(np.random.randint(0,1e4))
    args = (key_t,params, opt_state, features, observed,features_val, observed_val,opt, batch_size, print_training)
    (params, opt_state,train_loss,val_loss) = jax.lax.fori_loop(0, 300, update_step_jit, args)

    return params, train_loss, val_loss

I run the fit_jit function to train my circuit,

params,train_loss,val_loss = fit_jit(params,x_train,y_train,x_val,y_val,512,print_training=True)

Error faced:

----> 1 params,train_loss,val_loss = fit_jit(params,x_train,y_train,x_val,y_val,512,print_training=True)

    [... skipping hidden 12 frame]

/home/akashg/QNN21cm/Codes/AE_QNN-JAX.ipynb Cell 35 line 3
     33 key_t = random.PRNGKey(np.random.randint(0,1e4))
     34 args = (key_t,params, opt_state, features, observed,features_val, observed_val,opt, batch_size, print_training)
---> 35 (params, opt_state,train_loss,val_loss) = jax.lax.fori_loop(0, 300, update_step_jit, args)
     37 return params, train_loss, val_loss

    [... skipping hidden 7 frame]

File ~/anaconda3/envs/py3.10/lib/python3.10/site-packages/jax/_src/core.py:1472, in concrete_aval(x)
   1470 if hasattr(x, '__jax_array__'):
   1471   return concrete_aval(x.__jax_array__())
-> 1472 raise TypeError(f"Value {x!r} with type {type(x)} is not a valid JAX "
   1473                  "type")

TypeError: Value <function chain.<locals>.init_fn at 0x7f838159a4d0> with type <class 'function'> is not a valid JAX type

APPROACH-2
I tried this approach to check If I get the same error but turns out I get a different one and I don’t understand what the issue is.

@jax.jit
def update_step_jit(i, args):
    params, opt_state, data, targets = args

    train_cost, grads = jax.value_and_grad(jit_cost)(params, data, targets)
    updates, opt_state = opt.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    jax.debug.print("Step: {i}  Loss: {loss_val}", i=i, loss_val=train_cost)

    return (params, opt_state )

@jax.jit
def optimization_jit(params, data, targets):
    opt = optax.adam(learning_rate=0.1)
    opt_state = opt.init(params)

    args = (params, opt_state, data, targets)
    (params, opt_state) = jax.lax.fori_loop(0, 100, update_step_jit, args)

    return params

params_new = optimization_jit(params,x_train,y_train)

Error Faced:

----> 1 params_new = optimization_jit(params,x_train,y_train)

    [... skipping hidden 12 frame]

/home/akashg/QNN21cm/Codes/AE_QNN-JAX.ipynb Cell 33 line 2
     20 opt_state = opt.init(params)
     22 args = (params, opt_state, data, targets)
---> 23 (params, opt_state) = jax.lax.fori_loop(0, 100, update_step_jit, args)
     25 return params

    [... skipping hidden 4 frame]

File ~/anaconda3/envs/py3.10/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py:330, in _check_scan_carry_type(body_fun, in_carry, out_carry_tree, out_avals)
    324   else:
    325     differences = '\n'.join(
    326         f'  * {component(path)} is a {thing1} but the corresponding component '
    327         f'of the carry output is a {thing2}, so {explanation}\n'
    328         for path, thing1, thing2, explanation
    329         in equality_errors(in_carry, out_carry))
--> 330   raise TypeError(
    331       "Scanned function carry input and carry output must have the same "
    332       "pytree structure, but they differ:\n"
    333       f"{differences}\n"
    334       "Revise the scanned function so that its output is a pair where the "
    335       "first element has the same pytree structure as the first argument."
    336   )
    337 if not all(_map(core.typematch, in_avals, out_avals)):
    338   differences = '\n'.join(
    339       f'  * {component(path)} has type {in_aval.str_short()}'
    340       ' but the corresponding output carry component has type '
    341       f'{out_aval.str_short()}{_aval_mismatch_extra(in_aval, out_aval)}\n'
    342       for path, in_aval, out_aval in zip(paths, in_avals, out_avals)
    343       if not core.typematch(in_aval, out_aval))

TypeError: Scanned function carry input and carry output must have the same pytree structure, but they differ:
  * the input carry component loop_carry[1] is a tuple of length 4 but the corresponding component of the carry output is a tuple of length 2, so the lengths do not match

Revise the scanned function so that its output is a pair where the first element has the same pytree structure as the first argument.

This is my qml.about(),

Name: PennyLane
Version: 0.35.0
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page: https://github.com/PennyLaneAI/pennylane
Author: 
Author-email: 
License: Apache License 2.0
Location: /home/akashg/anaconda3/envs/py3.10/lib/python3.10/site-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, pennylane-lightning, requests, rustworkx, scipy, semantic-version, toml, typing-extensions
Required-by: PennyLane_Lightning

Platform info:           Linux-3.10.0-1160.88.1.el7.x86_64-x86_64-with-glibc2.17
Python version:          3.10.13
Numpy version:           1.26.4
Scipy version:           1.12.0
Installed devices:
- lightning.qubit (PennyLane_Lightning-0.35.0)
- default.clifford (PennyLane-0.35.0)
- default.gaussian (PennyLane-0.35.0)
- default.mixed (PennyLane-0.35.0)
- default.qubit (PennyLane-0.35.0)
- default.qubit.autograd (PennyLane-0.35.0)
- default.qubit.jax (PennyLane-0.35.0)
- default.qubit.legacy (PennyLane-0.35.0)
- default.qubit.tf (PennyLane-0.35.0)
- default.qubit.torch (PennyLane-0.35.0)
- default.qutrit (PennyLane-0.35.0)
- null.qubit (PennyLane-0.35.0)

Note: I am completely new to JAX-JIT so I might have overlooked something, so kindly go through and help me out. Thank you.

Hey @G_Akash,

Can you attach the full code so that I can copy-paste and try to replicate the issue on my side? In the mean time, maybe you can try out PennyLane-Catalyst in place of Jax’s native jitting! Introducing Catalyst: quantum just-in-time compilation | PennyLane Blog

The complete code:

import time

import pennylane as qml
from pennylane import numpy as pnp

import jax
from jax import numpy as jnp
from jax import random
import optax

 jax.config.update('jax_platform_name', 'gpu')
jax.config.update("jax_enable_x64", True)

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mp
from sklearn.model_selection import train_test_split

#Example data:
X = np.random.rand(10000,128)
Y = np.random.rand(10000,3)

def data_split(X,Y,test_size=0.2):
    x_train,x_1,y_train,y_1 = train_test_split(X,Y,test_size = test_size,random_state = 2)
    x_val,x_test,y_val,y_test = train_test_split(x_1,y_1,test_size = 0.5,random_state = 2)
    x_train, y_train = shuffle(x_train, y_train, random_state=2)
    x_train = jnp.array(x_train)
    y_train = jnp.array(y_train)
    x_val = jnp.array(x_val)
    y_val = jnp.array(y_val)
    x_test = jnp.array(x_test)
    y_test = jnp.array(y_test)
    return x_train,x_val,x_test,y_train,y_val,y_test

def params_init(n_layer,n_qubits,Y):
    key = jax.random.PRNGKey(np.random.randint(0,1e4))
    var_init = jax.random.uniform(key,(n_layer,n_qubits,3),minval=0,maxval=1)
    bias_init = jnp.zeros(Y.shape[1])

    params = {"weights": var_init, "bias": bias_init}
    return params

n_qubits = int(np.log2(X.shape[1]))
n_layer = 30
opt = optax.adam(learning_rate=0.05)
params = params_init(n_layer,n_qubits,Y) 
opt_state = opt.init(params)
x_train,x_val,x_test,y_train,y_val,y_test = data_split(X,Y,0.2)

dev = qml.device("default.qubit", wires=n_qubits)

@qml.qnode(dev,interface="jax",diff_method='backprop')
def circ(weights,inputs):
    qml.AmplitudeEmbedding(features=inputs, wires=range(n_qubits),normalize=True,pad_with=0.5)
    StronglyEntanglingLayers(weights, wires=range(n_qubits),imprimitive=qml.ops.CNOT)
    out =  [qml.expval(qml.PauliZ(i)+qml.PauliZ(i+1)) for i in range(n_qubits-1)[0::2]]
    return out

def qnn(params,inputs):
    weights = params["weights"]
    bias = params["bias"]
    circ_out = circ(weights,inputs)
    out = []
    for i in range(len(circ_out)):
        out.append(circ_out[i] + bias[i])
    return jnp.array(out) 

def mse(observed,predictions):
  loss = jnp.sum((observed - predictions) ** 2 / len(observed))
  return jnp.mean(loss)

def predict(params,features):
    preds = qnn(params,features)
    preds_np = jnp.asarray(preds).T
    return preds_np

batched_predict = jax.vmap(predict, in_axes=(None, 0))

def cost(params, features, observed):
    preds = batched_predict(params,features)
    cost = mse(observed, preds)
    return cost

@jax.jit
def jit_cost(params, features, observed):
    preds = batched_predict(params,features)
    cost = mse(observed, preds)
    return cost

num_train = len(x_train)
num_val =len(x_val)
@jax.jit
def update_step_jit(i, args):
    key_t,params, opt_state, x_train, y_train, x_val, y_val,opt,batch_size = args
    key_t, key_v = random.split(key_t)
    idx_train = random.choice(key_t,num_train,shape=(batch_size,))
    x_train_batch = jnp.asarray(x_train[idx_train])
    y_train_batch = jnp.asarray(y_train[idx_train])
    x_val_batch = jnp.asarray(x_val)
    y_val_batch = jnp.asarray(y_val)
    start_time = time.time()
    train_cost, grads = jax.value_and_grad(jit_cost)(params, x_train_batch, y_train_batch)
    updates, opt_state = opt.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    end_time = time.time()
    val_cost = jit_cost(params,x_val_batch,y_val_batch)
    epoch_time = end_time - start_time
    jax.debug.print("Epoch: {i} | ETA: {epoch_time} | Train_Loss: {train_cost} | Val_Loss: {val_cost}", i=i, epoch_time=epoch_time, train_cost=train_cost,val_loss=val_cost)

    train_loss = jnp.append(train_loss,train_cost)
    val_loss = jnp.append(val_loss,val_cost)
    return (params, opt_state,train_loss,train_loss)

@jax.jit
def fit_jit(params, features, observed,features_val, observed_val, batch_size):
    opt = optax.adam(learning_rate=0.01)
    opt_state = opt.init(params)
    key_t = random.PRNGKey(np.random.randint(0,1e4))
    args = (key_t,params, opt_state, features, observed,features_val, observed_val,opt, batch_size)
    (params, opt_state,train_loss,val_loss) = jax.lax.fori_loop(0, 300, update_step_jit, args)

    return params, train_loss, val_loss

params,train_loss,val_loss = fit_jit(params,x_train,y_train,x_val,y_val,512)

Also, regarding the catalyst implementation, I am dealing with that in another discussion thread :grin:

Hi @G_Akash! Your second approach looks to be on the right track. However, your update_step_jit function must have the same input and output shapes, i.e., so that its output matches the shape of args. I managed to get things working by updating the following.

  • The return of update_step_jit:

    return (params, opt_state, data, targets)
    
  • The jax.lax.fori_loop line:

    (params, opt_state, data, targets) = jax.lax.fori_loop(0, 100, update_step_jit, args)
    

Let us know if this solves your problem!

Hey @Tom_Bromley, thanks for the suggestion. It worked perfectly well. So only the shapes must match, not necessarily the input and output variables, right? I am attaching a code below which is a normal optimization loop using for loop with no jitting and I want to replicate the same functionality and return the said outputs.

    @jax.jit
    def update_step(params, opt_state, features, observed):
        train_cost, grads = jax.value_and_grad(jit_cost)(params, features, observed)
        updates, opt_state = opt.update(grads, opt_state)
        params = optax.apply_updates(params, updates)
        return params, opt_state, train_cost

    def fit(params,opt_state,x_train,y_train,x_val,y_val,epoch,batch_size):    
        train_loss = []
        val_loss = []
        training_time = []
        num_train = len(x_train)
        num_val =len(x_val)
        key_t = random.PRNGKey(np.random.randint(0,1e4))
        for i in range(epoch):    
            key_t, key_v = random.split(key_t)
            idx_train = random.choice(key_t,num_train,shape=(batch_size,))
            idx_val = random.choice(key_v,num_val,shape=(batch_size,))
            x_train_batch = jnp.asarray(x_train[idx_train])
            y_train_batch = jnp.asarray(y_train[idx_train])
            x_val_batch = jnp.asarray(x_val)
            y_val_batch = jnp.asarray(y_val)
            start_time = time.time()
            params, opt_state, train_cost = update_step(params,opt_state, x_train_batch, y_train_batch)
            end_time = time.time()
            val_cost = jit_cost(params,x_val_batch,y_val_batch)
            epoch_time = end_time - start_time
            print("Epoch: {:5d} | Loss: {:0.7f} | Val_Loss: {:0.7f} | Time: {:0.4f} seconds".format(i+1, train_cost,val_cost,epoch_time))
            training_time.append(epoch_time)
            train_loss.append(train_cost)
            val_loss.append(val_cost)
        tot_time = np.sum(training_time)
        minutes =  tot_time// 60
        remaining_seconds = tot_time % 60        
        print(f" Training time = {minutes} minutes {round(remaining_seconds)} seconds.")    
        return params, train_loss,val_loss    

Could you give me an idea on how to convert this to jit optimization? I will try from my side as well and will let you know. Thank you.

Hey, so I changed the code a bit such that it replicates the one I posted in the above message.

num_train = len(x_train)
num_val =len(x_val)
@jax.jit
def update_step_jit(i, args):
    key_t,params, opt_state, x_train, y_train, x_val, y_val,opt,batch_size = args
    key_t, key_v = random.split(key_t)
    idx_train = random.choice(key_t,num_train,shape=(batch_size,))
    # idx_val = random.choice(key_v,num_val,shape=(batch_size,))
    x_train_batch = jnp.asarray(x_train[idx_train])
    y_train_batch = jnp.asarray(y_train[idx_train])
    x_val_batch = jnp.asarray(x_val)
    y_val_batch = jnp.asarray(y_val)
    start_time = time.time()
    train_cost, grads = jax.value_and_grad(jit_cost)(params, x_train_batch, y_train_batch)
    updates, opt_state = opt.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    end_time = time.time()
    val_cost = jit_cost(params,x_val_batch,y_val_batch)
    epoch_time = end_time - start_time
    # def print_fn():
    jax.debug.print("Epoch: {i} | ETA: {epoch_time} | Train_Loss: {train_cost} | Val_Loss: {val_cost}", i=i, epoch_time=epoch_time, train_cost=train_cost,val_loss=val_cost)

    training_time = np.append(training_time,epoch_time)
    train_loss = jnp.append(train_loss,train_cost)
    val_loss = jnp.append(val_loss,val_cost)
    return (params, opt_state,train_loss,val_loss,training_time,x_train, y_train, x_val, y_val)

@jax.jit
def fit_jit(params, x_train, y_train, x_val, y_val, epochs, batch_size):
    opt = optax.adam(learning_rate=0.05)
    opt_state = opt.init(params)
    key_t = random.PRNGKey(np.random.randint(0,1e4))
    args = (key_t,params, opt_state, x_train, y_train, x_val, y_val,opt, batch_size)
    (params, opt_state,train_loss,val_loss,training_time,x_train, y_train, x_val, y_val) = jax.lax.fori_loop(0, epochs, update_step_jit, args)
    tot_time = np.sum(training_time)
    minutes =  tot_time// 60
    remaining_seconds = tot_time % 60        
    print(f" Training time = {minutes} minutes {round(remaining_seconds)} seconds.")
    return params, train_loss, val_loss

params,train_loss,val_loss = fit_jit(params,x_train,y_train,x_val,y_val,epochs=300,batch_size=512)

However, this leads to an error as well

TypeError                                 Traceback (most recent call last)
Cell In[32], line 1
----> 1 params,train_loss,val_loss = fit_jit(params,x_train,y_train,x_val,y_val,epochs=300,batch_size=512)

    [... skipping hidden 12 frame]

Cell In[31], line 34
     32 key_t = random.PRNGKey(np.random.randint(0,1e4))
     33 args = (key_t,params, opt_state, x_train, y_train, x_val, y_val,opt, batch_size)
---> 34 (params, opt_state,train_loss,val_loss,training_time,x_train, y_train, x_val, y_val) = jax.lax.fori_loop(0, epochs, update_step_jit, args)
     35 tot_time = np.sum(training_time)
     36 minutes =  tot_time// 60

    [... skipping hidden 7 frame]

File ~/anaconda3/envs/py3.10/lib/python3.10/site-packages/jax/_src/core.py:1472, in concrete_aval(x)
   1470 if hasattr(x, '__jax_array__'):
   1471   return concrete_aval(x.__jax_array__())
-> 1472 raise TypeError(f"Value {x!r} with type {type(x)} is not a valid JAX "
   1473                  "type")

TypeError: Value <function chain.<locals>.init_fn at 0x7ff1b4453250> with type <class 'function'> is not a valid JAX type

Hey @Tom_Bromley, just wanted to let you know that I have rectified the issue, and it is working now. The other problem I am facing is slow compilation with increasing qubits, and no amount of GPU is helping in the reduction of the compilation time, on the contrary, the compilation time increases with an increase in GPU performance. Do you have any suggestions on how to rectify that?

Also, I would like to know how to modify my circuit architecture to make it less computationally intensive while simultaneously increasing its ability to understand the complex features in the dataset. Thank you

Hey @G_Akash,

Apologies for the wait :sweat_smile:.

… slow compilation with increasing qubits, and no amount of GPU is helping in the reduction of the compilation time, on the contrary, the compilation time increases with an increase in GPU performance.

One thing that might be causing this is the lack of jax control flow operations (see here: Writing TPU kernels with Pallas — JAX documentation). I would try using the native control flow operations and see how that goes! The other thing I can recommend is to use Catalyst. Might be worth revisiting the post you made here?

I would like to know how to modify my circuit architecture to make it less computationally intensive while simultaneously increasing its ability to understand the complex features in the dataset.

That’s the million dollar question! Unfortunately I don’t have a one-size-fits-all answer.