Is there a way to make this code work?
def circ(x):
qml.RY(x, 0)
y = func(x)
if y == 1:
qml.PauliZ(0)
qml.Hadamard(0)
return qml.expval(qml.PauliZ(wires=0))
When using this with JAX-JIT, it results in the error at y == 1
line:
jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function.
If not using JIT, it works. It also does not work when trying to use vmap
.
There are ways to remove the if
condition from the code, like using jax.lax.switch
or jax.lax.cond
, but when using these, it throws out TypeError
for qml.PauliZ
saying it is not a valid JAX type.
Is there a way to make the code compatible with JIT and vmap
?
If it helps, y
can only be 0 or 1. This is just an example; in my work, I need to use qml.MultiControlledX
instead of qml.PauliZ
.