Hi @mlxd here is a minimal example. I believe the problem occurs due to parameter-shift and since Qiskit does not allow it, it doesn’t work properly.
import tensorflow as tf
import pennylane as qml
from pennylane import numpy as np
dev1 = qml.device("qiskit.aer", wires = 2, shots=10, backend='qasm_simulator')
dev2 = qml.device("default.qubit.tf", wires = 2, shots=None)
@qml.qnode(dev2, diff_method="backprop", interface="tf")
def circuit2(inputs, weights):
qml.AngleEmbedding(inputs, wires = range(2), rotation="Y")
qml.RY(weights[0], wires=0)
qml.RY(weights[1], wires=1)
qml.CNOT(wires = [0, 1])
return qml.probs(op=qml.PauliZ(1))
@qml.qnode(dev1, diff_method="parameter-shift", interface="tf")
def circuit1(inputs, weights):
qml.AngleEmbedding(inputs, wires = range(2), rotation="Y")
qml.RY(weights[0], wires=0)
qml.RY(weights[1], wires=1)
qml.CNOT(wires = [0, 1])
return qml.probs(op=qml.PauliZ(1))
weights = tf.Variable(tf.random.uniform((2,), dtype=tf.float64), trainable=True)
inputs = tf.random.uniform((10,2), dtype=tf.float64)
circ = tf.function(circuit2)
contract = lambda inpts : tf.vectorized_map(lambda vec: circ(vec, weights), inpts)
with tf.GradientTape() as tape:
yhat = tf.reduce_mean(contract(inputs))
tape.gradient(yhat, weights)
# Output : <tf.Tensor: shape=(2,), dtype=float64, numpy=array([ 1.38777878e-17, -2.77555756e-17])>
circ = tf.function(circuit1)
contract = lambda inpts : tf.vectorized_map(lambda vec: circ(vec, weights), inpts)
with tf.GradientTape() as tape:
yhat = tf.reduce_mean(contract(inputs))
tape.gradient(yhat, weights)
# Output: <tf.Tensor: shape=(2,), dtype=float64, numpy=array([0., 0.])>
My yhat
is poorly choosen but it shows whats happening here I believe. circuit1
always gives zero gradient no matter what but I can get proper results from circuit2
. Also If instead of using vectorized_map
if I use the batching function I wrote here it gives me a good gradient result as well but this does not parallelize the execution on GPU. So I believe I need to use vectorized_map
to parallelize the batch execution or is there any other way that you can suggest?
Thanks