Optmizing molecular coordinate using QPE and VQE

Suppose I have 2 H molecule in 3D space and I would like to optimize them without reducing to a 1D problem. I could see two solutions for that, one is VQE and the 2nd is QPE

VQE

I wrote some code to do that, but it seems Pennylane converting the coordinate to autograd tensor is at odd with my current implementation?

The code

Here I fix a H atom at position 0, 0, 0 and optimize the coordinate of the other atom H_1

import pennylane as qml
from pennylane import numpy as np
import jax
import optax


dev = qml.device("default.qubit", 4)

@qml.qnode(dev)
def circuit_expected(H):
    qml.BasisState([1,1,0,0], wires=[0, 1, 2, 3])
    qml.DoubleExcitation(0.2, wires=[0, 1, 2, 3])
    return qml.expval(H)

def loss_f(coord):
    symbols = ["H", "H"]
    H_0 = np.array([0.,0.,0.])  # fixed

    H, qb = qml.qchem.molecular_hamiltonian(symbols, np.array([0,0,0, *coord]))
    return circuit_expected(H)

def optimize():
    
    H_1 = np.array([1.,1.,1.])
    
    # prepare for the 1st run
    conv_tol = 1e-6
    opt = optax.sgd(learning_rate=0.4)

    max_iterations = 100
    opt = optax.sgd(learning_rate=0.4)
    opt_coords_state = opt.init(H_1)
    
    for i in range(10):        
        grad_coordinates = jax.grad(loss_f, 0)(H_1)
        updates, opt_coords_state = opt.update(grad_coordinates, opt_coords_state)
        H_1 = optax.apply_updates(H_1, updates)
        print(grad_coordinates)
    
optimize()
The error
TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[].
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

on the line of H, qb = qml.qchem.molecular_hamiltonian(symbols, np.array([0,0,0, *coord]))

I tried with np.array(), jnp.array(), jnp.concatenation() but the error persists. What am I doing wrong here?

QPE

Suppose I found \phi in U |\psi \rangle = e^{i \phi} |\psi \rangle and use that to find eigenstate |\psi\rangle. How can I derive the position operator? My guess if I take the expected value of the position operator, then I can arrive at the answer, but again I don’t know yet how to define that position operator.

Hi @mchau ,

Your code seems very complex.

My recommendation would be to follow our demo for optimization of molecular geometries or our demo on Differentiable Hartree-Fock instead.

Let me know if these do what you were looking for. If not then it’s best to split the problem into small parts and create a minimal working example in order to find the cause for the issue.

I hope this helps you!

1 Like

Sure, let me simplify the problem. I am trying to optimize the coordinates of two H molecules so that they form H_2. I init their position as H_{000}=[0,0,0] and H_{111}=[1,1,1].

I have an array named coords=[1,1,1], for which I would like to optimize using JAX

>>> coords
Traced<ConcreteArray([1. 1. 1.], dtype=float32)>with<JVPTrace(level=2/0)>
...

To calculate the Hamiltonian, I have to concatenate coords with a np.array([0, 0, 0)]. Rather than optimizing for 6 coordinates, I fixed the H_{000} and only optimize H_{111} coordinates. Therefore, I do

np.array([0, 0, 0, *coord])

It wouldn’t work, the error is
TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[].

I tried with jnp.concatenate(), jnp.array() but they won’t work when applying to the full code below.

Code
import pennylane as qml
from pennylane import numpy as np
import jax
import optax

dev = qml.device("default.qubit", 4)

@qml.qnode(dev)
def circuit_expected(H):
    qml.BasisState([1, 1, 0, 0], wires=[0, 1, 2, 3])
    qml.DoubleExcitation(0.2, wires=[0, 1, 2, 3])
    return qml.expval(H)


def loss_f(coord):
    symbols = ["H", "H"]
    H, qb = qml.qchem.molecular_hamiltonian(symbols, np.array([0, 0, 0, *coord]))
    return circuit_expected(H)

H_1 = np.array([1., 1., 1.])
opt = optax.sgd(learning_rate=0.4)
opt_coords_state = opt.init(H_1)

for i in range(10):
    grad_coordinates = jax.grad(loss_f, 0)(H_1)
    updates, opt_coords_state = opt.update(grad_coordinates, opt_coords_state)
    H_1 = optax.apply_updates(H_1, updates)
    print(grad_coordinates)

Hope it is clearer @CatalinaAlbornoz ?

Hi @mchau,

Unfortunately Jax isn’t as versatile as we would like it to be. It does seem like a purely Jax issue. Are you able to swich from Jax to something else? Torch usually fails less so you could use PennyLane’s Torch interface instead.

1 Like

Hi @mchau,

Actually the issue here might be that you’re mixing PennyLane’s numpy with JAX numpy. They don’t work well together. Why don’t you try using JAX without vanilla numpy and let me know if this solves your problem?

1 Like

Hi Catalina,

Thank you for you answer, yes I tried your suggestion

Below is the code with the tried method.

Summary
import pennylane as qml
from pennylane import numpy as np
import jax
import optax

dev = qml.device("default.qubit", 4)

@qml.qnode(dev)
def circuit_expected(H):
    qml.BasisState([1, 1, 0, 0], wires=[0, 1, 2, 3])
    qml.DoubleExcitation(0.2, wires=[0, 1, 2, 3])
    return qml.expval(H)


def loss_f(coord):
    symbols = ["H", "H"]
    H, qb = qml.qchem.molecular_hamiltonian(symbols, jax.numpy.array([0, 0, 0, *coord]))
    return circuit_expected(H)

H_1 = np.array([1., 1., 1.])
opt = optax.sgd(learning_rate=0.4)
opt_coords_state = opt.init(H_1)

for i in range(10):
    grad_coordinates = jax.grad(loss_f, 0)(H_1)
    updates, opt_coords_state = opt.update(grad_coordinates, opt_coords_state)
    H_1 = optax.apply_updates(H_1, updates)
    print(grad_coordinates)

I think the main trouble here is that I don’t want to calculate the gradient descent for every coordinates in one tensor, but in the current setting of Pennylane I guess it is not possible?

Hi @mchau ,

It looks like molecular_hamiltonian unfortunately only works with autograd so you would need to go back to using PennyLane Numpy instead of Jax :crying_cat_face: .

Thank you for opening this Forum thread and the issue on GitHub, we’ll look into adding a warning in the documentation about this.

1 Like

Hi @mchau ,

It looks like diff_hamiltonian does work with Jax. Why don’t you try it and let us know if it works for you?

1 Like