Wires specify in a jax DeviceArray

Hi, I am working with JAX and I try to creat a circuit which applies gate at wires specifies in a DeviceArray. For exemple,

dev = qml.device('default.qubit.jax', wires=n_qubits//2)  
@partial(jax.jit, static_argnums=(3))    
@qml.qnode(dev, interface='jax', diff_method="backprop")
def qnode(params, T, S, Observable):
   for t in T:
      qml.PauliX(t)
  
   ...(more gates)...

However, I get that “Wires must be hashable; got object of type <class ‘jaxlib.xla_extension.DeviceArray’>.”
I have manage to find a solution:

  dev = qml.device('default.qubit.jax', wires=n_qubits//2)  
  @partial(jax.jit, static_argnums=(3))    
  @qml.qnode(dev, interface='jax', diff_method="backprop")
  def qnode(params, T, S, Observable):
      for t in T:
        for i in range(N):
          qml.RX(phi=((t==i)*jnp.pi), wires=i)

However, this is not very elegant and it doesn’t scale well in N. Is there a better solution? (i also tried jax.lax.scan but without success)
Thanks a lot

Hi @paulin_ds, thanks for the question! :slightly_smiling_face:

As the error message suggests, our Wires class that is created internally in PennyLane expects hashable inputs. So making argument T be a static argument and passing it in as a tuple would be our recommended way for now.

Just double-checking: is the motivation behind having the wires be traced and/or be of a DeviceArray object to be able to optimize the T input or to be able to jit it?

For the latter solution, one thing that comes to mind is to make N be a static argument as mentioned above and T traced. The gate arguments could then be tracers:

for t, i in zip(T, N):
    qml.RX(phi=(t*jnp.pi), wires=i)

where the RX gate may be equivalent to applying the identity based on the value of phi. Such “dummy” rotations could then be cancelled by adding a quantum compilation pass.

Note that I might be off here and look forward to hearing more about the exact use case.

Thanks for your response @antalszava
Since T will change everytime, I don’t think that making it static will be very efficient. For you to have a better idea of what I am doing, hier is the full portion of code:

            def Circuits_Observable_phi_list_jit(params, p, Observable_list, inputs_n, inputs_m):  
              # k, S, T are set of indices, the goal is to init the state |inputs_n> + i**p|inputs_m>,  (inputs are bitstrings)
              k = jnp.nonzero(inputs_n != inputs_m, size=1)[0][0] 
              new_p = (-1)**inputs_n[k]*p%4
              new_inputs_n = inputs_n[k]*inputs_m + (1-inputs_n[k])*inputs_n
              new_inputs_m = inputs_n[k]*inputs_n + (1-inputs_n[k])*inputs_m
              S = jnp.nonzero(new_inputs_n != new_inputs_m, size=6)[0][1:]
              T = jnp.nonzero(new_inputs_n == 1, size=2)[0]

              dev = qml.device('default.qubit.jax', wires=n_qubits//2)  
     @qml.qnode(dev, interface='jax', diff_method="backprop")

              def qnode(params, T, S, Observable_list):

                  for t in T:
                    for i in range(N):
                      qml.RX(phi=((t==i)*jnp.pi), wires=i)

                  for i in range(N):
                    qml.Rot(phi=((k==i)*jnp.pi),theta=((k==i)*jnp.pi/4),omega=0,wires=i)
                    qml.PhaseShift(jnp.isin(new_p,jnp.array([2,3]))*(k==i)*jnp.pi, wires=i)
                    qml.PhaseShift(jnp.isin(new_p,jnp.array([1,3]))*(k==i)*jnp.pi/2, wires=i)

                  for l in S:
                    for i in range(N):
                      for j in range(i+1,N):
                          qml.CRot(phi=jnp.pi*(((k==i)*1+(l==j)*1)==2),theta=jnp.pi*(((k==i)*1+(l==j)*1)==2),omega=0,wires=[i,j])

                  brick_wall_entangling(params)
                 return [qml.expval(Obs) for Obs in Observable_list]

              return qnode(params_A, T, S, Observable_list)

with

def brick_wall_entangling(params): 
    layers, x, _ = params.shape
    qubits = x+1
    for i in range(layers):
        for j in range(qubits-1):
          U_1ex(phi=params[i][j][0], theta=params[i][j][1], wires=[j,j+1]) 

Iam trying to optimize the code and given the fact that I use this function with a lot of different values of inputs_n and inputs_m, I think it could help to trace them and be able to jit

jax.jit(partial(Circuits_Observable_phi_list_jit, params=params_A, Observable_list = H_oA))

So this code works (as mentionned in the first post). However, it doesn’t scale well with the number of layers and the number of element in Observable_list (which is large and I need to get every expectation separately). I don’t know if it’s possible to do better or if I am missing somethings.
I hope I’m not boring you too much with all my code, :sweat_smile:
Thanks a lot for everything!!! :relaxed:

Hi @paulin_ds,

Glad to discuss, no worries!

So this code works (as mentionned in the first post). However, it doesn’t scale well with the number of layers and the number of element in Observable_list (which is large and I need to get every expectation separately). I don’t know if it’s possible to do better or if I am missing somethings.

Could you elaborate on why the scaling is poor here? We are doing something similar when attempting to apply X gates in the decomposition of basis state preparation (see this recent PR https://github.com/PennyLaneAI/pennylane/pull/3239), so frankly, I’m not sure if there would be way more optimal solutions at hand. Allowing wires to be tracer objects is something we’re experimenting with, but I wouldn’t think it would be a feature in the near future - definitely good to see interest, though!

Ok I see. Nice that you are working on allowing the wires to be traced, it must be interresting! Maybe the scaling is poor because my measures are in a list format:
return [qml.expval(Obs) for Obs in Observable_list]
and there is a ~700 Obs. In this case, does the circuit recompile (prepare the state again) for every Obs?

Hi @paulin_ds,

The return [qml.expval(Obs) for Obs in Observable_list] syntax is used to return multiple measurements from the same QNode such that the same pre-measurement state is being used to return all the expectation values.

Just to clarify: would we expect that making the wires Tracers would boost the performance?

Ok thanks I see. Yes I think. Using JAX with Pennylane, I also had the error that the observables cannot be traced since they are not a valid jax type. I think this can also decrease the performance :slight_smile:

Hi @paulin_ds,

Using JAX with Pennylane, I also had the error that the observables cannot be traced since they are not a valid jax type.

Could you send through a minimal example for this?

Hi @antalszava
Yes of course. For exemple, if one need to execute a circuit like

def Circuits_ObservableB(params, Observable, inputs):
    dev = qml.device('default.qubit.jax', wires=N_B)   
    @qml.qnode(dev, interface='jax', diff_method="backprop")
    def qnode(params, inputs, Observable_list):
        for i in range(N_B):
          qml.RX(jnp.pi*inputs[i], wires=i)
        brick_wall_entanglingB(params)
        return qml.expval(Observable)
    return qnode(params, inputs, Observable) 

and this circuit will be excecute multiple times, it can be interesting to jit it. However, if we try

Circuits_ObservableB_jit = jax.jit(Circuits_ObservableB)
Circuits_ObservableB_jit(params=params_B,Observable=H_B, inputs=bitstringB[0])

we get the error

TypeError: Argument 'PauliZ(wires=[0]) @ Identity(wires=[1]) @ Identity(wires=[2]) @ Identity(wires=[3]) @ Identity(wires=[4])' of type <class 'pennylane.operation.Tensor'> is not a valid JAX type.

Indeed, one cannot jit a function if one of its argument is an qml.operator. A possible solution is then to use

Circuits_ObservableB_jit = jax.jit(partial(Circuits_ObservableB, Observable=H_B))

But we need to recompile the function if the observable change.
I hope my exemple was clear :slight_smile:

Hi @paulin_ds,

Indeed, this is not feasible at the moment. Let me make a note of this with the specific use case too.

Thanks for the feature request! :slightly_smiling_face:

Hi @antalszava,
With pleasure, don’t hesitate if I can help you for somethings! :relaxed:

1 Like