Error combining JAX with multiple devices

Good morning Pennylane community. I have an issue I would like to discuss with you, hoping to get a solution. I am working on QAOA and I am trying to use JAX for faster simulation but I have a problem where I call the decorator @jax.jit on a qnode.
I have started from your tutorial on Maxcut and mi idea is to make a function where store the qnode since I will create multiple devices with different number of qubits each time to have a statistic behavior of the task depending on the number of qubits.

def mutable_qnode(device, new_params, graph, edge=None):

    @jax.jit
    @qml.qnode(device, interface="jax")
    def qnode(new_params=new_params, graph=graph, edge=edge):
        [qml.Hadamard(i) for i in range(qubits)]
        for l in range(fixed_layers):
            gamma_circuit(opt_params[l, 0], graph=graph)
            beta_circuit(opt_params[l, 1])
        
        # variational block
        gamma_circuit(new_params[0], graph=graph)
        beta_circuit(new_params[1])

        '''if edge is None:
            return qml.counts()'''
        H = qml.PauliZ(edge[0]) @ qml.PauliZ(edge[1])
        return qml.expval(H)
    
    result = qnode(new_params, graph, edge=edge)
    return result

I have never seem it before. My dream is to use jax jit to speed the simulation.
Many thanks in advance.

francescoaldoventurelli@AirdiFrancesco QAOA % /usr/local/bin/python3.10 /Users/francescoaldoventurelli/qml/QAOA/j
ax_unmodified.py
/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.
  warnings.warn(
/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.
  warnings.warn(
jax.pure_callback failed
Traceback (most recent call last):
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/jax/_src/callback.py", line 86, in pure_callback_impl
    return tree_util.tree_map(np.asarray, callback(*args))
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/jax/_src/callback.py", line 64, in __call__
    return tree_util.tree_leaves(self.callback_func(*args, **kwargs))
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pennylane/workflow/interfaces/jax_jit.py", line 187, in wrapper
    return _to_jax(jpc.execute_and_compute_jacobian(new_tapes))
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pennylane/workflow/jacobian_products.py", line 312, in execute_and_compute_jacobian
    jac_tapes, jac_postprocessing = self._gradient_transform(tapes, **self._gradient_kwargs)
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pennylane/transforms/core/transform_dispatcher.py", line 135, in __call__
    return self._batch_transform(obj, targs, tkwargs)
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pennylane/transforms/core/transform_dispatcher.py", line 343, in _batch_transform
    new_tapes, fn = self(t, *targs, **tkwargs)
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pennylane/transforms/core/transform_dispatcher.py", line 100, in __call__
    intermediate_tapes, post_processing_fn = self._transform(
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pennylane/gradients/parameter_shift.py", line 1120, in param_shift
    diff_methods = find_and_validate_gradient_methods(tape, method, trainable_params)
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pennylane/gradients/gradient_transform.py", line 220, in find_and_validate_gradient_methods
    diff_methods = _find_gradient_methods(tape, trainable_param_indices, use_graph=use_graph)
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pennylane/gradients/gradient_transform.py", line 163, in _find_gradient_methods
    return {
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pennylane/gradients/gradient_transform.py", line 164, in <dictcomp>
    idx: _try_zero_grad_from_graph_or_get_grad_method(
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pennylane/gradients/gradient_transform.py", line 153, in _try_zero_grad_from_graph_or_get_grad_method
    if not any(tape.graph.has_path(op_or_mp, mp) for mp in tape.measurements):
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pennylane/gradients/gradient_transform.py", line 153, in <genexpr>
    if not any(tape.graph.has_path(op_or_mp, mp) for mp in tape.measurements):
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pennylane/tape/qscript.py", line 966, in graph
    self._graph = qml.CircuitGraph(
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pennylane/circuit_graph.py", line 127, in __init__
    wire = wires.index(w)
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pennylane/wires.py", line 250, in index
    return self._labels.index(wire)
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py", line 739, in op
    return getattr(self.aval, f"_{name}")(self, *args)
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py", line 265, in deferring_binary_op
    return binary_op(*args)
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/jax/_src/pjit.py", line 327, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/jax/_src/pjit.py", line 185, in _python_pjit_helper
    out_flat = pjit_p.bind(*args_flat, **params)
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/jax/_src/core.py", line 2829, in bind
    top_trace = find_top_trace(args)
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/jax/_src/core.py", line 1362, in find_top_trace
    top_tracer._assert_live()
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 1736, in _assert_live
    raise core.escaped_tracer_error(self, None)
jax.errors.UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type int64[] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was qnode at /Users/francescoaldoventurelli/qml/QAOA/jax_unmodified.py:56 traced for jit.
------------------------------
The leaked intermediate value was created on line /Users/francescoaldoventurelli/qml/QAOA/jax_unmodified.py:73 (mutable_qnode). 
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
------------------------------

To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/francescoaldoventurelli/qml/QAOA/jax_unmodified.py", line 188, in <module>
    energy, counts, optimal_last_gamma_beta, ar = qaoa_execution(dev, graph1)
  File "/Users/francescoaldoventurelli/qml/QAOA/jax_unmodified.py", line 156, in qaoa_execution
    grads = jax.grad(obj_function)(params)
  File "/Users/francescoaldoventurelli/qml/QAOA/jax_unmodified.py", line 147, in obj_function
    cost -= 0.5 * (1 - mutable_qnode(device, new_params, graph, edge=edge))
  File "/Users/francescoaldoventurelli/qml/QAOA/jax_unmodified.py", line 73, in mutable_qnode
    result = qnode(new_params, graph, edge=edge)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Generated function failed: CpuCallback error: Traceback (most recent call last):
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pennylane/gradients/gradient_transform.py", line 153, in <genexpr>
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pennylane/tape/qscript.py", line 966, in graph
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pennylane/circuit_graph.py", line 127, in __init__
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pennylane/wires.py", line 251, in index
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py", line 739, in op
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py", line 265, in deferring_binary_op
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/jax/_src/pjit.py", line 327, in cache_miss
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/jax/_src/pjit.py", line 194, in _python_pjit_helper
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/jax/_src/core.py", line 2829, in bind
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/jax/_src/core.py", line 1362, in find_top_trace
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 1736, in _assert_live
UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type int64[] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was qnode at /Users/francescoaldoventurelli/qml/QAOA/jax_unmodified.py:56 traced for jit.
------------------------------
The leaked intermediate value was created on line /Users/francescoaldoventurelli/qml/QAOA/jax_unmodified.py:73 (mutable_qnode). 
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
------------------------------

To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError
francescoaldoventurelli@AirdiFrancesco QAOA % 

Pennylane versions:
`Name: PennyLane
Version: 0.37.0
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page: GitHub - PennyLaneAI/pennylane: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Author:
Author-email:
License: Apache License 2.0
Location: /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, semantic-version, toml, typing-extensions
Required-by: PennyLane_Lightning

Platform info: macOS-14.4.1-arm64-arm-64bit
Python version: 3.10.10
Numpy version: 1.26.4
Scipy version: 1.10.1
Installed devices:

  • lightning.qubit (PennyLane_Lightning-0.37.0)
  • default.clifford (PennyLane-0.37.0)
  • default.gaussian (PennyLane-0.37.0)
  • default.mixed (PennyLane-0.37.0)
  • default.qubit (PennyLane-0.37.0)
  • default.qubit.autograd (PennyLane-0.37.0)
  • default.qubit.jax (PennyLane-0.37.0)
  • default.qubit.legacy (PennyLane-0.37.0)
  • default.qubit.tf (PennyLane-0.37.0)
  • default.qubit.torch (PennyLane-0.37.0)
  • default.qutrit (PennyLane-0.37.0)
  • default.qutrit.mixed (PennyLane-0.37.0)
  • default.tensor (PennyLane-0.37.0)
  • null.qubit (PennyLane-0.37.0)
    None`.

Hi @checcopo ,

The code you shared looks very different to the PennyLane demo on Maxcut. I suggest changing one thing at a time, either changing the code to Jax first and then try to add changeable devices or vice versa.

The code below shows the same code from the demo with some small edits to allow you to change devices. The most important change is that instead of using a decorator to attach the circuit to a device, you can instantiate the QNode via qml.QNode (learn more here). The QNode is then called ‘qnode’ instead of ‘circuit’ so I changed this in a couple of places.

I’ve marked my edits with comments. Let me know if you have any questions about it.

# Code with changeable devices
import pennylane as qml
from pennylane import numpy as np

np.random.seed(42)

n_wires = 4
graph = [(0, 1), (0, 3), (1, 2), (2, 3)]

# unitary operator U_B with parameter beta
def U_B(beta):
    for wire in range(n_wires):
        qml.RX(2 * beta, wires=wire)


# unitary operator U_C with parameter gamma
def U_C(gamma):
    for edge in graph:
        wire1 = edge[0]
        wire2 = edge[1]
        qml.CNOT(wires=[wire1, wire2])
        qml.RZ(gamma, wires=wire2)
        qml.CNOT(wires=[wire1, wire2])


def bitstring_to_int(bit_string_sample):
    bit_string = "".join(str(bs) for bs in bit_string_sample)
    return int(bit_string, base=2)

def circuit(gammas, betas, edge=None, n_layers=1):
    # apply Hadamards to get the n qubit |+> state
    for wire in range(n_wires):
        qml.Hadamard(wires=wire)
    # p instances of unitary operators
    for i in range(n_layers):
        U_C(gammas[i])
        U_B(betas[i])
    if edge is None:
        # measurement phase
        return qml.sample()
    # during the optimization phase we are evaluating a term
    # in the objective using expval
    H = qml.PauliZ(edge[0]) @ qml.PauliZ(edge[1])
    return qml.expval(H)

def qaoa_maxcut(device, n_layers=1): # Edited this
    print("\np={:d}".format(n_layers))

    # initialize the parameters near zero
    init_params = 0.01 * np.random.rand(2, n_layers, requires_grad=True)

    qnode = qml.QNode(circuit,device) # Added this

    # minimize the negative of the objective function
    def objective(params):
        gammas = params[0]
        betas = params[1]
        neg_obj = 0
        for edge in graph:
            # objective for the MaxCut problem
            neg_obj -= 0.5 * (1 - qnode(gammas, betas, edge=edge, n_layers=n_layers)) # Edited this
        return neg_obj

    # initialize optimizer: Adagrad works well empirically
    opt = qml.AdagradOptimizer(stepsize=0.5)

    # optimize parameters in objective
    params = init_params
    steps = 30
    for i in range(steps):
        params = opt.step(objective, params)
        if (i + 1) % 5 == 0:
            print("Objective after step {:5d}: {: .7f}".format(i + 1, -objective(params)))

    # sample measured bitstrings 100 times
    bit_strings = []
    n_samples = 100
    for i in range(0, n_samples):
        bit_strings.append(bitstring_to_int(qnode(params[0], params[1], edge=None, n_layers=n_layers))) # Edited this

    # print optimal parameters and most frequently sampled bitstring
    counts = np.bincount(np.array(bit_strings))
    most_freq_bit_string = np.argmax(counts)
    print("Optimized (gamma, beta) vectors:\n{}".format(params[:, :n_layers]))
    print("Most frequently sampled bit string is: {:04b}".format(most_freq_bit_string))

    return -objective(params), bit_strings

# Added two options for devices
dev1 = qml.device("lightning.qubit", wires=n_wires, shots=1)
dev2 = qml.device("default.qubit", wires=n_wires, shots=1)

# perform qaoa on our graph with p=1,2 and
# keep the bitstring sample lists
bitstrings1 = qaoa_maxcut(dev1, n_layers=1)[1] # Using dev1
bitstrings2 = qaoa_maxcut(dev2, n_layers=2)[1] # Using dev2

import matplotlib.pyplot as plt

xticks = range(0, 16)
xtick_labels = list(map(lambda x: format(x, "04b"), xticks))
bins = np.arange(0, 17) - 0.5

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.title("n_layers=1")
plt.xlabel("bitstrings")
plt.ylabel("freq.")
plt.xticks(xticks, xtick_labels, rotation="vertical")
plt.hist(bitstrings1, bins=bins)
plt.subplot(1, 2, 2)
plt.title("n_layers=2")
plt.xlabel("bitstrings")
plt.ylabel("freq.")
plt.xticks(xticks, xtick_labels, rotation="vertical")
plt.hist(bitstrings2, bins=bins)
plt.tight_layout()
plt.show()

Thanks @CatalinaAlbornoz for replying. The current code must be different from the one represented in the tutorial, since it develops a specific and precise task. As I mentioned, I have started from the one shared on the tutorial, but putting my hands on it to do something different. Anyway, it’s an error related to Jax which cannot understand some types of data. Could you provide an answer to that error from a Jax’s side?
Thanks.

Hi @checcopo ,

Unfortunately the code you provided is not enough to reproduce your error. Are you able to share a minimal (but self-contained) working example?

This is the simplest version of the code that reproduces the problem. It should include all necessary imports, data, functions, etc., so that we can copy-paste the code and reproduce the problem. However it shouldn’t contain any unnecessary data, functions, …, for example gates and functions that can be removed to simplify the code.

In this case for example you should either include the code for gamma_circuit and beta_circuit or remove them from the code in case they don’t contribute to the error.

Sorry for being late.

I can upload parts of the code. But, I recommend that I have tried to include Catalyst in the existing script. Here is what I got.

import jax
from jax import numpy as jnp
import pennylane as qml
import networkx as nx
from maxcut import *
import optax
from RandomGraphGeneration import RandomGraph
import time
from qaoa_circuit_utils import GammaCircuit, BetaCircuit
import numpy as np
import pandas as pd
import sys
import warnings
from catalyst import qjit, for_loop
import catalyst

warnings.filterwarnings("ignore")

jax.config.update("jax_enable_x64", True)

save_path = "/home/fv/storage1/qml/QAOA_transferability/res_selfopt"
shots = 1024
seed = 40
#qubits = int(sys.argv[1])
qubits = 6
dev = qml.device("lightning.qubit", wires=qubits, shots=shots)
#dev = qml.device("default.qubit.jax", wires=qubits, shots=shots)

layers = 4


@for_loop(0, qubits, 1)
def hadamard_gate(i):
    qml.Hadamard(wires=i)


@for_loop(0, layers, 1)
def gamma_beta_circuit(i, weights, graph):
    GammaCircuit(weights[i, 0], graph)
    BetaCircuit(weights[i, 1], qubits)


@qjit
@qml.qnode(dev)
def qnode(weights, graph, edge):
    [qml.Hadamard(wires=i) for i in range(qubits)]
    gamma_beta_circuit(weights, graph)
    '''for j in range(layers):
        GammaCircuit(weights[j, 0], graph)
        BetaCircuit(weights[j, 1], qubits)'''
    H = qml.PauliZ(edge[0]) @ qml.PauliZ(edge[1])
    return qml.expval(H)


@qjit
@qml.qnode(dev)
def qnode_count(weights, graph, edge):
    [qml.Hadamard(wires=i) for i in range(qubits)]
    gamma_beta_circuit(weights, graph)
    return qml.counts()


@qjit
def obj_function(weights, graph):
    cost = 0
    for edge in graph:
        cost -= 0.5 * (1 - qnode(weights, graph, edge))
    return cost


@qjit
def updating(i, args):
    params, opt_state, optax_optimizer = args
    grads = catalyst.grad(obj_function, method="fd")(params)
    updates, opt_state = optax_optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return (params, opt_state)


def qaoa_execution(seed: int, graph: list, graph_sorgent: nx.Graph) -> tuple:
    optax_optmizer = optax.adagrad(learning_rate=0.1)  ### Adagrad
    key = jax.random.PRNGKey(seed)
    weights = jnp.asarray(jax.random.normal(key, shape=(layers, 2)))
    params = {"weights": weights}
    opt_state = optax_optmizer.init(params)
    steps = 40

    for i in range(steps):
        params, opt_state = updating(i, (params, opt_state, optax_optmizer))
        print(f"Iteration {i}:", obj_function(params, graph))

    print("Last parameters updated:\n", params)

    counts = qnode_count(params, graph, edge=None)

    min_key, min_energy = maximum_cut(counts, graph_sorgent)
    print("The ground states are: ", min_key, "with energy: ", min_energy)

    most_freq_bit_string = max(counts, key=counts.get)
    res = [int(x) for x in str(most_freq_bit_string)]
    maxcut_val = maxcut_obj(res, graph_sorgent)
    print("Most frequent bit-string is: ", most_freq_bit_string)
    print("The cut value of most frequent bit-string is: ", maxcut_val)

    approximation_ratio = jnp.divide(obj_function(params), min_energy)
    print(approximation_ratio)
    return -obj_function(params), counts, params, approximation_ratio


def experiment():
    time_list, opt_beta_gamma_res, energy_res, ar_res, counts_res = [], [], [], [], []
    for s in range(seed):
        print(f"It: {s + 1}")
        graph_generator = RandomGraph(qubits, prob=0.7, seed=s)
        graph = list(graph_generator.edges)
        t0 = time.time()
        energy, counts, opt_beta_gamma, ar = qaoa_execution(s, graph, graph_generator)
        tf = time.time()
        dt = np.subtract(tf, t0)
        time_list.append(np.asarray(dt))
        energy_res.append(np.asarray(energy))
        opt_beta_gamma_res.append(np.asarray(opt_beta_gamma))
        ar_res.append(np.asarray(ar))
        counts_res.append(counts)
    print("Stop.")
    data = [energy_res, opt_beta_gamma_res, counts_res, ar_res, time_list]
    return data


if __name__ == "__main__":
    print("Self optimization")
    data = experiment()
    dataset = pd.DataFrame({'Ground energy': data[0],
                            'Opt_gamma_beta': data[1],
                            'Counts': data[2],
                            'Approx. ratio': data[3],
                            'Elapsed time': data[4]})
    '''data_seed_ = dataset.to_csv(
        save_path + "/data" + str(seed) + "_qubit" + str(qubits) + ".csv")'''

ERROR:

/home/fv/anaconda3/bin/conda run -p /home/fv/anaconda3 --no-capture-output python /mnt/storage1/fv/qml/QAOA_transferability/qjit_qaoa.py 
Self optimization
It: 1
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
Traceback (most recent call last):
  File "/home/fv/anaconda3/lib/python3.11/site-packages/jax/_src/api_util.py", line 584, in shaped_abstractify
    return _shaped_abstractify_handlers[type(x)](x)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^
KeyError: <class 'function'>

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/mnt/storage1/fv/qml/QAOA_transferability/qjit_qaoa.py", line 132, in <module>
    data = experiment()
           ^^^^^^^^^^^^
  File "/mnt/storage1/fv/qml/QAOA_transferability/qjit_qaoa.py", line 117, in experiment
    energy, counts, opt_beta_gamma, ar = qaoa_execution(s, graph, graph_generator)
                                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/storage1/fv/qml/QAOA_transferability/qjit_qaoa.py", line 89, in qaoa_execution
    params, opt_state = updating(i, (params, opt_state, optax_optmizer))
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fv/anaconda3/lib/python3.11/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/fv/anaconda3/lib/python3.11/site-packages/catalyst/jit.py", line 499, in __call__
    requires_promotion = self.jit_compile(args)
                         ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fv/anaconda3/lib/python3.11/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/fv/anaconda3/lib/python3.11/site-packages/catalyst/jit.py", line 570, in jit_compile
    self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture(args)
                                                              ^^^^^^^^^^^^^^^^^^
  File "/home/fv/anaconda3/lib/python3.11/site-packages/catalyst/debug/instruments.py", line 143, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/fv/anaconda3/lib/python3.11/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/fv/anaconda3/lib/python3.11/site-packages/catalyst/jit.py", line 621, in capture
    dynamic_sig = get_abstract_signature(dynamic_args)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fv/anaconda3/lib/python3.11/site-packages/catalyst/tracing/type_signatures.py", line 70, in get_abstract_signature
    abstract_args = [shaped_abstractify(arg) for arg in flat_args]
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fv/anaconda3/lib/python3.11/site-packages/catalyst/tracing/type_signatures.py", line 70, in <listcomp>
    abstract_args = [shaped_abstractify(arg) for arg in flat_args]
                     ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fv/anaconda3/lib/python3.11/site-packages/jax/_src/api_util.py", line 586, in shaped_abstractify
    return _shaped_abstractify_slow(x)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fv/anaconda3/lib/python3.11/site-packages/jax/_src/api_util.py", line 575, in _shaped_abstractify_slow
    raise TypeError(
TypeError: Cannot interpret value of type <class 'function'> as an abstract array; it does not have a dtype attribute
ERROR conda.cli.main_run:execute(125): `conda run python /mnt/storage1/fv/qml/QAOA_transferability/qjit_qaoa.py` failed. (See above for error)

Process finished with exit code 1

Gamma and beta circuits are the same you find in the tutorial.
Thanks in advance.

Hi @checcopo ,

Unfortunately qaoa_circuit_utils is not something that exists in PennyLane and I don’t know what GammaCircuit and BetaCircuit are since they’re not in our demo on MaxCut: QAOA for MaxCut | PennyLane Demos

Can you please copy-paste the code for GammaCircuit and BetaCircuit here? Thanks!

Here the qaoa_circuit_utils.py

import pennylane as qml
from jax import numpy as jnp
import networkx as nx


def BetaCircuit(beta: jnp.array, qubits: int):
    for wire in range(qubits):
        qml.RX(phi=2*beta, wires=wire)


def GammaCircuit(gamma: jnp.array, graph: nx.Graph):
    for edge in graph:
        wire1 = edge[0]
        wire2 = edge[1]
        qml.CNOT(wires=[wire1, wire2])
        qml.RZ(phi=2*gamma, wires=wire2)
        qml.CNOT(wires=[wire1, wire2])

I have also to re-write the code presented here qml.qaoa — PennyLane 0.37.0 documentation with the introduction of jax and catalyst, but when I call the @jax.jit decorator or @qjit, or when I apply the catalyst.grad() method to the function which is going to be differentiated, it seems it has not a valid dtype for this purpose, but honestly, I am stuck on that.
This is the additional code:

import pennylane as qml
from pennylane import qaoa
import catalyst
#from catalyst import qjit
import optax
import jax
from jax import numpy as jnp
from RandomGraphGeneration import RandomGraph


layers = 4
qubits = 6
prob = 0.7
seed = 40
nshots = 1024


dev = qml.device('lightning.qubit', wires=qubits, shots=nshots)

graph = RandomGraph(node=qubits, prob=prob, seed=seed)

#print(list(graph.edges))

cost_h, mixer_h = qaoa.maxcut(graph)


def qaoa_layer(gamma, beta):
    qaoa.cost_layer(gamma, cost_h)
    qaoa.mixer_layer(beta, mixer_h)


#@qjit
#@jax.jit
def circuit(params):
    for w in range(qubits):
        qml.Hadamard(wires=w)
    qml.layer(qaoa_layer, layers, params["gamma_weights"], params["beta_weights"])


#@qjit
#@jax.jit
@qml.qnode(dev)
def cost_function(params):
    circuit(params)
    return qml.expval(cost_h)


#@qjit
#### NOT USED !!!! 
def updating(i, args):
    params, opt_state, optax_optimizer = args
    grads = catalyst.grad(cost_function, method="fd")(params)
    updates, opt_state = optax_optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return (params, opt_state)


def jax_updating(params, opt_state, optax_optimizer):
    grads = jax.grad(cost_function)(params)
    updates, opt_state = optax_optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return (params, opt_state)


steps = 20
optax_optmizer = optax.adagrad(learning_rate=0.1)  ### Adagrad
key = jax.random.PRNGKey(seed)
gamma_w = jnp.asarray(jax.random.normal(key, shape=(layers,)))
beta_w = jnp.asarray(jax.random.normal(key, shape=(layers,)))
params = {"gamma_weights": gamma_w, "beta_weights": beta_w}
opt_state = optax_optmizer.init(params)


for i in range(steps):
    params, opt_state = jax_updating(params, opt_state, optax_optmizer)
    print(f"Iteration {i}:", cost_function(params))

print("Last parameters updated:\n", params)

Hi @checcopo ,

I’ve taken the example from the qml.qaoa() docs and updated it to work with Catalyst and Jax.

Your code example had changed a lot of things. I’d suggest starting with the code below which is simpler and make small changes if you need them. This way if you get stuck further ahead it will be easier to see what you change and what the problem might be.

# Docs with Jax and qjit v2
import pennylane as qml
from pennylane import qaoa
from networkx import Graph
import jax
import jaxopt
from jax import numpy as jnp

# Defines the wires and the graph on which MaxCut is being performed
wires = range(3)
graph = Graph([(0, 1), (1, 2), (2, 0)])

# Defines the QAOA cost and mixer Hamiltonians
cost_h, mixer_h = qaoa.maxcut(graph)

# Defines a layer of the QAOA ansatz from the cost and mixer Hamiltonians
def qaoa_layer(gamma, alpha):
    qaoa.cost_layer(gamma, cost_h)
    qaoa.mixer_layer(alpha, mixer_h)

# Defines the device and the QAOA cost function
dev = qml.device('lightning.qubit', wires=len(wires), shots = 1000)

@qml.qjit
@qml.qnode(dev)
def cost_function(params):
    # Repeatedly applies layers of the QAOA ansatz
    @qml.for_loop(0, wires, 1)
    def loop_H(i):
      qml.Hadamard(wires=i)

    qml.layer(qaoa_layer, 2, params[0], params[1])
    return qml.expval(cost_h)

print(cost_function([[1, 1], [1, 1]]))

# Optimization
@jax.jit
def optimization():
    # initial parameter
    params = jnp.array([[0.54, 0.3154], [0.54, 0.3154]])

    # define the optimizer using a qjit-decorated function
    opt = jaxopt.GradientDescent(cost_function, stepsize=0.4)
    update = lambda i, args: tuple(opt.update(*args))

    # perform optimization loop
    state = opt.init_state(params)
    (params, _) = jax.lax.fori_loop(0, 100, update, (params, state))

    return params

optimization()

I hope this helps!

For anyone else reading this thread, note that you will need to pip install pennylane pennylane-catalyst jaxopt .