Here is my attempt to implement the Ansatz from scratch:

class Ansatz(nn.Module):

def **init**(self, n_states):

super(Ansatz, self).**init**()

self.n_states = n_states

self.A = nn.Parameter(torch.rand(n_states)+1j*torch.rand(n_states), requires_grad=True)*

self.B = nn.Parameter(torch.rand(n_states)+1jtorch.rand(n_states), requires_grad=True)

self.Xdot = nn.Parameter(torch.rand(n_states), requires_grad=True)

self.Ydot = self.R(self.Xdot)

self.Zdot = self.F(self.Xdot)

self.xi = torch.tensor([self.compute_p_root(self.n_states, p) for p in range(1,self.n_states+1)])

# GEOMETRIC TENSORS

primes = [sympy.prime(p) for p in range(1,self.n_states+1)]

self.arch_norm = self.compute_Sn()

self.GL_Xdot = [self.GL2(self.Xdot, self.xi, p, i) for i, p in enumerate(primes)]

self.GL_Ydot = [self.GL2(self.Ydot, self.xi, p, i) for i, p in enumerate(primes)]

self.GL_Zdot = [self.GL2(self.Zdot, self.xi, p, i) for i, p in enumerate(primes)]

```
def forward(self,x):
h = [self.compute_reps(x[i,:]).unsqueeze(0) for i in range(x.size(0))]
h = torch.cat(h, dim = 0)
h = h.unsqueeze(1)
h_abs = h.abs()
return h_abs
def arch_loss(self):
l0 = (self.arch_norm[0] - self.arch_norm[1]).abs()
l1 = (self.arch_norm[0] - self.arch_norm[2]).abs()
l2 = (self.arch_norm[1] - self.arch_norm[2]).abs()
aloss = (l0 + l1 + l2).mean()
return aloss
def compute_reps(self, x):
psi_0 = self.init_state()
psi_t = self.unitary_trans()
H = self.operator().type(torch.complex128)
H0 = torch.mul(x.type(torch.complex128),
psi_0.squeeze(1).type(torch.complex128))
H1 = torch.mul(x.type(torch.complex128),
psi_t.squeeze(1).type(torch.complex128))
z = torch.matmul(torch.matmul(H1, H), H0)
return z
# OPERATOR MEASUREMENT
def operator(self):
h = self.GL_Zdot[0]
for i in range(1, self.n_states):
h = torch.kron(h,self.GL_Zdot[i])
return h
# UNITARY TRANSFORMATION
def unitary_trans(self):
f = [self.psi(self.A[i], self.B[i]) for i, _ in enumerate(range(self.n_states))]
H = []
for i in range(self.n_states):
U = torch.matmul(self.GL_Ydot[i].type(torch.complex128),
f[i].type(torch.complex128))
H.append(U)
h = H[0]
for i in range(1, self.n_states):
h = torch.kron(h,H[i])
return h
# COMPUTING ARCHITECTURE DISTANCE ON S_n[x-[p]]-sphere
def compute_Sn(self):
x = (self.xi - self.Xdot)**2
y = (self.xi - self.Ydot)**2
z = (self.xi - self.Zdot)**2
return torch.sqrt(x), torch.sqrt(y), torch.sqrt(z)
# CONSTRUCTING GL2 (Ry)
def GL2(self, z, xi, p, i):
a00 = torch.cos(2*torch.pi*(z[i]-xi[i])/p)
a01 = -torch.sin(2*torch.pi*(z[i]-xi[i])/p)
X = torch.tensor([a00,a01,-a01,a00])
X = X.reshape(2,2)
return X
def R(self,x):
y = torch.roll(x, shifts=1, dims = 0)
return y
def F(self,x):
y = torch.roll(x, shifts=-1, dims = 0)
return y
def init_state(self):
f = [self.psi(self.A[i], self.B[i]) for i,_ in enumerate(range(self.n_states))]
h = f[0]
for i in range(1, self.n_states):
h = torch.kron(h,f[i])
return h
def psi(self, a, b):
x = a*torch.tensor([1,0]) + b*torch.tensor([0,1])
x = x.reshape(2,1)
return x
def compute_p_root(self, n, prime_index):
prime = sympy.prime(prime_index) # Find the nth prime number
power = 2**(n+1) # Compute 2^(n+1)
remainder = power % prime # Take the remainder (2^(n+1) mod p)
return remainder
```

