QAOA implementation with JAX for MAX CUT Problem

Hi, I try to use QAOA for max cut problem for 18 nodes graph. I try to follow the similar startegey of QAOA mentioned in codebook but it takes too long to reponse. So, I try to implement it using JAX for 4nodes as similar to QAOA codebbok code but using JAX, but find some issue related to arraybox element and JAX type mismatch. Can you please guide me to resolve it?

Hi @roysuman088, welcome to the Forum!
Can you please post a minimal version of your code that I can run to see if I can replicate your error?
Also please post the output of qml.about() and the full error traceback.

Arraybox issues sometimes happen when you’re trying to update global variables or when you’re using ._value

However I would need to see your code in order to have a better idea of what can be wrong.

#import libraries
import pennylane as qml
from pennylane import numpy as np
from jax.config import config
config.update(“jax_enable_x64”, True)

import jax
import jax.numpy as jnp
import optax
from functools import partial

def main(graph, n_wires, offset):

# unitary operator U_B with parameter beta
def U_B(beta):
    for wire in range(n_wires):
        qml.RX(2 * beta, wires=wire)

# unitary operator U_C with parameter gamma
def U_C(gamma):
    for edge in graph:
        wire1 = edge[0]
        wire2 = edge[1]
        qml.CNOT(wires=[wire1, wire2])
        qml.RZ(gamma, wires=wire2)
        qml.CNOT(wires=[wire1, wire2])

def bitstring_to_int(bit_string_sample):
    bit_string = "".join(str(bs) for bs in bit_string_sample)
    return int(bit_string, base=2)

dev = qml.device("lightning.qubit", wires=n_wires, shots=1)

@partial(jax.jit, static_argnums=(2,3))
@qml.qnode(dev, interface='jax')
def circuit(gammas, betas, edge, n_layers):
    # apply Hadamards to get the n qubit |+> state
    for wire in range(n_wires):
    # p instances of unitary operators
    for i in range(n_layers):
    if edge is None:
        # measurement phase
        return qml.sample()
    # during the optimization phase we are evaluating a term
    # in the objective using expval
    H = qml.PauliZ(wires=edge[0]) @ qml.PauliZ(wires=edge[1])
    return qml.expval(H)

def cost(gammas, betas, edge, n_layers):
    return circuit(gammas, betas, edge, n_layers)

def qaoa_maxcut(n_layers=1):

    # initialize the parameters near zero
    init_params = jnp.array(0.01 * np.random.rand(2, n_layers))

    # minimize the negative of the objective function
    def objective(params):
        gammas = params[0]
        betas = params[1]
        neg_obj = 0
        for edge in graph:
            # objective for the MaxCut problem
            neg_obj -= 0.5 * edge[2] * (1 - cost(gammas, betas, edge=edge, n_layers=n_layers))
        return neg_obj

    # initialize optimizer: Adagrad works well empirically
    opt = optax.adam(0.2)

    # optimize parameters in objective
    params = init_params
    opt_state = opt.init(init_params)
    steps = 1500
    for i in range(steps):
        grads = jax.grad(objective)(params)
        updates, opt_state = opt.update(grads, opt_state)
        params = optax.apply_updates(params, updates)
        if (i + 1) % 5 == 0:
            print("Objective after step {:5d}: {: .7f}".format(i + 1, -objective(params)))
    # sample measured bitstrings 100 times
    bit_strings = []
    n_samples = 100
    for i in range(0, n_samples):
        bit_strings.append(bitstring_to_int(circuit(params[0], params[1], edge=None, n_layers=n_layers)))

    # print optimal parameters and most frequently sampled bitstring
    counts = np.bincount(np.array(bit_strings))
    most_freq_bit_string = np.argmax(counts)
    print("Optimized (gamma, beta) vectors:\n{}".format(params[:, :n_layers]))
    print("Most frequently sampled bit string is: {:04b}".format(most_freq_bit_string))

    return -objective(params), bit_strings, most_freq_bit_string

bitstrings = qaoa_maxcut(n_layers=1)
bitstrings1, most_freq_bit_string_1 = bitstrings[1], bitstrings[2]

bitstrings2 = qaoa_maxcut(n_layers=2)
bitstrings2, most_freq_bit_string_2 = bitstrings2[1], bitstrings2[2]
return most_freq_bit_string_1,most_freq_bit_string_2

graph=[(0, 1, 1), (0, 3, 0), (0, 9, 10), (0, 10, 0), (1, 2, 1), (1, 4, 0), (1, 9, 10), (1, 10, 0), (2, 5, 0), (2, 9, 10), (2, 10, 0), (3, 4, 1), (3, 6, 1), (3, 9, 0), (3, 10, -10), (4, 5, 1), (4, 7, 1), (4, 9, 0), (4, 10, -10), (5, 8, 1), (5, 9, 0), (5, 10, -10), (6, 7, 1), (6, 9, 0), (6, 10, -10), (7, 8, 1), (7, 9, 0), (7, 10, -10), (8, 9, 0), (8, 10, -10)]

Hi @CatalinaAlbornoz, I actually working on the Max cut problem using JAX for higher number of nodes. For the time being, the above code is working fine, just the graph itself is not taking each edge as a list but it allows it as a tuple. Can you check on that part, also if I want to execute more than 20 nodes graph, is it possible to do? Can you highlight any suggestions related to it? Even Im using JAX, it still takes quite a long time along with lightning.qubit, can you suggest something on it.

Thanks for your response.

Hi @roysuman088,

This system is very big so it will inevitably take a long time. Using JAX and lightning is the right approach though.

One additional thing that might help is setting the differentiation method to ‘adjoint’. You can specify it when you define your qnode.

@qml.qnode(dev, diff_method='adjoint', interface='jax')

Please let me know if this helps!

Thanks @CatalinaAlbornoz …will look into it