Use jax and pennylane in Quantum Convolutional Neural Networks

Hello, I followed the tutorial to create my first QCNN, now what I would like to do is to make the circuit trainable inside the training loops. I have tried many different approaches but I just can’t manage it. Can someone help me? I really don’t know how to proceed.

I am currently using the MNIST dataset.

Probably some of the JAX functions I am using are wrong, but I can’t understand where I am going wrong. Can someone provide me with the correct implementation?

import pennylane as qml
from pennylane import numpy as np
import jax
import jax.numpy as jnp
from jax import random
from tensorflow import keras
from tensorflow.keras import layers, Model
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from jax.lib import xla_bridge

# Check if JAX is using the GPU
print(xla_bridge.get_backend().platform)  # It should return 'gpu' if CUDA is properly configured.

# Parameters
n_qubits = 9    # Number of qubits (equivalent to the number of pixels for a 3x3 kernel)
n_layers = 6    # Number of layers in the circuit
kernel_size = 3 # Size of the filter (kernel)
stride = 1      # Stride of the filter
n_train = 20    # Number of training samples
n_test = 1      # Number of test samples

# Load and preprocess the MNIST dataset
mnist_dataset = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist_dataset.load_data()

# Add an extra dimension for the convolutional channels
train_images = jnp.array(train_images[..., None])  # Use jax.numpy instead of tf.newaxis
test_images = jnp.array(test_images[..., None])

# Reduce the dataset size for testing
train_images = train_images[:n_train]
train_labels = train_labels[:n_train]
test_images = test_images[:n_test]
test_labels = test_labels[:n_test]

# Normalize pixel values to between 0 and 1
train_images = train_images / 255.0
test_images = test_images / 255.0

print(train_images.shape)
print(train_labels.shape)

# Quantum device
dev = qml.device("default.qubit", wires=n_qubits)

# Quantum circuit
@jax.jit
@qml.qnode(dev, interface="jax")
def circuit(inputs, weights):
    # Data embedding
    qml.AngleEmbedding(inputs, wires=range(n_qubits), rotation="Y")
    # Entanglement layers
    qml.BasicEntanglerLayers(weights, wires=range(n_qubits))
    # Measurement of PauliZ operators
    return [qml.expval(qml.PauliZ(wires=i)) for i in range(n_qubits)]

# Quantum convolutional layer class
class QuantumConv2D:
    def __init__(self, kernel_size, stride, n_qubits, n_layers):
        self.kernel_size = kernel_size
        self.stride = stride
        self.n_qubits = n_qubits
        self.n_layers = n_layers
        self.weight_shapes = {"weights": (n_layers, n_qubits)}

    def __call__(self, X, weights):
        batch_size, height, width, channels = X.shape
        output_height = (height - self.kernel_size) // self.stride + 1
        output_width = (width - self.kernel_size) // self.stride + 1

        # Ensure the shape is concrete
        outputs = jnp.empty((batch_size, output_height, output_width, self.n_qubits), dtype=jnp.float32)

        # Iterate over the image patches
        for i in range(0, height - self.kernel_size + 1, self.stride):
            for j in range(0, width - self.kernel_size + 1, self.stride):
                # Extract a patch of the image
                patch = X[:, i:i+self.kernel_size, j:j+self.kernel_size, :]
                patch = patch.reshape((batch_size, -1))  # Flatten to adapt to qubits

                # Calculate the output for each example in the batch
                for b in range(batch_size):
                    outputs = outputs.at[b, i // self.stride, j // self.stride, :].set(
                        circuit(patch[b], weights)
                    )
        return outputs


# Defining the QCNN model
class QuantumCNNModel(Model):
    def __init__(self, quantum_conv, n_classes=10):
        super(QuantumCNNModel, self).__init__()
        self.quantum_conv = quantum_conv
        self.flatten = layers.Flatten()
        self.dense = layers.Dense(n_classes, activation='softmax')

    def call(self, inputs, training=False):
        x = self.quantum_conv(inputs, weights=self.quantum_weights)  # Quantum layer
        x = self.flatten(x)  # Flatten the quantum output
        x = self.dense(x)    # Classical layer for classification
        return x

# Define the quantum convolutional layer
quantum_conv = QuantumConv2D(kernel_size=kernel_size, stride=stride, n_qubits=n_qubits, n_layers=n_layers)

# Initialize the model
model = QuantumCNNModel(quantum_conv)

# Initialize quantum weights (trainable)
key = random.PRNGKey(42)  # Random key
initial_weights = random.normal(key, (n_layers, n_qubits))  # Use jax.random.normal to generate random numbers

# Add quantum weights as trainable variables
model.quantum_weights = tf.Variable(initial_weights, dtype=tf.float32)

# Ensure the model is compiled correctly
model.compile(optimizer=Adam(), loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# Train the model
model.fit(train_images, train_labels, epochs=5, batch_size=4)

Hi @Alessandro_Castelli , welcome to the Forum!

I ran your code and I’m seeing an error arising from the fact that your data is a SymbolicTensor.

What demo where you using as example?
Because your code is complex it’s hard to find how to fully solve it.

Maybe you can try an approach like I shared in this August 5th post.

Otherwise if you can share a small reproducible example of what you want to do that can help us help you.

I hope this helps!