QubitUnitary and jax.jit

Hi everyone,
I am using QubitUnitary together with JAX. But when jitting or vmapping my circuits there are some errors that are very cryptic. There is also an easy fix. But it might require a separate class for jax.

A minimal working example to reproduce the error:

def qnode_YX():
    m = jnp.matmul(qml.PauliX(0).matrix, jnp.array(qml.PauliY(0).matrix))
    U = partial(qml.QubitUnitary, m, wires = 0)
    dev = qml.device('default.qubit.jax', wires=1, shots = None)

    @partial(jax.jit, static_argnums = 0)
    @qml.qnode(dev, interface='jax', diff_method=None)
    def circuit(U):
       U()
       return qml.state()

    return circuit(U)

The error message is pretty long but mainly complaining about the tracer value that appear because the Qubit Unitary class checks if the given matrix has the right dimensions and also if it is indeed unitary. When I remove these checks the QubitUnitary works with jax.jit.

Min working example for the QubitUnitary class:

from pennylane.operation import AnyWires, Operation
from pennylane.wires import Wires

class QubitUnitary(Operation):
    num_wires = AnyWires
    num_params = 1
    grad_method = None

def __init__(self, *params, wires, do_queue=True):
    wires = Wires(wires)

    super().__init__(*params, wires=wires, do_queue=do_queue)

@classmethod
def _matrix(cls, *params):
    return params[0]


def adjoint(self):
    return QubitUnitary(qml.math.T(qml.math.conj(self.matrix)), wires=self.wires)


def label(self, decimals=None, base_label=None):
    return super().label(decimals=decimals, base_label=base_label or "U")

I thought it might be of interest to other people.

Hi @PatrickHuembeli, thank you for pointing this out!

We will take a look into it.

Thanks for posting this here!

Hi @PatrickHuembeli, I just wanted to let you know that we will be opening a bug issue about this. Thanks again for reporting it!