Hi everyone,
I am using QubitUnitary together with JAX. But when jitting or vmapping my circuits there are some errors that are very cryptic. There is also an easy fix. But it might require a separate class for jax.
A minimal working example to reproduce the error:
def qnode_YX():
m = jnp.matmul(qml.PauliX(0).matrix, jnp.array(qml.PauliY(0).matrix))
U = partial(qml.QubitUnitary, m, wires = 0)
dev = qml.device('default.qubit.jax', wires=1, shots = None)
@partial(jax.jit, static_argnums = 0)
@qml.qnode(dev, interface='jax', diff_method=None)
def circuit(U):
U()
return qml.state()
return circuit(U)
The error message is pretty long but mainly complaining about the tracer value that appear because the Qubit Unitary class checks if the given matrix has the right dimensions and also if it is indeed unitary. When I remove these checks the QubitUnitary works with jax.jit.
Min working example for the QubitUnitary class:
from pennylane.operation import AnyWires, Operation
from pennylane.wires import Wires
class QubitUnitary(Operation):
num_wires = AnyWires
num_params = 1
grad_method = None
def __init__(self, *params, wires, do_queue=True):
wires = Wires(wires)
super().__init__(*params, wires=wires, do_queue=do_queue)
@classmethod
def _matrix(cls, *params):
return params[0]
def adjoint(self):
return QubitUnitary(qml.math.T(qml.math.conj(self.matrix)), wires=self.wires)
def label(self, decimals=None, base_label=None):
return super().label(decimals=decimals, base_label=base_label or "U")
I thought it might be of interest to other people.