Hi!
I was able to run your code without errors by adding the @qml.qjit
decorator on top of the qml.node
.
def make_circuit(dev, n_qubits):
@qml.qjit
@qml.qnode(dev)
def circuit(x):
for i in range(n_qubits):
qml.RX(x[i], wires=i)
return qml.expval(qml.PauliZ(wires=0))
return jax.vmap(circuit, in_axes=0)
You can see this post where a similar issue was discussed.