How to use conditionals with JAX-JIT?

Is there a way to make this code work?

def circ(x):
    qml.RY(x, 0)
    y = func(x)
    if y == 1:
        qml.PauliZ(0)
    qml.Hadamard(0)
    return qml.expval(qml.PauliZ(wires=0))

When using this with JAX-JIT, it results in the error at y == 1 line:

jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function.

If not using JIT, it works. It also does not work when trying to use vmap.

There are ways to remove the if condition from the code, like using jax.lax.switch or jax.lax.cond, but when using these, it throws out TypeError for qml.PauliZ saying it is not a valid JAX type.

Is there a way to make the code compatible with JIT and vmap?

If it helps, y can only be 0 or 1. This is just an example; in my work, I need to use qml.MultiControlledX instead of qml.PauliZ.

Hi @ankit27kh, could you please share your full code if possible?

Hey @CatalinaAlbornoz, here is a simple, reproducible example:

  1. No JIT and vmap. This works:
import pennylane as qml
import jax
import jax.numpy as jnp
import pennylane.numpy as np

dev = qml.device('default.qubit.jax', wires=1, shots=None)

def func(x):
    return x - 1

@qml.qnode(dev, interface="jax")
def circ(x):
    qml.RY(x, 0)
    y = func(x)
    if y == 1:
        qml.PauliZ(0)
    qml.Hadamard(0)
    return qml.expval(qml.PauliZ(wires=0))

def no_map(X):
    e = 0
    for x in X:
        e = e + circ(x)
    return e / len(X)

X = np.random.randint(1, 3, 10)
print(no_map(X))
  1. Using vmap. Does not work:
import pennylane as qml
import jax
import jax.numpy as jnp
import pennylane.numpy as np

dev = qml.device('default.qubit.jax', wires=1, shots=None)

def func(x):
    return x - 1

@qml.qnode(dev, interface="jax")
def circ(x):
    qml.RY(x, 0)
    y = func(x)
    if y == 1:
        qml.PauliZ(0)
    qml.Hadamard(0)
    return qml.expval(qml.PauliZ(wires=0))

def map(X):
    return jnp.mean(jax.vmap(circ)(X))

X = np.random.randint(1, 3, 10)
print(map(X))
  1. Using JIT. Does not work:
import pennylane as qml
import jax
import jax.numpy as jnp

dev = qml.device('default.qubit.jax', wires=1, shots=None)

def func(x):
    return x - 1

@jax.jit
@qml.qnode(dev, interface="jax")
def circ(x):
    qml.RY(x, 0)
    y = func(x)
    if y == 1:
        qml.PauliZ(0)
    qml.Hadamard(0)
    return qml.expval(qml.PauliZ(wires=0))

print(circ(2))

Hi @ankit27kh, this seems to be an issue with JAX since y == 1 is a boolean. It seems to me like you need to find a workaround that doesn’t use a boolean in this way. You can read more about this error here.

@ankit27kh you might also find jax.lax.cond useful here

Hey @CatalinaAlbornoz, I am aware of the JAX limitation. I asked to know if the PennyLane people have some other ideas. :slight_smile:

@josh, I have tried using jax.lax.cond, jax.numpy.where and jax.lax.switch but no luck.

I also tried to remove the condition completely by using an extra wire in the circuit and applying extra gates to that wire when my condition is not True like this:

# Applies the gate to wire 1 when value is 1 and to wire 0 when value is 0.
#Then ignore wire 0 in the computations.
target_wire = jnp.min(jnp.array([1, value * 99]))

When using this in a PennyLane operator, I get this error:

pennylane.wires.WireError: Wires must be hashable; got object of type <class 'jax.interpreters.batching.BatchTracer'>.

Hi @ankit27kh,

One of my colleagues found a workaround. You should still check that this works the way you expect it to though!

import pennylane as qml
import jax
import jax.numpy as jnp
from functools import partial
from pennylane import numpy as np

dev = qml.device('default.qubit.jax', wires=1, shots=None)

def func(x):
    return x - 1

@jax.jit
@qml.qnode(dev, interface="jax")
def circ():
    qml.RY(x, 0)
    y = func(x)
    np.where((y == 1), qml.PauliZ(0), None)
    qml.Hadamard(0)
    return qml.expval(qml.PauliZ(wires=0))

x=2
print(circ()) 

Please let me know if this works for you.

Hey @CatalinaAlbornoz, I have tried this already. Unfortunately, this does not work.
Here it is working because x is not being passed to the function.

Once you pass x as an argument it will throw out an error. Then you’ll need to use jnp.where instead of np.where. But, this will also throw an error:

TypeError: where requires ndarray or scalar arguments, got <class 'pennylane.ops.qubit.non_parametric_ops.PauliZ'> at position 1.

Hi @ankit27kh,

I ran into similar problems while trying make some of the optimization transforms in PennyLane jittable (see here for the example in question). There didn’t end up being a good solution which will conditionally apply gates or not based on some value, one gets errors complaining that gates are not valid JAX types, etc.

jax.lax.cond seems to work best when dealing only with functions of numerical parameters. So one approach is to “trick” it by using a parametrized gate, where the parameter values are what get selected by cond. This way the shape and type are consistent for both branches of the conditional. So for example:

import pennylane as qml

import jax
import jax.numpy as jnp

dev = qml.device('default.qubit.jax', wires=1, shots=None)

def func(x):
    return x - 1

# Note: I will apply jit to the circuit later; the non-jitted version is so that
# we can draw it and compare the jitted vs. non-jitted results
@qml.qnode(dev, interface="jax")
def circ(x):
    qml.RY(x, wires=0)
    
    y = func(x)
    
    # Use the conditional to select a rotation angle
    rotation_angle = jax.lax.cond(
        y == 1, # Condition
        lambda x: jnp.array(jnp.pi), # What to do if condition is true
        lambda x: jnp.array(0.0), # What to do if condition is false
        y # Parameters for the condition go at the end
    )
    qml.RZ(rotation_angle, wires=0)
    
    qml.Hadamard(0)
    return qml.expval(qml.PauliZ(wires=0))

jitted_circ = jax.jit(circ)

for y in [1, 2]:
    print(qml.draw(circ)(y))
    print(circ(y))
    print(jitted_circ(y))
    print()

Running this gives:

0: ──RY(1.00)──RZ(0.00)──H─┤  <Z>
0.84147096
0.84147096

0: ──RY(2.00)──RZ(3.14)──H─┤  <Z>
-0.90929735
-0.90929735

You can see from the linked example how the conditionals can involve more complicated functions, multiple arguments, etc.

Hope that helps!

4 Likes

Thank you, @glassnotes_alt. This is a neat trick that will work for most cases. Unfortunately for me, I am using a qml.MultiControlledX gate:

if y == 1:
    qml.MultiControlledX(
            wires=control_wire_list+[target_wire],
            control_values=control_str
    )

Here the ‘parameter’ is the wire values. I tried your trick by having an extra qubit to become the dummy target wire. But this does not work and throws an error:

t = get_target_wire()
target = jax.lax.cond(y == 1, lambda x: t, lambda x: 0, y)
qml.MultiControlledX(
                wires=[3,4,5,6] + [target],
                control_values='0101',
)
pennylane.wires.WireError: Wires must be hashable; got object of type <class 'jax.interpreters.batching.BatchTracer'>.

I am marking this as an answer as this will probably help someone else with other gates.

Thank you again. If you find a way to use this in this specific case, please let me know!