Hello, I’m experimenting the same error using qiskit.aer
as device and the JAX vmap
function. My goal is to train a Quantum Graph Neural Network using qiskit backends. A simplified version is given below.
The initialization of the training state fails with the error:
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Generated function failed: CpuCallback error: CircuitError: "Invalid param type <class 'list'> for gate ry."
Is there any news about this?
import psutil, optax
from typing import Callable
import numpy as np
from datetime import datetime as dt
import pennylane as qml
import jax, jax.numpy as jnp
from flax.training import train_state
from flax import linen as nn
from jax import config
config.update("jax_enable_x64", True)
n_node = 50
n_edge = 80
n_train = 5
n_valid = 2
hid_dim = 4
n_layers = 3
n_qubits = 4
def create_train_state(model, key, graph):
params = model.init(key, graph)['params']
optimizer = optax.adam(learning_rate=0.01)
return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)
def rescale01(X):
return (X-np.min(X))/(np.max(X)-np.min(X))
# @qml.qnode(qml.device("default.qubit.jax", wires=n_qubits), interface="jax-python", diff_method="backprop")
@qml.qnode(qml.device('qiskit.aer', wires=n_qubits), interface="jax")
def circuit(iec_params, pqc_params, n_qubits, n_layers):
for i in range(n_qubits):
qml.RY(iec_params[i], wires=i)
w_iter = -1
for i in range(n_qubits):
w_iter = w_iter + 1
qml.RY(pqc_params[w_iter], wires=i)
for _ in range(n_layers):
qml.Barrier()
for i in range(n_qubits):
qml.CZ(wires=[(n_qubits-2-i)%n_qubits, (n_qubits-1-i)%n_qubits])
for i in range(n_qubits):
w_iter = w_iter + 1
qml.RY(pqc_params[w_iter], wires=i)
exp_vals = [qml.expval(qml.PauliZ(position)) for position in range(n_qubits)]
return tuple(exp_vals)
class QLayer(nn.Module):
my_circuit: Callable
num_params: int
n_layers: int
n_qubits: int
def init_params(self, key: jnp.ndarray):
return jnp.ones(self.n_qubits*(self.n_layers+1))
@nn.compact
def __call__(self, X):
qparams = self.param('qparams', self.init_params)
circuit_vmap = jax.vmap(self.my_circuit, in_axes=(0, None, None, None))
return circuit_vmap(X, qparams, self.n_qubits, self.n_layers)
class QEdgeNet(nn.Module):
Qlayer: nn.Module
@nn.compact
def __call__(self, X, Ri, Ro):
bo = jnp.tensordot(Ro, X, axes=([0],[0]))
bi = jnp.tensordot(Ri, X, axes=([0],[0]))
B = jnp.concatenate([bo, bi], axis = 1)
I = nn.Dense(n_qubits)(B)
I = nn.relu(I)
I = rescale01(I) * jnp.pi
Q = self.Qlayer(I)
Q = jnp.asarray(Q).transpose(1,0)
O = nn.Dense(1)(Q)
O = nn.sigmoid(O)
return O
class QNodeNet(nn.Module):
Qlayer: nn.Module
@nn.compact
def __call__(self, X, e, Ri, Ro):
bo = jnp.tensordot(Ro, X, axes=([0],[0]))
bi = jnp.tensordot(Ri, X, axes=([0],[0]))
Rwo = Ro * e[:,0]
Rwi = Ri * e[:,0]
mi = jnp.tensordot(Rwi, bo, axes=([1],[0]))
mo = jnp.tensordot(Rwo, bi, axes=([1],[0]))
M = jnp.concatenate([mi, mo, X], axis=1)
I = nn.Dense(n_qubits)(M)
I = nn.relu(I)
I = rescale01(I) * np.pi
Q = self.Qlayer(I)
Q = jnp.asarray(Q).transpose(1,0)
O = nn.Dense(hid_dim)(Q)
O = nn.relu(O)
return O
class QGNN(nn.Module):
EdgeLayer: nn.Module
NodeLayer: nn.Module
@nn.compact
def __call__(self, graph_array):
X, Ri, Ro = graph_array
H = nn.Dense(hid_dim)(X)
H = nn.relu(H)
H = jnp.concatenate([H, X], axis=1)
for i in range(3):
e = self.EdgeLayer(H, Ri, Ro)
H = self.NodeLayer(H, e, Ri, Ro)
H = jnp.concatenate([H, X], axis=1)
H = self.EdgeLayer(H, Ri, Ro)
H = jnp.squeeze(H, axis=1)
return H
def generate_graph(nNode, nEdge, key):
subk = jax.random.split(key, num=5)
r = jax.random.randint(subk[4], (1,), -9, 9)
nNode = int(nNode + nNode / 100 * r[0])
nEdge = int(nEdge + nEdge / 100 * r[0])
return (jax.random.normal(subk[0], (nNode, 3), dtype=np.float32),
jax.random.randint(subk[1], (nNode, nEdge), 0, 2).astype(np.float32),
jax.random.randint(subk[2], (nNode, nEdge), 0, 2).astype(np.float32),
jax.random.randint(subk[3], (nEdge,), 0, 2))
def generate_random_dataset(key):
dataset_dim = n_train+n_valid
dataset = []
subkeys = jax.random.split(key, dataset_dim)
for i in range(dataset_dim):
dataset.append(generate_graph(n_node, n_edge, subkeys[i]))
return dataset
if __name__ == "__main__":
# Initialize dataset
process = psutil.Process()
key = jax.random.PRNGKey(0)
dataset = generate_random_dataset(key)
train_list = [i for i in range(n_train)]
valid_list = [i+n_train for i in range(n_valid)]
# Initalize model
print('[{}] Dataset loaded'.format(dt.now())
model = QGNN(QEdgeNet(QLayer(circuit,n_qubits*(n_layers+1), n_layers, n_qubits)),
QNodeNet(QLayer(circuit,n_qubits*(n_layers+1), n_layers, n_qubits)))
X, Ri, Ro, y = dataset[train_list[0]]
key, init_key = jax.random.split(key)
state = create_train_state(model, init_key, (X, Ri, Ro))