Reshape error when sending training batches to Keras layer

I have been following a book on Quantum optimization and QML. One of the examples was to use pennylane on the breast cancer dataset from scikit learn. Everything seems to be going well until the actual fit part in which I am getting errors. These errors seem to be based on the batch size the training data in which the error is contained below.

# 
from silence_tensorflow import silence_tensorflow
silence_tensorflow()
import pennylane as qml
import numpy as np
import tensorflow as tf
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MaxAbsScaler
from sklearn.decomposition import PCA
from itertools import combinations

 seed = 4321
 np.random.seed(seed)
 tf.random.set_seed(seed)
 tf.keras.backend.set_floatx('float64')

x,y = load_breast_cancer(return_X_y = True)

x_tr, x_test, y_tr, y_test = train_test_split(x, y, train_size = 0.8)
x_val, x_test, y_val, y_test = train_test_split(x_test, y_test, train_size = 0.5)

scaler = MaxAbsScaler()
x_tr = scaler.fit_transform(x_tr)

x_test = scaler.transform(x_test)
x_val = scaler.transform(x_val)

x_test = np.clip(x_test, 0, 1)
x_val = np.clip(x_val, 0, 1)

pca = PCA(n_components = 4)
xs_tr = pca.fit_transform(x_tr)
xs_test = pca.transform(x_test)
xs_val = pca.transform(x_val)

def ZZFeatureMap(nqubits, data):
    # Number of variables that we will load:
    # could be smaller than the number of qubits.
    nload = min(len(data), nqubits)
    for i in range(nload):
        qml.Hadamard(i)
        qml.RZ(2.0 * data[i], wires = i)
    for pair in list(combinations(range(nload), 2)):
        q0 = pair[0]
        q1 = pair[1]
        qml.CZ(wires = [q0, q1])
        qml.RZ(2.0 * (np.pi- data[q0]) * (np.pi- data[q1]), wires = q1)
        qml.CZ(wires = [q0, q1])

def TwoLocal(nqubits, theta, reps = 1):
    for r in range(reps):
        for i in range(nqubits):
            qml.RY(theta[r * nqubits + i], wires = i)
        for i in range(nqubits- 1):
            qml.CNOT(wires = [i, i + 1])
    for i in range(nqubits):
        qml.RY(theta[reps * nqubits + i], wires = i)

state_0 = [[1], [0]]
M = state_0 * np.conj(state_0).T

nqubits = 4
dev = qml.device("default.qubit", wires=nqubits)

@qml.qnode(dev, interface="tf")
def qnn_circuit(inputs, theta):
    ZZFeatureMap(nqubits, inputs)
    TwoLocal(nqubits = nqubits, theta = theta, reps = 1)
    return qml.expval(qml.Hermitian(M, wires = [0]))

'''
#Testing to see the results using the qnode on the training data.
weights = tf.random.uniform(shape=[nqubits*2])
result = 0
for i in range(len(xs_tr)):
    results = qnn_circuit(xs_tr[i], weights)
    if (results >= .5):
        result = 1
    else:
        result = 0
    print(f"Results are {{0}}".format(result))
    print(f'Y value is {{0}}'.format(y_tr[i]))
'''

weights = {"theta": (nqubits*2)}
qlayer = qml.qnn.KerasLayer(qnn, weights, output_dim=None)
model = tf.keras.models.Sequential([qlayer])
opt = tf.keras.optimizers.Adam(learning_rate = 0.005)
model.compile(opt, loss=tf.keras.losses.BinaryCrossentropy())
earlystop = tf.keras.callbacks.EarlyStopping(monitor = "val_loss", patience = 2, verbose = 1, restore_best_weights = True)

history = model.fit(xs_tr, y_tr, epochs = 50, shuffle = True, validation_data = (xs_val, y_val), batch_size = 20, callbacks = [earlystop])

Here is the full error message below:

# ---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
Cell In[105], line 2
      1 #history = model.fit(xs_tr, y_tr, epochs = 50, shuffle = True, validation_data = (xs_val, y_val), batch_size = 20, callbacks = [earlystop])
----> 2 history = model.fit(xs_tr, y_tr, epochs = 5)

File /usr/local/lib/python3.11/dist-packages/keras/src/utils/traceback_utils.py:70, in filter_traceback.<locals>.error_handler(*args, **kwargs)
     67     filtered_tb = _process_traceback_frames(e.__traceback__)
     68     # To get the full stack trace, call:
     69     # `tf.debugging.disable_traceback_filtering()`
---> 70     raise e.with_traceback(filtered_tb) from None
     71 finally:
     72     del filtered_tb

File /usr/local/lib/python3.11/dist-packages/pennylane/qnn/keras.py:389, in KerasLayer.call(self, inputs)
    386 if has_batch_dim:
    387     # pylint:disable=unexpected-keyword-arg,no-value-for-parameter
    388     new_shape = tf.concat([batch_dims, tf.shape(results)[1:]], axis=0)
--> 389     results = tf.reshape(results, new_shape)
    391 return results

InvalidArgumentError: Exception encountered when calling layer 'keras_layer_12' (type KerasLayer).

{{function_node __wrapped__Reshape_device_/job:localhost/replica:0/task:0/device:GPU:0}} Input to reshape is a tensor with 4 values, but the requested shape has 32 [Op:Reshape] name: 

Call arguments received by layer 'keras_layer_12' (type KerasLayer):
  • inputs=tf.Tensor(shape=(32, 4), dtype=float64)

