I want to compute the gradient norm for every epoch, but when I add the qml.grad function, the model parameters get fixed. I tried different optimizers, but the problem persists."

I want to compute the gradient norm for every epoch, but when I add the qml.grad function, the model parameters get fixed. I tried different optimizers, but the problem persists.

main.py

    opt = qml.AdamOptimizer(0.05)

    '''load mnist data'''
    data_train,label_train, data_valid, label_valid, data_test, label_test = load_mnist_data()
    # data_train,label_train, data_valid, label_valid, data_test, label_test = load_iris_data()
    # data_train,label_train, data_valid, label_valid, data_test, label_test = load_medmnist_data()

    '''train'''
    best_params = train_fn(model, data_train, label_train, data_valid, label_valid, records,opt, args.epochs)

utilis.py

    def train_fn(model, data_train, label_train, data_valid, label_valid, records, opt, epochs):
        best_acc = 0
        barren_threshold = 1e-5
        best_params = None
    
        for epoch in range(epochs):
            # Training step
            print("Parameters before step:", model.params)
            model.params = opt.step(lambda params: cost_fn(params, model, data_train, label_train, records), model.params)
            print("Parameters after step:", model.params)
    
            # Calculate the cost (loss)
            loss = cost_fn(model.params, model, data_train, label_train, records)
    
            # Compute the gradient of the cost function
            grad_fn = qml.grad(cost_fn)
            gradients = grad_fn(model.params, model, data_train, label_train, records)
    
            # Compute the total gradient norm
            total_norm = 0
            for g in gradients:
                param_norm = np.linalg.norm(g)
                total_norm += param_norm ** 2
    
            total_norm = total_norm ** 0.5
    
            if (epoch + 1) % 2 == 0:
                print(f'Epoch [{epoch + 1}/{epochs}], Loss: {loss}, Gradient Norm: {total_norm}')

And, finally, make sure to include the versions of your packages. Specifically, show us the output of qml.about().

Hi @Hridoy_Chandra_Das ,

Thanks for your questions.

Could you please provide the following information? It can help us understand the problem and find possible solutions:

  1. The output of qml.about()

  2. A minimal reproducible example (or minimal working example)
    This is the simplest version of the code that reproduces the problem. It should be self-contained, including all necessary imports, data, functions, etc., so that we can copy-paste the code and reproduce the problem. However it shouldn’t contain any unnecessary data, functions, …, for example gates and functions that can be removed to simplify the code.

  3. The full error traceback (or the result you get in case there’s no error).

If you’re not sure what these mean then make sure to check out this video.

Let me know if you have any questions about this!