# -*- coding: utf-8 -*-
"""This module defines the ``Energy`` class."""
from ase.atoms import Atoms
from attr import field
from pint import Quantity
from nanotools.base import Base
from nanotools.system import System
from nanotools.utils import to_quantity
import attr
import numpy as np
def converter_energy(x):
    return to_quantity(x, "eV")
def converter_forces(x):
    return to_quantity(x, "eV / angstrom", shape=(-1, 3))
def converter_stress(x):
    return to_quantity(x, "eV / angstrom ** 3", shape=(3, 3))
[docs]
@attr.s(auto_detect=True, eq=False)
class Energy(Base):
    """``Energy`` class.
    The ``Energy`` class stores energy data. It is typically empty before a calculation.
    It gets overwritten during a calculation.
    Attributes:
        esr:
            esr is the short-range energy.
            Example::
                esr = energy.esr
        ebs:
            ebs is the band structure energy.
            Example::
                ebs = energy.ebs
        edh:
            edh is the delta Hartree energy.
            Example::
                edh = energy.edh
        efermi:
            efermi is the Fermi energy.
            Example::
                efermi = energy.efermi
        etot:
            etot is the total energy.
            Example::
                etot = energy.etot
        evxc:
            evxc is the exchange-correlation potential energy.
            Example::
                evxc = energy.evxc
        exc:
            exc is the exchange-correlation energy.
            Example::
                exc = energy.exc
        eigenvalues:
            eigenvalues is a three-dimensional array containing the Kohn-Sham energies (the eigenvalues of the Kohn-Sham equation). The dimensions are the following: bands, k-point, spin.
            Example::
                eigenvalues = energy.eigenvalues
        forces:
            forces is a two-dimensional array containing the atomic forces (as calculated by the Hellman-Feynman theorem).
            Example::
                forces = energy.forces
        forces_return:
            If True, the forces are computed and written to energy.forces. They are not computed otherwise.
            Example::
                energy.forces_return = True
        stress:
            stress is a two-dimensional array containing the stress tensor (as calculated by the Hellman-Feynman theorem).
            Example::
                stress = energy.stress
        stress_return:
            If True, the stress tensor is computed and written to energy.stress. It is not computed otherwise.
            Example::
                energy.stress_return = True
        edftd3:
            Grimme's DFT-D3 energy.
            It is evaluated with ASE's
            `DFTD3 calculator <https://wiki.fysik.dtu.dk/ase/ase/calculators/dftd3.html>`_.
            For additional details on the methods, please visit the Mulliken Center's
            `website <https://www.chemie.uni-bonn.de/pctc/mulliken-center/software/dft-d3/>`_.
        frc_dftd3:
            Grimme's DFT-D3 forces.
        stress_dftd3:
            Grimme's DFT-D3 stress.
        include_dftd3:
            Grimme's DFT-D3 switch.
            If True, include D3 dispersion corrections in energy, forces and stress.
            Do nothing otherwise.
    """
    esr: Quantity = field(
        default=None,
        converter=attr.converters.optional(converter_energy),
        validator=attr.validators.optional(attr.validators.instance_of(Quantity)),
    )
    ebg: Quantity = field(
        default=None,
        converter=attr.converters.optional(converter_energy),
        validator=attr.validators.optional(attr.validators.instance_of(Quantity)),
    )
    ebs: Quantity = field(
        default=None,
        converter=attr.converters.optional(converter_energy),
        validator=attr.validators.optional(attr.validators.instance_of(Quantity)),
    )
    edh: Quantity = field(
        default=None,
        converter=attr.converters.optional(converter_energy),
        validator=attr.validators.optional(attr.validators.instance_of(Quantity)),
    )
    exc: Quantity = field(
        default=None,
        converter=attr.converters.optional(converter_energy),
        validator=attr.validators.optional(attr.validators.instance_of(Quantity)),
    )
    evxc: Quantity = field(
        default=None,
        converter=attr.converters.optional(converter_energy),
        validator=attr.validators.optional(attr.validators.instance_of(Quantity)),
    )
    etot: Quantity = field(
        default=None,
        converter=attr.converters.optional(converter_energy),
        validator=attr.validators.optional(attr.validators.instance_of(Quantity)),
    )
    efermi: Quantity = field(
        default=None,
        converter=attr.converters.optional(converter_energy),
        validator=attr.validators.optional(attr.validators.instance_of(Quantity)),
    )
    eigenvalues: Quantity = field(
        default=None,
        converter=attr.converters.optional(converter_energy),
        validator=attr.validators.optional(attr.validators.instance_of(Quantity)),
    )
    efree: Quantity = field(
        default=None,
        converter=attr.converters.optional(converter_energy),
        validator=attr.validators.optional(attr.validators.instance_of(Quantity)),
    )
    entropy: Quantity = field(
        default=None,
        converter=attr.converters.optional(converter_energy),
        validator=attr.validators.optional(attr.validators.instance_of(Quantity)),
    )
    forces: Quantity = field(
        default=None,
        converter=attr.converters.optional(converter_forces),
        validator=attr.validators.optional(attr.validators.instance_of(Quantity)),
    )
    forces_return: bool = field(
        default=False,
        converter=bool,
        validator=attr.validators.instance_of(bool),
    )
    stress: Quantity = field(
        default=None,
        converter=attr.converters.optional(converter_stress),
        validator=attr.validators.optional(attr.validators.instance_of(Quantity)),
    )
    stress_return: bool = field(
        default=False,
        converter=bool,
        validator=attr.validators.instance_of(bool),
    )
    frc_t: Quantity = field(
        default=None,
        converter=attr.converters.optional(converter_forces),
        validator=attr.validators.optional(attr.validators.instance_of(Quantity)),
    )
    frc_s: Quantity = field(
        default=None,
        converter=attr.converters.optional(converter_forces),
        validator=attr.validators.optional(attr.validators.instance_of(Quantity)),
    )
    frc_vnl: Quantity = field(
        default=None,
        converter=attr.converters.optional(converter_forces),
        validator=attr.validators.optional(attr.validators.instance_of(Quantity)),
    )
    frc_veff: Quantity = field(
        default=None,
        converter=attr.converters.optional(converter_forces),
        validator=attr.validators.optional(attr.validators.instance_of(Quantity)),
    )
    frc_sr: Quantity = field(
        default=None,
        converter=attr.converters.optional(converter_forces),
        validator=attr.validators.optional(attr.validators.instance_of(Quantity)),
    )
    frc_vna: Quantity = field(
        default=None,
        converter=attr.converters.optional(converter_forces),
        validator=attr.validators.optional(attr.validators.instance_of(Quantity)),
    )
    frc_vdh: Quantity = field(
        default=None,
        converter=attr.converters.optional(converter_forces),
        validator=attr.validators.optional(attr.validators.instance_of(Quantity)),
    )
    frc_rpc: Quantity = field(
        default=None,
        converter=attr.converters.optional(converter_forces),
        validator=attr.validators.optional(attr.validators.instance_of(Quantity)),
    )
    edftd3: Quantity = field(
        default=None,
        converter=attr.converters.optional(converter_energy),
        validator=attr.validators.optional(attr.validators.instance_of(Quantity)),
    )
    frc_dftd3: Quantity = field(
        default=None,
        converter=attr.converters.optional(converter_forces),
        validator=attr.validators.optional(attr.validators.instance_of(Quantity)),
    )
    stress_dftd3: Quantity = field(
        default=None,
        converter=attr.converters.optional(converter_stress),
        validator=attr.validators.optional(attr.validators.instance_of(Quantity)),
    )
    include_dftd3: bool = attr.ib(
        default=False,
        converter=bool,
        validator=attr.validators.instance_of(bool),
    )
    dftd3_kwargs: dict = attr.ib(default=attr.Factory(dict))
    def __attrs_post_init__(self):
        pass
    def __eq__(self, other):
        if other.__class__ is not self.__class__:
            return NotImplemented
        valid = True
        for at in (
            "eigenvalues",  # arrays of floats
            "forces",
            "stress",
            "esr",  # floats
            "ebg",
            "ebs",
            "edh",
            "exc",
            "evxc",
            "etot",
            "efermi",
            "efree",
            "entropy",
        ):
            if getattr(self, at) is None:
                valid = valid and getattr(self, at) == getattr(other, at)
            else:
                valid = valid and np.allclose(getattr(self, at), getattr(other, at))
        return valid and (
            self.esr.u,
            self.ebg.u,
            self.ebs.u,
            self.edh.u,
            self.exc.u,
            self.evxc.u,
            self.etot.u,
            self.efermi.u,
            self.eigenvalues.u,
            self.efree.u,
            self.entropy.u,
            self.forces.u,
            self.forces_return,
            self.stress.u,
            self.stress_return,
        ) == (
            other.esr.u,
            other.ebg.u,
            other.ebs.u,
            other.edh.u,
            other.exc.u,
            other.evxc.u,
            other.etot.u,
            other.efermi.u,
            other.eigenvalues.u,
            other.efree.u,
            other.entropy.u,
            other.forces.u,
            other.forces_return,
            other.stress.u,
            other.stress_return,
        )
    def get_band_edges(self):
        emin = np.min(self.eigenvalues)
        emax = np.max(self.eigenvalues)
        return emin - 1, emax + 1
