from attr import define, field, validators
from joblib import Parallel, delayed
from matplotlib.figure import Figure
from pathlib import Path
from nanotools import TotalEnergy, Energy
from nanotools.base import Base, Quantity
from nanotools.utils import to_quantity, ureg
from nanotools.jsonio import json_read
from scipy import interpolate
from scipy import optimize
import attr
import matplotlib.pyplot as plt
import numpy as np
import os
[docs]
@define
class Displacement(Base):
"""``Displacement`` class.
Attributes:
separation: Quantity
Distance between the two atoms.
free_energy: Quantity
Free energy of the system.
workdir: str
Working directory for calculations.
"""
separation: Quantity = field(
converter=lambda x: to_quantity(x, "angstrom"),
validator=attr.validators.optional(attr.validators.instance_of(Quantity)),
)
free_energy: Quantity = field(
default=None,
converter=lambda x: to_quantity(x, "eV"),
validator=attr.validators.optional(attr.validators.instance_of(Quantity)),
)
workdir: str = field(default="/")
[docs]
def get_calc(self, calc):
"""Returns a total energy calculator."""
etot = calc.copy()
pos = etot.system.atoms.positions
dx = pos[1, :] - pos[0, :]
nx = np.linalg.norm(dx.m) * dx.u
pos[1, :] = pos[0, :] + dx / nx * self.separation
etot.system.set_positions(pos)
return etot
[docs]
def get_output_paths(self):
"""Returns the JSON and HDF5 output paths."""
self.workdir.mkdir(parents=True, exist_ok=True)
output = Path(self.workdir) / "nano_scf_out.json"
outhdf = Path(self.workdir) / "nano_scf_out.h5"
return output, outhdf
[docs]
def get_free_energy(self):
"""Returns the free energy of the system.
Raises:
Exception: Missing output file.
Returns:
float: Free energy
"""
output, outhdf = self.get_output_paths()
if output.exists() and outhdf.exists():
with open(output, "r") as f:
d = json_read(f)
energy = Energy(**d["energy"])
energy.set_units("si")
return energy.efree
else:
raise Exception(f"Output file {output.absolute()} not found.")
[docs]
def solve(self, calc):
"""Performs a total energy calculation.
Args:
calc (TotalEnergy): total energy calculator.
"""
root = os.getcwd()
output, outhdf = self.get_output_paths()
if output.exists() and outhdf.exists():
etot = TotalEnergy.read(output)
else:
etot = self.get_calc(calc)
if etot.solver.mix.converged:
print(f"Found converged calculation {output.absolute()}")
else:
os.chdir(self.workdir)
etot.solve(output=output)
os.chdir(root)
[docs]
@define
class BindingCurve(Base):
displacements = field(default=None)
separations: Quantity = field(
default=np.arange(3.0, 5.2, 0.2) * ureg.angstrom,
converter=lambda x: to_quantity(x, "angstrom"),
validator=attr.validators.optional(attr.validators.instance_of(Quantity)),
)
free_energies = field(default=None)
total_energy_calc: TotalEnergy = field(default=None)
workdir: str = field(
default="./", converter=str, validator=validators.instance_of(str)
)
def __attrs_post_init__(self):
natm = self.total_energy_calc.system.get_number_of_atoms()
if natm != 2:
raise ValueError(f"The number of atoms should be 2, but is now {natm}.")
self.displacements = list()
for d in self.separations:
workdir = Path(self.workdir).absolute() / ("etot_dist_%.2f" % (d.m))
self.displacements.append(
Displacement(
separation=d,
workdir=workdir,
)
)
self.free_energies = np.empty(len(self.separations)) * ureg.eV
[docs]
def get_free_energies(self):
"""Returns the free energy of each separation.
Returns:
List[Quantity]: list of free energies.
"""
for i, d in enumerate(self.displacements):
self.free_energies[i] = d.get_free_energy()
return self.free_energies
[docs]
def get_curve_fit(self):
"""Returns a spline interpolation function.
Returns:
PPoly: spline interpolator.
"""
x = self.separations
y = self.get_free_energies()
return interpolate.CubicSpline(x.m, y.m)
[docs]
def get_fine_grid(self, resolution=0.001):
"""Returns a fine grid sampling the separation interval."""
x = self.separations
n = int((max(x.m) - min(x.m)) / resolution)
x1 = np.linspace(min(x), max(x), n + 1)
return x1 * x.u
[docs]
def get_binding_curve(self):
"""Returns the fine grid and the binding curve sampled on the find grid."""
y = self.get_free_energies()
f = self.get_curve_fit()
x1 = self.get_fine_grid()
y1 = f(x1.m) * y.u
return x1, y1
[docs]
def get_bond_length(self):
"""Returns the equilibrium bond length."""
x = self.separations
y = self.get_free_energies()
f = self.get_curve_fit()
i = np.argmin(y.m)
res = optimize.minimize(f, x.m[i])
return res.x[0] * x.u
[docs]
def get_force(self):
"""Returns the fine grid and the force sampled on the find grid."""
y = self.get_free_energies()
f = self.get_curve_fit()
df = f.derivative()
x1 = self.get_fine_grid()
yp = -df(x1.m) * y.u / x1.u
return x1, yp
[docs]
def plot(self, figure=None, filename=None, shift=False, show=True, label=None):
"""Plot the binding curve.
Args:
figure (Figure, optional): Pyplot figure handle. Defaults to None.
filename (str, optional): if not None, the figure is saved to filename. Defaults to None.
shift (bool, optional): shift the last energy to zero. Defaults to False.
show (bool, optional): display figure before returning. Defaults to True.
label (str, optional): label for the legend. Defaults to None.
Returns:
Figure: Pyplot figure handle.
"""
x = self.separations
y = self.get_free_energies()
if isinstance(figure, Figure):
fig = figure
else:
fig = plt.figure()
ym = y.m
if shift:
i = np.argmax(x.m)
ym -= ym[i]
p = plt.plot(x.m, y.m, "o")
xc, yc = self.get_binding_curve()
ycm = yc.m
if shift:
i = np.argmax(x.m)
ycm -= ym[i]
plt.plot(xc.m, ycm, "-", label=label, color=p[0].get_color())
plt.xlabel(f"R ({x.u})")
plt.ylabel(f"Energy ({y.u})")
plt.grid(axis="x")
plt.legend(loc="upper right")
if show:
plt.show()
if filename is not None:
fig.savefig(filename)
return fig
[docs]
def solve(self, n_jobs=1):
"""Perform a total energy calculation for each separation.
Args:
n_jobs (int, optional): Perform n_jobs total energy calculations in parallel. Defaults to 1.
"""
Parallel(n_jobs=n_jobs)(
delayed(d.solve)(self.total_energy_calc) for d in self.displacements
)
# for d in self.displacements:
# d.solve(self.total_energy_calc)