And here is the test on predicting PES:

import os

from icecream import ic

import pandas as pd

import numpy as np

import torch

from model import *

import matplotlib.pyplot as plt

from utils import *

import plotly.express as px

from sklearn.model_selection import train_test_split

import h5py

from openfermion.chem import MolecularData

from openfermion.transforms import get_fermion_operator, jordan_wigner

import matplotlib.pyplot as plt

import torch

dirc = â**dataset**/3.3.2_PES/datasetâ

out_dirc = âQuantumChem_PESâ

try:

os.mkdir(out_dirc)

except:

pass

seed = 11

np.random.seed(seed)

torch.manual_seed(seed)

n_states = 2

num_epochs = 500

ch_idx = 0

for ch_idx in range(7):

try:

```
files = [os.path.join(dirc, x) for x in sorted(os.listdir(dirc)) if '.DS_Store' not in x]
dataset = ['{}/{}'.format(files[ch_idx],x) for x in sorted(os.listdir(files[ch_idx]))]
ENERGY = []
BOND_LENGTH = []
for f_idx in range(len(dataset)):
molecular_data = MolecularData(filename=dataset[f_idx]) # load hdf5 file
molecular_hamiltonian = get_fermion_operator(molecular_data.get_molecular_hamiltonian()) # get an instance of second quantized hamiltonian
BOND_LENGTH.append(molecular_data.general_calculations['bond_length'])
ENERGY.append(molecular_data.general_calculations['1st_excited_energy'])
ENERGY =torch.tensor(ENERGY)
BOND_LENGTH = torch.tensor(BOND_LENGTH).unsqueeze(1)
ENERGY = normalize(ENERGY)
X = torch.cat([torch.roll(BOND_LENGTH, shifts=-i).unsqueeze(1) for i in range(2**n_states)], dim = 1)
X = X.squeeze(2)
print(ENERGY.size())
print(X.size())
model = Ansatz(n_states)
criterion = nn.MSELoss() # loss function, for classification tasks
optimizer = optim.AdamW(model.parameters(), lr=0.03, weight_decay=3e-3) # optimizer
for name, param in model.state_dict().items():
print('Layer:', name)
print('Size:', param.size())
print('Values:', param, '\n')
best_loss = float('inf')
for epoch in range(num_epochs):
# TRAIN
model.train() # switch to training mode
optimizer.zero_grad()
# forward + backward + optimize
outputs = 1/model(X)
outputs = normalize(outputs)
arch_loss = model.arch_loss()
loss = criterion(outputs,
ENERGY.unsqueeze(1))
loss = loss + arch_loss
loss.backward(retain_graph=True) # backward pass
optimizer.step() # optimization step
print('Epoch: {}|Train Loss: {}'.format(epoch, loss.item()))
if loss < best_loss:
best_pred = outputs
best_loss = loss
torch.save(model.state_dict(), '{}/best_model_seed_{}_{}.pkl'.format(out_dirc,seed,files[ch_idx].split('/')[-1].split('.')[0]))
print('Epoch: {}|Test Loss: {}'.format(epoch, loss.item()))
chem_name = files[ch_idx].split('/')[-1].split('.')[0].split('_')[0]
# Assuming you have the tensors BOND_LENGTH and ENERGY
BOND_LENGTH = BOND_LENGTH.detach().numpy()
ENERGY = ENERGY.detach().numpy()
best_pred = best_pred.detach().numpy()
# Create the plot
plt.figure(figsize=(7, 7))
plt.plot(BOND_LENGTH[:-3], ENERGY[:-3], marker='o',
label='Actual Energy', markersize = 3)
plt.plot(BOND_LENGTH[:-3], best_pred[:-3], marker='s',
label='Predicted Energy', markersize = 3)
plt.xlabel('Bond Length')
plt.ylabel('Normalized Energy')
plt.title('{}\n (MSE: {:.4f})'.format(chem_name, best_loss.item()))
plt.legend()
plt.savefig('{}/{}.jpg'.format(out_dirc,chem_name), dpi = 600)
#plt.show() #
except:
pass
```

It is too early to see this approach is better than other, but I think implement Group Arithmetic could be useful for some problem related to primes.