How to use conditionals with JAX-JIT?

Hi @ankit27kh,

I ran into similar problems while trying make some of the optimization transforms in PennyLane jittable (see here for the example in question). There didn’t end up being a good solution which will conditionally apply gates or not based on some value, one gets errors complaining that gates are not valid JAX types, etc.

jax.lax.cond seems to work best when dealing only with functions of numerical parameters. So one approach is to “trick” it by using a parametrized gate, where the parameter values are what get selected by cond. This way the shape and type are consistent for both branches of the conditional. So for example:

import pennylane as qml

import jax
import jax.numpy as jnp

dev = qml.device('default.qubit.jax', wires=1, shots=None)

def func(x):
    return x - 1

# Note: I will apply jit to the circuit later; the non-jitted version is so that
# we can draw it and compare the jitted vs. non-jitted results
@qml.qnode(dev, interface="jax")
def circ(x):
    qml.RY(x, wires=0)
    
    y = func(x)
    
    # Use the conditional to select a rotation angle
    rotation_angle = jax.lax.cond(
        y == 1, # Condition
        lambda x: jnp.array(jnp.pi), # What to do if condition is true
        lambda x: jnp.array(0.0), # What to do if condition is false
        y # Parameters for the condition go at the end
    )
    qml.RZ(rotation_angle, wires=0)
    
    qml.Hadamard(0)
    return qml.expval(qml.PauliZ(wires=0))

jitted_circ = jax.jit(circ)

for y in [1, 2]:
    print(qml.draw(circ)(y))
    print(circ(y))
    print(jitted_circ(y))
    print()

Running this gives:

0: ──RY(1.00)──RZ(0.00)──H─┤  <Z>
0.84147096
0.84147096

0: ──RY(2.00)──RZ(3.14)──H─┤  <Z>
-0.90929735
-0.90929735

You can see from the linked example how the conditionals can involve more complicated functions, multiple arguments, etc.

Hope that helps!

4 Likes