# -*- coding: utf-8 -*-
"""
Created on 2021-06-16
@author: Vincent Michaud-Rioux
"""
from nanotools.base import Base, Quantity
from nanotools.utils import to_quantity, ureg
from nanotools.kpoint import increase_ksampling
from nanotools.totalenergy import TotalEnergy
from nanotools.utils import dict_converter
from typing import List
import attr
import matplotlib.pyplot as plt
import numpy as np
[docs]
@attr.s
class CheckPrecision(Base):
"""``CheckPrecision`` class.
Examples::
from nanotools.checkprecision import CheckPrecision
from nanotools.totalenergy import TotalEnergy
ecalc = TotalEnergy.read("nano_scf_out.json")
ecalc.solver.set_stdout("resculog.out")
calc = CheckPrecision(ecalc, parameter="resolution", etol=1.e-3)
calc.solve()
Attributes:
calculators (list):
Stores the calculators for each element of ``charge_deltas``.
etol:
Total energy tolerance.
parameter:
Parameter to converge ("k-sampling" or "resolution").
"""
reference_calculator: TotalEnergy = attr.ib(
converter=lambda d: dict_converter(d, TotalEnergy),
validator=attr.validators.instance_of(TotalEnergy),
)
# input is dictionary with default constructor
calculators: List[TotalEnergy] = attr.ib(default=[])
etol: Quantity = attr.ib(
default=1.0e-3 * ureg.eV,
converter=lambda x: to_quantity(x, "eV"),
validator=attr.validators.instance_of(Quantity),
)
parameter: str = attr.ib(
default="resolution",
validator=attr.validators.instance_of(str),
)
def __attrs_post_init__(self):
if self.calculators is not None:
calc = []
for s in self.calculators:
if isinstance(s, dict):
calc.append(TotalEnergy(**s))
elif isinstance(s, TotalEnergy):
calc.append(s)
else:
raise Exception("Invalid TotalEnergy calculators.")
self.calculators = calc
if self.parameter not in ["k-sampling", "resolution"]:
raise Exception(
'The parameter attribute must be "k-sampling" or "resolution"'
)
@classmethod
def from_totalenergy(cls, totalenergy, **kwargs):
if isinstance(totalenergy, TotalEnergy):
pass
else:
totalenergy = TotalEnergy.read(totalenergy)
calc = cls(reference_calculator=totalenergy, **kwargs)
return calc
[docs]
def get_delta_etots(self):
"""Returns the energy errors taking the highest resolution as reference."""
etot = self.get_etots()
return etot - etot[-1]
def get_etots(self):
return Quantity.from_list(
[c.get_total_energy_per_atom() for c in self.calculators]
)
def get_ksamplings(self):
return [c.system.kpoint.grid for c in self.calculators]
def get_resolutions(self):
return Quantity.from_list([c.system.cell.resolution for c in self.calculators])
[docs]
def plot(self, filename=None, show=True):
"""Generates a plot of the error as a function of parameter.
Args:
filename (str, optional):
If not None, then the figure is saved to filename.
show (bool, optional):
If True block and show figure.
If False, do not show figure.
Returns:
fig (:obj:`matplotlib.figure.Figure`):
A figure handle.
"""
detot = self.get_delta_etots()
if self.parameter == "resolution":
x = self.get_resolutions()
units = x.u
x = x.m
else:
kall = self.get_ksamplings()
x = np.array([np.prod(k) for k in kall])
units = "n"
fig = plt.figure()
plt.semilogy(x, np.abs(detot.m), "bo--")
ef = np.ones(x.size) * self.etol.to(detot.u)
plt.plot(x, ef.m, "--k")
plt.xlabel(f"{self.parameter} ({units})")
plt.ylabel(f"delta energy ({detot.u})")
if filename is not None:
fig.savefig(filename)
if show:
plt.show()
return fig
def print_summary(self):
if self.parameter == "resolution":
self._print_summary_res()
else:
self._print_summary_kpt()
def solve(self):
if self.parameter == "resolution":
self._solve_res()
else:
self._solve_kpt()
def _print_summary_kpt(self):
detot = self.get_delta_etots()
etot = self.get_etots()
kall = self.get_ksamplings()
line = "%20s | %25s | %20s" % (
"k-sampling",
"total energy/atom (eV)",
"energy error (eV)",
)
print(line)
for k, e, d in zip(kall, etot, detot):
line = "%6d %6d %6d | %+20.8f | %+20.8f" % (
k[0],
k[1],
k[2],
e.to("eV").m,
d.to("eV").m,
)
print(line)
def _print_summary_res(self):
detot = self.get_delta_etots()
etot = self.get_etots()
resolution = self.get_resolutions()
line = "%20s | %25s | %20s" % (
"resolution (ang)",
"total energy/atom (eV)",
"energy error (eV)",
)
print(line)
for r, e, d in zip(resolution, etot, detot):
line = "%20.8f | %+25.8f | %+20.8f" % (
r.to("angstrom").m,
e.to("eV").m,
d.to("eV").m,
)
print(line)
def _solve_kpt(self):
"""
Perform a series of calculations of increasing k-sampling until the total energy
varies by less than a prescribed tolerance.
"""
detot = self.etol * 2.0
count = 0
kref = self.reference_calculator.system.kpoint.grid
k1 = kref
etot = []
kall = []
line = "%20s | %25s | %20s" % (
"k-sampling",
"total energy/atom (eV)",
"delta energy (eV)",
)
print(line)
while abs(detot) > self.etol:
output = "kpt_scf_out_" + str(count)
ecalc = self.reference_calculator.copy()
# ecalc.system.kpoint.gamma_centered = True
ecalc.system.kpoint.set_grid(k1)
ecalc.solve(output=output)
self.calculators.append(ecalc)
self.write("nano_check_ksampling.json")
etot.append(ecalc.get_total_energy_per_atom())
kall.append(ecalc.system.kpoint.grid)
detot = etot[-1]
if count > 0:
detot -= etot[-2]
line = "%6d %6d %6d | %+20.8f | %+20.8f" % (
k1[0],
k1[1],
k1[2],
etot[-1].m,
detot.m,
)
print(line)
k1 = increase_ksampling(k1, kref)
count += 1
self.print_summary()
def _solve_res(self):
"""
Perform a series of calculations of increasing real space resolution until the total energy
varies by less than a prescribed tolerance.
"""
# make sure units are consistent
detot = self.etol * 2.0
count = 0
res1 = self.reference_calculator.system.cell.resolution
etot = []
resolution = []
line = "%20s | %25s | %20s" % (
"resolution (ang)",
"total energy/atom (eV)",
"delta energy (eV)",
)
print(line)
while abs(detot) > self.etol:
output = "res_scf_out_" + str(count)
ecalc = self.reference_calculator.copy()
ecalc.system.cell.set_resolution(res1)
ecalc.solve(output=output)
self.calculators.append(ecalc)
self.write("nano_check_resolution.json")
etot.append(ecalc.get_total_energy_per_atom())
resolution.append(ecalc.system.cell.resolution)
detot = etot[-1]
if count > 0:
detot -= etot[-2]
line = "%20.8f | %+25.8f | %+20.8f" % (
resolution[-1].to("angstrom").m,
etot[-1].to("eV").m,
detot.to("eV").m,
)
print(line)
res1 *= 0.9
count += 1
self.print_summary()