Source code for nanotools.relaxation.two_probe

import os
from typing import Union
import attr
import sys
from ase import Atoms, io
from ase.io import read
from ase.visualize import view
import matplotlib.pyplot as plt
import numpy as np
from ase.constraints import FixAtoms
from ase.optimize import BFGS
import warnings
from chgnet.model.model import CHGNet
from chgnet.model.dynamics import CHGNetCalculator

# warnings
warnings.filterwarnings("ignore", module="pymatgen")
warnings.filterwarnings("ignore", module="ase")


[docs] @attr.s class NBRelaxer: """ ``NBRelaxer`` class. The ``NBRelaxer`` class is used for the relaxation of the center region of two-probe systems. It uses the pretrained universal neural network potential, CHGNet, to accurately predict atomic forces and energies, facilitating the structural relaxation process using the Atomic Simulation Environment (ASE). Attributes: atomcell_L (AtomCell): The AtomCell object representing the atomic structure of the left lead. atomcell_R (AtomCell): The AtomCell object representing the atomic structure of the right lead. atomcell_C (AtomCell): The AtomCell object representing the atomic structure of the central region. transport_axis (str): The transport direction along which the relaxation is performed. Must be "X", "Y", or "Z". tolerance (float): The atomic force tolerance for the relaxation. The unit is eV/Å. max_steps (int): The maximum number of optimization steps to perform. fix_2_layer_leads (bool): A boolean flag to fix more buffer atoms in the leads. Default is False. """ atomcell_L = attr.ib() atomcell_R = attr.ib() atomcell_C = attr.ib() transport_axis = attr.ib(type=str, default="Z") tolerance = attr.ib(type=float, default=0.05) max_steps = attr.ib(type=int, default=100) fix_2_layer_leads = attr.ib(type=bool, default=False) ase_atoms_L = attr.ib(default=None) ase_atoms_R = attr.ib(default=None) ase_atoms_C = attr.ib(default=None) atoms_C_relaxed = attr.ib(default=None)
[docs] def transport_direction_to_index(self, direction: str): """Convert a transport direction ("X", "Y", or "Z") to an index (0, 1, or 2). Args: direction (str): The transport direction. Must be "X", "Y", or "Z". Returns: int: The index corresponding to the transport direction. 0 for "X", 1 for "Y", and 2 for "Z". Raises: ValueError: If `direction` is not "X", "Y", or "Z". """ if direction not in ["X", "Y", "Z"]: raise ValueError( f"Invalid transport direction. Use 'X', 'Y', or 'Z'. Got: {direction}" ) return {"X": 0, "Y": 1, "Z": 2}[direction]
[docs] def generate_pbc_for_direction(self, direction: int): """Generate periodic boundary conditions based on the specified direction. Args: direction (int): The direction along which the transport is happening. 0 for X, 1 for Y, 2 for Z. Returns: list[bool]: The periodic boundary conditions. A list of three booleans, where the element at the index corresponding to the transport direction is False, and the others are True. Raises: ValueError: If `direction` is not 0, 1, or 2. """ if direction not in [0, 1, 2]: raise ValueError( f"Invalid direction. Expected 0, 1, or 2. Got: {direction}" ) pbc = [True, True, True] pbc[direction] = False return pbc
[docs] def sort_atoms_along_direction(self, atoms: Atoms, direction_index: int): """Sorts the atomic positions of an ASE Atoms object along a specified direction. Args: atoms (Atoms): ASE Atoms object. direction_index (int): The direction along which to sort the atoms. Either 0, 1, or 2. Returns: Atoms: ASE Atoms object with sorted atomic positions. Raises: ValueError: If `direction_index` is not 0, 1, or 2. """ if direction_index not in [0, 1, 2]: raise ValueError( f"Invalid direction index. Expected 0, 1, or 2. Got: {direction_index}" ) # Sort atoms along the specified direction sorted_indices = np.argsort(atoms.positions[:, direction_index]) # Create a copy of the original atoms object and update the positions and symbols sorted_atoms = atoms.copy() sorted_atoms.positions = atoms.positions[sorted_indices] sorted_atoms.symbols = atoms.symbols[sorted_indices] return sorted_atoms
[docs] def parse_xyz_file(self, atom_cell, pbc=None): """Parse the AtomCell object to an ASE Atoms object. Args: atom_cell (AtomCell): An AtomCell object. pbc (list of bool, optional): The periodic boundary conditions. Defaults to [True, True, True]. Returns: Atoms: An ASE Atoms object, or None if `atom_cell` is None. Raises: ValueError: If `pbc` is not a list of three booleans. """ if pbc is None: pbc = [True, True, True] if ( not isinstance(pbc, list) or len(pbc) != 3 or not all(isinstance(x, bool) for x in pbc) ): raise ValueError( f"Invalid pbc. Expected a list of three booleans. Got: {pbc}" ) if atom_cell: symbols = atom_cell.formula positions = atom_cell.positions.to("angstrom").m unit_cell = atom_cell.avec.to("angstrom").m atoms = Atoms(symbols=symbols, positions=positions) # sort the atoms along the transport direction atoms = self.sort_atoms_along_direction( atoms, self.transport_direction_to_index(self.transport_axis) ) atoms.cell = unit_cell atoms.pbc = pbc return atoms else: return None
[docs] def translate_atoms_along_direction( self, atoms: Atoms, direction: int, translation_distance: float ): """Translate the ASE Atoms object along the specified direction. Args: atoms (Atoms): ASE Atoms object. direction (int): The direction along which to translate. Either 0 for X, 1 for Y, or 2 for Z. translation_distance (float): The distance to translate along the specified direction. Returns: Atoms: ASE Atoms object with translated coordinates. Raises: ValueError: If `direction` is not 0, 1, or 2. """ if direction not in [0, 1, 2]: raise ValueError( f"Invalid direction. Expected 0, 1, or 2. Got: {direction}" ) # Create a copy of the original atoms object and translate the copy translated_atoms = atoms.copy() translation_vector = [0, 0, 0] translation_vector[direction] = translation_distance translated_atoms.translate(translation_vector) return translated_atoms
[docs] def scale_atoms(self, atoms, scaling_factor, transport_direction): """Scale the atoms along the specified transport direction. Args: atoms (Atoms): ASE Atoms object. scaling_factor (float): The factor by which to scale the atoms. transport_direction (int): The direction along which to scale. Either 0 for X, 1 for Y, or 2 for Z. Returns: Atoms: ASE Atoms object with scaled coordinates. """ scaling_vector = [1, 1, 1] scaling_vector[transport_direction] = scaling_factor scaled_atoms = ( atoms.copy() ) # Create a copy to avoid modifying the original object scaled_atoms *= scaling_vector # Duplicate atoms return scaled_atoms
[docs] def check_overlap_side(self, atoms_C_side, atoms_side): """Check for overlap between two sets of atoms. Args: atoms_C_side (Atoms): First set of ASE Atoms. atoms_side (Atoms): Second set of ASE Atoms. Returns: ndarray: Indices of atoms in `atoms_side` that overlap with any atom in `atoms_C_side`. """ return np.where( np.any( np.all( np.isclose( atoms_C_side.positions[:, np.newaxis, :], atoms_side.positions, atol=2e-1, ), axis=2, ), axis=0, ) )[0]
[docs] def check_overlap(self): """Check for overlap between the lead atoms and the buffer atoms in the center region. Returns: tuple: A tuple containing two lists: - list of indices of overlapping atoms between the leads and the buffer atoms in the center region - list of indices of non-overlapping atoms between the leads and the atoms in the center region Raises: ValueError: If `ase_atoms_C`, `ase_atoms_L`, or `ase_atoms_R` is None. """ if ( self.ase_atoms_C is None or self.ase_atoms_L is None or self.ase_atoms_R is None ): raise ValueError( "ase_atoms_C, ase_atoms_L, and ase_atoms_R must not be None" ) # convert the transport direction to index trans_indx = self.transport_direction_to_index(self.transport_axis) # option to fix more buffer atoms if self.fix_2_layer_leads: scaling_factor = 2 self.ase_atoms_L = self.scale_atoms( self.ase_atoms_L, scaling_factor, trans_indx ) self.ase_atoms_R = self.scale_atoms( self.ase_atoms_R, scaling_factor, trans_indx ) # Check if the number of center atoms is larger than the total number of left and right atoms, if not, return [] total_left_right_atoms = len(self.ase_atoms_L) + len(self.ase_atoms_R) if len(self.ase_atoms_C) < total_left_right_atoms: return [] left_overlap_indices = self.check_overlap_side( self.ase_atoms_C[: len(self.ase_atoms_L)], self.ase_atoms_L ) if len(left_overlap_indices) == len(self.ase_atoms_L): print("left buffer atoms checked") right_new = self.ase_atoms_R.copy() translation_distance = ( self.ase_atoms_C.cell[trans_indx, trans_indx] - self.ase_atoms_R.cell[trans_indx, trans_indx] ) right_new = self.translate_atoms_along_direction( right_new, trans_indx, translation_distance ) right_overlap_indices = self.check_overlap_side( self.ase_atoms_C[-len(self.ase_atoms_R) :], right_new ) right_overlap_indices = ( right_overlap_indices + len(self.ase_atoms_C) - len(self.ase_atoms_R) ) # If the center region also matches right atoms on the right side, return the common indices if len(right_overlap_indices) == len(self.ase_atoms_R): print("right buffer atoms checked") common_indices = set(left_overlap_indices) | set(right_overlap_indices) # Extract indices from atoms_C atoms_C_indices = set(atom.index for atom in self.ase_atoms_C) # Find indices in atoms_C but not in the given list indices_not_in_list = atoms_C_indices - common_indices return list(common_indices), list(indices_not_in_list) else: warnings.warn("right buffer atoms not aligned") return [] else: warnings.warn("left buffer atoms not aligned") return []
[docs] def relax_central_region(self): """Relax the central region of the structure. This method performs relaxation of the central region of the structure by setting up a calculator, fixing atoms, perturbating the structure, and performing relaxation using the BFGS minimizer of the ASE package. Raises: SystemExit: If there are insufficient buffer atoms in the center region. Returns: None """ # check whether the buffer atoms in the center region contain at least the leads on both sides try: indices, indices_not_in_list = self.check_overlap() except ValueError as e: print( f"Warning: Insufficient buffer atoms in the center region. Check the structure. Error: {e}" ) sys.exit(1) # set up calculator, fix atoms, perturbate the structure, set up force tolerance, max steps use_device = "cpu" # "cpu", "cuda", "mps" chgnet = CHGNet.load() calculator = CHGNetCalculator(model=chgnet, use_device=use_device) self.ase_atoms_C.calc = calculator constraint = FixAtoms(indices=indices) self.ase_atoms_C.set_constraint(constraint) # Perform relaxation using BFGS minimizer of the ASE package optimizer = BFGS(self.ase_atoms_C, trajectory="relax.traj", logfile="relax.log") optimizer.run(fmax=self.tolerance, steps=self.max_steps) # Retrieve the relaxed atoms self.atoms_C_relaxed = self.ase_atoms_C.copy()
[docs] def relax(self): """Relax the structure. This method converts AtomCell objects to ASE Atoms objects, sets the boundary condition of the center region to False along the transport direction, parses XYZ files, and relaxes the central region. Raises: ValueError: If `atomcell_L`, `atomcell_R`, or `atomcell_C` is None. Returns: None """ if ( self.atomcell_L is None or self.atomcell_R is None or self.atomcell_C is None ): raise ValueError("atomcell_L, atomcell_R, and atomcell_C must not be None") # convert AtomCell objects to ASE Atoms objects # set the boundary condition of center region to False along the transport direction pbc_C = self.generate_pbc_for_direction( self.transport_direction_to_index(self.transport_axis) ) # parse XYZ files self.ase_atoms_L = self.parse_xyz_file(self.atomcell_L) self.ase_atoms_R = self.parse_xyz_file(self.atomcell_R) self.ase_atoms_C = self.parse_xyz_file(self.atomcell_C, pbc=pbc_C) self.relax_central_region() if self.atoms_C_relaxed: print("Relaxation finished")
[docs] def write( self, filename: Union[str, bytes, os.PathLike, None] = None, output_format: str = "xyz", ): """Write the relaxed atoms to a file in a specified format. Args: filename (os.Pathlike, optional): The filename of the file. Defaults to "relaxed". output_format (str, optional): The format to write the file in. Defaults to "xyz". Raises: ValueError: If `self.atoms_C_relaxed` is None or if `output_format` is not a valid output format. Returns: None """ if filename is None: filename = "relaxed" if self.atoms_C_relaxed is None: raise ValueError("No relaxed atoms to write.") if output_format in io.formats.ioformats.keys(): filename = f"{str(filename)}.{output_format}" io.write(filename, self.atoms_C_relaxed, format=output_format) print(f"Atoms written to {filename}") else: raise ValueError("Invalid output format.")
[docs] def plot_energy( self, filename: Union[str, bytes, os.PathLike, None] = None, show=False, trajectory_filename: str = "relax.traj", ): """Plot the energy during the optimization and save the plot to a file. Args: filename (os.Pathlike, optional): The filename of the file to save the plot to. Defaults to "Energy". show (bool, optional): Whether to display the plot. Defaults to False. trajectory_filename (str): The filename of the trajectory file to load. Raises: FileNotFoundError: If the trajectory file is not found. Returns: None """ if filename is None: filename = "Potential energy" try: # Load the trajectory file trajectory = read(trajectory_filename, index=":") # plot the energy during the optimization energies = [atoms.get_potential_energy() for atoms in trajectory] steps = range(len(energies)) plt.figure(figsize=(8, 6)) plt.plot(steps, energies, "o-", label="Potential energy") plt.xlabel("Optimization Step") plt.ylabel("Potential energy (eV)") plt.legend() plt.title("Potential energy during optimization") plt.savefig(f"{filename}") print(f"Potential energy optimization saved to {filename}.png") if show: plt.show() except FileNotFoundError as e: print(f"Warning: Trajectory file not found: {e}") sys.exit(1)
[docs] def visualize_trajectory(self, trajectory_filename: str = "relax.traj"): """Attempts to read a trajectory file and display it. Args: trajectory_filename (str): The filename of the trajectory file to load. Raises: FileNotFoundError: If the trajectory file is not found, a warning message is printed and the program exits with status code 1. Returns: None """ try: trajectory = read(trajectory_filename, index=":") view(trajectory) except FileNotFoundError as e: print(f"Warning: Trajectory file not found: {e}") sys.exit(1)