"""
two-probe transport calculator
"""
from matplotlib import pyplot as plt
from pathlib import Path
from nanotools.base import Base, Quantity
from nanotools.utils import to_quantity, ureg
from nanotools.dos import Dos
from nanotools.totalenergy import TotalEnergy
from nanotools.twoprobe import TwoProbe, get_transport_dir
from nanotools.utils import dict_converter
from attr import field
import attr
import copy
import numpy as np
import os
import shutil
[docs]
@attr.s
class Transmission(Base):
    """Transmission class.
    Attributes:
        dos: density of states
        transport_axis
        left_equal_right: boolean variable indicating whether the two lead structures are identical.
    """
    center: TotalEnergy = attr.ib(
        converter=lambda d: dict_converter(d, TotalEnergy),
        validator=attr.validators.instance_of(TotalEnergy),
    )
    left: TotalEnergy = attr.ib(
        default=None,
        converter=attr.converters.optional(lambda d: dict_converter(d, TotalEnergy)),
        validator=attr.validators.optional(attr.validators.instance_of(TotalEnergy)),
    )
    right: TotalEnergy = attr.ib(
        default=None,
        converter=attr.converters.optional(lambda d: dict_converter(d, TotalEnergy)),
        validator=attr.validators.optional(attr.validators.instance_of(TotalEnergy)),
    )
    left_equal_right: bool = attr.ib(default=False)  # are lead structures identical?
    dos: Dos = attr.ib(
        factory=Dos,
        converter=lambda d: dict_converter(d, Dos),
        validator=attr.validators.instance_of(Dos),
    )
    # small eta for Green function calculations
    eta: Quantity = field(
        default=1.0e-6 * ureg.hartree,
        converter=attr.converters.optional(lambda x: to_quantity(x, "eV")),
        validator=attr.validators.optional(
            attr.validators.instance_of(Quantity),
        ),
    )
    transport_axis: int = attr.ib(
        default=-1, validator=attr.validators.in_([-1, 0, 1, 2])
    )
    classname: str = attr.ib()
    @classname.default
    def _classname_default_value(self):
        return self.__class__.__name__
    def __attrs_post_init__(self):
        if self.transport_axis < 0:
            self.transport_axis = get_transport_dir(self.center.system.cell.boundary)
        if self.transport_axis not in [0, 1, 2]:
            raise Exception("transport direction not specified.")
        self.center = copy.deepcopy(self.center)
        if self.left is None:
            self.left = copy.deepcopy(self.center)
        else:
            self.left = copy.deepcopy(self.left)
        if self.right is None:
            self.left_equal_right = True
        if self.left_equal_right:
            print("Left and right lead structures are same.")
            self.right = copy.deepcopy(self.left)
        else:
            self.right = copy.deepcopy(self.right)
        # self.center.solver.restart.densityPath = "center_out.h5"
    def _set_dos_grid(self, energies):
        self.dos.set_energy(energies)
    def set_eta(self, eta):
        self.eta = to_quantity(eta, "eV")
[docs]
    @classmethod
    def from_twoprobe(cls, twoprb, **kwargs):
        """initializes instance from TwoProbe object"""
        if not isinstance(twoprb, TwoProbe):
            raise Exception("Reading from not a TwoProbe object.")
        calc = cls(
            center=twoprb.center,
            left=twoprb.left,
            right=twoprb.right,
            transport_axis=twoprb.transport_axis,
            left_equal_right=twoprb.left_equal_right,
            **kwargs,
        )
        return calc 
    def solve(self, energies=None, input="nano_trsm_in", output="nano_trsm_out"):
        self.calc_transmission(energies=energies, input=input, output=output)
[docs]
    def calc_transmission(
        self, energies=None, input="nano_trsm_in", output="nano_trsm_out"
    ):
        """Calculates transmission
        This method triggers the nanodcalplus executable.
        Result is stored in self.dos.transmission
        Args:
            energies (float / iterable): energy grid over which the transmission is to be calculated.
                eg. energies=0.1
                eg. energies=[0.1,0.2,0.3]
            unit (string): energy unit
        """
        if energies is not None:
            self._set_dos_grid(energies)
        if not (Path(self.center.solver.restart.densityPath).is_file()):
            raise Exception("Error in transmission.solve: densityPath file not found.")
        if not (Path(self.left.solver.restart.densityPath).is_file()):
            raise Exception("Error in transmission.solve: densityPath file not found.")
        if not self.left_equal_right:  # if different structure
            if not (Path(self.right.solver.restart.densityPath).is_file()):
                raise Exception(
                    "Error in transmission.solve: densityPath file not found."
                )
        inputname = os.path.splitext(input)[0] + ".json"
        self.left.solver.mpidist = self.center.solver.mpidist
        self.right.solver.mpidist = self.center.solver.mpidist
        self.write(inputname)
        command, binname = self.center.solver.cmd.get_cmd("trsm")
        ret = command(inputname)
        ret.check_returncode()
        shutil.move("nano_trsm_out.json", output + ".json")
        self._update(output + ".json")
        self.set_units("si") 
    def get_transmission(self):
        return self.dos.transmission
    def get_energies(self):
        return self.dos.energy
