Adjoint Gradient with complicated targets

Pennylane’s Adjoint Differentiation is really great.
I was looking at the table of various diff methods and their pennylane support at Gradients and training — PennyLane 0.31.0 documentation

My questions are about Adjoint diff with statevector return type:

  1. I see it’s designated as 6. Not supported. The adjoint differentiation algorithm is only implemented for computing the expectation values of observables.
    Does it mean it’s implementable and there is no theory saying it’s not possible?
  2. assuming it is implementable, do you guys have a timeline on when you expect to add support for it?

To add more context - I am interested in using adjoint differentiation method where the target function I am taking gradients with respect to is not an observable but is some function that depends on the whole statevector.

Thanks

1 Like

Hey @Hayk_Tepanyan! Welcome back :sunglasses:

I spoke to someone a bit more knowledgeable on this and here’s what they said:

There should be no theoretical barrier to doing this. You would run backprop as usual up until the statevector step, then you would swap over to an adjoint implementation from then onwards. PL itself is currently not structured to take advantage of this, but there are bits in PL-lightning that should actually support this if standard PL ever does.

See these bits of code in PL-lightning:

Hope this helps!

1 Like

Hi @isaacdevlugt ,

Thanks for the info.
I looked at the math of the adjoint derivatives and it seems to me it uses the fact that the derivative is with respect to an observable. If instead it’s a more complex function, say L2 norm, seems like its still doable to compute the derivatives (as your response suggests) but now the runtime and memory overhead will be quadratically worse. E.g. instead of O(2^n) it will be O(2^{2n}) and there is a big difference between the two.
Please let me know if this your conclusion as well or I am missing something.

Thanks!

Hi @Hayk_Tepanyan,

I’m checking with one of my colleagues but it would seem that it’s O(2^n). I will take a deeper look at this with her and one of us will come back to you with a more complete answer.

I just realized I’ve been thinking about it the wrong way and it wouldn’t be as hard to add as I thought it would be.

I make no promises about any sorts of performance scaling or that it works well with larger circuits and all interfaces, but I have a prototype here: GitHub - PennyLaneAI/pennylane at adjoint-diff-state

On this branch, I can do:

import pennylane as qml
import jax

@qml.qnode(qml.devices.experimental.DefaultQubit2())
def circuit(x, y):
    qml.RX(x, 0)
    qml.RY(y, 0)
    qml.CNOT((0,1))
    return qml.state()

jac1 = jax.jacobian(circuit, holomorphic=True, argnums=(1))(jax.numpy.array(0.1+0j), jax.numpy.array(0.2+0j))


@qml.qnode(qml.devices.experimental.DefaultQubit2(), diff_method="adjoint")
def circuit2(x, y):
    qml.RX(x, 0)
    qml.RY(y, 0)
    qml.CNOT((0,1))
    return qml.state()

jac2 = jax.jacobian(circuit2, holomorphic=True, argnums=(1))(jax.numpy.array(0.1+0j), jax.numpy.array(0.2+0j))
>>> jac1 
Array([-0.04985433+0.02486474j,  0.        +0.j        ,
        0.        +0.j        ,  0.49688035+0.0024948j ],      dtype=complex64, weak_type=True)
>>> jac2
Array([-0.04985433+0.02486474j,  0.        +0.j        ,
        0.        +0.j        ,  0.49688032+0.0024948j ],      dtype=complex64, weak_type=True)

Unfortunately, this isn’t on our roadmap, and adding it would take some more testing and documentation. We won’t be adding it to any of our performance simulators, but it may be an option for our next gen python simulator.

Feel free to explore that branch and let me know if it works ok. If we have time between our other priorities, we may slip this change in.

Relevant block of code is here: https://github.com/PennyLaneAI/pennylane/blob/93d169d21c3558550b6cd17b91b1d37af76b8a68/pennylane/devices/qubit/adjoint_jacobian.py#L51

4 Likes

Thanks @christina for the code pointer!

The performance is very important to us and we think adjoint differentiation is fascinatingly efficient - which is awesome.
The method you provided, correct me if I am wrong, has runtime of O(g^2*2^n) where g is the number of gates (and/or params), compared to the regular adjoint with observables that is O(g*2^n).
No need to provide an implementation, but theoretically, do you think it’s possible to implement adjoint diff with complicated targets (e.g. not observables) with runtime O(g*2^n)?
Looking at the math of adjoint diff I think it’s not possible since once you move away from observables you have to re-compute \langle{b_i}| for every param - just like your implementation does.

Let me know if I am missing something, Thanks!

@Hayk_Tepanyan Yes, great insights!

It does not seem like you’re missing anything, and this seems to be a very open theoretical which would need a fair amount of pondering on our side. Let us know if you come up with anything.

Cheers,

Alvaro

1 Like