Hi, I need a little of help, I was trying to use the JAX interface to accelerate VQE, but I’m having some troubles with data types and I don’t know where is the error (below you can find the error message and the minimal code that reproduced). I was thinking that I should use the grad JAX’s function to update the parameters (as is done in this tutorial Using JAX with PennyLane | PennyLane Demos), but I’m looking the option to keep the pennylane’s optimizers.
Any tip or suggestions is welcome. Thanks and have a happy holiday.
Jax
Error message
TypeError: Can't find vector space for value -1.1173489211359313 of type <class 'jaxlib.xla_extension.ArrayImpl'>. Valid types are dict_keys([<class 'autograd.core.SparseObject'>, <class 'list'>, <class 'tuple'>, <class 'dict'>, <class 'numpy.ndarray'>, <class 'float'>, <class 'numpy.longdouble'>, <class 'numpy.float64'>, <class 'numpy.float32'>, <class 'numpy.float16'>, <class 'complex'>, <class 'numpy.clongdouble'>, <class 'numpy.complex64'>, <class 'numpy.complex128'>, <class 'numpy.linalg.linalg.EigResult'>, <class 'numpy.linalg.linalg.EighResult'>, <class 'numpy.linalg.linalg.QRResult'>, <class 'numpy.linalg.linalg.SlogdetResult'>, <class 'numpy.linalg.linalg.SVDResult'>, <class 'pennylane.numpy.tensor.tensor'>])
Code
from pennylane import numpy as np
import pennylane as qml
from jax.config import config
import jax
config.update("jax_enable_x64", True)
symbols = ["H", "H"]
coordinates = np.array([0.0, 0.0, -0.6614, 0.0, 0.0, 0.6614])
H, qubits = qml.qchem.molecular_hamiltonian(symbols, coordinates)
dev = qml.device("lightning.qubit", wires=qubits)
electrons = 2
hf = qml.qchem.hf_state(electrons, qubits)
def circuit(param, wires):
qml.BasisState(hf, wires=wires)
qml.DoubleExcitation(param, wires=[0, 1, 2, 3])
@qml.qnode(dev, interface="jax")
def cost_fn(param):
circuit(param, wires=range(qubits))
return qml.expval(H)
opt = qml.GradientDescentOptimizer(stepsize=0.4)
theta = np.array(0.0, requires_grad=True)
energy = [cost_fn(theta)]
angle = [theta]
max_iterations = 100
conv_tol = 1e-06
for n in range(max_iterations):
theta, prev_energy = opt.step_and_cost(cost_fn, theta)
energy.append(cost_fn(theta))
angle.append(theta)
conv = np.abs(energy[-1] - prev_energy)
if n % 2 == 0:
print(f"Step = {n}, Energy = {energy[-1]:.8f} Ha")
if conv <= conv_tol:
break
Libraries
I’m using Python 3.9.6, pennylane’s 0.33.1 and JAX 0.4.23.
Jax-jit
And changing to “jax-jit” I have a similar TypeError problem
Error
TypeError: Argument 'Autograd ArrayBox with value 0.0' of type <class 'autograd.numpy.numpy_boxes.ArrayBox'> is not a valid JAX type.
Code
from pennylane import numpy as np
import pennylane as qml
from jax.config import config
import jax
config.update("jax_enable_x64", True)
symbols = ["H", "H"]
coordinates = np.array([0.0, 0.0, -0.6614, 0.0, 0.0, 0.6614])
H, qubits = qml.qchem.molecular_hamiltonian(symbols, coordinates)
dev = qml.device("lightning.qubit", wires=qubits)
electrons = 2
hf = qml.qchem.hf_state(electrons, qubits)
def circuit(param, wires):
qml.BasisState(hf, wires=wires)
qml.DoubleExcitation(param, wires=[0, 1, 2, 3])
@qml.qnode(dev, interface="jax-jit")
def cost_fn(param):
circuit(param, wires=range(qubits))
return qml.expval(H)
opt = qml.GradientDescentOptimizer(stepsize=0.4)
theta = np.array(0.0, requires_grad=True)
energy = [cost_fn(theta)]
angle = [theta]
max_iterations = 100
conv_tol = 1e-06
for n in range(max_iterations):
theta, prev_energy = opt.step_and_cost(cost_fn, theta)
energy.append(cost_fn(theta))
angle.append(theta)
conv = np.abs(energy[-1] - prev_energy)
if n % 2 == 0:
print(f"Step = {n}, Energy = {energy[-1]:.8f} Ha")
if conv <= conv_tol:
break