Problems of jitting the optimization process when simulating a circuit consist of sub-circuits with wires as arguments

Hello! I want to simulate a circuit which is constructed by some sub-circuits that have the same structure but operating on different wires. I wrote the code of the sub-circuit as a function, and let the wires be an argument of it. However, I encountered several errors when trying to use the jax.jit to accelerate the optimization process. To illustrate the problem, let’s use the ansätze of Picture 2 in this article as a simple example. Suppose we regard the 2 adjacent column of the SU(4) gates in the picture as a layer of the circuit, and the operating wires of the left column as “wires pattern A” while the right as “wires pattern B”. The code below is what I’m trying to run. The two_qubit_decomp() function is from the article, and I use the fidelity between the target states and the output states as the cost function. The update_params_jax() is to update the parameters in a single epoch of the training, and the training_process_jax() is the loop of the whole training. The optimization_jax() prints the initial cost and start the training loop.

import jax
jax.config.update("jax_enable_x64", True)
import matplotlib
matplotlib.use('Agg')
import optax
import pennylane as qml
import time
from functools import partial
from jax import numpy as jnp
from matplotlib import pyplot as plt

N=5 # Number of wires
L=5 # Number of layers
dev=qml.device('lightning.qubit', wires=N)
opt=optax.adam(-0.01, b2=0.99)
initial_states_data=jnp.load(f'ini_states_SU_{N}bit_1600round.npy')
target_states_data=jnp.load(f'tar_states_SU_{N}bit_1600round.npy')

def two_qubit_decomp(params, wires):
    """Implement an arbitrary SU(4) gate on two qubits
    using the decomposition from Theorem 5 in
    https://arxiv.org/pdf/quant-ph/0308006.pdf"""
    i, j = wires
    # Single U(2) parameterization on both qubits separately
    qml.Rot(*params[:3], wires=i)
    qml.Rot(*params[3:6], wires=j)
    qml.CNOT(wires=[j, i])  # First CNOT
    qml.RZ(params[6], wires=i)
    qml.RY(params[7], wires=j)
    qml.CNOT(wires=[i, j])  # Second CNOT
    qml.RY(params[8], wires=j)
    qml.CNOT(wires=[j, i])  # Third CNOT
    # Single U(2) parameterization on both qubits separately
    qml.Rot(*params[9:12], wires=i)
    qml.Rot(*params[12:15], wires=j)

