Hello! I want to simulate a circuit which is constructed by some sub-circuits that have the same structure but operating on different wires. I wrote the code of the sub-circuit as a function, and let the wires be an argument of it. However, I encountered several errors when trying to use the jax.jit to accelerate the optimization process. To illustrate the problem, let’s use the ansätze of Picture 2 in this article as a simple example. Suppose we regard the 2 adjacent column of the SU(4) gates in the picture as a layer of the circuit, and the operating wires of the left column as “wires pattern A” while the right as “wires pattern B”. The code below is what I’m trying to run. The two_qubit_decomp() function is from the article, and I use the fidelity between the target states and the output states as the cost function. The update_params_jax() is to update the parameters in a single epoch of the training, and the training_process_jax() is the loop of the whole training. The optimization_jax() prints the initial cost and start the training loop.
import jax
jax.config.update("jax_enable_x64", True)
import matplotlib
matplotlib.use('Agg')
import optax
import pennylane as qml
import time
from functools import partial
from jax import numpy as jnp
from matplotlib import pyplot as plt
N=5 # Number of wires
L=5 # Number of layers
dev=qml.device('lightning.qubit', wires=N)
opt=optax.adam(-0.01, b2=0.99)
initial_states_data=jnp.load(f'ini_states_SU_{N}bit_1600round.npy')
target_states_data=jnp.load(f'tar_states_SU_{N}bit_1600round.npy')
def two_qubit_decomp(params, wires):
"""Implement an arbitrary SU(4) gate on two qubits
using the decomposition from Theorem 5 in
https://arxiv.org/pdf/quant-ph/0308006.pdf"""
i, j = wires
# Single U(2) parameterization on both qubits separately
qml.Rot(*params[:3], wires=i)
qml.Rot(*params[3:6], wires=j)
qml.CNOT(wires=[j, i]) # First CNOT
qml.RZ(params[6], wires=i)
qml.RY(params[7], wires=j)
qml.CNOT(wires=[i, j]) # Second CNOT
qml.RY(params[8], wires=j)
qml.CNOT(wires=[j, i]) # Third CNOT
# Single U(2) parameterization on both qubits separately
qml.Rot(*params[9:12], wires=i)
qml.Rot(*params[12:15], wires=j)
@qml.qnode(dev, interface='jax')
def circuit_jax(params, wires, layer_num, initial_state, target_state):
wires_pattern_A = [[wires[2 * i], wires[2 * i + 1]] for i in range(len(wires) // 2)]
wires_pattern_B = [[wires[2 * i + 1], wires[(2 * i + 2) % len(wires)]] for i in range(len(wires) // 2)]
def body_fun_A(layer, loop_args): # Outer loop in the circuit
params, wires, wires_pattern_A, wires_pattern_B = loop_args
def body_fun_B1(w, loop_args): # First inner loop in the circuit
params, wires, layer = loop_args
two_qubit_decomp(params[2 * layer * len(wires) // 2 * 15:(2 * layer * len(wires) // 2 + 1) * 15], wires_pattern_A[w])
return (params, wires, layer)
def body_fun_B2(w, loop_args): # Second inner loop in the circuit
params, wires, layer = loop_args
two_qubit_decomp(params[(2 * layer + 1) * len(wires) // 2 * 15:((2 * layer + 1) * len(wires) // 2 + 1) * 15], wires_pattern_B[w])
return (params, wires, layer)
jax.lax.fori_loop(0, len(wires_pattern_A), body_fun_B1, (params, wires, layer))
jax.lax.fori_loop(0, len(wires_pattern_B), body_fun_B2, (params, wires, layer))
return (params, wires, wires_pattern_A, wires_pattern_B)
qml.StatePrep(initial_state, wires=wires)
jax.lax.fori_loop(0, layer_num, body_fun_A, (params, wires, wires_pattern_A, wires_pattern_B))
return qml.expval(qml.Hermitian(qml.math.dm_from_state_vector(target_state), wires))
@partial(jax.jit, static_argnames=['wires', 'layer_num'])
def cost_jax(params, wires, layer_num, initial_state_batch, target_state_batch):
def body_fun(i, c):
ini_state = initial_state_batch[i]
tar_state = target_state_batch[i]
return c + circuit_jax(params, wires, layer_num, ini_state, tar_state)
c = jax.lax.fori_loop(0, len(initial_state_batch), body_fun, 0)
return c / len(initial_state_batch)
@partial(jax.jit, static_argnames=['wires', 'layer_num'])
def update_params_jax(i, args, wires, layer_num):
params, opt_state, initial_state_batch, target_state_batch= args
cur_cost, grads = jax.value_and_grad(lambda p: cost_jax(p, wires, layer_num, initial_state_batch, target_state_batch))(params)
updates, opt_state=opt.update(grads, opt_state)
params=optax.apply_updates(params, updates)
jax.lax.cond(jnp.logical_and(i % 10 == 0, i), lambda _: jax.debug.print('Epoch {i}, current fidelity is {j}', i=i, j=cur_cost), lambda _: None, operand=None)
return params, opt_state, cur_cost
@partial(jax.jit, static_argnames=['wires', 'layer_num', 'round_limit'])
def training_process_jax(params, wires, layer_num, opt_state, cur_cost, round_limit, costs):
converge_count = 0
prev_cost = 1
def cond_fun(loop_state):
params, opt_state, prev_cost, cur_cost, converge_count, i, costs = loop_state
return jnp.logical_and(converge_count < 5, i <= round_limit)
def body_fun(loop_state):
params, opt_state, prev_cost, cur_cost, converge_count, i, costs = loop_state
converge_count = jax.lax.cond(jnp.logical_and(abs(cur_cost - prev_cost) < 0.003, cur_cost < 0.1), lambda cc: cc + 1, lambda cc: 0, converge_count)
prev_cost = cur_cost
initial_state_batch = initial_states_data[i]
target_state_batch = target_states_data[i]
params, opt_state, cur_cost = update_params_jax(i, (params, opt_state, initial_state_batch, target_state_batch), wires, layer_num)
costs = costs.at[i].set(1 - cur_cost)
i += 1
return params, opt_state, prev_cost, cur_cost, converge_count, i, costs
params, opt_state, prev_cost, cur_cost, converge_count, i, costs = jax.lax.while_loop(cond_fun, body_fun, (params, opt_state, prev_cost, cur_cost, converge_count, 1, costs))
return params, cur_cost, i, costs
@partial(jax.jit, static_argnames=['wires', 'layer_num'])
def optimization_jax(params, wires, layer_num):
initial_state_batch = initial_states_data[0]
target_state_batch = target_states_data[0]
opt_state = opt.init(params)
params, opt_state, cur_cost = update_params_jax(0, (params, opt_state, initial_state_batch, target_state_batch), wires, layer_num)
print(f'Initial fidelity is {cur_cost}, time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}')
costs = jnp.concatenate((jnp.array([cur_cost]), jnp.zeros(len(initial_states_data))))
params, cur_cost, round_number, costs = training_process_jax(params, wires, layer_num, opt_state, cur_cost, len(initial_states_data), costs)
return params, round_number, costs
if __name__ == '__main__':
t=time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
print(f'start time: {t}')
param_num=L*N//2*30
print(f'The number of parameters is {param_num}')
params=jnp.zeros(param_num)
params, costs=optimization_jax(params, range(N), L)
fig, ax=plt.subplots()
ax.plot(range(len(costs)+1), costs)
ax.set_xlabel('Training epoch')
ax.set_ylabel('1-Fidelity')
ax.set_title(f'Training of a {N} qubit circuit')
plt.savefig(f'qml_SU4_ansatz_{N}bit_{L}layer.png')
plt.clf()
t=time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
print(f'end time: {t}')
To run this code, you will need the 2 files below (please delete the extention txt, it is added due to the restriction of the Pennylane Forum). The initial states are some 32-dimentional randomly-generated jnp vectors, and the target states are generated by simply multiplying a Special Unitary matrix with the initial states.
ini_states_SU_5bit_1600round.npy.txt (6.3 MB)
tar_states_SU_5bit_1600round.npy.txt (6.3 MB)
The error occurs at the line jax.lax.fori_loop(0, layer_num, body_fun_A, (params, wires, wires_pattern_A, wires_pattern_B)) in the circuit_jax(), saying that TypeError: Argument ‘range(0, 5)’ of type ‘<class ‘range’>’ is not a valid JAX type. If I change the line params, costs=optimization_jax(params, range(N), L) to params, costs=optimization_jax(params, list(range(N)), L), then the error occurs at this line itself, saying ValueError: Non-hashable static arguments are not supported. An error occurred while trying to hash an object of type <class ‘list’>, [0, 1, 2, 3, 4]. The error was:
TypeError: unhashable type: ‘list’. I also read the post 2256 and changed the line to params, costs=optimization_jax(params, (0,1,2,3,4), L), and now the error occurs at the line two_qubit_decomp(params[2 * layer * len(wires) // 2 * 15:(2 * layer * len(wires) // 2 + 1) * 15], wires_pattern_A[w]) in body_fun_B1() of circuit_jax(), saying IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(Traced<~int64>with, Traced<~int64>with, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).. So how to use the jax.jit when there is sub-circuit functions in my circuit, especially when I need to distribute the parameters to the sub-circuits based on the length of the wires?
I also noticed that in simple cases without sub-circuits (such as just training the two_qubit_decomp() on 2 qubits), if I pass the wires argument with a tuple, and decorate the two_qubit_decomp() with @partial(jax.jit, static_argnames=['wires']), an error will be thrown at the second line of two_qubit_decomp() saying that abstract wires is not supported in quantum gates. Do I need to jit the circuit function and the sub-circuit functions (if exist)? If I do not jit them, will the compilation time be affected?