Qiskit.Aer plugin with JAX not working

Hello, I’m having some troubles in using Qiskit plugin and JAX. In particular, I’m using qiskit.aer as device, together with a noise_model coming from a fake Qiskit backend.

I’m using JAX to jit and vmap the function executing the circuit:

def create_circuit(n_qubits,layers,ansatz):
    fake_backend = FakeMontrealV2()
    noise_model = noise.NoiseModel.from_backend(fake_backend)
    device = qml.device('qiskit.aer', wires=n_qubits,  noise_model=noise_model)
    ansatz, params_per_layer = get_ansatz(ansatz,n_qubits)

    @qml.qnode(device, interface='jax')
    def circuit(x, theta):
        qml.AngleEmbedding(x, wires=range(n_qubits), rotation='Y')
        for i in range(layers):
            ansatz(theta[i * params_per_layer: (i + 1) * params_per_layer], wires=range(n_qubits))
        return qml.expval(qml.PauliZ(wires=0))
    return jax.jit(circuit)
# quantum circuit
qnn_tmp = create_circuit(n_qubits,layers,ansatz)
  
# apply vmap on x (first circuit param)
qnn_batched = jax.vmap(qnn_tmp, (0, None))
  
# Jit for faster execution
qnn = jax.jit(qnn_batched)

Then, the circuit is executed by simply calling:

output = qnn(X, theta)

However I’m getting this error:

jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Generated function failed: CpuCallback error: CircuitError: "Invalid param type <class 'list'> for gate ry."

Without using jax.vmap, the code runs smoothly.
It also runs smoothly if I comment the embedding out: #qml.AngleEmbedding(x, wires=range(n_qubits), rotation='Y'), i.e. if the circuit doesn’t call rotations on x. It seems that there’s some problem with Qiskit broadcasting JAX’s batches to ry Qiskit rotations.

Hey @Andrea_Ceschini! Welcome to the forum :muscle:

jax.vmap can sometimes be problematic. But, since it’s working without it, do you actually need it? PennyLane should be able to handle the broadcasting! Parameter broadcasting is a feature we added last year and continue to improve and add more support :grin:.

Hi, thanks for the quick reply! Actually, I need to use jax.vmap since it speeds-up the code. Without using it, the simulation of the circuit with the noise model (qiskit.aer backend) is too slow.

With vmap, I can pass to the circuit both my x, whose shape is (200 x 5), i.e. 200 samples with 5 features each, and an array of parameters of shape (layers * params_per_layer, ).
With my implementation of qnn_batched = jax.vmap(qnn_tmp, (0, None)), the problem is the embedding of x with the instruction qml.AngleEmbedding(x, wires=range(n_qubits), rotation='Y'). It gives me the error above; without the embedding instead, so just with the parameterized ansatz, the code works fine.

After inspecting a bit, I’ve seen that the problem should be that the type BatchTracer object is not compatible with Qiskit’s ry operation. Is it possible to solve it?

Oh interesting… Thanks for explaining more! I think you should submit a bug report actually!