A subclasses of KerasLayer, keras running errors

Haha, actually there was still a slight problem with that code, weight 0 never updated, but I rewrote the code using pytorch and everything worked fine.

Now I’m trying to optimise the image processing time for QCNN, I found this and followed that.

In this simple example, it works well.

import pennylane as qml
from pennylane import numpy as np
dev = qml.device("lightning.qubit", wires=5)

def test_circuit(inputs):
    qml.AngleEmbedding(inputs[10:15], wires=range(0,4), rotation='X') # a
    qml.AngleEmbedding(inputs[5:10], wires=range(0,4), rotation='X') # a
    qml.AngleEmbedding(inputs[0:5], wires=range(0,4), rotation='X') # a
    qml.AngleEmbedding(inputs[0:5], wires=range(0,4), rotation='X') # a
    
@qml.qnode(dev)
def circuit(inputs):
    print(inputs.shape)
    test_circuit(inputs)

    return qml.expval(qml.PauliZ(0))

x = np.array(range(4*16)).reshape(16,4)
print(qml.draw(circuit)(x))
print(circuit(x))

Then I want try to replace dask.compute to speedout the computing (Always single-threaded; if using scheduler=“processes,” then weight 0 will disappear in pytorch).

Similarly, in class QcnnLayer(nn.Module) — def _quanv(self, image), I created shape(81,4) a single channel array, and RGB has 3. Then, combining them by using torch.cat(), finally get a shape (243,4) array.

But, I got an error after sending it to self.circuit(), be like:

# kernel_size = (2, 2), stride = 1, image_size = (10, 10), target_size = (9, 9)
# R[0, 81], G[81, 162], B[162, 243]
# Use the qnode Parameter broadcasting method
# /python3.9/site-packages/pennylane/tape/qscript.py:519, in QuantumScript._update_batch_size(self)
    517 if candidate:
    518     if op_batch_size != candidate:
--> 519         raise ValueError(
    520             "The batch sizes of the quantum script operations do not match, they include "
    521             f"{candidate} and {op_batch_size}."
    522         )
    523 else:
    524     candidate = op_batch_size

ValueError: The batch sizes of the quantum script operations do not match, they include 81 and 1.

This makes me confused, the simple code works fine, but an error occurs when it is acted on in pytorch. I’ve been trying to solve this problem for a while now but no luck…

What should I do to solve this problem? @isaacdevlugt Thank you a lot.