Currently, I am working on a regression problem and after several months of trial and error and research, I was able to perform my task efficiently wrt time using JAX-JIT. However, when I was going through How to optimize a QML model using JAX and Optax I came across jitting the optimization loop and when I ran that notebook in my python environment, it ran seamlessly and the optimization jitting showed further decrease in time. I tried to apply the same method but I ended up facing few errors. Moreover, in that tutorial, they vectorized the data embedding and hence the whole circuit execution is vectorized, which is different from what I am doing (Is that necessary to jit the optimization loop??)
To give an idea about my input and output, my input is of the size (10000,128) which I split into train, validation and test sets. The input is 10000 different realization of some curve (let’s say Gaussian curve) and the output is the parameters characterizing these curves which is of the size (10000,3) i.e., 3 parameters for each curve.
MY MODEL
dev = qml.device("default.qubit", wires=n_qubits)
@qml.qnode(dev,interface="jax",diff_method='backprop')
def circ(weights,inputs):
qml.AmplitudeEmbedding(features=inputs, wires=range(n_qubits),normalize=True,pad_with=0.5)
StronglyEntanglingLayers(weights, wires=range(n_qubits),imprimitive=qml.ops.CNOT)
out = [qml.expval(qml.PauliZ(i)+qml.PauliZ(i+1)) for i in range(n_qubits-1)[0::2]]
return out
def qnn(params,inputs):
weights = params["weights"]
bias = params["bias"]
circ_out = circ(weights,inputs)
out = []
for i in range(len(circ_out)):
out.append(circ_out[i] + bias[i])
return out
def mse(observed,predictions):
loss = jnp.sum((observed - predictions) ** 2 / len(observed))
return jnp.mean(loss)
def predict(params,features):
preds = qnn(params,features)
preds_np = jnp.asarray(preds).T
return preds_np
batched_predict = jax.vmap(predict, in_axes=(None, 0))
def cost(params, features, observed):
preds = batched_predict(params,features)
cost = mse(observed, preds)
return cost
@jax.jit
def jit_cost(params, features, observed):
preds = batched_predict(params,features)
cost = mse(observed, preds)
return cost
APPROACH-1
num_train = len(x_train)
num_val =len(x_val)
@jax.jit
def update_step_jit(i, args):
key_t,params, opt_state, x_train, y_train, x_val, y_val,opt,batch_size,print_training = args
key_t, key_v = random.split(key_t)
idx_train = random.choice(key_t,num_train,shape=(batch_size,))
x_train_batch = jnp.asarray(x_train[idx_train])
y_train_batch = jnp.asarray(y_train[idx_train])
x_val_batch = jnp.asarray(x_val)
y_val_batch = jnp.asarray(y_val)
start_time = time.time()
train_cost, grads = jax.value_and_grad(jit_cost)(params, x_train_batch, y_train_batch)
updates, opt_state = opt.update(grads, opt_state)
params = optax.apply_updates(params, updates)
end_time = time.time()
val_cost = jit_cost(params,x_val_batch,y_val_batch)
epoch_time = end_time - start_time
jax.debug.print("Epoch: {:5d} | Training_Loss: {:0.7f} | Val_Loss: {:0.7f} ".format(i+1, train_cost,train_cost))
train_loss = jnp.append(train_loss,train_cost)
val_loss = jnp.append(val_loss,val_cost)
return (params, opt_state,train_loss,train_loss)
@jax.jit
def fit_jit(params, features, observed,features_val, observed_val, batch_size, print_training=False):
opt = optax.adam(learning_rate=0.01)
opt_state = opt.init(params)
key_t = random.PRNGKey(np.random.randint(0,1e4))
args = (key_t,params, opt_state, features, observed,features_val, observed_val,opt, batch_size, print_training)
(params, opt_state,train_loss,val_loss) = jax.lax.fori_loop(0, 300, update_step_jit, args)
return params, train_loss, val_loss
I run the fit_jit function to train my circuit,
params,train_loss,val_loss = fit_jit(params,x_train,y_train,x_val,y_val,512,print_training=True)
Error faced:
----> 1 params,train_loss,val_loss = fit_jit(params,x_train,y_train,x_val,y_val,512,print_training=True)
[... skipping hidden 12 frame]
/home/akashg/QNN21cm/Codes/AE_QNN-JAX.ipynb Cell 35 line 3
33 key_t = random.PRNGKey(np.random.randint(0,1e4))
34 args = (key_t,params, opt_state, features, observed,features_val, observed_val,opt, batch_size, print_training)
---> 35 (params, opt_state,train_loss,val_loss) = jax.lax.fori_loop(0, 300, update_step_jit, args)
37 return params, train_loss, val_loss
[... skipping hidden 7 frame]
File ~/anaconda3/envs/py3.10/lib/python3.10/site-packages/jax/_src/core.py:1472, in concrete_aval(x)
1470 if hasattr(x, '__jax_array__'):
1471 return concrete_aval(x.__jax_array__())
-> 1472 raise TypeError(f"Value {x!r} with type {type(x)} is not a valid JAX "
1473 "type")
TypeError: Value <function chain.<locals>.init_fn at 0x7f838159a4d0> with type <class 'function'> is not a valid JAX type
APPROACH-2
I tried this approach to check If I get the same error but turns out I get a different one and I don’t understand what the issue is.
@jax.jit
def update_step_jit(i, args):
params, opt_state, data, targets = args
train_cost, grads = jax.value_and_grad(jit_cost)(params, data, targets)
updates, opt_state = opt.update(grads, opt_state)
params = optax.apply_updates(params, updates)
jax.debug.print("Step: {i} Loss: {loss_val}", i=i, loss_val=train_cost)
return (params, opt_state )
@jax.jit
def optimization_jit(params, data, targets):
opt = optax.adam(learning_rate=0.1)
opt_state = opt.init(params)
args = (params, opt_state, data, targets)
(params, opt_state) = jax.lax.fori_loop(0, 100, update_step_jit, args)
return params
params_new = optimization_jit(params,x_train,y_train)
Error Faced:
----> 1 params_new = optimization_jit(params,x_train,y_train)
[... skipping hidden 12 frame]
/home/akashg/QNN21cm/Codes/AE_QNN-JAX.ipynb Cell 33 line 2
20 opt_state = opt.init(params)
22 args = (params, opt_state, data, targets)
---> 23 (params, opt_state) = jax.lax.fori_loop(0, 100, update_step_jit, args)
25 return params
[... skipping hidden 4 frame]
File ~/anaconda3/envs/py3.10/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py:330, in _check_scan_carry_type(body_fun, in_carry, out_carry_tree, out_avals)
324 else:
325 differences = '\n'.join(
326 f' * {component(path)} is a {thing1} but the corresponding component '
327 f'of the carry output is a {thing2}, so {explanation}\n'
328 for path, thing1, thing2, explanation
329 in equality_errors(in_carry, out_carry))
--> 330 raise TypeError(
331 "Scanned function carry input and carry output must have the same "
332 "pytree structure, but they differ:\n"
333 f"{differences}\n"
334 "Revise the scanned function so that its output is a pair where the "
335 "first element has the same pytree structure as the first argument."
336 )
337 if not all(_map(core.typematch, in_avals, out_avals)):
338 differences = '\n'.join(
339 f' * {component(path)} has type {in_aval.str_short()}'
340 ' but the corresponding output carry component has type '
341 f'{out_aval.str_short()}{_aval_mismatch_extra(in_aval, out_aval)}\n'
342 for path, in_aval, out_aval in zip(paths, in_avals, out_avals)
343 if not core.typematch(in_aval, out_aval))
TypeError: Scanned function carry input and carry output must have the same pytree structure, but they differ:
* the input carry component loop_carry[1] is a tuple of length 4 but the corresponding component of the carry output is a tuple of length 2, so the lengths do not match
Revise the scanned function so that its output is a pair where the first element has the same pytree structure as the first argument.
This is my qml.about()
,
Name: PennyLane
Version: 0.35.0
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page: https://github.com/PennyLaneAI/pennylane
Author:
Author-email:
License: Apache License 2.0
Location: /home/akashg/anaconda3/envs/py3.10/lib/python3.10/site-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, pennylane-lightning, requests, rustworkx, scipy, semantic-version, toml, typing-extensions
Required-by: PennyLane_Lightning
Platform info: Linux-3.10.0-1160.88.1.el7.x86_64-x86_64-with-glibc2.17
Python version: 3.10.13
Numpy version: 1.26.4
Scipy version: 1.12.0
Installed devices:
- lightning.qubit (PennyLane_Lightning-0.35.0)
- default.clifford (PennyLane-0.35.0)
- default.gaussian (PennyLane-0.35.0)
- default.mixed (PennyLane-0.35.0)
- default.qubit (PennyLane-0.35.0)
- default.qubit.autograd (PennyLane-0.35.0)
- default.qubit.jax (PennyLane-0.35.0)
- default.qubit.legacy (PennyLane-0.35.0)
- default.qubit.tf (PennyLane-0.35.0)
- default.qubit.torch (PennyLane-0.35.0)
- default.qutrit (PennyLane-0.35.0)
- null.qubit (PennyLane-0.35.0)
Note: I am completely new to JAX-JIT so I might have overlooked something, so kindly go through and help me out. Thank you.