Non-commuting observables with Catalyst

Hi! Firstly, thanks for an excellent piece of software. I recently started using pennylane more, and enjoying it.

Here is a problem I’m trying to solve. I have a deep circuit composed of identical layers with different parameters (think a typical brickwork circuit). I have many parameter configurations, and for each need to compute expectation values of many different (non-commuting) observables. By compute, I mean simulate (noiseless pure simulation).

To efficiently compute expectation values of many observables, I return them all in the same quantum node. As far as I understand, in this case the statevector is computed only once, and there is basically no overhead for additional observables.

But also, I need to do this for many parameter configurations, so I’d like to jit the circuit. Here is one way to do it

  import pennylane as qml
  import jax
  import jax.numpy as jnp
  
  dev = qml.device('lightning.qubit', wires=1)
  @qml.qnode(dev, interface='jax')
  def circuit(x):
      for xi in x:
          qml.RX(xi, wires=0)
      return qml.expval(qml.PauliX(0)), qml.expval(qml.PauliZ(0))
  
  expectation = jax.jit(circuit)
  x = jnp.linspace(0, 1, 500)
  print(expectation(x))

The problem here is that for deep circuits, the compile time becomes prohibitive. However, because all layers are basically the same, I guess this should be possible by using something like catalyst.for_loop. And indeed, I can get the circuit to compile very quickly with the following code

import pennylane as qml
import jax
import jax.numpy as jnp
import catalyst

dev = qml.device('lightning.qubit', wires=1)
@catalyst.qjit
@qml.qnode(dev, interface='jax')
def circuit(x):
    def loop_fn(i):
        qml.RX(x[i], wires=0)

    catalyst.for_loop(0, len(x), 1)(loop_fn)()
    return qml.expval(qml.PauliX(0)) #, qml.expval(qml.PauliZ(0))

expectation = jax.jit(circuit)
x = jnp.linspace(0, 1, 10000)
print(expectation(x))

But if I include the non-commuting observables (e.g. uncomment qml.expval(qml.PauliZ(0))) I get

pennylane.QuantumFunctionError: Only observables that are qubit-wise commuting Pauli words can be returned on the same wire, some of the following measurements do not commute:
[expval(X(0)), expval(Z(0))]

Fundamentally, I do not see why returning non-commuting observables is fine in one case, but no the other.

Is there a workaround for that? Or perhaps another way to make compiling repeated layers fast/compute observables in batches efficiently?

Hi @idnm, thank you trying out PennyLane & Catalyst, I’m glad you’ve been enjoying it!

The problem here is that for deep circuits, the compile time becomes prohibitive. However, because all layers are basically the same, I guess this should be possible by using something like catalyst.for_loop. And indeed, I can get the circuit to compile very quickly with the following code

Nice! This is exactly one of the reasons we built Catalyst for :slight_smile:

But if I include the non-commuting observables (e.g. uncomment qml.expval(qml.PauliZ(0))) I get pennylane.QuantumFunctionError

Unfortunately, the issue you stumbled upon is a limitation with the old device system in PennyLane, which didn’t allow non-commuting observables on the same tape (instead the circuit had to be split up into multiple tapes).
default.qubit has already made the transition to the new system which does not have this limitation.
lightning.qubit is also completing this transition at the moment (note that even though your example appears to work with lightning.qubit, under the hood it does split the circuit into two copies with a different observable each).
Similarly, we are updating Catalyst to work with the new system as well, so this issue will be very much resolved with the next release. In the meantime, you might have to use two separate functions if you need non-commuting observables on the same qubit, but you can still compile them together:

def circuit_body(x):
    def loop_fn(i):
        qml.RX(x[i], wires=0)

    catalyst.for_loop(0, len(x), 1)(loop_fn)()

@qml.qnode(dev)
def circuit1(x):
    circuit_body(x)
    return qml.expval(qml.PauliX(0))

@qml.qnode(dev)
def circuit2(x):
    circuit_body(x)
    return qml.expval(qml.PauliZ(0))

@catalyst.qjit
def main(x):
    return circuit1(x), circuit2(x)

x = jnp.linspace(0, 1, 10000)
print(main(x))

(Note that you do not need to apply jax.jit again if you already applied catalyst.qjit, the latter can be used as a full replacement for the former.)

Hi @David_Ittah, thanks for your reply!

thank you trying out PennyLane & Catalyst

My pleasure:)

I believe I get your explanation, and was suspecting something similar going on. Will be looking forward to future releases.

In the meantime, however, I’d really like efficiency in both directions – I have many parameter configurations, each with many observables (not just two). Could you perhaps point me to a different workaround? I was thinking maybe it is possible to pull out the statevector of the circuit and then process the observables in batches using vmap.

Great thinking! This is a bit unorthodox, but there is a internal function you could use to compute the expectation values of your observables based on the computed state (which is also jit-compatible):

@qml.qnode(dev)
def circuit(x):
    def loop_fn(i):
        qml.RX(x[i], wires=0)

    for_loop(0, len(x), 1)(loop_fn)()
    return qml.state()

@qml.qjit
def main(x):
    state = circuit(x)
    observables = [qml.PauliZ(0), qml.PauliY(0), qml.PauliX(0)]
    return [qml.devices.qubit.measure(qml.expval(obs), state) for obs in observables]

x = jnp.linspace(0, 1, 10000)
main(x)

Let me know if something like this this is working for you.

(Note that vmap would not work directly with (most) observables, because vmap operates over data. The observable type however (say PauliX vs PauliZ) is not data. With the Hermitian observable type you might be able to vmap over the hermitian matrix data though.)

1 Like

@David_Ittah At least in this toy case, your suggestion indeed works exactly as I anticipated. Thanks a lot!

3 Likes