AttributeError: module 'jax.core' has no attribute 'ConcreteArray'


import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn import datasets
import seaborn as sns
import jax;

jax.config.update('jax_platform_name', 'cpu')
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

import optax  # optimization using jax

import pennylane as qml
import pennylane.numpy as pnp

sns.set()

seed = 0
rng = np.random.default_rng(seed=seed)


##############################################################################
# To construct a convolutional and pooling layer in a quantum circuit, we will
# follow the QCNN construction proposed in [#CongQuantumCNN]_. The former layer
# will extract local correlations, while the latter allows reducing the dimensionality
# of the feature vector. In a quantum circuit, the convolutional layer, consisting of a kernel swept
# along the entire image, is a two-qubit unitary that correlates neighbouring
# qubits.  As for the pooling layer, we will use a conditioned single-qubit unitary that depends
# on the measurement of a neighboring qubit. Finally, we use a *dense layer* that entangles all
# qubits of the final state using an all-to-all unitary gate as shown in the figure below.
#
# .. figure:: /_static/demonstration_assets/learning_few_data/qcnn-architecture.png
#     :width: 75%
#     :align: center
#
#     QCNN architecture. Taken from Ref. [#CongQuantumCNN]_.
#
# Breaking down the layers
# --------------------------
#
# The convolutional layer should have the weights of the two-qubit unitary as an input, which are
# updated at every training step.  In PennyLane, we model this arbitrary two-qubit unitary
# with a particular sequence of gates: two single-qubit  :class:`~.pennylane.U3` gates (parametrized by three
# parameters, each), three Ising interactions between both qubits (each interaction is
# parametrized by one parameter), and two additional :class:`~.pennylane.U3` gates on each of the two
# qubits. 

def convolutional_layer(weights, wires, skip_first_layer=True):
    """Adds a convolutional layer to a circuit.
    Args:
        weights (np.array): 1D array with 15 weights of the parametrized gates.
        wires (list[int]): Wires where the convolutional layer acts on.
        skip_first_layer (bool): Skips the first two U3 gates of a layer.
    """
    n_wires = len(wires)
    assert n_wires >= 3, "this circuit is too small!"

    for p in [0, 1]:
        for indx, w in enumerate(wires):
            if indx % 2 == p and indx < n_wires - 1:
                if indx % 2 == 0 and not skip_first_layer:
                    qml.U3(*weights[:3], wires=[w])
                    qml.U3(*weights[3:6], wires=[wires[indx + 1]])
                qml.IsingXX(weights[6], wires=[w, wires[indx + 1]])
                qml.IsingYY(weights[7], wires=[w, wires[indx + 1]])
                qml.IsingZZ(weights[8], wires=[w, wires[indx + 1]])
                qml.U3(*weights[9:12], wires=[w])
                qml.U3(*weights[12:], wires=[wires[indx + 1]])


##############################################################################
# The pooling layer's inputs are the weights of the single-qubit conditional unitaries, which in
# this case are :class:`~.pennylane.U3` gates. Then, we apply these conditional measurements to half of the
# unmeasured wires, reducing our system size by a factor of 2.


def pooling_layer(weights, wires):
    """Adds a pooling layer to a circuit.
    Args:
        weights (np.array): Array with the weights of the conditional U3 gate.
        wires (list[int]): List of wires to apply the pooling layer on.
    """
    n_wires = len(wires)
    assert len(wires) >= 2, "this circuit is too small!"

    for indx, w in enumerate(wires):
        if indx % 2 == 1 and indx < n_wires:
            m_outcome = qml.measure(w)
            qml.cond(m_outcome, qml.U3)(*weights, wires=wires[indx - 1])


##############################################################################
# We can construct a QCNN by combining both layers and using an arbitrary unitary to model
# a dense layer. It will take a set of features — the image — as input, encode these features using
# an embedding map, apply rounds of convolutional and pooling layers, and eventually output the
# desired measurement statistics of the circuit.


def conv_and_pooling(kernel_weights, n_wires, skip_first_layer=True):
    """Apply both the convolutional and pooling layer."""
    convolutional_layer(kernel_weights[:15], n_wires, skip_first_layer=skip_first_layer)
    pooling_layer(kernel_weights[15:], n_wires)


def dense_layer(weights, wires):
    """Apply an arbitrary unitary gate to a specified set of wires."""
    qml.ArbitraryUnitary(weights, wires)


num_wires = 6
device = qml.device("default.qubit", wires=num_wires)