It seems like it is based on the batch size I am trying to pass into the Keras layer if it is batch size 20, the tensor shape is (20,4). I have been looking to see if there is something I am missing but so far, I have been coming up short, so I am hoping there to find a solution here because I have tried the demo from here (Turning quantum nodes into Keras Layers | PennyLane Demos) and everything works so I am assuming I am missing something when I create the Qnode that Keras doesn’t like?

qml.about().
Name: PennyLane
Version: 0.35.1
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page: GitHub - PennyLaneAI/pennylane: PennyLane is a cross-platform Python library for differentiable programming of quantum computers. Train a quantum computer the same way as a neural network.
Author:
Author-email:
License: Apache License 2.0
Location: /usr/local/lib/python3.11/dist-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, pennylane-lightning, requests, rustworkx, scipy, semantic-version, toml, typing-extensions
Required-by: PennyLane_Lightning

Platform info: Linux-5.15.146.1-microsoft-standard-WSL2-x86_64-with-glibc2.35
Python version: 3.11.0
Numpy version: 1.26.2
Scipy version: 1.12.0
Installed devices:

  • lightning.qubit (PennyLane_Lightning-0.35.1)
  • default.clifford (PennyLane-0.35.1)
  • default.gaussian (PennyLane-0.35.1)
  • default.mixed (PennyLane-0.35.1)
  • default.qubit (PennyLane-0.35.1)
  • default.qubit.autograd (PennyLane-0.35.1)
  • default.qubit.jax (PennyLane-0.35.1)
  • default.qubit.legacy (PennyLane-0.35.1)
  • default.qubit.tf (PennyLane-0.35.1)
  • default.qubit.torch (PennyLane-0.35.1)
  • default.qutrit (PennyLane-0.35.1)
  • null.qubit (PennyLane-0.35.1)

Hey @CthulhuUSN, welcome to the forum! :rocket:

The issue is with your ZZFeatureMap quantum function, specifically any time you’re calling an element of data in it. Your data, xs_tr, has a shape of (455, 4); it comprises 455 features and each feature has a length of 4. When you write something like qml.RZ(data[i], wires=0), you’re intending your model to interpret this one of two ways:

  1. i \in [0, 1, 2, 3] and data is a specific feature with length 4. You want your model to insert that particular entry of that particular feature, data[i], into qml.RZ.

  2. i \in [0, 1, 2, 3] and data is a batch of feature vectors that are all length 4. Let’s say there are N batches. You want your model to insert the i^{\text{th}} entry of all N batches into qml.RZ — this is broadcasting, and the result from your circuit should have a leading dimension of N, like your input batch.

However, when you pass your model data and data is a batch with N features, data[i] is the i^{\text{th}} feature, not the i^{\text{th}} entry of a feature. In other words, data[i].shape is (4,) and qml.RZ(data[i], wires=0) is broadcasting over 4 features.

This is illustrated well by simply calling qnn_circuit and giving it xs_tr

qnn_circuit(xs_tr, tf.random.uniform((nqubits*2,), 0, 1))
tf.Tensor([0.41958319 0.13525734 0.50499195 0.30518346], shape=(4,), dtype=float64)

The leading dimension (what we’re broadcasting over) is 4, so out pops a length 4 result! This isn’t what you’re intending to do, but PennyLane will still see this as a valid input to broadcast over and not throw any errors.

Now you might be wondering why your model, which just contains qnn_circuit, can’t handle having xs_tr as input:

model(xs_tr)
InvalidArgumentError: Exception encountered when calling layer 'keras_layer_2' (type KerasLayer).

{{function_node __wrapped__Reshape_device_/job:localhost/replica:0/task:0/device:CPU:0}} Input to reshape is a tensor with 4 values, but the requested shape has 455 [Op:Reshape] name: 

Call arguments received by layer 'keras_layer_2' (type KerasLayer):
  • inputs=tf.Tensor(shape=(455, 4), dtype=float64)

This is because TF/Keras is expecting your model to do what I listed in points 1 and 2 above. More specifically, if you’re giving your model a batch of data with a leading dimension of 455 (which xs_tr has), it expects the output to conserve that leading dimension: 455 inputs in, 455 outputs out.

So… how to remedy this? Just be super explicit with your quantum layer and how you want it to process / interpret data. When you call data[i], you really want data[:, i].

def ZZFeatureMap(nqubits, data):
    # Number of variables that we will load:
    # could be smaller than the number of qubits.
    nload = min(len(data), nqubits)
    for i in range(nload):
        qml.Hadamard(i)
        qml.RZ(2.0 * data[:, i], wires = i)
    for pair in list(combinations(range(nload), 2)):
        q0 = pair[0]
        q1 = pair[1]
        qml.CZ(wires = [q0, q1])
        qml.RZ(2.0 * (np.pi- data[:, q0]) * (np.pi- data[:, q1]), wires = q1)
        qml.CZ(wires = [q0, q1])

If I use this instead,

print(qnn_circuit(xs_tr, tf.random.uniform((nqubits*2,), 0, 1)).shape)
print(model(xs_tr).shape)
(455,)
(455,)

Everything is processed as intended.

Hope this helps!

That helped out a lot and now the code is working! Making the changes in the ZZfeature with the input data did the trick. Thank you very much for the help.

1 Like

My pleasure! Glad to help :slight_smile: