Qiskit.Aer plugin with JAX not working

Hello, I’m having some troubles in using Qiskit plugin and JAX. In particular, I’m using qiskit.aer as device, together with a noise_model coming from a fake Qiskit backend.

I’m using JAX to jit and vmap the function executing the circuit:

def create_circuit(n_qubits,layers,ansatz):
    fake_backend = FakeMontrealV2()
    noise_model = noise.NoiseModel.from_backend(fake_backend)
    device = qml.device('qiskit.aer', wires=n_qubits,  noise_model=noise_model)
    ansatz, params_per_layer = get_ansatz(ansatz,n_qubits)

    @qml.qnode(device, interface='jax')
    def circuit(x, theta):
        qml.AngleEmbedding(x, wires=range(n_qubits), rotation='Y')
        for i in range(layers):
            ansatz(theta[i * params_per_layer: (i + 1) * params_per_layer], wires=range(n_qubits))
        return qml.expval(qml.PauliZ(wires=0))
    return jax.jit(circuit)
# quantum circuit
qnn_tmp = create_circuit(n_qubits,layers,ansatz)
# apply vmap on x (first circuit param)
qnn_batched = jax.vmap(qnn_tmp, (0, None))
# Jit for faster execution
qnn = jax.jit(qnn_batched)

Then, the circuit is executed by simply calling:

output = qnn(X, theta)

However I’m getting this error:

jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Generated function failed: CpuCallback error: CircuitError: "Invalid param type <class 'list'> for gate ry."

Without using jax.vmap, the code runs smoothly.
It also runs smoothly if I comment the embedding out: #qml.AngleEmbedding(x, wires=range(n_qubits), rotation='Y'), i.e. if the circuit doesn’t call rotations on x. It seems that there’s some problem with Qiskit broadcasting JAX’s batches to ry Qiskit rotations.

Hey @Andrea_Ceschini! Welcome to the forum :muscle:

jax.vmap can sometimes be problematic. But, since it’s working without it, do you actually need it? PennyLane should be able to handle the broadcasting! Parameter broadcasting is a feature we added last year and continue to improve and add more support :grin:.

Hi, thanks for the quick reply! Actually, I need to use jax.vmap since it speeds-up the code. Without using it, the simulation of the circuit with the noise model (qiskit.aer backend) is too slow.

With vmap, I can pass to the circuit both my x, whose shape is (200 x 5), i.e. 200 samples with 5 features each, and an array of parameters of shape (layers * params_per_layer, ).
With my implementation of qnn_batched = jax.vmap(qnn_tmp, (0, None)), the problem is the embedding of x with the instruction qml.AngleEmbedding(x, wires=range(n_qubits), rotation='Y'). It gives me the error above; without the embedding instead, so just with the parameterized ansatz, the code works fine.

After inspecting a bit, I’ve seen that the problem should be that the type BatchTracer object is not compatible with Qiskit’s ry operation. Is it possible to solve it?

1 Like

Oh interesting… Thanks for explaining more! I think you should submit a bug report actually!

Hello, I’m experimenting the same error using qiskit.aer as device and the JAX vmap function. My goal is to train a Quantum Graph Neural Network using qiskit backends. A simplified version is given below.

The initialization of the training state fails with the error:

jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Generated function failed: CpuCallback error: CircuitError: "Invalid param type <class 'list'> for gate ry."

Is there any news about this?

import psutil, optax
from typing import Callable
import numpy as np
from datetime import datetime as dt
import pennylane as qml
import jax, jax.numpy as jnp
from flax.training import train_state
from flax import linen as nn
from jax import config
config.update("jax_enable_x64", True)

n_node = 50
n_edge = 80
n_train = 5
n_valid = 2
hid_dim = 4
n_layers = 3
n_qubits = 4

def create_train_state(model, key, graph):
    params = model.init(key, graph)['params'] 
    optimizer = optax.adam(learning_rate=0.01)
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)

def rescale01(X):
    return (X-np.min(X))/(np.max(X)-np.min(X))

# @qml.qnode(qml.device("default.qubit.jax", wires=n_qubits), interface="jax-python", diff_method="backprop")
@qml.qnode(qml.device('qiskit.aer', wires=n_qubits), interface="jax")
def circuit(iec_params, pqc_params, n_qubits, n_layers):
    for i in range(n_qubits):
        qml.RY(iec_params[i], wires=i)
    w_iter = -1
    for i in range(n_qubits):
        w_iter = w_iter + 1
        qml.RY(pqc_params[w_iter], wires=i)
    for _ in range(n_layers):
        for i in range(n_qubits):
            qml.CZ(wires=[(n_qubits-2-i)%n_qubits, (n_qubits-1-i)%n_qubits])
        for i in range(n_qubits):
            w_iter = w_iter + 1
            qml.RY(pqc_params[w_iter], wires=i)
    exp_vals = [qml.expval(qml.PauliZ(position)) for position in range(n_qubits)]
    return tuple(exp_vals)

