Hi, I met some problems while doing loss.backward()
related to the trainable parameters of
quantum circuit.
The example code:
import pennylane as qml
import torch
import torch.nn as nn
import numpy as np
import math
def create_qnode(info_size, anc_size):
info_qubit = int(np.ceil(math.log2(info_size)))
anc_qubit = int(np.ceil(math.log2(anc_size)))
total_qubit = info_qubit + anc_qubit
dev = qml.device("default.qubit", wires=total_qubit)
@qml.qnode(dev)
def qnode(input_feature, weight):
qml.AmplitudeEmbedding(features=input_feature, wires=range(total_qubit), normalize=True)
for ii in range(5):
for jj in range(anc_qubit):
qml.RY(weight[jj], wires=info_qubit+jj)
for jj in range(anc_qubit-1):
qml.CNOT(wires=[info_qubit+jj,info_qubit+jj+1])
return qml.probs(wires=list(range(total_qubit)))
return qnode
def QuantumCircuit(input_state, info_size, anc_size, params):
qnode = create_qnode(info_size=info_size, anc_size=anc_size)
qnode_output_tmp = qnode(input_state.clone(),params[0].detach().numpy())
qnode_output = torch.from_numpy(qnode_output_tmp)
return qnode_output
information_size = 4
ancilla_size = 4
q_params = nn.ParameterList([nn.Parameter(torch.randn(5 * information_size), requires_grad=True)])
output = QuantumCircuit(input_state=torch.arange(16), info_size=information_size, anc_size=ancilla_size, params=q_params)
target = output[0]
label = torch.tensor(0.5,dtype=torch.float64)
criterion = nn.BCELoss()
loss = criterion(output[0],label)
loss.backward()
I get an error from it:
And I thought the error meant it was unsuccessful while computing the gradient on parameters.
This line from the self-defined function QuantumCircuit(...)
might have caused this error.
qnode_output_tmp = qnode(input_state.clone(), params[0].detach().numpy())
The variable/parameter params[0].detach().numpy()
detaches the params[0] from the current computational graph and makes loss.backward()
fail to require gradient, so I thought this is why the above error occurred.
However, when I removed detach()
from the params[0], such as:
qnode_output_tmp = qnode(input_state.clone(),params[0])
And the other error might show:
Please help me if you have any idea to fix the above error, thanks!