Source code for nanotools.utils

# -*- coding: utf-8 -*-
"""
Created on 2020-06-16

@author: Vincent Michaud-Rioux
"""

from pint import Quantity, Unit
import h5py
import numpy as np
import os
import re
import scipy.io as sio
import pint


ureg = pint.UnitRegistry(system="atomic")


def convert_quantity(atr, unit_system):
    units = atr.to_tuple()[1]
    newunits = [list(u) for u in units]
    for j, u in enumerate(units):
        unit = Unit(u[0])
        if unit_system == "si":
            if unit.is_compatible_with("angstrom"):
                newunits[j][0] = "angstrom"
            if unit.is_compatible_with("eV"):
                newunits[j][0] = "eV"
        if unit_system == "atomic":
            if unit.is_compatible_with("bohr"):
                newunits[j][0] = "bohr"
            if unit.is_compatible_with("hartree"):
                newunits[j][0] = "hartree"
    string = []
    for u in newunits:
        string.append(u[0] + f"**{u[1]}")
    string = "*".join(string)
    atr.ito(string)


def to_quantity(x, units=None, allow_none=True, allow_string=False, shape=None):
    if x is None and not allow_none:
        raise Exception("Entry x cannot be None.")
    if isinstance(x, str) and not allow_string:
        raise Exception("Entry x cannot be a string.")
    if isinstance(x, list):
        if all([isinstance(e, Quantity) for e in x]):
            x = Quantity.from_list(x)
    if x is None or isinstance(x, Quantity):
        pass
    elif isinstance(x, str):
        if units is None:
            raise Exception("x cannot be a string without units.")
        return (x, units)
    elif isinstance(x, dict):
        if x["__class__"] == "Quantity":
            return to_quantity(
                x["magnitude"],
                x["units"],
                allow_string=allow_string,
                allow_none=allow_none,
                shape=shape,
            )
        else:
            cls = x["__class__"]
            raise Exception(f"Invalid class {cls}.")
    elif isinstance(x, tuple):
        return to_quantity(
            x[0],
            units=x[1],
            allow_string=allow_string,
            allow_none=allow_none,
            shape=shape,
        )
    elif units is None:
        raise Exception("Invalid None units.")
    else:
        x *= ureg(units)
    if shape is not None:
        if isinstance(x, list):
            x = Quantity.from_list(x)
        x = np.reshape(x.astype(float), shape, order="F")
    return x


def add_ext(s, ext):
    return os.path.splitext(s)[0] + "." + ext


def dict_converter(d, obj):
    if isinstance(d, dict):
        return obj(**d)
    if isinstance(d, obj):
        return d
    else:
        raise TypeError(f"Object of {d.__class__} must be of {obj}")


def is_array_like(var):
    ial = isinstance(var, np.ndarray)
    ial = ial or isinstance(var, list)
    ial = ial or isinstance(var, tuple)
    ial = ial or (isinstance(var, Quantity) and is_array_like(var.m))
    return ial


def is_row_vector(a, len=3):
    na = np.array(a, dtype=object)
    if na.size != len:
        return False
    rv = True
    c = 0
    while rv and c < na.size:
        if not isinstance(na[c], float):
            rv = False
        c += 1
    return rv


def fermi_dirac(dE_in, T_in):
    # dE: eV
    # T : Kelvin
    dE = to_quantity(dE_in, "eV")
    T = to_quantity(T_in, "kelvin")
    kb = 8.617333262e-5 * ureg.eV / ureg.kelvin
    kT = kb * T
    de_kt = dE / kT
    if de_kt > 10.0:
        return 0.0
    elif de_kt < -10.0:
        return 1.0
    else:
        return 1.0 / (1.0 + np.exp(de_kt))


def load_dcal(filename, varname=None):
    try:
        fmt = "mat"
        data = sio.loadmat(filename)
        data = data["data"][0]
        if varname is not None:
            data = data[varname][0]
    except:
        fmt = "h5"
        data = h5py.File(filename, "r")
        data = data["data"]
        if varname is not None:
            if varname not in data.keys():
                raise Exception(f"{varname} not found in {filename}")
            if "MATLAB_empty" in data[varname].attrs.keys():
                if data[varname].attrs.get("MATLAB_empty"):
                    return None, fmt
            data = data[varname]
    if len(data) == 0:
        data = None
    return data, fmt


