Thanks for your response @antalszava
Since T will change everytime, I don’t think that making it static will be very efficient. For you to have a better idea of what I am doing, hier is the full portion of code:
def Circuits_Observable_phi_list_jit(params, p, Observable_list, inputs_n, inputs_m):
# k, S, T are set of indices, the goal is to init the state |inputs_n> + i**p|inputs_m>, (inputs are bitstrings)
k = jnp.nonzero(inputs_n != inputs_m, size=1)[0][0]
new_p = (-1)**inputs_n[k]*p%4
new_inputs_n = inputs_n[k]*inputs_m + (1-inputs_n[k])*inputs_n
new_inputs_m = inputs_n[k]*inputs_n + (1-inputs_n[k])*inputs_m
S = jnp.nonzero(new_inputs_n != new_inputs_m, size=6)[0][1:]
T = jnp.nonzero(new_inputs_n == 1, size=2)[0]
dev = qml.device('default.qubit.jax', wires=n_qubits//2)
@qml.qnode(dev, interface='jax', diff_method="backprop")
def qnode(params, T, S, Observable_list):
for t in T:
for i in range(N):
qml.RX(phi=((t==i)*jnp.pi), wires=i)
for i in range(N):
qml.Rot(phi=((k==i)*jnp.pi),theta=((k==i)*jnp.pi/4),omega=0,wires=i)
qml.PhaseShift(jnp.isin(new_p,jnp.array([2,3]))*(k==i)*jnp.pi, wires=i)
qml.PhaseShift(jnp.isin(new_p,jnp.array([1,3]))*(k==i)*jnp.pi/2, wires=i)
for l in S:
for i in range(N):
for j in range(i+1,N):
qml.CRot(phi=jnp.pi*(((k==i)*1+(l==j)*1)==2),theta=jnp.pi*(((k==i)*1+(l==j)*1)==2),omega=0,wires=[i,j])
brick_wall_entangling(params)
return [qml.expval(Obs) for Obs in Observable_list]
return qnode(params_A, T, S, Observable_list)
with
def brick_wall_entangling(params):
layers, x, _ = params.shape
qubits = x+1
for i in range(layers):
for j in range(qubits-1):
U_1ex(phi=params[i][j][0], theta=params[i][j][1], wires=[j,j+1])
Iam trying to optimize the code and given the fact that I use this function with a lot of different values of inputs_n and inputs_m, I think it could help to trace them and be able to jit
jax.jit(partial(Circuits_Observable_phi_list_jit, params=params_A, Observable_list = H_oA))
So this code works (as mentionned in the first post). However, it doesn’t scale well with the number of layers and the number of element in Observable_list (which is large and I need to get every expectation separately). I don’t know if it’s possible to do better or if I am missing somethings.
I hope I’m not boring you too much with all my code,
Thanks a lot for everything!!!