def circ(x):
qml.RY(x, 0)
y = func(x)
if y == 1:
qml.PauliZ(0)
qml.Hadamard(0)
return qml.expval(qml.PauliZ(wires=0))
When using this with JAX-JIT, it results in the error at y == 1 line:
jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function.
If not using JIT, it works. It also does not work when trying to use vmap.
There are ways to remove the if condition from the code, like using jax.lax.switch or jax.lax.cond, but when using these, it throws out TypeError for qml.PauliZ saying it is not a valid JAX type.
Is there a way to make the code compatible with JIT and vmap?
If it helps, y can only be 0 or 1. This is just an example; in my work, I need to use qml.MultiControlledX instead of qml.PauliZ.
import pennylane as qml
import jax
import jax.numpy as jnp
import pennylane.numpy as np
dev = qml.device('default.qubit.jax', wires=1, shots=None)
def func(x):
return x - 1
@qml.qnode(dev, interface="jax")
def circ(x):
qml.RY(x, 0)
y = func(x)
if y == 1:
qml.PauliZ(0)
qml.Hadamard(0)
return qml.expval(qml.PauliZ(wires=0))
def no_map(X):
e = 0
for x in X:
e = e + circ(x)
return e / len(X)
X = np.random.randint(1, 3, 10)
print(no_map(X))
Using vmap. Does not work:
import pennylane as qml
import jax
import jax.numpy as jnp
import pennylane.numpy as np
dev = qml.device('default.qubit.jax', wires=1, shots=None)
def func(x):
return x - 1
@qml.qnode(dev, interface="jax")
def circ(x):
qml.RY(x, 0)
y = func(x)
if y == 1:
qml.PauliZ(0)
qml.Hadamard(0)
return qml.expval(qml.PauliZ(wires=0))
def map(X):
return jnp.mean(jax.vmap(circ)(X))
X = np.random.randint(1, 3, 10)
print(map(X))
Using JIT. Does not work:
import pennylane as qml
import jax
import jax.numpy as jnp
dev = qml.device('default.qubit.jax', wires=1, shots=None)
def func(x):
return x - 1
@jax.jit
@qml.qnode(dev, interface="jax")
def circ(x):
qml.RY(x, 0)
y = func(x)
if y == 1:
qml.PauliZ(0)
qml.Hadamard(0)
return qml.expval(qml.PauliZ(wires=0))
print(circ(2))
Hi @ankit27kh, this seems to be an issue with JAX since y == 1 is a boolean. It seems to me like you need to find a workaround that doesn’t use a boolean in this way. You can read more about this error here.
Hey @CatalinaAlbornoz, I am aware of the JAX limitation. I asked to know if the PennyLane people have some other ideas.
@josh, I have tried using jax.lax.cond, jax.numpy.where and jax.lax.switch but no luck.
I also tried to remove the condition completely by using an extra wire in the circuit and applying extra gates to that wire when my condition is not True like this:
# Applies the gate to wire 1 when value is 1 and to wire 0 when value is 0.
#Then ignore wire 0 in the computations.
target_wire = jnp.min(jnp.array([1, value * 99]))
When using this in a PennyLane operator, I get this error:
pennylane.wires.WireError: Wires must be hashable; got object of type <class 'jax.interpreters.batching.BatchTracer'>.
I ran into similar problems while trying make some of the optimization transforms in PennyLane jittable (see here for the example in question). There didn’t end up being a good solution which will conditionally apply gates or not based on some value, one gets errors complaining that gates are not valid JAX types, etc.
jax.lax.cond seems to work best when dealing only with functions of numerical parameters. So one approach is to “trick” it by using a parametrized gate, where the parameter values are what get selected by cond. This way the shape and type are consistent for both branches of the conditional. So for example:
import pennylane as qml
import jax
import jax.numpy as jnp
dev = qml.device('default.qubit.jax', wires=1, shots=None)
def func(x):
return x - 1
# Note: I will apply jit to the circuit later; the non-jitted version is so that
# we can draw it and compare the jitted vs. non-jitted results
@qml.qnode(dev, interface="jax")
def circ(x):
qml.RY(x, wires=0)
y = func(x)
# Use the conditional to select a rotation angle
rotation_angle = jax.lax.cond(
y == 1, # Condition
lambda x: jnp.array(jnp.pi), # What to do if condition is true
lambda x: jnp.array(0.0), # What to do if condition is false
y # Parameters for the condition go at the end
)
qml.RZ(rotation_angle, wires=0)
qml.Hadamard(0)
return qml.expval(qml.PauliZ(wires=0))
jitted_circ = jax.jit(circ)
for y in [1, 2]:
print(qml.draw(circ)(y))
print(circ(y))
print(jitted_circ(y))
print()
Thank you, @glassnotes_alt. This is a neat trick that will work for most cases. Unfortunately for me, I am using a qml.MultiControlledX gate:
if y == 1:
qml.MultiControlledX(
wires=control_wire_list+[target_wire],
control_values=control_str
)
Here the ‘parameter’ is the wire values. I tried your trick by having an extra qubit to become the dummy target wire. But this does not work and throws an error:
Hi @ankit27kh! I wanted to give a quick update here. Since your post, we’ve been working hard on Catalyst, a compiler that adds a JAX-compatible @qjit to PennyLane.
This allows you to JIT arbitrary hybrid quantum-classical code. We’ve just added support for native Python if statements, so it looks like your example is now working with Catalyst
If you are on Linux or MacOS, you can install Catalyst with
pip install pennylane-catalyst
The following program should then work:
import pennylane as qml
from catalyst import qjit
import jax
import jax.numpy as jnp
import numpy as np
dev = qml.device('lightning.qubit', wires=1, shots=None)
def func(x):
return x - 1
@qjit(autograph=True)
@qml.qnode(dev)
def circ(x):
qml.RY(x, 0)
y = func(x)
if y == 1:
qml.PauliZ(0)
qml.Hadamard(0)
return qml.expval(qml.PauliZ(wires=0))
Executing the circuit as-is:
>>> circ(2)
array(-0.90929743)
Using vmap:
>>> x = np.random.randint(1, 3, 10)
>>> jnp.mean(jax.vmap(circ)(x))
Array(-0.03391322, dtype=float64)