QNGOptimizer with Variational Classifiers

Hi @Amandeep ,

Optimization problems are problem-specific, meaning that it’s very hard to tell which optimizer will give better or faster results than another. There’s no one-size-fits-all.

That being said, I can give you some pointers.

I made a QNGO test script that you can see below. For that script the table below answers the question: “Does this diff_method and device combination work?”

diff_method default.qubit lightning.qubit
backprop Yes No
adjoint No No
parameter-shift Yes Yes
finite-diff Yes Yes
hadamard Yes Yes
device No No
spsa Yes Yes
best Yes Yes

Not all combinations will work equally fast. Generally default.qubit and backprop will be fastest. However when you have 15 qubits or more lightning.qubit might be fastest as seen on our performance page. However, parameter-shift can be very slow, so you’ll have to test yourself what works better for your particular use case.

# Data
data = pnp.array([0.,1.],requires_grad=False)

# Device
n_qubits=2
# We create a device with one extra wire because we need an auxiliary wire when using QNGO
dev = qml.device('default.qubit', wires=n_qubits+1)

# QNode
diff_method='backprop'

@qml.qnode(dev,diff_method=diff_method)
def circuit(params):
  # Data embedding
  qml.RX(data[0],wires=0)
  qml.RX(data[1],wires=1)

  # Parametrized layer
  qml.Rot(params[0],params[1],params[2],wires=0)
  qml.Rot(params[0],params[1],params[2],wires=1)
  qml.Hadamard(wires=0)
  qml.CNOT(wires=[0,1])

  # Measurement
  return qml.expval(qml.Z(0))

# Initial value of the parameters
params = pnp.array([1.,2.,3.],requires_grad=True)

# Initial value of the circuit
print(circuit(params))

# Cost function
def cost_f(params):
  return pnp.abs(circuit(params))

# Optimizer
opt = qml.QNGOptimizer()

# If we're using QNGO we need to define a metric tensor function
mt_fn = qml.metric_tensor(circuit)
print(mt_fn(params))

# Optimization loop
for it in range(10):
  params = opt.step(cost_f,params,metric_tensor_fn=mt_fn)
  print(params)
  print('Cost: ', cost_f(params))

Note that the non-trainable data here is passed directly into the inside of the function and not as an argument, otherwise things start breaking. I hope this helps!