Param shift, intermediary evaluations

Hello Pennylane Team,
Say I have a qnode circuit.
Is there a way instead of getting directly the gradient evaluation using qml.gradients.param_shift(circuit), one can get the intermediary shift evaluations, which can then be reused on top of computing the gradient?

So far, I had to reimplement the shift-rule but if you have a nicer way, I’d be glad to learn.

Hi @cnada!

There is indeed a way of doing this, but it requires ‘dropping down’ to a lower level of abstraction, and working with tapes. These are the internal data structures that represents the variational circuit, and are created internally whenever a QNode is executed.

You can create a tape manually:

with qml.tape.QuantumTape() as tape:
    qml.RX(0.5, wires=0)
    qml.CNOT(wires=[0, 1])
    qml.expval(qml.PauliZ(1))

or you can extract it from a QNode:

dev = qml.device("default.qubit", wires=2)

@qml.qnode(dev)
def circuit(x):
    qml.RX(x, wires=0)
    qml.CNOT(wires=[0, 1])
    return qml.expval(qml.PauliZ(1))

# first we 'construct' the QNode with a particular
# parameter value
x = np.array(0.6, requires_grad=True)

# pass a list of QNode arguments and
# a dict of keyword arguments
circuit.construct([x], {})

tape = circuit.tape

Now, we can apply qml.gradients.param_shift directly to the tape. Unlike when we apply this gradient transform to a QNode, no quantum evaluation will occur.

Instead, we get returned a list of gradient tapes (representing each shifted argument in the parameter-shift rule) and a post-processing function.

>>> g_tapes, fn = qml.gradients.param_shift(tape)
>>> len(g_tapes)
2
>>> for t in g_tapes:
...     print(t.draw())
 0: ──RX(2.17)──╭C──┤
 1: ────────────╰X──┤ ⟨Z⟩

 0: ──RX(-0.971)──╭C──┤
 1: ──────────────╰X──┤ ⟨Z⟩

We can now evaluate these ‘gradient tapes’ using qml.execute:

>>> results = qml.execute(g_tapes, dev, None)
>>> results
[array([-0.56464247]), array([0.56464247])]

and compute the gradient by ‘post-processing’ them with the generated function fn:

>>> fn(results)
array([[-0.56464247]])