# -*- coding: utf-8 -*-
"""
Created on 2020-06-16
@author: Vincent Michaud-Rioux
"""
from pint import Quantity, Unit
import h5py
import numpy as np
import os
import re
import scipy.io as sio
import pint
ureg = pint.UnitRegistry(system="atomic")
def convert_quantity(atr, unit_system):
units = atr.to_tuple()[1]
newunits = [list(u) for u in units]
for j, u in enumerate(units):
unit = Unit(u[0])
if unit_system == "si":
if unit.is_compatible_with("angstrom"):
newunits[j][0] = "angstrom"
if unit.is_compatible_with("eV"):
newunits[j][0] = "eV"
if unit_system == "atomic":
if unit.is_compatible_with("bohr"):
newunits[j][0] = "bohr"
if unit.is_compatible_with("hartree"):
newunits[j][0] = "hartree"
string = []
for u in newunits:
string.append(u[0] + f"**{u[1]}")
string = "*".join(string)
atr.ito(string)
def to_quantity(x, units=None, allow_none=True, allow_string=False, shape=None):
if x is None and not allow_none:
raise Exception("Entry x cannot be None.")
if isinstance(x, str) and not allow_string:
raise Exception("Entry x cannot be a string.")
if isinstance(x, list):
if all([isinstance(e, Quantity) for e in x]):
x = Quantity.from_list(x)
if x is None or isinstance(x, Quantity):
pass
elif isinstance(x, str):
if units is None:
raise Exception("x cannot be a string without units.")
return (x, units)
elif isinstance(x, dict):
if x["__class__"] == "Quantity":
return to_quantity(
x["magnitude"],
x["units"],
allow_string=allow_string,
allow_none=allow_none,
shape=shape,
)
else:
cls = x["__class__"]
raise Exception(f"Invalid class {cls}.")
elif isinstance(x, tuple):
return to_quantity(
x[0],
units=x[1],
allow_string=allow_string,
allow_none=allow_none,
shape=shape,
)
elif units is None:
raise Exception("Invalid None units.")
else:
x *= ureg(units)
if shape is not None:
if isinstance(x, list):
x = Quantity.from_list(x)
x = np.reshape(x.astype(float), shape, order="F")
return x
def add_ext(s, ext):
return os.path.splitext(s)[0] + "." + ext
def dict_converter(d, obj):
if isinstance(d, dict):
return obj(**d)
if isinstance(d, obj):
return d
else:
raise TypeError(f"Object of {d.__class__} must be of {obj}")
def is_array_like(var):
ial = isinstance(var, np.ndarray)
ial = ial or isinstance(var, list)
ial = ial or isinstance(var, tuple)
ial = ial or (isinstance(var, Quantity) and is_array_like(var.m))
return ial
def is_row_vector(a, len=3):
na = np.array(a, dtype=object)
if na.size != len:
return False
rv = True
c = 0
while rv and c < na.size:
if not isinstance(na[c], float):
rv = False
c += 1
return rv
def fermi_dirac(dE_in, T_in):
# dE: eV
# T : Kelvin
dE = to_quantity(dE_in, "eV")
T = to_quantity(T_in, "kelvin")
kb = 8.617333262e-5 * ureg.eV / ureg.kelvin
kT = kb * T
de_kt = dE / kT
if de_kt > 10.0:
return 0.0
elif de_kt < -10.0:
return 1.0
else:
return 1.0 / (1.0 + np.exp(de_kt))
def load_dcal(filename, varname=None):
try:
fmt = "mat"
data = sio.loadmat(filename)
data = data["data"][0]
if varname is not None:
data = data[varname][0]
except:
fmt = "h5"
data = h5py.File(filename, "r")
data = data["data"]
if varname is not None:
if varname not in data.keys():
raise Exception(f"{varname} not found in {filename}")
if "MATLAB_empty" in data[varname].attrs.keys():
if data[varname].attrs.get("MATLAB_empty"):
return None, fmt
data = data[varname]
if len(data) == 0:
data = None
return data, fmt
def load_dcal_var(data, varname, fmt, index):
i = index
if fmt == "mat":
if index is None:
i = 0
var = data[0][varname][i].squeeze()
elif index is None:
var = data[varname][0:].squeeze()
else:
var = data[data[varname][i][0]][0:].flatten()
return var
def load_dcal_parameter(data, varname, fmt, index):
i = index
if fmt == "mat":
data = data[0]
flds = list(data["Parameter"][i].dtype.names)
if varname in flds:
parameter = data["Parameter"][i][varname][0][0]
else:
parameter = None
else:
tmp = data[data["Parameter"][i][0]]
flds = tmp.keys()
if varname in flds:
parameter = tmp[varname][0]
else:
parameter = None
return parameter
[docs]
def list_methods(obj):
"""Returns a list of the methods of an object.
`Reference <https://stackoverflow.com/questions/1911281/how-do-i-get-list-of-methods-in-a-python-class>`_
"""
cls = obj.__class__
methods = [func for func in dir(cls) if callable(getattr(cls, func))]
return methods
[docs]
def update_recursive(d, u):
"""Update a dictionary recursively.
`Source <https://stackoverflow.com/questions/3232943/update-value-of-a-nested-dictionary-of-varying-depth>`_
Args:
d (dict): original dict
u (dict): new dict
Returns:
dict: updated dict
"""
import collections.abc
for k, v in u.items():
if isinstance(v, collections.abc.Mapping):
d[k] = update_recursive(d.get(k, {}), v)
else:
d[k] = v
return d
def print_to_console_and_file(fileobj, line):
print(line)
fileobj.write(line + "\n")
[docs]
def read_field(filename, fieldname, convert=True):
"""
Read a field from an HDF5 file.
Args:
filename (str):
Path the the HDF5 file. For example, "nano_scf_out.h5".
fieldname (str):
Path of the field in the HDF5 file. For example, "potential/effective".
Returns:
fld (ndarray):
3D numpy array containing the field.
"""
f = h5py.File(filename, mode="r")
fld = f[fieldname][0:]
fld = np.transpose(fld, [i for i in range(fld.ndim - 1, -1, -1)])
fld = np.asfortranarray(fld)
if not convert:
return fld
if re.match("potential", fieldname):
fld = fld * ureg.hartree
fld.ito("eV")
elif re.match("density", fieldname):
fld = fld / ureg.bohr**3
fld.ito("angstrom ** -3")
elif re.match("ldos", fieldname):
fld = fld / ureg.bohr**3
fld.ito("angstrom ** -3")
elif re.search("wavefunction", fieldname):
fld = fld / ureg.bohr**1.5
fld = fld[::2, :] + 1j * fld[1::2, :]
fld.ito("angstrom ** -1.5")
else:
raise Exception("Unknown field type.")
return fld
[docs]
def get_chemical_symbols():
"""Returns an ordered list of atomic species."""
chemical_symbols = [
# not found
"X",
# row-1
"H",
"He",
# row-2
"Li",
"Be",
"B",
"C",
"N",
"O",
"F",
"Ne",
# row-3
"Na",
"Mg",
"Al",
"Si",
"P",
"S",
"Cl",
"Ar",
# row-4
"K",
"Ca",
"Sc",
"Ti",
"V",
"Cr",
"Mn",
"Fe",
"Co",
"Ni",
"Cu",
"Zn",
"Ga",
"Ge",
"As",
"Se",
"Br",
"Kr",
# row-5
"Rb",
"Sr",
"Y",
"Zr",
"Nb",
"Mo",
"Tc",
"Ru",
"Rh",
"Pd",
"Ag",
"Cd",
"In",
"Sn",
"Sb",
"Te",
"I",
"Xe",
# row-6
"Cs",
"Ba",
"La",
"Ce",
"Pr",
"Nd",
"Pm",
"Sm",
"Eu",
"Gd",
"Tb",
"Dy",
"Ho",
"Er",
"Tm",
"Yb",
"Lu",
"Hf",
"Ta",
"W",
"Re",
"Os",
"Ir",
"Pt",
"Au",
"Hg",
"Tl",
"Pb",
"Bi",
"Po",
"At",
"Rn",
# row-7
"Fr",
"Ra",
"Ac",
"Th",
"Pa",
"U",
"Np",
"Pu",
"Am",
"Cm",
"Bk",
"Cf",
"Es",
"Fm",
"Md",
"No",
"Lr",
"Rf",
"Db",
"Sg",
"Bh",
"Hs",
"Mt",
"Ds",
"Rg",
"Cn",
"Nh",
"Fl",
"Mc",
"Lv",
"Ts",
"Og",
]
return chemical_symbols