Hello, I’m having some troubles in using Qiskit plugin and JAX. In particular, I’m using qiskit.aer
as device, together with a noise_model
coming from a fake Qiskit backend.
I’m using JAX to jit
and vmap
the function executing the circuit:
def create_circuit(n_qubits,layers,ansatz):
fake_backend = FakeMontrealV2()
noise_model = noise.NoiseModel.from_backend(fake_backend)
device = qml.device('qiskit.aer', wires=n_qubits, noise_model=noise_model)
ansatz, params_per_layer = get_ansatz(ansatz,n_qubits)
@qml.qnode(device, interface='jax')
def circuit(x, theta):
qml.AngleEmbedding(x, wires=range(n_qubits), rotation='Y')
for i in range(layers):
ansatz(theta[i * params_per_layer: (i + 1) * params_per_layer], wires=range(n_qubits))
return qml.expval(qml.PauliZ(wires=0))
return jax.jit(circuit)
# quantum circuit
qnn_tmp = create_circuit(n_qubits,layers,ansatz)
# apply vmap on x (first circuit param)
qnn_batched = jax.vmap(qnn_tmp, (0, None))
# Jit for faster execution
qnn = jax.jit(qnn_batched)
Then, the circuit is executed by simply calling:
output = qnn(X, theta)
However I’m getting this error:
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Generated function failed: CpuCallback error: CircuitError: "Invalid param type <class 'list'> for gate ry."
Without using jax.vmap
, the code runs smoothly.
It also runs smoothly if I comment the embedding out: #qml.AngleEmbedding(x, wires=range(n_qubits), rotation='Y')
, i.e. if the circuit doesn’t call rotations on x
. It seems that there’s some problem with Qiskit broadcasting JAX’s batches to ry
Qiskit rotations.