I just realized I’ve been thinking about it the wrong way and it wouldn’t be as hard to add as I thought it would be.
I make no promises about any sorts of performance scaling or that it works well with larger circuits and all interfaces, but I have a prototype here: GitHub - PennyLaneAI/pennylane at adjoint-diff-state
On this branch, I can do:
import pennylane as qml
import jax
@qml.qnode(qml.devices.experimental.DefaultQubit2())
def circuit(x, y):
qml.RX(x, 0)
qml.RY(y, 0)
qml.CNOT((0,1))
return qml.state()
jac1 = jax.jacobian(circuit, holomorphic=True, argnums=(1))(jax.numpy.array(0.1+0j), jax.numpy.array(0.2+0j))
@qml.qnode(qml.devices.experimental.DefaultQubit2(), diff_method="adjoint")
def circuit2(x, y):
qml.RX(x, 0)
qml.RY(y, 0)
qml.CNOT((0,1))
return qml.state()
jac2 = jax.jacobian(circuit2, holomorphic=True, argnums=(1))(jax.numpy.array(0.1+0j), jax.numpy.array(0.2+0j))
>>> jac1
Array([-0.04985433+0.02486474j, 0. +0.j ,
0. +0.j , 0.49688035+0.0024948j ], dtype=complex64, weak_type=True)
>>> jac2
Array([-0.04985433+0.02486474j, 0. +0.j ,
0. +0.j , 0.49688032+0.0024948j ], dtype=complex64, weak_type=True)
Unfortunately, this isn’t on our roadmap, and adding it would take some more testing and documentation. We won’t be adding it to any of our performance simulators, but it may be an option for our next gen python simulator.
Feel free to explore that branch and let me know if it works ok. If we have time between our other priorities, we may slip this change in.
Relevant block of code is here: https://github.com/PennyLaneAI/pennylane/blob/93d169d21c3558550b6cd17b91b1d37af76b8a68/pennylane/devices/qubit/adjoint_jacobian.py#L51