Source code for nanotools.checkconv

# -*- coding: utf-8 -*-
"""
Created on 2021-06-16

@author: Vincent Michaud-Rioux
"""

from nanotools.base import Base, Quantity
from nanotools.utils import to_quantity, ureg
from nanotools.kpoint import increase_ksampling
from nanotools.totalenergy import TotalEnergy
from nanotools.utils import dict_converter
from typing import List
import attr
import matplotlib.pyplot as plt
import numpy as np


[docs] @attr.s class CheckPrecision(Base): """``CheckPrecision`` class. Examples:: from nanotools.checkprecision import CheckPrecision from nanotools.totalenergy import TotalEnergy ecalc = TotalEnergy.read("nano_scf_out.json") ecalc.solver.set_stdout("resculog.out") calc = CheckPrecision(ecalc, parameter="resolution", etol=1.e-3) calc.solve() Attributes: calculators (list): Stores the calculators for each element of ``charge_deltas``. etol: Total energy tolerance. parameter: Parameter to converge ("k-sampling" or "resolution"). """ reference_calculator: TotalEnergy = attr.ib( converter=lambda d: dict_converter(d, TotalEnergy), validator=attr.validators.instance_of(TotalEnergy), ) # input is dictionary with default constructor calculators: List[TotalEnergy] = attr.ib(default=[]) etol: Quantity = attr.ib( default=1.0e-3 * ureg.eV, converter=lambda x: to_quantity(x, "eV"), validator=attr.validators.instance_of(Quantity), ) parameter: str = attr.ib( default="resolution", validator=attr.validators.instance_of(str), ) def __attrs_post_init__(self): if self.calculators is not None: calc = [] for s in self.calculators: if isinstance(s, dict): calc.append(TotalEnergy(**s)) elif isinstance(s, TotalEnergy): calc.append(s) else: raise Exception("Invalid TotalEnergy calculators.") self.calculators = calc if self.parameter not in ["k-sampling", "resolution"]: raise Exception( 'The parameter attribute must be "k-sampling" or "resolution"' ) @classmethod def from_totalenergy(cls, totalenergy, **kwargs): if isinstance(totalenergy, TotalEnergy): pass else: totalenergy = TotalEnergy.read(totalenergy) calc = cls(reference_calculator=totalenergy, **kwargs) return calc
[docs] def get_delta_etots(self): """Returns the energy errors taking the highest resolution as reference.""" etot = self.get_etots() return etot - etot[-1]
def get_etots(self): return Quantity.from_list( [c.get_total_energy_per_atom() for c in self.calculators] ) def get_ksamplings(self): return [c.system.kpoint.grid for c in self.calculators] def get_resolutions(self): return Quantity.from_list([c.system.cell.resolution for c in self.calculators])
[docs] def plot(self, filename=None, show=True): """Generates a plot of the error as a function of parameter. Args: filename (str, optional): If not None, then the figure is saved to filename. show (bool, optional): If True block and show figure. If False, do not show figure. Returns: fig (:obj:`matplotlib.figure.Figure`): A figure handle. """ detot = self.get_delta_etots() if self.parameter == "resolution": x = self.get_resolutions() units = x.u x = x.m else: kall = self.get_ksamplings() x = np.array([np.prod(k) for k in kall]) units = "n" fig = plt.figure() plt.semilogy(x, np.abs(detot.m), "bo--") ef = np.ones(x.size) * self.etol.to(detot.u) plt.plot(x, ef.m, "--k") plt.xlabel(f"{self.parameter} ({units})") plt.ylabel(f"delta energy ({detot.u})") if filename is not None: fig.savefig(filename) if show: plt.show() return fig
def print_summary(self): if self.parameter == "resolution": self._print_summary_res() else: self._print_summary_kpt() def solve(self): if self.parameter == "resolution": self._solve_res() else: self._solve_kpt() def _print_summary_kpt(self): detot = self.get_delta_etots() etot = self.get_etots() kall = self.get_ksamplings() line = "%20s | %25s | %20s" % ( "k-sampling", "total energy/atom (eV)", "energy error (eV)", ) print(line) for k, e, d in zip(kall, etot, detot): line = "%6d %6d %6d | %+20.8f | %+20.8f" % ( k[0], k[1], k[2], e.to("eV").m, d.to("eV").m, ) print(line) def _print_summary_res(self): detot = self.get_delta_etots() etot = self.get_etots() resolution = self.get_resolutions() line = "%20s | %25s | %20s" % ( "resolution (ang)", "total energy/atom (eV)", "energy error (eV)", ) print(line) for r, e, d in zip(resolution, etot, detot): line = "%20.8f | %+25.8f | %+20.8f" % ( r.to("angstrom").m, e.to("eV").m, d.to("eV").m, ) print(line) def _solve_kpt(self): """ Perform a series of calculations of increasing k-sampling until the total energy varies by less than a prescribed tolerance. """ detot = self.etol * 2.0 count = 0 kref = self.reference_calculator.system.kpoint.grid k1 = kref etot = [] kall = [] line = "%20s | %25s | %20s" % ( "k-sampling", "total energy/atom (eV)", "delta energy (eV)", ) print(line) while abs(detot) > self.etol: output = "kpt_scf_out_" + str(count) ecalc = self.reference_calculator.copy() # ecalc.system.kpoint.gamma_centered = True ecalc.system.kpoint.set_grid(k1) ecalc.solve(output=output) self.calculators.append(ecalc) self.write("nano_check_ksampling.json") etot.append(ecalc.get_total_energy_per_atom()) kall.append(ecalc.system.kpoint.grid) detot = etot[-1] if count > 0: detot -= etot[-2] line = "%6d %6d %6d | %+20.8f | %+20.8f" % ( k1[0], k1[1], k1[2], etot[-1].m, detot.m, ) print(line) k1 = increase_ksampling(k1, kref) count += 1 self.print_summary() def _solve_res(self): """ Perform a series of calculations of increasing real space resolution until the total energy varies by less than a prescribed tolerance. """ # make sure units are consistent detot = self.etol * 2.0 count = 0 res1 = self.reference_calculator.system.cell.resolution etot = [] resolution = [] line = "%20s | %25s | %20s" % ( "resolution (ang)", "total energy/atom (eV)", "delta energy (eV)", ) print(line) while abs(detot) > self.etol: output = "res_scf_out_" + str(count) ecalc = self.reference_calculator.copy() ecalc.system.cell.set_resolution(res1) ecalc.solve(output=output) self.calculators.append(ecalc) self.write("nano_check_resolution.json") etot.append(ecalc.get_total_energy_per_atom()) resolution.append(ecalc.system.cell.resolution) detot = etot[-1] if count > 0: detot -= etot[-2] line = "%20.8f | %+25.8f | %+20.8f" % ( resolution[-1].to("angstrom").m, etot[-1].to("eV").m, detot.to("eV").m, ) print(line) res1 *= 0.9 count += 1 self.print_summary()