QNN on Mnist with JAX

Good morning PennyLane’s team, I would like to discuss with you a problem on image classification over MNIST dataset. I am using JAX to optimize the process, and I have realized a “data-reuploading” quantum circuit, which has 2 times the feature map encoding the same features (I downsampled the MNIST from 784 to 64 features (8,8)) and 2 different variational quantum layers which “learn” different parameters. I return the qml.expval(qml.PauliZ(0)) and subsequently I have made the cross_entropy loss function simply using:

def binary_crossentropy(X, y, theta):
    ytrue = jnp.array(y)
    ypred = mapping(qnn(X, theta)) #qnn is the model I made
    epsilon = 1e-8 # to avoid log overflow (to -infty)
    ypred_bounded = jnp.clip(ypred, epsilon, 1 - epsilon) # to bound the value of ypred 
    loss = ytrue*jnp.log10(ypred_bounded) + (1-ytrue)*jnp.log10(1-ypred_bounded)
    return -jnp.mean(loss)

and the function mapping is defined as

def mapping(y):
    return (y+1)/2

since outcomes are in the range {-1,1} and should be remapped into the set {0,1} to feed the cross_entropy. I also upload the accuracy which I calculate over the training process

def calculate_accuracy(theta, X, y):
    y_pred = mapping(qnn(X, theta)) # I will pass x of training set
    thresholded_y_pred = jnp.where(y_pred <= 0.5, 0, 1)  ### threshold to 0.5 because the values are in the range {0,1}
    accuracy = jnp.mean(thresholded_y_pred == y)
    return accuracy

I would like to know if it’s correct. I mean, I have made the training and everything works fine I suppose (got train accuracy 0.93) and accuracy on the test set after model trained on 500 epochs is :

ypred = mapping(qnn(X_test, opt_params)) # again I should remap the values to confine them into {0,1}

0.89, so little overfitting, but I could play with hyperparameters on a validation set (varying number of layers, gates_per_layers, entangling scheme…), but the question is much more related to Physics. In Qiskit you make the measurement and the compute the accuracy by using the probabilities that should be calculated as it follows (approximately):

job = execute(qc, backend, shots=1024)
result = job.result()
counts = result.get_counts()
probs = []
for key in counts.keys():

or something like that. But I am using the expval, not the shots and counts and this mentioned scheme. Am I right anyway? Could you explain me if I should make changes in my code? Can I use the expval anyway for the classification? Is what I have done correct?
Last question (sorry for bothering you), how to set up a multiclass classification based on the provieded functions? (Suppose to have 4 classes, I should measure 2 qubits, change the loss and then?).
Thanks in advance guys, I need your help!!
#qml #quantumcomputing #pennylane

Hey @checcopo, welcome to the forum!

From what you posted, what you’re doing seems fine to me (and it’s nice that your results confirm that :+1:). If you have multiclass classification, I suggest hybridizing your model with a Softmax layer at the end (Softmax function - Wikipedia). That’s the activation function that’s typically used at the end of a network for those types of problems. That, or you could manually do some post-processing yourself and partition the [-1, 1] interval into 4 equal bins, where each bin belongs to your 4 labels.

Let me know if that helps!

Good morning, thanks for replying.
Here I attach the rest of the code:
Data loading

def downsample(x_array, size):
    newsize = (size, size)
    x_array = np.reshape(x_array, (x_array.shape[0], x_array.shape[1], 1))
    new_array = tf.image.resize(x_array, newsize)
    return new_array.numpy() 

