JAX + VQE TypeError problem

Hi, I need a little of help, I was trying to use the JAX interface to accelerate VQE, but I’m having some troubles with data types and I don’t know where is the error (below you can find the error message and the minimal code that reproduced). I was thinking that I should use the grad JAX’s function to update the parameters (as is done in this tutorial Using JAX with PennyLane | PennyLane Demos), but I’m looking the option to keep the pennylane’s optimizers.
Any tip or suggestions is welcome. Thanks and have a happy holiday.

Jax

Error message

TypeError: Can't find vector space for value -1.1173489211359313 of type <class 'jaxlib.xla_extension.ArrayImpl'>. Valid types are dict_keys([<class 'autograd.core.SparseObject'>, <class 'list'>, <class 'tuple'>, <class 'dict'>, <class 'numpy.ndarray'>, <class 'float'>, <class 'numpy.longdouble'>, <class 'numpy.float64'>, <class 'numpy.float32'>, <class 'numpy.float16'>, <class 'complex'>, <class 'numpy.clongdouble'>, <class 'numpy.complex64'>, <class 'numpy.complex128'>, <class 'numpy.linalg.linalg.EigResult'>, <class 'numpy.linalg.linalg.EighResult'>, <class 'numpy.linalg.linalg.QRResult'>, <class 'numpy.linalg.linalg.SlogdetResult'>, <class 'numpy.linalg.linalg.SVDResult'>, <class 'pennylane.numpy.tensor.tensor'>])

Code

from pennylane import numpy as np
import pennylane as qml
from jax.config import config
import jax

config.update("jax_enable_x64", True)


symbols = ["H", "H"]
coordinates = np.array([0.0, 0.0, -0.6614, 0.0, 0.0, 0.6614])
H, qubits = qml.qchem.molecular_hamiltonian(symbols, coordinates)

dev = qml.device("lightning.qubit", wires=qubits)

electrons = 2
hf = qml.qchem.hf_state(electrons, qubits)


def circuit(param, wires):
    qml.BasisState(hf, wires=wires)
    qml.DoubleExcitation(param, wires=[0, 1, 2, 3])


@qml.qnode(dev, interface="jax")
def cost_fn(param):
    circuit(param, wires=range(qubits))
    return qml.expval(H)


opt = qml.GradientDescentOptimizer(stepsize=0.4)
theta = np.array(0.0, requires_grad=True)

energy = [cost_fn(theta)]
angle = [theta]

max_iterations = 100
conv_tol = 1e-06

for n in range(max_iterations):
    theta, prev_energy = opt.step_and_cost(cost_fn, theta)

    energy.append(cost_fn(theta))
    angle.append(theta)

    conv = np.abs(energy[-1] - prev_energy)

    if n % 2 == 0:
        print(f"Step = {n},  Energy = {energy[-1]:.8f} Ha")

    if conv <= conv_tol:
        break

Libraries

I’m using Python 3.9.6, pennylane’s 0.33.1 and JAX 0.4.23.

Jax-jit

And changing to “jax-jit” I have a similar TypeError problem

Error

TypeError: Argument 'Autograd ArrayBox with value 0.0' of type <class 'autograd.numpy.numpy_boxes.ArrayBox'> is not a valid JAX type.

Code

from pennylane import numpy as np
import pennylane as qml
from jax.config import config
import jax

config.update("jax_enable_x64", True)


symbols = ["H", "H"]
coordinates = np.array([0.0, 0.0, -0.6614, 0.0, 0.0, 0.6614])
H, qubits = qml.qchem.molecular_hamiltonian(symbols, coordinates)

dev = qml.device("lightning.qubit", wires=qubits)

electrons = 2
hf = qml.qchem.hf_state(electrons, qubits)


def circuit(param, wires):
    qml.BasisState(hf, wires=wires)
    qml.DoubleExcitation(param, wires=[0, 1, 2, 3])


@qml.qnode(dev, interface="jax-jit")
def cost_fn(param):
    circuit(param, wires=range(qubits))
    return qml.expval(H)


opt = qml.GradientDescentOptimizer(stepsize=0.4)
theta = np.array(0.0, requires_grad=True)

energy = [cost_fn(theta)]
angle = [theta]

max_iterations = 100
conv_tol = 1e-06

for n in range(max_iterations):
    theta, prev_energy = opt.step_and_cost(cost_fn, theta)

    energy.append(cost_fn(theta))
    angle.append(theta)

    conv = np.abs(energy[-1] - prev_energy)

    if n % 2 == 0:
        print(f"Step = {n},  Energy = {energy[-1]:.8f} Ha")

    if conv <= conv_tol:
        break

Hey @jnorambu! I recommend checking out this bit of the documentation that goes into how PennyLane needs to interface with Jax for optimizing circuits: JAX interface — PennyLane 0.33.0 documentation. The PennyLane optimizers can’t be used with the Jax interface.

Let me know if this helps and if you have more questions!

Hi @jnorambu,

Unfortunately if you use Jax then you need to use JAXopt and Optax optimizers.

I updated your code to work with Optax. I hope this helps you make the transition and let me know if you have any questions!

import pennylane as qml
import jax
import jax.numpy as jnp
import optax
from jax.config import config
config.update("jax_enable_x64", True)

symbols = ["H", "H"]
coordinates = jnp.array([0.0, 0.0, -0.6614, 0.0, 0.0, 0.6614])
H, qubits = qml.qchem.molecular_hamiltonian(symbols, coordinates)

dev = qml.device("lightning.qubit", wires=qubits)

electrons = 2
hf = qml.qchem.hf_state(electrons, qubits)


def circuit(param, wires):
    qml.BasisState(hf, wires=wires)
    qml.DoubleExcitation(param, wires=[0, 1, 2, 3])

@jax.jit
@qml.qnode(dev, interface="jax")
def cost_fn(param):
    circuit(param, wires=range(qubits))
    return qml.expval(H)

theta = jnp.array(1.0)

energy = [cost_fn(theta)]
angle = [theta]

max_iterations = 100
conv_tol = 1e-06

learning_rate = 0.15
optimizer = optax.adam(learning_rate)

opt_state = optimizer.init(theta)

for n in range(max_iterations):
    grads = jax.grad(cost_fn)(theta)
    updates, opt_state = optimizer.update(grads, opt_state)
    theta = optax.apply_updates(theta, updates)

    energy.append(cost_fn(theta))
    angle.append(theta)

    conv = jnp.abs(energy[-1] - energy[-2])
    if n % 2 == 0:
        print(f"Step = {n},  Energy = {energy[-1]:.8f} Ha")

    if conv <= conv_tol:
        break

1 Like

Thanks, Catalina, that was exactly what I want. I will need to test some things, but in essence is what I need. There are a lot of documents to review :sweat_smile:. Both of you have a nice Christmas.

1 Like