def load_dcal_var(data, varname, fmt, index):
    i = index
    if fmt == "mat":
        if index is None:
            i = 0
        var = data[0][varname][i].squeeze()
    elif index is None:
        var = data[varname][0:].squeeze()
    else:
        var = data[data[varname][i][0]][0:].flatten()
    return var


def load_dcal_parameter(data, varname, fmt, index):
    i = index
    if fmt == "mat":
        data = data[0]
        flds = list(data["Parameter"][i].dtype.names)
        if varname in flds:
            parameter = data["Parameter"][i][varname][0][0]
        else:
            parameter = None
    else:
        tmp = data[data["Parameter"][i][0]]
        flds = tmp.keys()
        if varname in flds:
            parameter = tmp[varname][0]
        else:
            parameter = None
    return parameter


[docs] def list_methods(obj): """Returns a list of the methods of an object. `Reference <https://stackoverflow.com/questions/1911281/how-do-i-get-list-of-methods-in-a-python-class>`_ """ cls = obj.__class__ methods = [func for func in dir(cls) if callable(getattr(cls, func))] return methods
[docs] def update_recursive(d, u): """Update a dictionary recursively. `Source <https://stackoverflow.com/questions/3232943/update-value-of-a-nested-dictionary-of-varying-depth>`_ Args: d (dict): original dict u (dict): new dict Returns: dict: updated dict """ import collections.abc for k, v in u.items(): if isinstance(v, collections.abc.Mapping): d[k] = update_recursive(d.get(k, {}), v) else: d[k] = v return d
def print_to_console_and_file(fileobj, line): print(line) fileobj.write(line + "\n")
[docs] def read_field(filename, fieldname, convert=True): """ Read a field from an HDF5 file. Args: filename (str): Path the the HDF5 file. For example, "nano_scf_out.h5". fieldname (str): Path of the field in the HDF5 file. For example, "potential/effective". Returns: fld (ndarray): 3D numpy array containing the field. """ f = h5py.File(filename, mode="r") fld = f[fieldname][0:] fld = np.transpose(fld, [i for i in range(fld.ndim - 1, -1, -1)]) fld = np.asfortranarray(fld) if not convert: return fld if re.match("potential", fieldname): fld = fld * ureg.hartree fld.ito("eV") elif re.match("density", fieldname): fld = fld / ureg.bohr**3 fld.ito("angstrom ** -3") elif re.match("ldos", fieldname): fld = fld / ureg.bohr**3 fld.ito("angstrom ** -3") elif re.search("wavefunction", fieldname): fld = fld / ureg.bohr**1.5 fld = fld[::2, :] + 1j * fld[1::2, :] fld.ito("angstrom ** -1.5") else: raise Exception("Unknown field type.") return fld
[docs] def get_chemical_symbols(): """Returns an ordered list of atomic species.""" chemical_symbols = [ # not found "X", # row-1 "H", "He", # row-2 "Li", "Be", "B", "C", "N", "O", "F", "Ne", # row-3 "Na", "Mg", "Al", "Si", "P", "S", "Cl", "Ar", # row-4 "K", "Ca", "Sc", "Ti", "V", "Cr", "Mn", "Fe", "Co", "Ni", "Cu", "Zn", "Ga", "Ge", "As", "Se", "Br", "Kr", # row-5 "Rb", "Sr", "Y", "Zr", "Nb", "Mo", "Tc", "Ru", "Rh", "Pd", "Ag", "Cd", "In", "Sn", "Sb", "Te", "I", "Xe", # row-6 "Cs", "Ba", "La", "Ce", "Pr", "Nd", "Pm", "Sm", "Eu", "Gd", "Tb", "Dy", "Ho", "Er", "Tm", "Yb", "Lu", "Hf", "Ta", "W", "Re", "Os", "Ir", "Pt", "Au", "Hg", "Tl", "Pb", "Bi", "Po", "At", "Rn", # row-7 "Fr", "Ra", "Ac", "Th", "Pa", "U", "Np", "Pu", "Am", "Cm", "Bk", "Cf", "Es", "Fm", "Md", "No", "Lr", "Rf", "Db", "Sg", "Bh", "Hs", "Mt", "Ds", "Rg", "Cn", "Nh", "Fl", "Mc", "Lv", "Ts", "Og", ] return chemical_symbols