Problem using jax with DiagonalQubitUnitary

I’m trying to run batch circuit executions with pennylane 0.22.1 using jax 0.3.6. Here’s the original that I run, which works fine:

import pennylane as qml
from pennylane import numpy as qnp
import numpy as np

dev = qml.device("default.qubit", wires=1, shots=None)

@qml.qnode(dev)
def circuit(x):
    qml.DiagonalQubitUnitary(qnp.exp(1j * x * np.array([-1, 1]) / 2), wires=0)
    return qml.expval(qml.PauliX(0))

batch_params = [0, 1]
print([circuit(x) for x in batch_params])

which outputs [tensor(0., requires_grad=True), tensor(0., requires_grad=True)] as expected. Then I convert this to a circuit compatible with jax:

import jax
import jax.numpy as jnp

dev = qml.device("default.qubit", wires=1, shots=None)

@qml.qnode(dev, interface="jax")
def circuit(x):
    qml.DiagonalQubitUnitary(jnp.exp(1j * x * jnp.array([1, -1]) / 2), wires=0)
    return qml.expval(qml.PauliX(0))

# Batch execute with jax
vcircuit = jax.vmap(circuit)
batch_params = jnp.asarray([0, 1])
res = vcircuit(batch_params)

This throws the error below. I get the feeling there’s an issue with using the unitary validator on a jax tracer type. But maybe my circuit is written poorly. Any thoughts on a workaround or fix?

Thanks


Traceback:

ConcretizationTypeError                   Traceback (most recent call last)
Input In [4], in <module>
     13 vcircuit = jax.vmap(circuit)
     14 batch_params = jnp.asarray([0, 1])
---> 15 res = vcircuit(batch_params)

    [... skipping hidden 3 frame]

File ~/projects/envs/xanadu-3.8/lib/python3.8/site-packages/pennylane/qnode.py:578, in QNode.__call__(self, *args, **kwargs)
    571 using_custom_cache = (
    572     hasattr(cache, "__getitem__")
    573     and hasattr(cache, "__setitem__")
    574     and hasattr(cache, "__delitem__")
    575 )
    576 self._tape_cached = using_custom_cache and self.tape.hash in cache
--> 578 res = qml.execute(
    579     [self.tape],
    580     device=self.device,
    581     gradient_fn=self.gradient_fn,
    582     interface=self.interface,
    583     gradient_kwargs=self.gradient_kwargs,
    584     override_shots=override_shots,
    585     **self.execute_kwargs,
    586 )
    588 if autograd.isinstance(res, (tuple, list)) and len(res) == 1:
    589     # If a device batch transform was applied, we need to 'unpack'
    590     # the returned tuple/list to a float.
   (...)
    597     # TODO: find a more explicit way of determining that a batch transform
    598     # was applied.
    600     res = res[0]

File ~/projects/envs/xanadu-3.8/lib/python3.8/site-packages/pennylane/interfaces/batch/__init__.py:342, in execute(tapes, device, gradient_fn, interface, mode, gradient_kwargs, cache, cachesize, max_diff, override_shots, expand_fn, max_expansion, device_batch_transform)
    338     return batch_fn(res)
    340 if gradient_fn == "backprop" or interface is None:
    341     return batch_fn(
--> 342         cache_execute(batch_execute, cache, return_tuple=False, expand_fn=expand_fn)(tapes)
    343     )
    345 # the default execution function is batch_execute
    346 execute_fn = cache_execute(batch_execute, cache, expand_fn=expand_fn)

File ~/projects/envs/xanadu-3.8/lib/python3.8/site-packages/pennylane/interfaces/batch/__init__.py:173, in cache_execute.<locals>.wrapper(tapes, **kwargs)
    169         return (res, []) if return_tuple else res
    171 else:
    172     # execute all unique tapes that do not exist in the cache
--> 173     res = fn(execution_tapes.values(), **kwargs)
    175 final_res = []
    177 for i, tape in enumerate(tapes):

File ~/projects/envs/xanadu-3.8/lib/python3.8/site-packages/pennylane/interfaces/batch/__init__.py:125, in cache_execute.<locals>.fn(tapes, **kwargs)
    123 def fn(tapes, **kwargs):  # pylint: disable=function-redefined
    124     tapes = [expand_fn(tape) for tape in tapes]
--> 125     return original_fn(tapes, **kwargs)

File /usr/lib/python3.8/contextlib.py:75, in ContextDecorator.__call__.<locals>.inner(*args, **kwds)
     72 @wraps(func)
     73 def inner(*args, **kwds):
     74     with self._recreate_cm():
---> 75         return func(*args, **kwds)

File ~/projects/envs/xanadu-3.8/lib/python3.8/site-packages/pennylane/_qubit_device.py:289, in QubitDevice.batch_execute(self, circuits)
    284 for circuit in circuits:
    285     # we need to reset the device here, else it will
    286     # not start the next computation in the zero state
    287     self.reset()
