New to ML in general, need help figuring out why my gradients are so low here

I have mainly used qml.StronglyEntanglingLayers when dealing with Variational Circuits and I am now trying to use Qiskits TwoLocal, I implemented it into Pennylane, but when I attempted to train a circuit that used it, it had gradients in the magnitude of 10^-6 or even 10^-18, changing the learning rate didn’t do much, so I wanted to know what could be causing the small gradients and lack of training.

Here is my Github repo: GitHub - MelonLord8/PennylaneDemo

To add what I attempted:

  • Changing between adam and sdg optimizers
  • Changing the learning rate
  • Changing the number of repetitions and structure of my twolocal instance
  • Normalizing inputs
  • Changing how I initialize parameters between a uniform random distribution, orthogonal distribution, and initializing the parameters as ones
  • Changing the loss function from mean absolute error to mean square error
  • Changing the observable I used in the end from interacting with one qubit to interacting with 3
  • Recreating my two_local instance with specifically circular entanglement to see if the if statements of my library were interfering with gradient calculation

Hi @MelonLord ,

Welcome to the Forum!

Should I be looking at the main.ipynb notebook in your repo? I noticed it throws an error.

It would help a lot if you could reduce your code to a minimal reproducible example. A minimal reproducible example (or minimal working example) is the simplest version of the code that reproduces the problem. It should be self-contained, including all necessary imports, data, functions, etc., so that we can copy-paste the code and reproduce the problem. However it shouldn’t contain any unnecessary data, functions, etc. (for example gates and functions that can be removed to simplify the code). If you’re not sure what this means then please make sure to check out this video.

It looks like you’ve made good attempts to solve the issue. The creation of a minimal reproducible example that creates a “miniature” version of your problem can be a very effective way to pinpoint the exact issue.

I hope this helps!
Let me know if you have any further questions.

Hello, yes I can see now how confusing the repo is structured, thank you for pointing that out, here is something self contained and should run without errors.

import pandas as pd
import numpy as np
from sklearn.datasets import make_classification
import pennylane as qml
import jax
from jax import numpy as jnp
from math import pi
import optax
from sklearn.preprocessing import MinMaxScaler

jax.config.update('jax_enable_x64', True) 
dev = qml.device("default.qubit", wires = 4)

X, y = make_classification(
    n_samples=200,
    n_features=4,    
    n_informative=4, 
    n_redundant=0,   
    n_repeated=0,    
    n_classes=2,     
    random_state=42  
)

scaler = MinMaxScaler(feature_range=(0,1))
scaled_x = scaler.fit_transform(X)

x_train = scaled_x[:150]
x_test = scaled_x[150:]

y_train = y[:150]
y_test = y[150:]

@qml.qnode(dev)
def circuit(params, x):
    for i in range(4):
        qml.H(wires = i)
        qml.PhaseShift(2*x[i], wires = i)
    for i in range(1,4):
        for j in range(i):
            qml.CNOT([j,i])
            qml.PhaseShift(2*(pi - x[i])*(pi - x[j]), wires = i)
            qml.CNOT([j,i])
    for i in range(4):
        qml.RY(params[i], wires = i)
    for layer in range(1, 6 + 1):
        for i in range(3):
            qml.CNOT([i,i+1])
        for i in range(4):
            qml.RY(params[layer*4 + i], wires = i)
    return qml.probs(wires = [0,1,2,3])

def get_output(params, x):
    probs = circuit(params,x)
    out = 0
    for i in range(0, 16, 2):
        out += probs[i]
    return out

@jax.jit
def BCE(params, data, target):
    out = get_output(params, data)
    return -1*(target * jnp.log(out) + (1 - target) * jnp.log(1 - out))

BCE_map = jax.vmap(BCE, (None, 0, 0))

@jax.jit
def loss_fn(params, data, target):
    return jnp.mean(BCE_map(params, data, target))

opt = optax.lbfgs(1e-8)
max_steps = 200

@jax.jit
def f(params):
    return loss_fn(params, x_train, y_train)
value_and_grad = optax.value_and_grad_from_state(f)

@jax.jit
def optimiser(params, print_training):
    param_history = jnp.zeros(((max_steps + 1),) + params.shape)  
    grad_history = jnp.zeros(((max_steps + 1),) + params.shape)  
    opt_state = opt.init(params)
    #Packages the arguments to be sent to the function update_step_jit 
    args = (params, opt_state, print_training, param_history, grad_history)
    #Loops max_steps number of times
    (params, opt_state,  _, param_history, grad_history) = jax.lax.fori_loop(0, max_steps+1, update_step_jit, args) 
    return params