@qml.qnode(device)
def conv_net(weights, last_layer_weights, features):
    """Define the QCNN circuit
    Args:
        weights (np.array): Parameters of the convolution and pool layers.
        last_layer_weights (np.array): Parameters of the last dense layer.
        features (np.array): Input data to be embedded using AmplitudEmbedding."""

    layers = weights.shape[1]
    wires = list(range(num_wires))

    # inputs the state input_state
    qml.AmplitudeEmbedding(features=features, wires=wires, pad_with=0.5)
    qml.Barrier(wires=wires, only_visual=True)

    # adds convolutional and pooling layers
    for j in range(layers):
        conv_and_pooling(weights[:, j], wires, skip_first_layer=(not j == 0))
        wires = wires[::2]
        qml.Barrier(wires=wires, only_visual=True)

    assert last_layer_weights.size == 4 ** (len(wires)) - 1, (
        "The size of the last layer weights vector is incorrect!"
        f" \n Expected {4 ** (len(wires)) - 1}, Given {last_layer_weights.size}"
    )
    dense_layer(last_layer_weights, wires)
    return qml.probs(wires=(0))


fig, ax = qml.draw_mpl(conv_net)(
    np.random.rand(18, 2), np.random.rand(4 ** 2 - 1), np.random.rand(2 ** num_wires)
)
plt.show()

##############################################################################
# In the problem we will address, we need to encode 64 features
# in our quantum state. Thus, we require six qubits (:math:`2^6 = 64`) to encode
# each feature value in the amplitude of each computational basis state.
#
# Training the QCNN on the digits dataset
# ---------------------------------------
# In this demo, we are going to classify the digits ``0`` and ``1`` from the classical ``digits`` dataset.
# Each hand-written digit image is represented as an :math:`8 \times 8` array of pixels as shown below:

import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.utils import shuffle

# Load the CSV file
data_path = "/content/drive/MyDrive/SleepyEEG.csv"
df = pd.read_csv(data_path)

# Extract features and labels
features = df.iloc[:, :-1].values  # All columns except the last one
labels = df.iloc[:, -1].values  # Last column (classification: 0 or 1)

# Count occurrences of each class
unique, counts = np.unique(labels, return_counts=True)
class_distribution = dict(zip(unique, counts))
print("Class distribution:", class_distribution)

# Shuffle the data and labels
features, labels = shuffle(features, labels, random_state=42)

# Split the dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(features, labels, test_size=0.2, random_state=42)

# Normalize the data
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

# Print the shapes of the datasets
print("Training data shape:", X_train_scaled.shape)
print("Testing data shape:", X_test_scaled.shape)
print("Training labels shape:", y_train.shape)
print("Testing labels shape:", y_test.shape)
print("Training data type:",type(X_train_scaled))
print("Testing data type:",type(X_test_scaled))
print("Training label type:",type(y_train))
print("Testinging label type:",type(y_test))


##############################################################################
# For convenience, we create a ``load_digits_data`` function that will make random training and
# testing sets from the ``digits`` dataset from ``sklearn.dataset``:


def load_digits_data(num_train, num_test, rng):
    """Return training and testing data of digits dataset."""
    digits = datasets.load_digits()
    features, labels = digits.data, digits.target

    # only use first two classes
    features = features[np.where((labels == 0) | (labels == 1))]
    labels = labels[np.where((labels == 0) | (labels == 1))]

    # normalize data
    features = features / np.linalg.norm(features, axis=1).reshape((-1, 1))

    # subsample train and test split
    train_indices = rng.choice(len(labels), num_train, replace=False)
    test_indices = rng.choice(
        np.setdiff1d(range(len(labels)), train_indices), num_test, replace=False
    )

    x_train, y_train = features[train_indices], labels[train_indices]
    x_test, y_test = features[test_indices], labels[test_indices]

    return (
        jnp.asarray(x_train),
        jnp.asarray(y_train),
        jnp.asarray(x_test),
        jnp.asarray(y_test),
    )


##############################################################################
# To optimize the weights of our variational model, we define the cost and accuracy functions
# to train and quantify the performance on the classification task of the previously described QCNN:


@jax.jit
def compute_out(weights, weights_last, features, labels):
    """Computes the output of the corresponding label in the qcnn"""
    cost = lambda weights, weights_last, feature, label: conv_net(weights, weights_last, feature)[
        label
    ]
    return jax.vmap(cost, in_axes=(None, None, 0, 0), out_axes=0)(
        weights, weights_last, features, labels
    )