--> 289     res = self.execute(circuit)
    290     results.append(res)
    292 if self.tracker.active:

File ~/projects/envs/xanadu-3.8/lib/python3.8/site-packages/pennylane/_qubit_device.py:201, in QubitDevice.execute(self, circuit, **kwargs)
    198 self.check_validity(circuit.operations, circuit.observables)
    200 # apply all circuit operations
--> 201 self.apply(circuit.operations, rotations=circuit.diagonalizing_gates, **kwargs)
    203 # generate computational basis samples
    204 if self.shots is not None or circuit.is_sampled:

File ~/projects/envs/xanadu-3.8/lib/python3.8/site-packages/pennylane/devices/default_qubit.py:226, in DefaultQubit.apply(self, operations, rotations, **kwargs)
    224                 self._debugger.snapshots[len(self._debugger.snapshots)] = state_vector
    225     else:
--> 226         self._state = self._apply_operation(self._state, operation)
    228 # store the pre-rotated state
    229 self._pre_rotated_state = self._state

File ~/projects/envs/xanadu-3.8/lib/python3.8/site-packages/pennylane/devices/default_qubit.py:251, in DefaultQubit._apply_operation(self, state, operation)
    248     axes = self.wires.indices(wires)
    249     return self._apply_ops[operation.base_name](state, axes, inverse=operation.inverse)
--> 251 matrix = self._get_unitary_matrix(operation)
    253 if operation in diagonal_in_z_basis:
    254     return self._apply_diagonal_unitary(state, matrix, wires)

File ~/projects/envs/xanadu-3.8/lib/python3.8/site-packages/pennylane/devices/default_qubit.py:556, in DefaultQubit._get_unitary_matrix(self, unitary)
    545 """Return the matrix representing a unitary operation.
    546 
    547 Args:
   (...)
    553     a 1D array representing the matrix diagonal.
    554 """
    555 if unitary in diagonal_in_z_basis:
--> 556     return unitary.get_eigvals()
    558 return unitary.get_matrix()

File ~/projects/envs/xanadu-3.8/lib/python3.8/site-packages/pennylane/operation.py:1380, in Operation.get_eigvals(self)
   1379 def get_eigvals(self):
-> 1380     op_eigvals = super().get_eigvals()
   1382     if self.inverse:
   1383         return qml.math.conj(op_eigvals)

File ~/projects/envs/xanadu-3.8/lib/python3.8/site-packages/pennylane/operation.py:753, in Operator.get_eigvals(self)
    728 r"""Eigenvalues of the operator in the computational basis (static method).
    729 
    730 If :attr:`diagonalizing_gates` are specified and implement a unitary :math:`U`, the operator
   (...)
    749     tensor_like: eigenvalues
    750 """
    752 try:
--> 753     return self.compute_eigvals(*self.parameters, **self.hyperparameters)
    754 except EigvalsUndefinedError:
    755     # By default, compute the eigenvalues from the matrix representation.
    756     # This will raise a NotImplementedError if the matrix is undefined.
    757     try:

File ~/projects/envs/xanadu-3.8/lib/python3.8/site-packages/pennylane/ops/qubit/matrix_ops.py:406, in DiagonalQubitUnitary.compute_eigvals(D)
    380 r"""Eigenvalues of the operator in the computational basis (static method).
    381 
    382 If :attr:`diagonalizing_gates` are specified and implement a unitary :math:`U`,
   (...)
    402 tensor([ 1, -1])
    403 """
    404 D = qml.math.asarray(D)
--> 406 if not qml.math.allclose(D * qml.math.conj(D), qml.math.ones_like(D)):
    407     raise ValueError("Operator must be unitary.")
    409 return D

    [... skipping hidden 1 frame]

File ~/projects/envs/xanadu-3.8/lib/python3.8/site-packages/jax/core.py:1123, in concretization_function_error.<locals>.error(self, arg)
   1122 def error(self, arg):
-> 1123   raise ConcretizationTypeError(arg, fname_context)

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<BatchTrace(level=1/0)> with
  val = DeviceArray([ True,  True], dtype=bool)
  batch_dim = 0
The problem arose with the `bool` function. 
This Tracer was created on line /home/ares/projects/envs/xanadu-3.8/lib/python3.8/site-packages/autoray/autoray.py:84 (do)

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

Hey @EvanPeters, you are exactly right: there is a single line inside of qml.DiagonalQubitUnitary that is validating that the matrix is unitary, and this is not JIT compatible:

If you comment out these two lines in your local PennyLane source code, your example above with jax.vmap works perfectly.

Let me make a quick PR to PennyLane to fix this.

1 Like

@EvanPeters I’ve created a PR fixing this bug here: Add support for JIT when using `DiagonalQubitUnitary` by josh146 · Pull Request #2445 · PennyLaneAI/pennylane · GitHub. Feel free to use this branch directly in the meantime :slight_smile: