I am trying to draw the circuit in pennylane for the kernel code: https://github.com/thubregtsen/qhack/blob/master/paper/noiseless_MNIST.py
I am able to draw the circuit but not able to understand much
. Please guide me how to understand this circuit? Also, is there any better way to draw the circuit?
from keras.datasets import mnist
from keras.datasets import mnist
# import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import pennylane as qml
from pennylane import numpy as np
## device and circuit characteristics
width = 4
depth = 8
dev = qml.device("default.qubit", wires=width)
wires = list(range(width))
print("The value of wires= ", wires)
# Our Ansatz
def layer(x1, params, wires, i0=0, inc=1):
"""Building block of the embedding Ansatz"""
i = i0
for j, wire in enumerate(wires):
qml.Hadamard(wires=[wire])
qml.RZ(x1[i % len(x1)], wires=[wire])
i += inc
qml.RY(params[0, j], wires=[wire])
qml.broadcast(unitary=qml.CRZ, pattern="ring", wires=wires, parameters=params[1])
# @qml.template # type: ignore
def ansatz(x1, params, wires):
"""The embedding Ansatz"""
for j, layer_params in enumerate(params):
layer(x1, layer_params, wires, i0=j * len(wires))
# init the embedding kernel
@qml.qnode(dev)
def kernel(x1, x2, params):
ansatz(x1, params, wires)
qml.adjoint(ansatz)(x2, params, wires) # type: ignore
return qml.expval(qml.Projector([0]*width, wires=wires))
x1 = np.array([0.5, 0.25])
x2 = np.array([0.28571429, 0.53571429])
params = np.array([[[5.27502056, 1.47427201, 4.93076496, 2.46091687],
[5.31970254, 3.72079515, 5.23318151, 3.6141986 ]],
[[4.56873422, 4.71736634, 3.74526732, 4.03018843],
[2.50768043, 6.05648213, 0.15691596, 4.6339621 ]],
[[3.13801229, 4.06608748, 3.67299568, 0.56942335],
[3.7249661, 0.21089213, 1.81957038, 5.11239481]],
[[0.23371406, 0.38418188, 1.57283641, 2.1281051 ],
[4.82044232, 2.29357766, 2.28309303, 1.36251918]],
[[3.81316379, 1.72378489, 5.33909833, 1.99670676],
[0.95241059, 4.91302177, 5.11174524, 0.50581628]],
[[2.93171839, 4.72425322, 4.13123708, 2.1082139 ],
[3.44454441, 3.57751799, 4.82533803, 3.01520779]],
[[4.82192352, 4.9103957, 5.62978572, 2.72576366],
[2.52244612, 2.06081289, 6.28092847, 4.55329168]],
[[5.67932649, 1.43286835, 6.13243934, 3.72696941],
[5.50192702, 1.69022441, 3.61985605, 4.94172026]]])
# print(qml.draw(kernel(x1, x2, params), charset="unicode")) # type: ignore
fig, ax = qml.draw_mpl(kernel)(x1, x2, params)
# print(qml.draw(circuit_qnode)(X_train_dist_zero_not_zero[0], params))
plt.show()