@jax.jit
def update_step_jit(i,args):
    # Unpacks the arguments
    params, opt_state, print_training, param_history, grad_history = args
    param_history = param_history.at[i].set(params)
    # Gets the loss and the gradients to be applied to the parameters, by passing in the loss function and the parameters, to see how the parameters perform 
    loss_val, grads = value_and_grad(params, state = opt_state)
    grad_history = grad_history.at[i].set(grads)
    #Prints the loss every 25 steps if print_training is enable
    def print_fn():
        jax.debug.print("Step: {i}  Loss: {loss_val}", i=i, loss_val= loss_val)
    jax.lax.cond((jnp.mod(i, 50) == 0 ) & print_training, print_fn, lambda: None)
    #Applies the param updates and updates the optimiser states
    updates, opt_state = opt.update(
            grads, opt_state, params, value=loss_val, grad=grads, value_fn=f
        )
    params =  optax.apply_updates(params, updates)
    #Returns the arguments to be resupplied in the next iteration
    return (params, opt_state, print_training, param_history, grad_history)

init_params = jnp.array(np.random.default_rng().random(size = (28,)))
optimal_params = optimiser(init_params ,True)`

I would also like to mention, when I recreated the circuit in Qiskit it showed an improvement, so it might be a problem with how I set up pennylane and the optimiser.

Hi @MelonLord ,

I changed the get_output and loss_fn just to make them suited for broadcasting (and easier to understand and compare for me).

I also changed some of the functions at the end to make them closer to our JAX-Optax demo (for comparison and debugging).

What I see is that when using the adam optimizer things works nicely but not with the L-BFGS optimizer for some reason. Are you specifically interested in using that optimizer?

On the other hand, maybe if you can share your Qiskit code I can think of ways to translating it to PennyLane. I can’t promise anything but it could bring new ideas.

import pandas as pd
import numpy as np
from sklearn.datasets import make_classification
import pennylane as qml
import jax
from jax import numpy as jnp
from math import pi
import optax
from sklearn.preprocessing import MinMaxScaler

jax.config.update('jax_enable_x64', True) 
dev = qml.device("default.qubit", wires = 4)

X, y = make_classification(
    n_samples=200,
    n_features=4,    
    n_informative=4, 
    n_redundant=0,   
    n_repeated=0,    
    n_classes=2,     
    random_state=42  
)

scaler = MinMaxScaler(feature_range=(0,1))
scaled_x = scaler.fit_transform(X)

x_train = scaled_x[:150]
x_test = scaled_x[150:]

y_train = y[:150]
y_test = y[150:]

@qml.qnode(dev)
def circuit(params, x):
    for i in range(4):
        qml.H(wires = i)
        qml.PhaseShift(2*x[i], wires = i)
    for i in range(1,4):
        for j in range(i):
            qml.CNOT([j,i])
            qml.PhaseShift(2*(pi - x[i])*(pi - x[j]), wires = i)
            qml.CNOT([j,i])
    for i in range(4):
        qml.RY(params[i], wires = i)
    for layer in range(1, 6 + 1):
        for i in range(3):
            qml.CNOT([i,i+1])
        for i in range(4):
            qml.RY(params[layer*4 + i], wires = i)
    return qml.probs(wires = [0,1,2,3])

# Modified this function
def get_output(params, x):
    probs = circuit(params,x)
    # We update these functions to allow for the use of the entire dataset instead of individual datapoints
    out = jnp.zeros(x.shape[1])
    for i in range(0, 16, 2):
        out += probs[:,i]
    return out

# Modified this function
@jax.jit
def loss_fn(params, data, targets):
    # To get the predictions for the entire dataset we need to use the transpose of the data for the dimensions to match
    predictions = get_output(params, data.T)
    loss = jnp.sum((targets - predictions) ** 2 / len(data))
    return loss

init_params = jnp.array(np.random.default_rng().random(size = (28,)))
print('Initial loss: ', loss_fn(init_params, x_train, y_train)) # added this print

# Define the optimizer we want to work with
opt = optax.adam(learning_rate=0.3)
#opt = optax.lbfgs(1e-8)

max_steps = 200

@jax.jit
def f(params):
    return loss_fn(params, x_train, y_train)
# value_and_grad = optax.value_and_grad_from_state(f) # removed this

# Changed the code below

@jax.jit
def update_step_jit(i, args):
    params, opt_state, data, targets, print_training = args

    loss_val, grads = jax.value_and_grad(loss_fn)(params, data, targets)
    updates, opt_state = opt.update(
            grads, opt_state, params, value=loss_val, grad=grads, value_fn=f
        )
    params = optax.apply_updates(params, updates)

    def print_fn():
        jax.debug.print("Step: {i}  Loss: {loss_val}", i=i, loss_val=loss_val)

    # if print_training=True, print the loss every 5 steps
    jax.lax.cond((jnp.mod(i, 5) == 0) & print_training, print_fn, lambda: None)

    return (params, opt_state, data, targets, print_training)

@jax.jit
def optimization_jit(params, data, targets, print_training=False):

    opt_state = opt.init(params)
    args = (params, opt_state, data, targets, print_training)
    (params, opt_state, _, _, _) = jax.lax.fori_loop(0, 100, update_step_jit, args)

    return params

params = init_params
data = x_train
targets = y_train

optimization_jit(params, data, targets, print_training=True)

I hope this helps!

Thank you for your help, I managed to get the from_qiskit() function to work (I needed to use Qiskit’s decompose() vqc._circuit before conversion) as well as use jaxopt’s scipyminimizer function to be able to use lbfgsb.

That’s great @MelonLord! Thanks for posting your solution here.