Discarding and replacing a qubit inside a parametric circuit

Hi everyone,

I am using PennyLane to train a PQC to implement a certain transformation on an input state \lvert \psi \rangle, to match a target state. I am performing some non-standard operations and I was wondering if my code could be improved.

Theory
Right now I am experimenting with non-linearity, and to do so, I am using a StronglyEntanglingLayer to which I am feeding my state, after extending it with an ancilla qubit, namely

\lvert 0 \rangle \otimes \lvert \psi \rangle

at the output, I get something like

\lvert 0 \rangle \otimes \lvert \phi \rangle + \lvert 1 \rangle \otimes \lvert \chi \rangle

And I am post-selecting only the first part (this corresponds to measuring the ancilla and accepting the output only if the outcome is \lvert 0 \rangle ). So this circuit takes \lvert \psi \rangle and outputs \lvert \phi \rangle.

Code implementation

To implement this circuit, I am using the following code:

import pennylane as qml
import torch

NUM_QUBITS = 3
BATCH_SIZE = 2

# get the device
dev = qml.device("default.qubit", wires=NUM_QUBITS+1)

# define unitary block
@qml.qnode(dev, interface="torch")
def block(thetas, state):

    # load the initial state
    qml.QubitStateVector(state, wires=range(NUM_QUBITS+1))

    # create parameterized circuit
    qml.StronglyEntanglingLayers(thetas, wires = range(NUM_QUBITS+1))

    # return state vector
    return qml.state()

# define general circuit
def circuit(thetas, state = None):
    # extend state by concatenating ancilla in the |0> state
    ext_state = torch.cat((state, torch.zeros(state.shape[0], 2**NUM_QUBITS)), dim = 1)
    # apply unitary block
    state = block(thetas, ext_state)
    # measure ancilla qubit and accept remaining qubits only if ancilla is |0>, namely take first half of the state
    state = state[:, :2**NUM_QUBITS]
    # renormalise
    state = state/torch.linalg.vector_norm(state, dim = 1).view(-1, 1)

    return state

# get initial state (batched) and weights
state = torch.randn(BATCH_SIZE, 2**NUM_QUBITS)
state = state/torch.linalg.vector_norm(state, dim = 1).view(-1, 1)
thetas = torch.randn(2, NUM_QUBITS+1, 3, requires_grad=True)

# feed forward
output = circuit(thetas, state)
print(output)

The script above is a simplification of my actual one. In reality, I am experimenting with hundreds of these layers, looping inside circuit(), extending the state with the ancilla, feeding to block() and taking one half of the state each time (post-selecting). I am pretty sure this is very inefficient computationally, since I am calling block() hundreds of times for each circuit() call, and with ~100 layers, a circuit call with a batch size of 64 and 8 qubits takes around 1 sec.

Question

Is there a way to add an ancilla in the state \lvert 0 \rangle and post-select from within block(), without having to call it every time inside another function? Are there any obvious ways to make it faster?

Thank you very much!

Hey @Andrea! Welcome to the forum! :rocket:!

I may not be understanding correctly, but can you project out the state(s) you want with qml.Projector?