Hi,
I have a question regarding the JAX interface and whether it can be used for simulating noisy quantum computers. I am interested in JAX because of the jax.vmap
function for parallelizing the quantum circuit computations.
What I tried to do is the following. First I defined a noisy quantum circuit such as
def noisy_circuit(prob,**kwargs):
>>>>for k in range(len(G.nodes)):
>>>>>>>>qml.BitFlip(prob, wires=k)
>>>>return qml.expval(qml.PauliZ(0))
and then parallelize it and run it as follows:
dev = qml.device("default.mixed", wires = len(G.nodes))
qcircuit = qml.QNode(noisy_circuit, dev, interface = "jax")
vcircuit = jax.vmap(qcircuit)
where
Gis a graph that I defined with the
networkx` package.
If I define probs = jax.array([0., 0.05, 0.1])
, then vcircuit(probs)
gives rise to the following error NotImplementedError: batching rules are implemented only for id_tap, not for call.
. However, qcircuit(0.01)
works perfectly, as expected.
Thank you very much in advance!
Cheers,
Javier.
P.D.: Sorry for the >>>>
thing.