@qml.qnode(dev, interface='jax')
def circuit_jax(params, wires, layer_num, initial_state, target_state):
    wires_pattern_A = [[wires[2 * i], wires[2 * i + 1]] for i in range(len(wires) // 2)]
    wires_pattern_B = [[wires[2 * i + 1], wires[(2 * i + 2) % len(wires)]] for i in range(len(wires) // 2)]

    def body_fun_A(layer, loop_args): # Outer loop in the circuit
        params, wires, wires_pattern_A, wires_pattern_B = loop_args

        def body_fun_B1(w, loop_args): # First inner loop in the circuit
            params, wires, layer = loop_args
            two_qubit_decomp(params[2 * layer * len(wires) // 2 * 15:(2 * layer * len(wires) // 2 + 1) * 15], wires_pattern_A[w])
            return (params, wires, layer)

        def body_fun_B2(w, loop_args): # Second inner loop in the circuit
            params, wires, layer = loop_args
            two_qubit_decomp(params[(2 * layer + 1) * len(wires) // 2 * 15:((2 * layer + 1) * len(wires) // 2 + 1) * 15], wires_pattern_B[w])
            return (params, wires, layer)
        
        jax.lax.fori_loop(0, len(wires_pattern_A), body_fun_B1, (params, wires, layer))
        jax.lax.fori_loop(0, len(wires_pattern_B), body_fun_B2, (params, wires, layer))

        return (params, wires, wires_pattern_A, wires_pattern_B)
    
    qml.StatePrep(initial_state, wires=wires)
    jax.lax.fori_loop(0, layer_num, body_fun_A, (params, wires, wires_pattern_A, wires_pattern_B))
    return qml.expval(qml.Hermitian(qml.math.dm_from_state_vector(target_state), wires))

@partial(jax.jit, static_argnames=['wires', 'layer_num'])
def cost_jax(params, wires, layer_num, initial_state_batch, target_state_batch):

    def body_fun(i, c):
        ini_state = initial_state_batch[i]
        tar_state = target_state_batch[i]
        return c + circuit_jax(params, wires, layer_num, ini_state, tar_state)

    c = jax.lax.fori_loop(0, len(initial_state_batch), body_fun, 0)
    return c / len(initial_state_batch)

@partial(jax.jit, static_argnames=['wires', 'layer_num'])
def update_params_jax(i, args, wires, layer_num):
    params, opt_state, initial_state_batch, target_state_batch= args
    cur_cost, grads = jax.value_and_grad(lambda p: cost_jax(p, wires, layer_num, initial_state_batch, target_state_batch))(params)
    updates, opt_state=opt.update(grads, opt_state)
    params=optax.apply_updates(params, updates)
    jax.lax.cond(jnp.logical_and(i % 10 == 0, i), lambda _: jax.debug.print('Epoch {i}, current fidelity is {j}', i=i, j=cur_cost), lambda _: None, operand=None)
    return params, opt_state, cur_cost

@partial(jax.jit, static_argnames=['wires', 'layer_num', 'round_limit'])
def training_process_jax(params, wires, layer_num, opt_state, cur_cost, round_limit, costs):
    converge_count = 0
    prev_cost = 1

    def cond_fun(loop_state):
        params, opt_state, prev_cost, cur_cost, converge_count, i, costs = loop_state
        return jnp.logical_and(converge_count < 5, i <= round_limit)

    def body_fun(loop_state):
        params, opt_state, prev_cost, cur_cost, converge_count, i, costs = loop_state
        converge_count = jax.lax.cond(jnp.logical_and(abs(cur_cost - prev_cost) < 0.003, cur_cost < 0.1), lambda cc: cc + 1, lambda cc: 0, converge_count)
        prev_cost = cur_cost
        initial_state_batch = initial_states_data[i]
        target_state_batch = target_states_data[i]
        params, opt_state, cur_cost = update_params_jax(i, (params, opt_state, initial_state_batch, target_state_batch), wires, layer_num)
        costs = costs.at[i].set(1 - cur_cost)
        i += 1
        return params, opt_state, prev_cost, cur_cost, converge_count, i, costs
    params, opt_state, prev_cost, cur_cost, converge_count, i, costs = jax.lax.while_loop(cond_fun, body_fun, (params, opt_state, prev_cost, cur_cost, converge_count, 1, costs))

    return params, cur_cost, i, costs

@partial(jax.jit, static_argnames=['wires', 'layer_num'])
def optimization_jax(params, wires, layer_num):
    initial_state_batch = initial_states_data[0]
    target_state_batch = target_states_data[0]
    opt_state = opt.init(params)
    params, opt_state, cur_cost = update_params_jax(0, (params, opt_state, initial_state_batch, target_state_batch), wires, layer_num)
    print(f'Initial fidelity is {cur_cost}, time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}')
    costs = jnp.concatenate((jnp.array([cur_cost]), jnp.zeros(len(initial_states_data))))
    params, cur_cost, round_number,  costs = training_process_jax(params, wires, layer_num, opt_state, cur_cost, len(initial_states_data), costs)
    return params, round_number, costs

if __name__ == '__main__':
    t=time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
    print(f'start time: {t}')
    param_num=L*N//2*30
    print(f'The number of parameters is {param_num}')
    params=jnp.zeros(param_num)
    params, costs=optimization_jax(params, range(N), L)
    fig, ax=plt.subplots()
    ax.plot(range(len(costs)+1), costs)
    ax.set_xlabel('Training epoch')
    ax.set_ylabel('1-Fidelity')
    ax.set_title(f'Training of a {N} qubit circuit')
    plt.savefig(f'qml_SU4_ansatz_{N}bit_{L}layer.png')
    plt.clf()
    t=time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
    print(f'end time: {t}')

To run this code, you will need the 2 files below (please delete the extention txt, it is added due to the restriction of the Pennylane Forum). The initial states are some 32-dimentional randomly-generated jnp vectors, and the target states are generated by simply multiplying a Special Unitary matrix with the initial states.

ini_states_SU_5bit_1600round.npy.txt (6.3 MB)

tar_states_SU_5bit_1600round.npy.txt (6.3 MB)

The error occurs at the line jax.lax.fori_loop(0, layer_num, body_fun_A, (params, wires, wires_pattern_A, wires_pattern_B)) in the circuit_jax(), saying that TypeError: Argument ‘range(0, 5)’ of type ‘<class ‘range’>’ is not a valid JAX type. If I change the line params, costs=optimization_jax(params, range(N), L) to params, costs=optimization_jax(params, list(range(N)), L), then the error occurs at this line itself, saying ValueError: Non-hashable static arguments are not supported. An error occurred while trying to hash an object of type <class ‘list’>, [0, 1, 2, 3, 4]. The error was:
TypeError: unhashable type: ‘list’. I also read the post 2256 and changed the line to params, costs=optimization_jax(params, (0,1,2,3,4), L), and now the error occurs at the line two_qubit_decomp(params[2 * layer * len(wires) // 2 * 15:(2 * layer * len(wires) // 2 + 1) * 15], wires_pattern_A[w]) in body_fun_B1() of circuit_jax(), saying IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(Traced<~int64>with, Traced<~int64>with, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).. So how to use the jax.jit when there is sub-circuit functions in my circuit, especially when I need to distribute the parameters to the sub-circuits based on the length of the wires?

I also noticed that in simple cases without sub-circuits (such as just training the two_qubit_decomp() on 2 qubits), if I pass the wires argument with a tuple, and decorate the two_qubit_decomp() with @partial(jax.jit, static_argnames=['wires']), an error will be thrown at the second line of two_qubit_decomp() saying that abstract wires is not supported in quantum gates. Do I need to jit the circuit function and the sub-circuit functions (if exist)? If I do not jit them, will the compilation time be affected?

Hi @Tz_19 ,

I’m trying to understand your issue but your code is rather complex so it’s hard to know what’s actually going on. Can you please write a minimal reproducible example? This would be a small example that I can copy-paste and run, without all of the added complexity of your full code.

Hello @CatalinaAlbornoz, sorry to take you much time to understand my issue. Is the circuit structure of my example or the training and optimization functions in the code makes you confused? Since the grammar of jax.lax.cond(), jax.lax.fori_loop() and jax.lax.while_loop() makes the code awkward, I rewrite this code only using Pennylane to illustrate what I’m trying to do. I’m able to use the normal if, for and while in the code, so I hope that this code is much more clear to you.

import matplotlib
matplotlib.use('Agg')
import pennylane as qml
import time
from matplotlib import pyplot as plt
from pennylane import numpy as np

N=5
L=5
dev=qml.device('lightning.qubit', wires=N)
initial_states_data=np.load(f'ini_states_SU_{N}bit_1600round_qml.npy')
target_states_data=np.load(f'tar_states_SU_{N}bit_1600round_qml.npy')

def two_qubit_decomp(params, wires):
    """Implement an arbitrary SU(4) gate on two qubits
    using the decomposition from Theorem 5 in
    https://arxiv.org/pdf/quant-ph/0308006.pdf"""
    i, j = wires
    # Single U(2) parameterization on both qubits separately
    qml.Rot(*params[:3], wires=i)
    qml.Rot(*params[3:6], wires=j)
    qml.CNOT(wires=[j, i])  # First CNOT
    qml.RZ(params[6], wires=i)
    qml.RY(params[7], wires=j)
    qml.CNOT(wires=[i, j])  # Second CNOT
    qml.RY(params[8], wires=j)
    qml.CNOT(wires=[j, i])  # Third CNOT
    # Single U(2) parameterization on both qubits separately
    qml.Rot(*params[9:12], wires=i)
    qml.Rot(*params[12:15], wires=j)

@qml.qnode(dev)
def circuit(params, wires, layer_num, initial_state, target_state):
    qml.StatePrep(initial_state, wires=wires)
    wires_pattern_A=[[wires[2*i], wires[2*i+1]] for i in range(len(wires)//2)]
    wires_pattern_B=[[wires[2*i+1], wires[(2*i+2)%len(wires)]] for i in range(len(wires)//2)]
    for layer in range(layer_num):
        for w in wires_pattern_A:
            two_qubit_decomp(params[2*layer*len(wires)//2*15:(2*layer*len(wires)//2+1)*15], w)
        for w in wires_pattern_B:
            two_qubit_decomp(params[(2*layer+1)*len(wires)//2*15:((2*layer+1)*len(wires)//2+1)*15], w)
    return qml.expval(qml.Hermitian(qml.math.dm_from_state_vector(target_state), wires))

def cost(params, wires, layer_num, initial_state_batch, target_state_batch):
    c=0
    for ini_state, tar_state in zip(initial_state_batch, target_state_batch):
        c+=circuit(params, wires, layer_num, ini_state, tar_state)
    return c/len(initial_state_batch)

def training_process_qml(params, wires, layer_num, opt):
    initial_fidelity=cost(params, wires, layer_num, initial_states_data[0], target_states_data[0])
    print(f'Initial fidelity is {initial_fidelity}')
    prev_cost=0
    cur_cost=initial_fidelity
    costs=[initial_fidelity]
    converge_count=0
    i=0
    while converge_count<5 and i<len(initial_states_data):
        if abs(cur_cost-prev_cost)<0.003 and cur_cost<0.1:
            converge_count+=1
        else:
            converge_count=0
        prev_cost=cur_cost
        params, cur_cost=opt.step_and_cost(lambda p: cost(p, range(N), L, initial_states_data[i], target_states_data[i]), params)
        costs.append(cur_cost)
        i += 1
        if not (i % 10):
            print(f'Epoch {i}, current fidelity is {cur_cost}, time: {time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())}')
    return params, costs

def optimization_qml(params, wires, layer_num):
    opt=qml.AdamOptimizer(-0.01)
    params, costs=training_process_qml(params, wires, layer_num, opt)
    return params, costs

if __name__ == '__main__':
    t=time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
    print(f'start time: {t}')
    param_num=L*N//2*30
    print(f'The number of parameters is {param_num}')
    params=np.zeros(param_num)
    params, costs=optimization_qml(params, range(N), L)
    fig, ax=plt.subplots()
    ax.plot(range(len(costs)+1), costs)
    ax.set_xlabel('Training epoch')
    ax.set_ylabel('1-Fidelity')
    ax.set_title(f'Training of a {N} qubit circuit')
    plt.savefig(f'qml_SU4_ansatz_{N}bit_{L}layer.png')
    plt.clf()
    t=time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
    print(f'end time: {t}')

The training_process_qml() corresponds to the update_params_jax() and training_process_jax(). The optimization_qml() corresponds to the optimization_jax(). This code can run without any problem, but the code that use JAX to optimize the circuit will cause errors described in the last post, although the logic of the 2 versions of the code is generally the same. If you want to run this code, you will need the 2 files below:

ini_states_SU_5bit_1600round_qml.npy.txt (6.3 MB)

tar_states_SU_5bit_1600round_qml.npy.txt (6.3 MB)

Hi @Tz_19 ,

Thank you for adding this new version of the code, I appreciate it.

Unfortunately the issue is that neither of these examples are actually minimal. Both the previous code and the new one include a lot of complexity that distracts from the error. Part of the hard work in software is finding these minimal examples so that you can actually find the cause of your errors. It’s hard but it’s how we can solve our own issues!

After spending a while with both versions of your code I think the issue is in how you’re defining wires_pattern_A and wires_pattern_B you have a for loop with a range there. I cannot guarantee that this is the cause for the issue (or the only issue) but it’s an avenue worth exploring!

The best step forward would be to simplify the code to remove complexity until the errors go away, and then slowly add the complexity one step at a time.

I hope this helps!

Hi @CatalinaAlbornoz , thank you very much for spending time helping me with my problems! I managed to simplify the code to make the problems more clear. There are mainly 2 problems I’m facing. Let’s use the following code to illustrate the first problem:

import jax
jax.config.update("jax_enable_x64", True)
import pennylane as qml
import time
from functools import partial
from jax import numpy as jnp

N=5 # Number of wires
dev=qml.device('lightning.qubit', wires=N)

@qml.qnode(dev, interface='jax')
def circuit_jax(params, wires, initial_state, target_state):

    def body_fun(i, state):
        params, wires=state
        qml.Rot(params[3*i], params[3*i+1], params[3*i+2], wires=wires[i])
        return state
    
    qml.StatePrep(initial_state, wires)
    jax.lax.fori_loop(0, len(wires), body_fun, (params, wires))
    return qml.Hermitian(qml.math.dm_from_state_vector(target_state), wires=wires)

@partial(jax.jit, static_argnames=['wires'])
def cost_jax(params, wires, initial_state, target_state):
    return 1-circuit_jax(params, wires, initial_state, target_state)

if __name__ == '__main__':
    t=time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
    print(f'start time: {t}')
    params=jnp.zeros(3*N)
    initial_state=jax.random.normal(jax.random.PRNGKey(0), (2**N,))+1j*jax.random.normal(jax.random.PRNGKey(1), (2**N,))
    initial_state=initial_state/jnp.linalg.norm(initial_state)
    target_state=jax.random.normal(jax.random.PRNGKey(2), (2**N,))+1j*jax.random.normal(jax.random.PRNGKey(3), (2**N,))
    target_state=target_state/jnp.linalg.norm(target_state)
    print(cost_jax(params, range(N), initial_state, target_state))
    t=time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
    print(f'end time: {t}')

Running this code will result into the following error:

  File "/home/nodegpu/ztz/Quantum_LM/test.py", line 30, in circuit_jax
    jax.lax.fori_loop(0, len(wires), body_fun, (params, wires))
    ~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Argument 'range(0, 5)' of type '<class 'range'>' is not a valid JAX type

Replacing the range(N) with tuple(range(N)) in the line print(cost_jax(params, range(N), initial_state, target_state)) in if __name__ == '__main__': will result into another error:

  File "/home/nodegpu/ztz/Quantum_LM/test.py", line 26, in body_fun
    qml.Rot(params[3*i], params[3*i+1], params[3*i+2], wires=wires[i])
                                                             ~~~~~^^^
jax.errors.TracerIntegerConversionError: The __index__() method was called on traced array with shape int64[]
The error occurred while tracing the function body_fun at /home/nodegpu/ztz/Quantum_LM/test.py:24 for scan. This concrete value was not available in Python because it depends on the value of the argument loop_carry[0].

So how to pass the wires argument to the circuit function if the circuit has some operations like applying the same kind of quantum gates to all the wires when using jax.jit in the code?

To illustrate the second problem, let’s take the following code as an example. The code shows the condition that I want to apply a circuit (Rotations() in the code) repeatedly on different wires each time. In circuit_jax(), the variable wires_for_subcircuit contains the wires that each Rotations() should apply to. All the parameters are passed to circuit_jax(), and they are split for each Rotations() when calling Rotations() in body_fun() in circuit_jax().

import jax
jax.config.update("jax_enable_x64", True)
import pennylane as qml
import time
from functools import partial
from jax import numpy as jnp

N=5 # Number of wires
L=2 # Number of layers
dev=qml.device('lightning.qubit', wires=N)

def Rotations(params, wires):
    
    def body_fun(i, state):
        params, wires=state
        qml.Rot(params[3*i], params[3*i+1], params[3*i+2], wires=wires[i])
        return state
    
    jax.lax.fori_loop(0, len(wires), body_fun, (params, wires))

@qml.qnode(dev, interface='jax')
def circuit_jax(params, wires, initial_state, target_state):
    wires_for_subcircuit=[wires]*L

    def body_fun(i, params):
        Rotations(params[i*len(wires):(i+1)*len(wires)], wires_for_subcircuit[i])
        return params

    qml.StatePrep(initial_state, wires)
    jax.lax.fori_loop(0, L, body_fun, params)
    return qml.Hermitian(qml.math.dm_from_state_vector(target_state), wires=wires)

@partial(jax.jit, static_argnames=['wires'])
def cost_jax(params, wires, initial_state, target_state):
    return 1-circuit_jax(params, wires, initial_state, target_state)

if __name__ == '__main__':
    t=time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
    print(f'start time: {t}')
    params=jnp.zeros(3*N*L)
    initial_state=jax.random.normal(jax.random.PRNGKey(0), (2**N,))+1j*jax.random.normal(jax.random.PRNGKey(1), (2**N,))
    initial_state=initial_state/jnp.linalg.norm(initial_state)
    target_state=jax.random.normal(jax.random.PRNGKey(2), (2**N,))+1j*jax.random.normal(jax.random.PRNGKey(3), (2**N,))
    target_state=target_state/jnp.linalg.norm(target_state)
    print(cost_jax(params, tuple(range(N)), initial_state, target_state))
    t=time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
    print(f'end time: {t}')

The code will result into the following error:

  File "/home/nodegpu/ztz/Quantum_LM/test.py", line 38, in body_fun
    Rotations(params[i*len(wires):(i+1)*len(wires)], wires_for_subcircuit[i])
              ~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(Traced<~int64[]>with<DynamicJaxprTrace>, Traced<~int64[]>with<DynamicJaxprTrace>, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).

So how to split the parameters for the sub-circuits correctly when using jax.jit?

Hello @CatalinaAlbornoz . In the second code of the last post, I also tried to replace the tuple(range(N)) with jnp.array(range(N)) in the line print(cost_jax(params, tuple(range(N)), initial_state, target_state)) and replace the @partial(jax.jit, static_argnames=['wires']) with @jax.jit since the jnp array is not hashable. However, it will result into the same error at the same line. So, is it possible to call the Rotations() multiple times and split the parameters for each calling inside the circuit_jax() in this example?

Hi @Tz_19 ,

Thanks for following up and explaining the problem. Note that this still isn’t a minimal example.

I have made an example based on our demo How to optimize a QML model using JAX and Optax. From what I could tell from your code you were trying to compare two quantum states. In this case it’s best to use fidelity or a different distance measure. Please check out the Codebook Module on Distance Measures if you have questions about this.

In the example below you will see that I have commented out a few lines from the demo to get closer to the problem you’re trying to solve. In the future it’s best if you start small, or start from an example in the demos, and do small iterations instead of starting with complex piece of code which will be harder to debug.

import pennylane as qml
import jax
from jax import numpy as jnp
import optax

n_wires = 5
data = jnp.sin(jnp.mgrid[-2:2:0.2].reshape(n_wires, -1)) ** 3
#targets = jnp.array([-0.2, 0.4, 0.35, 0.2])
target_state = jnp.array([1/jnp.sqrt(2**n_wires)]*2**n_wires)

dev = qml.device("default.qubit", wires=n_wires)

def subcircuit(weights, my_wires):
    """Quantum circuit ansatz"""
    # trainable ansatz
    for i in my_wires:
        qml.RX(weights[i, 0], wires=i)
        qml.RY(weights[i, 1], wires=i)
        qml.RX(weights[i, 2], wires=i)
        qml.CNOT(wires=[i, (i + 1) % n_wires])

@qml.qnode(dev)
def circuit(data, weights):
    """Quantum circuit"""

    # data embedding
    for i in range(n_wires):
        # data[i] will be of shape (4,); we are
        # taking advantage of operation vectorization here
        qml.RY(data[i], wires=i)

    # trainable ansatz
    # each subcircuit acts on a different set of qubits
    subcircuit(weights, range(0,2))
    subcircuit(weights, range(2, n_wires))

    # use qml.state() instead
    return qml.state()#qml.expval(qml.sum(*[qml.PauliZ(i) for i in range(n_wires)]))

def my_model(data, weights, bias):
    return circuit(data, weights) + bias

@jax.jit
def loss_fn(params, data, target):
    predictions = my_model(data, params["weights"], params["bias"])
    # Compute the fidelities instead of a simple subtraction
    fidelities = qml.math.fidelity_statevector(target, predictions)
    loss = jnp.sum((fidelities) ** 2 / len(data))
    return loss

weights = jnp.ones([n_wires, 3])
bias = jnp.array(0.)
params = {"weights": weights, "bias": bias}

print("loss_fn",loss_fn(params, data, target_state))

print("grad",jax.grad(loss_fn)(params, data, target_state))

opt = optax.adam(learning_rate=0.3)
opt_state = opt.init(params)

def update_step(opt, params, opt_state, data, target):
    loss_val, grads = jax.value_and_grad(loss_fn)(params, data, target)
    updates, opt_state = opt.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss_val

loss_history = []

for i in range(100):
    params, opt_state, loss_val = update_step(opt, params, opt_state, data, target_state)

    if i % 5 == 0:
        print(f"Step: {i} Loss: {loss_val}")

    loss_history.append(loss_val)

I hope this helps.

Hi @CatalinaAlbornoz , thank you for your effort for helping me with my problem! After a period of time of debugging, I find that the problem is about the usage of JAX, not Pennylane. Sorry for occupying you time and thank you so much for your help!

No worries @Tz_19 !

Thanks for asking your questions here, it helps others having similar issues too. I’m glad you were able to troubleshoot your problem! If you have any pointers that can help others having similar issues feel free to share them here.