def compute_accuracy(weights, weights_last, features, labels):
    """Computes the accuracy over the provided features and labels"""
    out = compute_out(weights, weights_last, features, labels)
    return jnp.sum(out > 0.5) / len(out)


def compute_cost(weights, weights_last, features, labels):
    """Computes the cost over the provided features and labels"""
    out = compute_out(weights, weights_last, features, labels)
    return 1.0 - jnp.sum(out) / len(labels)


def init_weights():
    """Initializes random weights for the QCNN model."""
    weights = pnp.random.normal(loc=0, scale=1, size=(18, 2), requires_grad=True)
    weights_last = pnp.random.normal(loc=0, scale=1, size=4 ** 2 - 1, requires_grad=True)
    return jnp.array(weights), jnp.array(weights_last)


value_and_grad = jax.jit(jax.value_and_grad(compute_cost, argnums=[0, 1]))



def train_qcnn(n_train, n_test, n_epochs):
    """
    Args:
        n_train  (int): number of training examples
        n_test   (int): number of test examples
        n_epochs (int): number of training epochs
        desc  (string): displayed string during optimization

    Returns:
        dict: n_train,
        steps,
        train_cost_epochs,
        train_acc_epochs,
        test_cost_epochs,
        test_acc_epochs

    """
    # load data
    x_train, y_train, x_test, y_test = load_digits_data(n_train, n_test, rng)

    # init weights and optimizer
    weights, weights_last = init_weights()

    # learning rate decay
    cosine_decay_scheduler = optax.cosine_decay_schedule(0.1, decay_steps=n_epochs, alpha=0.95)
    optimizer = optax.adam(learning_rate=cosine_decay_scheduler)
    opt_state = optimizer.init((weights, weights_last))

    # data containers
    train_cost_epochs, test_cost_epochs, train_acc_epochs, test_acc_epochs = [], [], [], []

    for step in range(n_epochs):
        # Training step with (adam) optimizer
        train_cost, grad_circuit = value_and_grad(weights, weights_last, x_train, y_train)
        updates, opt_state = optimizer.update(grad_circuit, opt_state)
        weights, weights_last = optax.apply_updates((weights, weights_last), updates)

        train_cost_epochs.append(train_cost)

        # compute accuracy on training data
        train_acc = compute_accuracy(weights, weights_last, x_train, y_train)
        train_acc_epochs.append(train_acc)

        # compute accuracy and cost on testing data
        test_out = compute_out(weights, weights_last, x_test, y_test)
        test_acc = jnp.sum(test_out > 0.5) / len(test_out)
        test_acc_epochs.append(test_acc)
        test_cost = 1.0 - jnp.sum(test_out) / len(test_out)
        test_cost_epochs.append(test_cost)

    return dict(
        n_train=[n_train] * n_epochs,
        step=np.arange(1, n_epochs + 1, dtype=int),
        train_cost=train_cost_epochs,
        train_acc=train_acc_epochs,
        test_cost=test_cost_epochs,
        test_acc=test_acc_epochs,
    )


##############################################################################
# .. note::
#
#     There are some small intricacies for speeding up this code that are worth mentioning. We are using ``jax`` for our training
#     because it allows for `just-in-time <https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html>`_ (``jit``) compilation. A function decorated with ``@jax.jit`` will be compiled upon its first execution
#     and cached for future executions. This means the first execution will take longer, but all subsequent executions are substantially faster.
#     Further, we use ``jax.vmap`` to vectorize the execution of the QCNN over all input states, as opposed to looping through the training and test set at every execution.

##############################################################################
# Training for different training set sizes yields different accuracies, as seen below. As we increase the training data size, the overall test accuracy,
# a proxy for the models' generalization capabilities, increases:

n_test = 100
n_epochs = 100
n_reps = 100


def run_iterations(n_train):
    results_df = pd.DataFrame(
        columns=["train_acc", "train_cost", "test_acc", "test_cost", "step", "n_train"]
    )

    for _ in range(n_reps):
        results = train_qcnn(n_train=n_train, n_test=n_test, n_epochs=n_epochs)
        results_df = pd.concat(
            [results_df, pd.DataFrame.from_dict(results)], axis=0, ignore_index=True
        )

    return results_df


