Hi! Firstly, thanks for an excellent piece of software. I recently started using pennylane more, and enjoying it.
Here is a problem I’m trying to solve. I have a deep circuit composed of identical layers with different parameters (think a typical brickwork circuit). I have many parameter configurations, and for each need to compute expectation values of many different (non-commuting) observables. By compute, I mean simulate (noiseless pure simulation).
To efficiently compute expectation values of many observables, I return them all in the same quantum node. As far as I understand, in this case the statevector is computed only once, and there is basically no overhead for additional observables.
But also, I need to do this for many parameter configurations, so I’d like to jit
the circuit. Here is one way to do it
import pennylane as qml
import jax
import jax.numpy as jnp
dev = qml.device('lightning.qubit', wires=1)
@qml.qnode(dev, interface='jax')
def circuit(x):
for xi in x:
qml.RX(xi, wires=0)
return qml.expval(qml.PauliX(0)), qml.expval(qml.PauliZ(0))
expectation = jax.jit(circuit)
x = jnp.linspace(0, 1, 500)
print(expectation(x))
The problem here is that for deep circuits, the compile time becomes prohibitive. However, because all layers are basically the same, I guess this should be possible by using something like catalyst.for_loop
. And indeed, I can get the circuit to compile very quickly with the following code
import pennylane as qml
import jax
import jax.numpy as jnp
import catalyst
dev = qml.device('lightning.qubit', wires=1)
@catalyst.qjit
@qml.qnode(dev, interface='jax')
def circuit(x):
def loop_fn(i):
qml.RX(x[i], wires=0)
catalyst.for_loop(0, len(x), 1)(loop_fn)()
return qml.expval(qml.PauliX(0)) #, qml.expval(qml.PauliZ(0))
expectation = jax.jit(circuit)
x = jnp.linspace(0, 1, 10000)
print(expectation(x))
But if I include the non-commuting observables (e.g. uncomment qml.expval(qml.PauliZ(0))
) I get
pennylane.QuantumFunctionError: Only observables that are qubit-wise commuting Pauli words can be returned on the same wire, some of the following measurements do not commute:
[expval(X(0)), expval(Z(0))]
Fundamentally, I do not see why returning non-commuting observables is fine in one case, but no the other.
Is there a workaround for that? Or perhaps another way to make compiling repeated layers fast/compute observables in batches efficiently?