Hi @Amandeep ,
Optimization problems are problem-specific, meaning that it’s very hard to tell which optimizer will give better or faster results than another. There’s no one-size-fits-all.
That being said, I can give you some pointers.
I made a QNGO test script that you can see below. For that script the table below answers the question: “Does this diff_method and device combination work?”
diff_method | default.qubit | lightning.qubit |
---|---|---|
backprop | Yes | No |
adjoint | No | No |
parameter-shift | Yes | Yes |
finite-diff | Yes | Yes |
hadamard | Yes | Yes |
device | No | No |
spsa | Yes | Yes |
best | Yes | Yes |
Not all combinations will work equally fast. Generally default.qubit
and backprop
will be fastest. However when you have 15 qubits or more lightning.qubit
might be fastest as seen on our performance page. However, parameter-shift
can be very slow, so you’ll have to test yourself what works better for your particular use case.
# Data
data = pnp.array([0.,1.],requires_grad=False)
# Device
n_qubits=2
# We create a device with one extra wire because we need an auxiliary wire when using QNGO
dev = qml.device('default.qubit', wires=n_qubits+1)
# QNode
diff_method='backprop'
@qml.qnode(dev,diff_method=diff_method)
def circuit(params):
# Data embedding
qml.RX(data[0],wires=0)
qml.RX(data[1],wires=1)
# Parametrized layer
qml.Rot(params[0],params[1],params[2],wires=0)
qml.Rot(params[0],params[1],params[2],wires=1)
qml.Hadamard(wires=0)
qml.CNOT(wires=[0,1])
# Measurement
return qml.expval(qml.Z(0))
# Initial value of the parameters
params = pnp.array([1.,2.,3.],requires_grad=True)
# Initial value of the circuit
print(circuit(params))
# Cost function
def cost_f(params):
return pnp.abs(circuit(params))
# Optimizer
opt = qml.QNGOptimizer()
# If we're using QNGO we need to define a metric tensor function
mt_fn = qml.metric_tensor(circuit)
print(mt_fn(params))
# Optimization loop
for it in range(10):
params = opt.step(cost_f,params,metric_tensor_fn=mt_fn)
print(params)
print('Cost: ', cost_f(params))
Note that the non-trainable data here is passed directly into the inside of the function and not as an argument, otherwise things start breaking. I hope this helps!