An error when I use qml.draw_mpl

I want to draw my circuit, so I use qml.draw_mpl, But there is a KeyError

  • includes comments
n_qubits = 4
n_layers = 3
S = [[0, 1], [1, 2], [2, 3], [3, 0]]
dev = qml.device("default.qubit", wires=n_qubits)
@qml.qnode(dev)
def circuit(inputs, weights):
    
    for h in range(n_qubits):
        qml.Hadamard(wires=h)
    for k1 in S:
        qml.CRY(phi=inputs[k1[0]], wires=(k1[0], k1[1])) 

    qml.RY(weights[0,0], wires=0)
    qml.RY(weights[1,1], wires=1)
    qml.RY(weights[1,0], wires=2)
    qml.RY(weights[0,1], wires=3)
    qml.CNOT(wires=[0, 1])
    return qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliZ(1))

weight_shapes = {"weights": (n_layers, n_qubits)}       
qlayer = qml.qnn.KerasLayer(circuit, weight_shapes, output_dim=2)

input   = tf.keras.layers.Input(shape=(24,))
Dense_1 = Dense(5, activation="relu")
QNN_1   = qml.qnn.KerasLayer(circuit, weight_shapes, output_dim=2, name="QNN_1")
output  = Dense(2, activation="softmax")

model = tf.keras.models.Sequential()
model.add(input)
model.add(Dense_1)
model.add(QNN_1)
model.add(output)
model(X_train[:20])

fig, ax = qml.draw_mpl(circuit)(inputs=input[0], weights=weight_shapes)
fig.show()

Here is the error.

Traceback (most recent call last):
  File "E:\Postgraduate\Projects\AMC\AMR-Benchmark-main\HCQDNN3\main_pennelane.py", line 70, in <module>
    fig, ax = qml.draw_mpl(circuit)(inputs=input[0], weights=weight_shapes)
  File "D:\Anaconda\envs\Tensorflow-gpu-python3.8\lib\site-packages\pennylane\drawer\draw.py", line 440, in wrapper
    qnode.construct(args, kwargs_qnode)
  File "D:\Anaconda\envs\Tensorflow-gpu-python3.8\lib\site-packages\pennylane\qnode.py", line 711, in construct
    self._tape = make_qscript(self.func)(*args, **kwargs)
  File "D:\Anaconda\envs\Tensorflow-gpu-python3.8\lib\site-packages\pennylane\tape\qscript.py", line 1346, in wrapper
    result = fn(*args, **kwargs)
  File "E:\Postgraduate\Projects\AMC\AMR-Benchmark-main\HCQDNN3\main_pennelane.py", line 45, in circuit
    qml.RY(weights[0,0], wires=0)
KeyError: (0, 0)

What should I do to solve the promble?

Hey @AHHil!

The issue is here:

weight_shapes is just a dictionary that states what the shape of weights is. It doesn’t store the actual values of weights.

Also, be careful when using input as a variable! input is a built-in function name in Python (see here: Python input() Function). I’d be safe and use a different variable name :slight_smile:.