# run training for multiple sizes
train_sizes = [2, 5, 10, 20, 40, 80]
results_df = run_iterations(n_train=2)
for n_train in train_sizes[1:]:
    results_df = pd.concat([results_df, run_iterations(n_train=n_train)])

##############################################################################
# Finally, we plot the loss and accuracy for both the training and testing set
# for all training epochs, and compare the test and train accuracy of the model:

# aggregate dataframe
df_agg = results_df.groupby(["n_train", "step"]).agg(["mean", "std"])
df_agg = df_agg.reset_index()

sns.set_style('whitegrid')
colors = sns.color_palette()
fig, axes = plt.subplots(ncols=3, figsize=(16.5, 5))

generalization_errors = []

# plot losses and accuracies
for i, n_train in enumerate(train_sizes):
    df = df_agg[df_agg.n_train == n_train]

    dfs = [df.train_cost["mean"], df.test_cost["mean"], df.train_acc["mean"], df.test_acc["mean"]]
    lines = ["o-", "x--", "o-", "x--"]
    labels = [fr"$N={n_train}$", None, fr"$N={n_train}$", None]
    axs = [0,0,2,2]
    
    for k in range(4):
        ax = axes[axs[k]]   
        ax.plot(df.step, dfs[k], lines[k], label=labels[k], markevery=10, color=colors[i], alpha=0.8)


    # plot final loss difference
    dif = df[df.step == 100].test_cost["mean"] - df[df.step == 100].train_cost["mean"]
    generalization_errors.append(dif)

# format loss plot
ax = axes[0]
ax.set_title('Train and Test Losses', fontsize=14)
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')

# format generalization error plot
ax = axes[1]
ax.plot(train_sizes, generalization_errors, "o-", label=r"$gen(\alpha)$")
ax.set_xscale('log')
ax.set_xticks(train_sizes)
ax.set_xticklabels(train_sizes)
ax.set_title(r'Generalization Error $gen(\alpha) = R(\alpha) - \hat{R}_N(\alpha)$', fontsize=14)
ax.set_xlabel('Training Set Size')

# format loss plot
ax = axes[2]
ax.set_title('Train and Test Accuracies', fontsize=14)
ax.set_xlabel('Epoch')
ax.set_ylabel('Accuracy')
ax.set_ylim(0.5, 1.05)

legend_elements = [
    mpl.lines.Line2D([0], [0], label=f'N={n}', color=colors[i]) for i, n in enumerate(train_sizes)
    ] + [
    mpl.lines.Line2D([0], [0], marker='o', ls='-', label='Train', color='Black'),
    mpl.lines.Line2D([0], [0], marker='x', ls='--', label='Test', color='Black')
    ]

axes[0].legend(handles=legend_elements, ncol=3)
axes[2].legend(handles=legend_elements, ncol=3)

axes[1].set_yscale('log', base=2)
plt.show()

####  
Error Message
Class distribution: {0: 2135, 1: 1600}
Training data shape: (2988, 10)
Testing data shape: (747, 10)
Training labels shape: (2988,)
Testing labels shape: (747,)
Training data type: <class 'numpy.ndarray'>
Testing data type: <class 'numpy.ndarray'>
Training label type: <class 'numpy.ndarray'>
Testinging label type: <class 'numpy.ndarray'>
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-38-4b023ddae810> in <cell line: 0>()
    371 # run training for multiple sizes
    372 train_sizes = [2, 5, 10, 20, 40, 80]
--> 373 results_df = run_iterations(n_train=2)
    374 for n_train in train_sizes[1:]:
    375     results_df = pd.concat([results_df, run_iterations(n_train=n_train)])

15 frames
    [... skipping hidden 29 frame]

    [... skipping hidden 13 frame]

    [... skipping hidden 7 frame]

/usr/local/lib/python3.11/dist-packages/jax/_src/deprecations.py in getattr(name)
     55       warnings.warn(message, DeprecationWarning, stacklevel=2)
     56       return fn
---> 57     raise AttributeError(f"module {module!r} has no attribute {name!r}")
     58 
     59   return getattr

AttributeError: module 'jax.core' has no attribute 'ConcreteArray'

Hi @MKK_QML , welcome to the Forum!

This is probably due to the version of JAX you’re using. If you use JAX version 0.4.33 you probably won’t see the issue. If you use JAX >=0.5 then you probably will.

Check out thread #8053 if you need some help on how to downgrade your JAX version.

Let us know if you’re having trouble downgrading or if you still face this issue after downgrading.

And please let us know if this worked to solve the issue!