[docs]
    def get_cbm(self):
        """
        Returns the conduction band maximum.
        If partially occupied bands are detected, as in metals, then the function returns ``None``.
        Returns:
            float: conduction band maximum
        """
        if self.ismetal():
            return None
        ev = self.eigenvalues
        return np.min(ev[ev > self.efermi]) 
[docs]
    def get_free_energy(self):
        """Returns the free energy."""
        return self.efree 
[docs]
    def get_total_energy(self):
        """Returns the total energy."""
        return self.etot 
[docs]
    def get_vbm(self):
        """
        Returns the valence band maximum.
        If partially occupied bands are detected, as in metals, then the function returns ``None``.
        Returns:
            float: valence band maximum
        """
        if self.ismetal():
            return None
        ev = self.eigenvalues
        return np.max(ev[ev < self.efermi]) 
[docs]
    def set_stress_return(self, stress_return):
        """Sets stress_return attribute.
        If stress_return is True, then forces_return is also set to True.
        Args:
            stress_return (bool): stress_return value
        """
        if not isinstance(stress_return, bool):
            raise Exception(
                f"Invalid class {stress_return.__class__} for arg stress_return."
            )
        self.stress_return = stress_return
        if self.stress_return:
            self.forces_return = True 
    def set_dftd3_energy(self, atoms, include_dftd3=True, **kwargs):
        from ase.calculators.dftd3 import DFTD3
        from ase.atoms import Atoms
        if isinstance(atoms, System):
            atoms = atoms.to_ase_atoms()
        elif isinstance(atoms, Atoms):
            pass
        else:
            raise Exception("Argument atoms must be of class System or ASE Atoms.")
        self.dftd3_kwargs.update(kwargs)
        self.include_dftd3 = include_dftd3
        d3 = DFTD3(**self.dftd3_kwargs)
        atoms.calc = d3
        self.edftd3 = to_quantity(atoms.get_potential_energy(), "eV")
        self.frc_dftd3 = to_quantity(atoms.get_forces(), "eV / angstrom", shape=(-1, 3))
        self.stress_dftd3 = to_quantity(
            atoms.get_stress(voigt=False), "eV / angstrom ** 3", shape=(3, 3)
        )
        self._update_energy_forces_stress()
    set_dftd3_forces = set_dftd3_stress = set_dftd3_energy
    def _update_energy_forces_stress(self):
        if not self.include_dftd3:
            return
        ats = ["etot", "efree", "forces", "stress"]
        d3s = ["edftd3", "edftd3", "frc_dftd3", "stress_dftd3"]
        for a, d in zip(ats, d3s):
            frc = getattr(self, a)
            dfr = getattr(self, d)
            if frc is None:
                setattr(self, a, dfr)
            else:
                setattr(self, a, frc + dfr)