import os
from typing import Union
import attr
import sys
from ase import Atoms, io
from ase.io import read
from ase.visualize import view
import matplotlib.pyplot as plt
import numpy as np
from ase.constraints import FixAtoms
from ase.optimize import BFGS
import warnings
from chgnet.model.model import CHGNet
from chgnet.model.dynamics import CHGNetCalculator
# warnings
warnings.filterwarnings("ignore", module="pymatgen")
warnings.filterwarnings("ignore", module="ase")
[docs]
@attr.s
class NBRelaxer:
"""
``NBRelaxer`` class.
The ``NBRelaxer`` class is used for the relaxation of the center region of two-probe systems.
It uses the pretrained universal neural network potential, CHGNet, to accurately predict atomic forces and energies,
facilitating the structural relaxation process using the Atomic Simulation Environment (ASE).
Attributes:
atomcell_L (AtomCell):
The AtomCell object representing the atomic structure of the left lead.
atomcell_R (AtomCell):
The AtomCell object representing the atomic structure of the right lead.
atomcell_C (AtomCell):
The AtomCell object representing the atomic structure of the central region.
transport_axis (str):
The transport direction along which the relaxation is performed. Must be "X", "Y", or "Z".
tolerance (float):
The atomic force tolerance for the relaxation. The unit is eV/Å.
max_steps (int):
The maximum number of optimization steps to perform.
fix_2_layer_leads (bool):
A boolean flag to fix more buffer atoms in the leads. Default is False.
"""
atomcell_L = attr.ib()
atomcell_R = attr.ib()
atomcell_C = attr.ib()
transport_axis = attr.ib(type=str, default="Z")
tolerance = attr.ib(type=float, default=0.05)
max_steps = attr.ib(type=int, default=100)
fix_2_layer_leads = attr.ib(type=bool, default=False)
ase_atoms_L = attr.ib(default=None)
ase_atoms_R = attr.ib(default=None)
ase_atoms_C = attr.ib(default=None)
atoms_C_relaxed = attr.ib(default=None)
[docs]
def transport_direction_to_index(self, direction: str):
"""Convert a transport direction ("X", "Y", or "Z") to an index (0, 1, or 2).
Args:
direction (str): The transport direction. Must be "X", "Y", or "Z".
Returns:
int: The index corresponding to the transport direction. 0 for "X", 1 for "Y", and 2 for "Z".
Raises:
ValueError: If `direction` is not "X", "Y", or "Z".
"""
if direction not in ["X", "Y", "Z"]:
raise ValueError(
f"Invalid transport direction. Use 'X', 'Y', or 'Z'. Got: {direction}"
)
return {"X": 0, "Y": 1, "Z": 2}[direction]
[docs]
def generate_pbc_for_direction(self, direction: int):
"""Generate periodic boundary conditions based on the specified direction.
Args:
direction (int): The direction along which the transport is happening.
0 for X, 1 for Y, 2 for Z.
Returns:
list[bool]: The periodic boundary conditions. A list of three booleans,
where the element at the index corresponding to the transport
direction is False, and the others are True.
Raises:
ValueError: If `direction` is not 0, 1, or 2.
"""
if direction not in [0, 1, 2]:
raise ValueError(
f"Invalid direction. Expected 0, 1, or 2. Got: {direction}"
)
pbc = [True, True, True]
pbc[direction] = False
return pbc
[docs]
def sort_atoms_along_direction(self, atoms: Atoms, direction_index: int):
"""Sorts the atomic positions of an ASE Atoms object along a specified direction.
Args:
atoms (Atoms): ASE Atoms object.
direction_index (int): The direction along which to sort the atoms. Either 0, 1, or 2.
Returns:
Atoms: ASE Atoms object with sorted atomic positions.
Raises:
ValueError: If `direction_index` is not 0, 1, or 2.
"""
if direction_index not in [0, 1, 2]:
raise ValueError(
f"Invalid direction index. Expected 0, 1, or 2. Got: {direction_index}"
)
# Sort atoms along the specified direction
sorted_indices = np.argsort(atoms.positions[:, direction_index])
# Create a copy of the original atoms object and update the positions and symbols
sorted_atoms = atoms.copy()
sorted_atoms.positions = atoms.positions[sorted_indices]
sorted_atoms.symbols = atoms.symbols[sorted_indices]
return sorted_atoms
[docs]
def parse_xyz_file(self, atom_cell, pbc=None):
"""Parse the AtomCell object to an ASE Atoms object.
Args:
atom_cell (AtomCell): An AtomCell object.
pbc (list of bool, optional): The periodic boundary conditions. Defaults to [True, True, True].
Returns:
Atoms: An ASE Atoms object, or None if `atom_cell` is None.
Raises:
ValueError: If `pbc` is not a list of three booleans.
"""
if pbc is None:
pbc = [True, True, True]
if (
not isinstance(pbc, list)
or len(pbc) != 3
or not all(isinstance(x, bool) for x in pbc)
):
raise ValueError(
f"Invalid pbc. Expected a list of three booleans. Got: {pbc}"
)
if atom_cell:
symbols = atom_cell.formula
positions = atom_cell.positions.to("angstrom").m
unit_cell = atom_cell.avec.to("angstrom").m
atoms = Atoms(symbols=symbols, positions=positions)
# sort the atoms along the transport direction
atoms = self.sort_atoms_along_direction(
atoms, self.transport_direction_to_index(self.transport_axis)
)
atoms.cell = unit_cell
atoms.pbc = pbc
return atoms
else:
return None
[docs]
def translate_atoms_along_direction(
self, atoms: Atoms, direction: int, translation_distance: float
):
"""Translate the ASE Atoms object along the specified direction.
Args:
atoms (Atoms): ASE Atoms object.
direction (int): The direction along which to translate. Either 0 for X, 1 for Y, or 2 for Z.
translation_distance (float): The distance to translate along the specified direction.
Returns:
Atoms: ASE Atoms object with translated coordinates.
Raises:
ValueError: If `direction` is not 0, 1, or 2.
"""
if direction not in [0, 1, 2]:
raise ValueError(
f"Invalid direction. Expected 0, 1, or 2. Got: {direction}"
)
# Create a copy of the original atoms object and translate the copy
translated_atoms = atoms.copy()
translation_vector = [0, 0, 0]
translation_vector[direction] = translation_distance
translated_atoms.translate(translation_vector)
return translated_atoms
[docs]
def scale_atoms(self, atoms, scaling_factor, transport_direction):
"""Scale the atoms along the specified transport direction.
Args:
atoms (Atoms): ASE Atoms object.
scaling_factor (float): The factor by which to scale the atoms.
transport_direction (int): The direction along which to scale. Either 0 for X, 1 for Y, or 2 for Z.
Returns:
Atoms: ASE Atoms object with scaled coordinates.
"""
scaling_vector = [1, 1, 1]
scaling_vector[transport_direction] = scaling_factor
scaled_atoms = (
atoms.copy()
) # Create a copy to avoid modifying the original object
scaled_atoms *= scaling_vector # Duplicate atoms
return scaled_atoms
[docs]
def check_overlap_side(self, atoms_C_side, atoms_side):
"""Check for overlap between two sets of atoms.
Args:
atoms_C_side (Atoms): First set of ASE Atoms.
atoms_side (Atoms): Second set of ASE Atoms.
Returns:
ndarray: Indices of atoms in `atoms_side` that overlap with any atom in `atoms_C_side`.
"""
return np.where(
np.any(
np.all(
np.isclose(
atoms_C_side.positions[:, np.newaxis, :],
atoms_side.positions,
atol=2e-1,
),
axis=2,
),
axis=0,
)
)[0]
[docs]
def check_overlap(self):
"""Check for overlap between the lead atoms and the buffer atoms in the center region.
Returns:
tuple: A tuple containing two lists:
- list of indices of overlapping atoms between the leads and the buffer atoms in the center region
- list of indices of non-overlapping atoms between the leads and the atoms in the center region
Raises:
ValueError: If `ase_atoms_C`, `ase_atoms_L`, or `ase_atoms_R` is None.
"""
if (
self.ase_atoms_C is None
or self.ase_atoms_L is None
or self.ase_atoms_R is None
):
raise ValueError(
"ase_atoms_C, ase_atoms_L, and ase_atoms_R must not be None"
)
# convert the transport direction to index
trans_indx = self.transport_direction_to_index(self.transport_axis)
# option to fix more buffer atoms
if self.fix_2_layer_leads:
scaling_factor = 2
self.ase_atoms_L = self.scale_atoms(
self.ase_atoms_L, scaling_factor, trans_indx
)
self.ase_atoms_R = self.scale_atoms(
self.ase_atoms_R, scaling_factor, trans_indx
)
# Check if the number of center atoms is larger than the total number of left and right atoms, if not, return []
total_left_right_atoms = len(self.ase_atoms_L) + len(self.ase_atoms_R)
if len(self.ase_atoms_C) < total_left_right_atoms:
return []
left_overlap_indices = self.check_overlap_side(
self.ase_atoms_C[: len(self.ase_atoms_L)], self.ase_atoms_L
)
if len(left_overlap_indices) == len(self.ase_atoms_L):
print("left buffer atoms checked")
right_new = self.ase_atoms_R.copy()
translation_distance = (
self.ase_atoms_C.cell[trans_indx, trans_indx]
- self.ase_atoms_R.cell[trans_indx, trans_indx]
)
right_new = self.translate_atoms_along_direction(
right_new, trans_indx, translation_distance
)
right_overlap_indices = self.check_overlap_side(
self.ase_atoms_C[-len(self.ase_atoms_R) :], right_new
)
right_overlap_indices = (
right_overlap_indices + len(self.ase_atoms_C) - len(self.ase_atoms_R)
)
# If the center region also matches right atoms on the right side, return the common indices
if len(right_overlap_indices) == len(self.ase_atoms_R):
print("right buffer atoms checked")
common_indices = set(left_overlap_indices) | set(right_overlap_indices)
# Extract indices from atoms_C
atoms_C_indices = set(atom.index for atom in self.ase_atoms_C)
# Find indices in atoms_C but not in the given list
indices_not_in_list = atoms_C_indices - common_indices
return list(common_indices), list(indices_not_in_list)
else:
warnings.warn("right buffer atoms not aligned")
return []
else:
warnings.warn("left buffer atoms not aligned")
return []
[docs]
def relax_central_region(self):
"""Relax the central region of the structure.
This method performs relaxation of the central region of the structure by setting up a calculator,
fixing atoms, perturbating the structure, and performing relaxation using the BFGS minimizer of the ASE package.
Raises:
SystemExit: If there are insufficient buffer atoms in the center region.
Returns:
None
"""
# check whether the buffer atoms in the center region contain at least the leads on both sides
try:
indices, indices_not_in_list = self.check_overlap()
except ValueError as e:
print(
f"Warning: Insufficient buffer atoms in the center region. Check the structure. Error: {e}"
)
sys.exit(1)
# set up calculator, fix atoms, perturbate the structure, set up force tolerance, max steps
use_device = "cpu" # "cpu", "cuda", "mps"
chgnet = CHGNet.load()
calculator = CHGNetCalculator(model=chgnet, use_device=use_device)
self.ase_atoms_C.calc = calculator
constraint = FixAtoms(indices=indices)
self.ase_atoms_C.set_constraint(constraint)
# Perform relaxation using BFGS minimizer of the ASE package
optimizer = BFGS(self.ase_atoms_C, trajectory="relax.traj", logfile="relax.log")
optimizer.run(fmax=self.tolerance, steps=self.max_steps)
# Retrieve the relaxed atoms
self.atoms_C_relaxed = self.ase_atoms_C.copy()
[docs]
def relax(self):
"""Relax the structure.
This method converts AtomCell objects to ASE Atoms objects, sets the boundary condition of the center region to False
along the transport direction, parses XYZ files, and relaxes the central region.
Raises:
ValueError: If `atomcell_L`, `atomcell_R`, or `atomcell_C` is None.
Returns:
None
"""
if (
self.atomcell_L is None
or self.atomcell_R is None
or self.atomcell_C is None
):
raise ValueError("atomcell_L, atomcell_R, and atomcell_C must not be None")
# convert AtomCell objects to ASE Atoms objects
# set the boundary condition of center region to False along the transport direction
pbc_C = self.generate_pbc_for_direction(
self.transport_direction_to_index(self.transport_axis)
)
# parse XYZ files
self.ase_atoms_L = self.parse_xyz_file(self.atomcell_L)
self.ase_atoms_R = self.parse_xyz_file(self.atomcell_R)
self.ase_atoms_C = self.parse_xyz_file(self.atomcell_C, pbc=pbc_C)
self.relax_central_region()
if self.atoms_C_relaxed:
print("Relaxation finished")
[docs]
def write(
self,
filename: Union[str, bytes, os.PathLike, None] = None,
output_format: str = "xyz",
):
"""Write the relaxed atoms to a file in a specified format.
Args:
filename (os.Pathlike, optional): The filename of the file. Defaults to "relaxed".
output_format (str, optional): The format to write the file in. Defaults to "xyz".
Raises:
ValueError: If `self.atoms_C_relaxed` is None or if `output_format` is not a valid output format.
Returns:
None
"""
if filename is None:
filename = "relaxed"
if self.atoms_C_relaxed is None:
raise ValueError("No relaxed atoms to write.")
if output_format in io.formats.ioformats.keys():
filename = f"{str(filename)}.{output_format}"
io.write(filename, self.atoms_C_relaxed, format=output_format)
print(f"Atoms written to {filename}")
else:
raise ValueError("Invalid output format.")
[docs]
def plot_energy(
self,
filename: Union[str, bytes, os.PathLike, None] = None,
show=False,
trajectory_filename: str = "relax.traj",
):
"""Plot the energy during the optimization and save the plot to a file.
Args:
filename (os.Pathlike, optional): The filename of the file to save the plot to. Defaults to "Energy".
show (bool, optional): Whether to display the plot. Defaults to False.
trajectory_filename (str): The filename of the trajectory file to load.
Raises:
FileNotFoundError: If the trajectory file is not found.
Returns:
None
"""
if filename is None:
filename = "Potential energy"
try:
# Load the trajectory file
trajectory = read(trajectory_filename, index=":")
# plot the energy during the optimization
energies = [atoms.get_potential_energy() for atoms in trajectory]
steps = range(len(energies))
plt.figure(figsize=(8, 6))
plt.plot(steps, energies, "o-", label="Potential energy")
plt.xlabel("Optimization Step")
plt.ylabel("Potential energy (eV)")
plt.legend()
plt.title("Potential energy during optimization")
plt.savefig(f"{filename}")
print(f"Potential energy optimization saved to {filename}.png")
if show:
plt.show()
except FileNotFoundError as e:
print(f"Warning: Trajectory file not found: {e}")
sys.exit(1)
[docs]
def visualize_trajectory(self, trajectory_filename: str = "relax.traj"):
"""Attempts to read a trajectory file and display it.
Args:
trajectory_filename (str): The filename of the trajectory file to load.
Raises:
FileNotFoundError: If the trajectory file is not found, a warning message is printed and the program exits with status code 1.
Returns:
None
"""
try:
trajectory = read(trajectory_filename, index=":")
view(trajectory)
except FileNotFoundError as e:
print(f"Warning: Trajectory file not found: {e}")
sys.exit(1)