def data(num_train_samples=None, num_test_samples=None, shuffle=False, resize=None): 
    (train_X, train_y), (test_X, test_y) = mnist.load_data()
    X_train_filtered = train_X[np.isin(train_y, [0, 1])]
    y_train_filtered = train_y[np.isin(train_y, [0, 1])]
    X_test_filtered = test_X[np.isin(test_y, [0, 1])]
    y_test_filtered = test_y[np.isin(test_y, [0, 1])]

    X_train_filtered = X_train_filtered.astype('float32') / 255
    X_test_filtered = X_test_filtered.astype('float32') / 255
    X_train_new = []
    X_test_new = []
    if resize is not None and resize <= 28:
        for train in X_train_filtered:
            X_train_new.append(downsample(train, resize))
        for test in X_test_filtered:
            X_test_new.append(downsample(test, resize))
        raise Exception("The new size must be smaller than the actual Mnist size that is 28!")

    ### shuffle
    X_train_new = np.array(X_train_new)
    X_test_new = np.array(X_test_new)
    train_indices = np.arange(len(X_train_new))
    test_indices = np.arange(len(X_test_new))
    if shuffle == True:

    if num_train_samples is not None:
        X_train_ = np.array(X_train_new)[:num_train_samples]
        y_train_filtered = y_train_filtered[:num_train_samples]

    if num_test_samples is not None:
        X_test_ = np.array(X_test_new)[:num_test_samples]
        y_test_filtered = y_test_filtered[:num_test_samples]
        X_train_ = X_train_.reshape(X_train_.shape[0], X_train_.shape[1]*X_train_.shape[2])
        X_test_ = X_test_.reshape(X_test_.shape[0], X_test_.shape[1]*X_test_.shape[2])
    return (

new_shape = 8
X_train, y_train, X_test, y_test = data(num_train_samples=200, num_test_samples=400, shuffle=True, resize=new_shape)

Quantum circuit

def feature_map_basic(X, wires=n_qubits):
    idx = 0
    for i in range(wires):
        qml.Rot(phi=X[idx+0], theta=X[idx+1], omega=X[idx+2], wires=i)
        idx +=3
        qml.Rot(phi=X[idx+0], theta=X[idx+1], omega=X[idx+2], wires=i)
        idx +=3
        qml.Rot(phi=X[idx+0], theta=X[idx+1], omega=X[idx+2], wires=i)
        idx +=3

def qlayer__(X, params):
    idx = 0
    for i in range(n_qubits):
        qml.Rot(phi=params[0+idx], theta=params[1+idx], omega=params[2+idx], wires=i)
        idx +=3
    for i in range(n_qubits):
        qml.Rot(phi=params[0+idx], theta=params[1+idx], omega=params[2+idx], wires=i)
        idx +=3

def qlayer2__(X, params):
    idx = 0
    for i in range(n_qubits):
        qml.Rot(phi=params[0+idx], theta=params[1+idx], omega=params[2+idx], wires=i)
        idx +=3
    for i in range(n_qubits):
        qml.Rot(phi=params[0+idx], theta=params[1+idx], omega=params[2+idx], wires=i)
        idx +=3

dev2 = qml.device("default.qubit", wires=n_qubits)

def qnode2(Xval, params):
    qlayer__(Xval, params)
    qlayer2__(Xval, params)
    return qml.expval(qml.PauliZ(0))

Training and useful functions

def optimizer_update(opt_state, params, x, y, ):

    loss_value, grads = jax.value_and_grad(lambda theta: binary_crossentropy(x, y, theta,))(params)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss_value

epochs = 500 
batch_size = 10#
seed = 199
qnn_batched = jax.vmap(qnode2, (0, None,))
qnn = jax.jit(qnn_batched)

#Lists to save data 
costs = []
val_costs = []
train_per_epoch = []
val_per_epoch = []
acc_per_epoch = []

#Creating the initial random parameters for the QNN
key = jax.random.PRNGKey(seed)
initial_params = jax.random.normal(key, shape=(param_per_gate*n_qubits*gate_per_layer*layers,))
key = jax.random.split(key)[0]
params = jnp.copy(initial_params)

#Optimizer initialization
optimizer = optax.adam(learning_rate=0.01)
opt_state = optimizer.init(initial_params)

for epoch in range(1,epochs+1):
    # Generation of random indices to be used for batch
    idxs_dataset = jax.random.choice(key, jnp.array(list(range(X_train.shape[0]))), shape=(X_train.shape[0],), replace=False)
    key = jax.random.split(key)[0]

    params, opt_state, cost = optimizer_update(opt_state, params, X_train, y_train)
    cost = binary_crossentropy(X_train, y_train,params) 
    acc = calculate_accuracy(params, X_train, y_train)
    val_cost = binary_crossentropy(X_test, y_test, params, )


Of course you add the previous functions I have uploaded in the first comment.
Hope this is correct for you.
The question would be if you can go more in details about the multiclass classification implementation. If I understood correctly, should I modify the cross_entropy like this?

def hybrid_layer(ypredicted):
    return tf.keras.layers.Softmax()

def cate_cross(X, y, theta):
    ytrue = jnp.array(y)
    #ypred = mapping(qnn(X, theta)) #qnn is the model I made
    ypred = qnn(X, theta)
    ypred_multiclass = hybrid_layer(ypred).numpy()
    epsilon = 1e-8 # to avoid log overflow (to -infty)
    ypred_bounded = jnp.clip(ypred_multiclass, epsilon, 1 - epsilon) # to bound the value of ypred 
    loss = ytrue*jnp.log10(ypred_bounded) 
    return -jnp.mean(loss)

Obviously I need to measure two qubits and probably I need to flatten the result of the expval in order to have a (1,4) array.
It may be unnecessary to map the values out of the circuit into positive values in the range {0,1} since softmax should make everything positive, right? How to deal with accuracy in this case? Is this correct? How can I implement the second advice you gave me?
Anyway, thanks a lot!!!

Hey @checcopo,

It’s a little difficult to 100% discern what’s correct and what isn’t here, but it might be worth checking out PyTorch’s documentation for dealing with multi-class classification problems (I know you’re working with JAX but PyTorch’s documentation is nicer IMO). E.g., for accuracy / loss functions you can look at cross-entropy loss: CrossEntropyLoss — PyTorch 2.2 documentation. You’ll be dealing a lot with categorical distribtions, so this might also be nice to look at: Probability distributions - torch.distributions — PyTorch 2.2 documentation

Let me know if that helps!

Thanks for helping me. Unfortunately I have tried to implement what you suggested and the accuracy on the test set is not enough (67%). I know it’s not a problem related to Torch, but I would have loved to use JAX, but I couldn’t find the solution.
Anyway, here’s my code after using same quantum circuits I have already mentioned above, hope is correct.

def qnode(inputs, weights):
    return [qml.expval(qml.PauliZ(i)) for i in range(2)]
layers = 3
gate_per_layer = 3
param_per_gate = 3
weight_shapes = {"weights": (param_per_gate*(n_qubits) + param_per_gate*7 + gate_per_layer*n_qubits, layers)}

QUANTUM_LAYER = qml.qnn.TorchLayer(qnode, weight_shapes)
CLASSICAL_DENSE_LAYER = torch.nn.Linear(6, 4) ## I was wondering why is not (4,4)????
SOFTMAX = torch.nn.Softmax(dim=1)
MODEL = torch.nn.Sequential(*LYS)
opt = torch.optim.Adam(MODEL.parameters(), lr=7e-4)
loss = torch.nn.CrossEntropyLoss()
#data_loader = (X_train, y_train)
batch_size = 5
X = torch.tensor(X_train, requires_grad=True).float()
Xval = torch.tensor(X_val, requires_grad=True).float()
y = y_train

data_loader = torch.utils.data.DataLoader(
    list(zip(X, y)), batch_size=batch_size, shuffle=True, drop_last=True
X_tensor = X.clone().detach()
y_tensor = torch.tensor(y, dtype=torch.long)
def training():
    epochs = 100
    for epoch in range(epochs):
        loss_value = 0
        for xs, ys in data_loader:
            ## to speed up? 
            """for param in MODEL.parameters():
                param.grad = None"""
            loss_value = loss(MODEL(xs),ys)


        print(f"Epoch n° {epoch+1}:", loss_value.item())

    ### Evaluation
    y_pred = MODEL(Xval)
    predictions = torch.argmax(y_pred, axis=1).detach().numpy()
    correct_predictions = [1 if p == p_true else 0 for p, p_true in zip(predictions, y_val)]
    accuracy = sum(correct_predictions) / len(correct_predictions)
    print(f"Accuracy: {accuracy * 100}%")
    cm_test = confusion_matrix(y_true=y_val, y_pred=predictions)

if __name__ == "__main__":

Hey @checcopo,

The “correctness” of your code is hard for me to evaluate because it’s your project / research. But, if you have any questions about how to use PennyLane and plugins for it (e.g., bugs, overarching questions, or implementation questions), I’d be happy to help!

I worked out the mistake in the Torch version: I’ m measuring the expval, but I need probs (in this case measuring 2 qubits to get 4 probs I can connect these 4 outcomes to a flatten Torch layer), I’ll modify accordingly later.
Thanks for the help. I’ll let you know about other questions.

1 Like