class QLayer(nn.Module):
    my_circuit: Callable
    num_params: int
    n_layers: int
    n_qubits: int

    def init_params(self, key: jnp.ndarray):
        return jnp.ones(self.n_qubits*(self.n_layers+1))
    def __call__(self, X):
        qparams = self.param('qparams', self.init_params)
        circuit_vmap = jax.vmap(self.my_circuit, in_axes=(0, None, None, None))
        return circuit_vmap(X, qparams, self.n_qubits, self.n_layers)

class QEdgeNet(nn.Module):
    Qlayer: nn.Module
    def __call__(self, X, Ri, Ro):
        bo = jnp.tensordot(Ro, X, axes=([0],[0]))
        bi = jnp.tensordot(Ri, X, axes=([0],[0]))
        B = jnp.concatenate([bo, bi], axis = 1)
        I = nn.Dense(n_qubits)(B)
        I = nn.relu(I)
        I = rescale01(I) * jnp.pi
        Q = self.Qlayer(I)
        Q = jnp.asarray(Q).transpose(1,0)
        O = nn.Dense(1)(Q)
        O = nn.sigmoid(O)
        return O
class QNodeNet(nn.Module):  
    Qlayer: nn.Module
    def __call__(self, X, e, Ri, Ro):
        bo = jnp.tensordot(Ro, X, axes=([0],[0]))
        bi = jnp.tensordot(Ri, X, axes=([0],[0]))
        Rwo = Ro * e[:,0]
        Rwi = Ri * e[:,0]
        mi = jnp.tensordot(Rwi, bo, axes=([1],[0]))
        mo = jnp.tensordot(Rwo, bi, axes=([1],[0]))
        M = jnp.concatenate([mi, mo, X], axis=1)
        I = nn.Dense(n_qubits)(M)
        I = nn.relu(I)
        I = rescale01(I) * np.pi
        Q = self.Qlayer(I)
        Q = jnp.asarray(Q).transpose(1,0)
        O = nn.Dense(hid_dim)(Q)
        O = nn.relu(O)
        return O
class QGNN(nn.Module):
    EdgeLayer: nn.Module
    NodeLayer: nn.Module
    def __call__(self, graph_array):
        X, Ri, Ro = graph_array
        H = nn.Dense(hid_dim)(X)       
        H = nn.relu(H)
        H = jnp.concatenate([H, X], axis=1)
        for i in range(3):
            e = self.EdgeLayer(H, Ri, Ro)
            H = self.NodeLayer(H, e, Ri, Ro)
            H = jnp.concatenate([H, X], axis=1)
        H = self.EdgeLayer(H, Ri, Ro)
        H = jnp.squeeze(H, axis=1)
        return H

def generate_graph(nNode, nEdge, key):
    subk = jax.random.split(key, num=5)
    r = jax.random.randint(subk[4], (1,), -9, 9)
    nNode = int(nNode + nNode / 100 * r[0])
    nEdge = int(nEdge + nEdge / 100 * r[0])
    return (jax.random.normal(subk[0], (nNode, 3), dtype=np.float32),
            jax.random.randint(subk[1], (nNode, nEdge), 0, 2).astype(np.float32),
            jax.random.randint(subk[2], (nNode, nEdge), 0, 2).astype(np.float32),
            jax.random.randint(subk[3], (nEdge,), 0, 2))

def generate_random_dataset(key):
    dataset_dim = n_train+n_valid
    dataset = []
    subkeys = jax.random.split(key, dataset_dim)
    for i in range(dataset_dim):
        dataset.append(generate_graph(n_node, n_edge, subkeys[i]))
    return dataset
if __name__ == "__main__":
    # Initialize dataset
    process = psutil.Process()
    key = jax.random.PRNGKey(0)
    dataset = generate_random_dataset(key)
    train_list = [i for i in range(n_train)]
    valid_list = [i+n_train for i in range(n_valid)]
    # Initalize model
    print('[{}] Dataset loaded'.format(dt.now())
    model = QGNN(QEdgeNet(QLayer(circuit,n_qubits*(n_layers+1), n_layers, n_qubits)),
                 QNodeNet(QLayer(circuit,n_qubits*(n_layers+1), n_layers, n_qubits)))
    X, Ri, Ro, y = dataset[train_list[0]]
    key, init_key = jax.random.split(key)
    state = create_train_state(model, init_key, (X, Ri, Ro))

Thanks for posting this @Laura_Cappelli . This does seem to a bug on our side.

Our handling of vmap currently assumes that the device natively supports parameter broadcasting, which is only true is a limited subsection of devices.

Would you mind opening a bug report on our github repo? Issues · PennyLaneAI/pennylane · GitHub

A more minimal example of the problem is:

dev = qml.device('lightning.qubit', wires=10)

def circuit(x):
    qml.RX(x, wires=0)
    return qml.expval(qml.PauliZ(0))

jax.vmap(circuit)(jax.numpy.array([0.5, 0.6, 0.7]))

1 Like