Error faced in training the quantum network for estimating parameters

These are my initial parameters/conditions,

n_qubits = int(np.log2(X_21.shape[1]))
n_layer = 30
opt = optax.adam(learning_rate=0.05)
params = params_init(n_layer,n_qubits,Y_21)
opt_state = opt.init(params)
x_train,x_val,x_test,y_train,y_val,y_test = data_split(X_data_nf,Y,0.2)

So, these are my common functions,

def circ(weights,inputs):
    qml.AmplitudeEmbedding(features=inputs, wires=range(n_qubits),normalize=True,pad_with=0.5)
    StronglyEntanglingLayers(weights, wires=range(n_qubits),imprimitive=qml.ops.CNOT)
    out =  [qml.expval(qml.PauliZ(i)+qml.PauliZ(i+1)) for i in range(n_qubits-1)[0::2]]
    return out

def qnn(params,inputs):
    weights = params["weights"]
    bias = params["bias"]
    circ_out = circ(weights,inputs)
    out = []
    for i in range(len(circ_out)):
        out.append(circ_out[i] + bias[i])
    return out 

def mse(observed,predictions):
  loss = jnp.sum((observed - predictions) ** 2 / len(observed))
  return jnp.mean(loss)

def predict(params,features):
    preds = qnn_qjit(params,features)
    preds_np = jnp.asarray(preds).T
    return preds_np

batched_predict = jax.vmap(predict, in_axes=(None, 0))

def cost(params, features, observed):
    preds = batched_predict(params,features)
    cost = mse(observed, preds)
    return cost

@jax.jit
def jit_cost(params, features, observed):
    preds = batched_predict(params,features)
    cost = mse(observed, preds)
    return cost

@jax.jit
def update_step(params, opt_state, features, observed):
    train_cost, grads = jax.value_and_grad(jit_cost)(params, features, observed)
    updates, opt_state = opt.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, train_cost

def fit(params,opt_state,x_train,y_train,x_val,y_val,epoch,batch_size):    
    train_loss = []
    val_loss = []
    num_train = len(x_train)
    num_val =len(x_val)
    key_t = random.PRNGKey(np.random.randint(0,1e4))
    for i in range(epoch):    
        key_t, key_v = random.split(key_t)
        idx_train = random.choice(key_t,num_train,shape=(batch_size,))
        idx_val = random.choice(key_v,num_val,shape=(batch_size,))
        x_train_batch = jnp.asarray(x_train[idx_train])
        y_train_batch = jnp.asarray(y_train[idx_train])
        x_val_batch = jnp.asarray(x_val)
        y_val_batch = jnp.asarray(y_val)
        start_time = time.time()
        params, opt_state, train_cost = update_step(params, opt_state, x_train_batch, y_train_batch)
        end_time = time.time()
        val_cost = jit_cost(params,x_val_batch,y_val_batch)
        epoch_time = end_time - start_time
        print("Epoch: {:5d} | Loss: {:0.7f} | Val_Loss: {:0.7f} | Time: {:0.4f} seconds".format(i+1, train_cost,val_cost,epoch_time))

        train_loss.append(train_cost)
        val_loss.append(val_cost)
    return params, train_loss,val_loss  

I run the below to train the circuit,

params,train_loss,val_loss = fit(params,opt_state,x_train,y_train,x_val,y_val,300,512)

For only JAX-JIT, I use,

dev = qml.device("default.qubit", wires=n_qubits)
@qml.qnode(dev,interface="jax",diff_method='backprop')  #I use this along with circ function

For catalyst + JAX, I use, (and yes I use JAX-JIT on the cost function and update step function)

dev = qml.device("lightning.qubit", wires=n_qubits)
@qml.qnode(dev,diff_method='adjoint')

For your reference, this is what my input is, and the corresponding output is the parameters defining each curve and there are totally 10000 curves sampled at 1024 points (10000,1024), after which I use an autoencoder to reduce it to (10000,128). So my goal is to estimate the parameters of these said curves.
21cm_signal