Hello, I’m having some troubles in using Qiskit plugin and JAX. In particular, I’m using `qiskit.aer`

as device, together with a `noise_model`

coming from a fake Qiskit backend.

I’m using JAX to `jit`

and `vmap`

the function executing the circuit:

```
def create_circuit(n_qubits,layers,ansatz):
fake_backend = FakeMontrealV2()
noise_model = noise.NoiseModel.from_backend(fake_backend)
device = qml.device('qiskit.aer', wires=n_qubits, noise_model=noise_model)
ansatz, params_per_layer = get_ansatz(ansatz,n_qubits)
@qml.qnode(device, interface='jax')
def circuit(x, theta):
qml.AngleEmbedding(x, wires=range(n_qubits), rotation='Y')
for i in range(layers):
ansatz(theta[i * params_per_layer: (i + 1) * params_per_layer], wires=range(n_qubits))
return qml.expval(qml.PauliZ(wires=0))
return jax.jit(circuit)
```

```
# quantum circuit
qnn_tmp = create_circuit(n_qubits,layers,ansatz)
# apply vmap on x (first circuit param)
qnn_batched = jax.vmap(qnn_tmp, (0, None))
# Jit for faster execution
qnn = jax.jit(qnn_batched)
```

Then, the circuit is executed by simply calling:

```
output = qnn(X, theta)
```

However I’m getting this error:

```
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Generated function failed: CpuCallback error: CircuitError: "Invalid param type <class 'list'> for gate ry."
```

Without using `jax.vmap`

, the code runs smoothly.

It also runs smoothly if I comment the embedding out: `#qml.AngleEmbedding(x, wires=range(n_qubits), rotation='Y')`

, i.e. if the circuit doesn’t call rotations on `x`

. It seems that there’s some problem with Qiskit broadcasting JAX’s batches to `ry`

Qiskit rotations.