[docs]
    def plot(self, filename=None, show=True):
        """visualizes the result"""
        sys = self.center.system
        nb = self.dos.energy.size
        nk = sys.kpoint.get_kpoint_num()
        ispin = sys.hamiltonian.ispin
        axes_cross = list(range(3))
        axes_cross.pop(self.transport_axis)
        grid_cross = list(sys.kpoint.grid)
        grid_cross.pop(self.transport_axis)
        labels_cross = ["kx", "ky", "kz"]
        labels_cross.pop(self.transport_axis)
        if nk == 1:
            x = self.dos.energy.m
            xunit = self.dos.energy.u
            y = self.dos.transmission
            # assert y.shape[2] == sys.hamiltonian.ispin
            el = 0.0 * xunit
            er = el - sys.pop.bias.to(xunit)
            Emin = min(el.m, er.m)
            Emax = max(el.m, er.m)
            fig = plt.figure()
            if ispin == 2:
                plt.plot(x, y[:, 0, 0], "-g", label="spin-up")
                plt.plot(x, y[:, 0, 1], "-r", label="spin-down")
            else:
                plt.plot(x, y[:, 0, 0], "-k")
            if abs(Emax - Emin) < 1.0e-10:
                plt.axvline(x=Emin, color="k", linestyle="--", label="Fermi energy")
            else:
                plt.axvspan(Emin, Emax, facecolor="r", alpha=0.25)
            plt.xlabel(f"Energy ({xunit})")
            plt.ylabel("Transmission")
            plt.legend()
            fig.tight_layout()
        elif sys.kpoint.type == "line" or np.count_nonzero(sys.kpoint.grid == 1) == 2:
            # kpoints are lined up
            raise Exception("We cannot plot this.")
        elif sys.kpoint.type == "full" and nb == 1:
            # plot transmission through the BZ at single energy point
            bvec_cross = sys.kpoint.bvec.m[axes_cross, :]
            bvec_unit = sys.kpoint.bvec.u
            k_coo = sys.kpoint.fractional_coordinates[:, axes_cross]
            ind = np.lexsort((k_coo[:, 0], k_coo[:, 1]))
            k_coo = k_coo[ind, :]
            x_coo = np.reshape(k_coo[:, 0], tuple(grid_cross), order="F")
            y_coo = np.reshape(k_coo[:, 1], tuple(grid_cross), order="F")
            if ispin == 1 or ispin == 4:
                fig, ax0 = plt.subplots()
            else:
                fig, (ax0, ax1) = plt.subplots(ncols=2)
            # spin-maj
            t_max = np.amax(self.dos.transmission)
            t_min = np.amin(self.dos.transmission)
            trans = self.dos.transmission[0, ind, 0]
            trans = np.reshape(trans, tuple(grid_cross), order="F")
            cs = ax0.contourf(
                y_coo[0, :] * np.linalg.norm(bvec_cross[1]),
                x_coo[:, 0] * np.linalg.norm(bvec_cross[0]),
                trans,
                vmin=t_min,
                vmax=t_max,
            )
            ax0.set(
                xlabel=labels_cross[1] + f" ({bvec_unit})",
                ylabel=labels_cross[0] + f" ({bvec_unit})",
            )
            fig.colorbar(cs, ax=ax0)
            # spin-min
            if ispin == 2:
                ax0.set_title("spin-majority")
                trans = self.dos.transmission[0, ind, 1]
                trans = np.reshape(trans, tuple(grid_cross), order="F")
                cs = ax1.contourf(
                    y_coo[0, :] * np.linalg.norm(bvec_cross[1]),
                    x_coo[:, 0] * np.linalg.norm(bvec_cross[0]),
                    trans,
                    vmin=t_min,
                    vmax=t_max,
                )
                ax1.set(
                    xlabel=labels_cross[1] + f" ({bvec_unit})",
                    ylabel=labels_cross[0] + f" ({bvec_unit})",
                )
                ax1.set_title("spin-minority")
                fig.colorbar(cs, ax=ax1)
        else:
            raise Exception("We cannot plot this.")
        if show:
            plt.show()
        if filename is not None:
            fig.savefig(filename)
        return fig