How do I encode the MSE in a certain cost function?

Hello, I’m using Pennylane and JAX to implement a certain circuit called cost and compute its gradients.

def get_observables(N):
    
    observables = []

    # Coupling operators
    for i in range(N-1):
        observables.append(qml.PauliZ(i) @ qml.PauliZ(i+1))

    # Identity operator
    for i in range(N):
        observables.append(qml.Identity(i))

    return observables

def get_coeffs(params, N):

    coeffs = []

    # Coupling coeffs
    for i in range(N-1):
        coeffs.append((params[0])**2/params[1])

    # Constant coeffs
    for i in range(N):
        coeffs.append(params[1])

    return coeffs

def create_Hamiltonian(params):

    coeffs = get_coeffs(params, nqubits)
    obs = get_observables(nqubits)
    
    # H = qml.dot(coeffs, obs)
    H = qml.Hamiltonian(coeffs, obs)

    return H

def get_Sy(nqubits, a):
    S_0 = nqubits/2
    c = 0

    for i in range(nqubits):
        c += (1/(2*S_0))*qml.PauliY(wires=i)
    return qml.s_prod(a, c)

def create_params(L, scale = 0.1):
    
    params = jnp.array([])

    print('Params in create_params before:' + str(type(params)))

    for i in range(L):
        J = scale*np.random.uniform()
        O = 1.0
        theta = scale*np.random.uniform()

        # Convert the list to a NumPy array
        param_array = jnp.array([J, O, theta])
        
        params = jnp.append(params, param_array)

    print('Params in create_params after:' + str(type(params)))
            
    return params

def U1(params):
    start_index = 0
    num_trotter_steps = 10

    for i in range(L):
        new_params = params[start_index:start_index + 3]
        H = create_Hamiltonian(new_params[0:2])
        qml.evolve(H, num_steps = num_trotter_steps)
        for j in range(nqubits):
            qml.RX(new_params[2], wires=j) # Change params to make sure that theta value changes for each L
        start_index += 3 # Put state_index in again

dev = qml.device("default.qubit.jax", wires=nqubits, shots=None)

# @jax.jit
# @qml.qnode(dev, interface='jax')
def circuit(params, a, phi):
    print('Type of params in circuit:'+str(type(params)))
    print('Type of a in circuit:'+str(type(a)))

    for i in range(nqubits): # Making the initial CSS
        qml.Hadamard(wires=i)

    U1(params)

    for z in range(nqubits): # Perturbation
        qml.RY(phi, wires = z)

    qml.adjoint(U1)(params)

    # expectation_values = [qml.expval(qml.PauliY(wires=i)) for i in range(nqubits)]

    c = get_Sy(nqubits, a)

    print(c)

    return qml.expval(c)

# @jax.jit
@qml.qnode(dev, interface='jax')
def cost(params, a, phi):
    print('Type of params in cost: '+str(type(params)))
    print('Type of a in cost: '+str(type(a)))
    circuit_output = circuit(params, a, phi)
    mse = jnp.mean((phi-circuit_output)**2)
    return circuit_output

However, when I do try to calculate the gradients via

L = 1
params = create_params(L)
phi = jnp.array(0.001)
a = jnp.array(0.001)

val, grads = jax.value_and_grad(cost)(params, a, phi)

I get this error

    117 print('Type of a in cost: '+str(type(a)))
    118 circuit_output = circuit(params, a, phi)
--> 119 mse = jnp.mean((phi-circuit_output)**2)
    120 return circuit_output

TypeError: unsupported operand type(s) for -: 'ArrayImpl' and 'ExpectationMP'

How do I actually encode the MSE as the output of my cost QNode? Thank you so much!

Hey @NickGut0711,

It looks like you need to uncomment # @qml.qnode(dev, interface='jax') above your definition of circuit :). Also, your cost function can’t be decorated by @qml.qnode since it’s no a function that contains quantum instructions and a measurement process.

Let me know if that helps!