Hi @MelonLord ,
I changed the get_output
and loss_fn
just to make them suited for broadcasting (and easier to understand and compare for me).
I also changed some of the functions at the end to make them closer to our JAX-Optax demo (for comparison and debugging).
What I see is that when using the adam optimizer things works nicely but not with the L-BFGS optimizer for some reason. Are you specifically interested in using that optimizer?
On the other hand, maybe if you can share your Qiskit code I can think of ways to translating it to PennyLane. I can’t promise anything but it could bring new ideas.
import pandas as pd
import numpy as np
from sklearn.datasets import make_classification
import pennylane as qml
import jax
from jax import numpy as jnp
from math import pi
import optax
from sklearn.preprocessing import MinMaxScaler
jax.config.update('jax_enable_x64', True)
dev = qml.device("default.qubit", wires = 4)
X, y = make_classification(
n_samples=200,
n_features=4,
n_informative=4,
n_redundant=0,
n_repeated=0,
n_classes=2,
random_state=42
)
scaler = MinMaxScaler(feature_range=(0,1))
scaled_x = scaler.fit_transform(X)
x_train = scaled_x[:150]
x_test = scaled_x[150:]
y_train = y[:150]
y_test = y[150:]
@qml.qnode(dev)
def circuit(params, x):
for i in range(4):
qml.H(wires = i)
qml.PhaseShift(2*x[i], wires = i)
for i in range(1,4):
for j in range(i):
qml.CNOT([j,i])
qml.PhaseShift(2*(pi - x[i])*(pi - x[j]), wires = i)
qml.CNOT([j,i])
for i in range(4):
qml.RY(params[i], wires = i)
for layer in range(1, 6 + 1):
for i in range(3):
qml.CNOT([i,i+1])
for i in range(4):
qml.RY(params[layer*4 + i], wires = i)
return qml.probs(wires = [0,1,2,3])
# Modified this function
def get_output(params, x):
probs = circuit(params,x)
# We update these functions to allow for the use of the entire dataset instead of individual datapoints
out = jnp.zeros(x.shape[1])
for i in range(0, 16, 2):
out += probs[:,i]
return out
# Modified this function
@jax.jit
def loss_fn(params, data, targets):
# To get the predictions for the entire dataset we need to use the transpose of the data for the dimensions to match
predictions = get_output(params, data.T)
loss = jnp.sum((targets - predictions) ** 2 / len(data))
return loss
init_params = jnp.array(np.random.default_rng().random(size = (28,)))
print('Initial loss: ', loss_fn(init_params, x_train, y_train)) # added this print
# Define the optimizer we want to work with
opt = optax.adam(learning_rate=0.3)
#opt = optax.lbfgs(1e-8)
max_steps = 200
@jax.jit
def f(params):
return loss_fn(params, x_train, y_train)
# value_and_grad = optax.value_and_grad_from_state(f) # removed this
# Changed the code below
@jax.jit
def update_step_jit(i, args):
params, opt_state, data, targets, print_training = args
loss_val, grads = jax.value_and_grad(loss_fn)(params, data, targets)
updates, opt_state = opt.update(
grads, opt_state, params, value=loss_val, grad=grads, value_fn=f
)
params = optax.apply_updates(params, updates)
def print_fn():
jax.debug.print("Step: {i} Loss: {loss_val}", i=i, loss_val=loss_val)
# if print_training=True, print the loss every 5 steps
jax.lax.cond((jnp.mod(i, 5) == 0) & print_training, print_fn, lambda: None)
return (params, opt_state, data, targets, print_training)
@jax.jit
def optimization_jit(params, data, targets, print_training=False):
opt_state = opt.init(params)
args = (params, opt_state, data, targets, print_training)
(params, opt_state, _, _, _) = jax.lax.fori_loop(0, 100, update_step_jit, args)
return params
params = init_params
data = x_train
targets = y_train
optimization_jit(params, data, targets, print_training=True)
I hope this helps!