Error using adjoint differentiation for metric tensor with complex statevector outputs

Hi!

I came across an issue when I applied diff_method="adjoint" using JAX backend under the hood of built-in metric_tensor. Seems like JAX only allows real‐valued outputs for reverse‐mode differentation. Since the QNode returns the full quantum state (a complex vector), we must tell JAX to treat the function as holomorphic.
I could workaround, as I use qml.expval(), however in that case later when I optimize the circuit I run into a ProbabilityMP error. I would appreciate any advice on how to use adjoint differentiation and JAX and full state.

Here is my code:

dev = qml.device('default.qubit')

# Parameter values
theta   = jnp.array([jnp.pi / 2, jnp.pi / 2])  # One theta per node

@qml.qnode(dev, interface="jax", diff_method="adjoint")
def circuit(theta):

    # Hamiltonian
    obs = [qml.Z(0), qml.X(1)]
    H_Z = qml.Hamiltonian(theta, obs)

    # Dynamics
    n = 3 # number of Trotter steps
    for _ in range(n):
      qml.ApproxTimeEvolution(H_Z, 1/n, 1)

    # Return
    return  qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliZ(1))


QFIM = 4*qml.gradients.metric_tensor(circuit)(theta)
print(QFIM)

This is the error message I get:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/tmp/ipython-input-99-1342454709.py in <cell line: 0>()
----> 1 QFIM = 4*qml.gradients.metric_tensor(circuit)(theta)
      2 print(QFIM)

10 frames
/usr/local/lib/python3.11/dist-packages/jax/_src/api.py in _check_output_dtype_revderiv(name, holomorphic, x)
    759                       f"but got {aval.dtype.name}.")
    760   elif dtypes.issubdtype(aval.dtype, np.complexfloating):
--> 761     raise TypeError(f"{name} requires real-valued outputs (output dtype that is "
    762                     f"a sub-dtype of np.floating), but got {aval.dtype.name}. "
    763                     "For holomorphic differentiation, pass holomorphic=True. "

TypeError: jacrev requires real-valued outputs (output dtype that is a sub-dtype of np.floating), but got complex128. For holomorphic differentiation, pass holomorphic=True. For differentiation of non-holomorphic functions involving complex outputs, use jax.vjp directly.

The Pennylane version I used:

Name: PennyLane
Version: 0.41.1
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page: https://github.com/PennyLaneAI/pennylane
Author: 
Author-email: 
License: Apache License 2.0
Location: /usr/local/lib/python3.11/dist-packages
Requires: appdirs, autograd, autoray, cachetools, diastatic-malt, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, tomlkit, typing-extensions
Required-by: PennyLane_Lightning

Platform info:           Linux-6.1.123+-x86_64-with-glibc2.35
Python version:          3.11.13
Numpy version:           2.0.2
Scipy version:           1.15.3
Installed devices:
- lightning.qubit (PennyLane_Lightning-0.41.1)
- default.clifford (PennyLane-0.41.1)
- default.gaussian (PennyLane-0.41.1)
- default.mixed (PennyLane-0.41.1)
- default.qubit (PennyLane-0.41.1)
- default.qutrit (PennyLane-0.41.1)
- default.qutrit.mixed (PennyLane-0.41.1)
- default.tensor (PennyLane-0.41.1)
- null.qubit (PennyLane-0.41.1)
- reference.qubit (PennyLane-0.41.1)

I’m sorry, seems like updating PennyLane to 0.42 solved the issue.

1 Like