Hi!
I came across an issue when I applied diff_method="adjoint"
using JAX backend under the hood of built-in metric_tensor
. Seems like JAX only allows real‐valued outputs for reverse‐mode differentation. Since the QNode returns the full quantum state (a complex vector), we must tell JAX to treat the function as holomorphic.
I could workaround, as I use qml.expval()
, however in that case later when I optimize the circuit I run into a ProbabilityMP
error. I would appreciate any advice on how to use adjoint differentiation and JAX and full state.
Here is my code:
dev = qml.device('default.qubit')
# Parameter values
theta = jnp.array([jnp.pi / 2, jnp.pi / 2]) # One theta per node
@qml.qnode(dev, interface="jax", diff_method="adjoint")
def circuit(theta):
# Hamiltonian
obs = [qml.Z(0), qml.X(1)]
H_Z = qml.Hamiltonian(theta, obs)
# Dynamics
n = 3 # number of Trotter steps
for _ in range(n):
qml.ApproxTimeEvolution(H_Z, 1/n, 1)
# Return
return qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliZ(1))
QFIM = 4*qml.gradients.metric_tensor(circuit)(theta)
print(QFIM)
This is the error message I get:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/tmp/ipython-input-99-1342454709.py in <cell line: 0>()
----> 1 QFIM = 4*qml.gradients.metric_tensor(circuit)(theta)
2 print(QFIM)
10 frames
/usr/local/lib/python3.11/dist-packages/jax/_src/api.py in _check_output_dtype_revderiv(name, holomorphic, x)
759 f"but got {aval.dtype.name}.")
760 elif dtypes.issubdtype(aval.dtype, np.complexfloating):
--> 761 raise TypeError(f"{name} requires real-valued outputs (output dtype that is "
762 f"a sub-dtype of np.floating), but got {aval.dtype.name}. "
763 "For holomorphic differentiation, pass holomorphic=True. "
TypeError: jacrev requires real-valued outputs (output dtype that is a sub-dtype of np.floating), but got complex128. For holomorphic differentiation, pass holomorphic=True. For differentiation of non-holomorphic functions involving complex outputs, use jax.vjp directly.
The Pennylane version I used:
Name: PennyLane
Version: 0.41.1
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page: https://github.com/PennyLaneAI/pennylane
Author:
Author-email:
License: Apache License 2.0
Location: /usr/local/lib/python3.11/dist-packages
Requires: appdirs, autograd, autoray, cachetools, diastatic-malt, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, tomlkit, typing-extensions
Required-by: PennyLane_Lightning
Platform info: Linux-6.1.123+-x86_64-with-glibc2.35
Python version: 3.11.13
Numpy version: 2.0.2
Scipy version: 1.15.3
Installed devices:
- lightning.qubit (PennyLane_Lightning-0.41.1)
- default.clifford (PennyLane-0.41.1)
- default.gaussian (PennyLane-0.41.1)
- default.mixed (PennyLane-0.41.1)
- default.qubit (PennyLane-0.41.1)
- default.qutrit (PennyLane-0.41.1)
- default.qutrit.mixed (PennyLane-0.41.1)
- default.tensor (PennyLane-0.41.1)
- null.qubit (PennyLane-0.41.1)
- reference.qubit (PennyLane-0.41.1)