import os
import attr
from typing import Union
from ase import Atoms, io
from ase.io import read
from ase.visualize import view
import matplotlib.pyplot as plt
from ase.constraints import FixAtoms
from ase.constraints import StrainFilter
from ase.optimize import BFGS
import warnings
from chgnet.model.model import CHGNet
from chgnet.model.dynamics import CHGNetCalculator
import attr
[docs]
@attr.s
class NBRelaxer:
"""
``NBRelaxer`` class.
The ``NBRelaxer`` class is used for the relaxation of atomic structures of general material systems.
It uses the pretrained universal neural network potential, CHGNet, to accurately predict atomic forces and energies,
facilitating the structural relaxation process using the Atomic Simulation Environment (ASE).
Attributes:
atomcell (AtomCell):
The AtomCell object representing the atomic structure.
relaxation_type (str): The type of relaxation to perform. Available options are:
"atom_partial_relax": Relax the material while fixing certain atoms.
The indices of the atoms to be fixed are specified using the ``fix_atoms_index`` parameter.
By default, the first atom is fixed.
"cell_fractional_relax": Relax the unit cell while keeping the scaled positions fixed.
"unconstrained_relax": Perform a full relaxation without any constraints.
tolerance (float):
The atomic force tolerance for the relaxation. The unit is eV/Å.
max_steps (int):
The maximum number of optimization steps to perform.
fix_atoms_index (list):
The indices of atoms to fix during the relaxation. If not given, by default, the first atom is fixed.
"""
atomcell = attr.ib()
relaxation_type = attr.ib(default = "atom_partial_relax")
tolerance = attr.ib(default = 0.05)
max_steps = attr.ib(default = 100)
fix_atoms_index = attr.ib(default = [0])
ase_atoms = attr.ib(default = None)
atoms_relaxed = attr.ib(default = None)
[docs]
def parse_xyz_file(self, atom_cell):
"""Parses the AtomCell object to an ASE Atoms object.
Args:
atom_cell (AtomCell): An AtomCell object.
Returns:
Atoms: An ASE Atoms object. Returns None if `atom_cell` is None.
Raises:
ValueError: If `atom_cell.pbc` is not a list of three booleans.
"""
if atom_cell is None:
return None
if atom_cell.pbc is None:
atom_cell.pbc = [True, True, True]
if not isinstance(atom_cell.pbc, list) or len(atom_cell.pbc) != 3 or not all(isinstance(x, bool) for x in atom_cell.pbc):
raise ValueError(f"Invalid atom_cell.pbc. Expected a list of three booleans. Got: {atom_cell.pbc}")
symbols = atom_cell.formula
if atom_cell.positions is not None:
positions = atom_cell.positions.to("angstrom").m
atoms_updated = Atoms(symbols=symbols, positions=positions, cell=atom_cell.avec.to("angstrom").m, pbc=atom_cell.pbc)
if atom_cell.fractional_positions is not None:
fractional_positions = atom_cell.fractional_positions
atoms_updated = Atoms(symbols=symbols, scaled_positions=fractional_positions, cell=atom_cell.avec.to("angstrom").m, pbc=atom_cell.pbc)
return atoms_updated
[docs]
def relax_position(self, calculator):
"""Relax the position of atoms using the provided calculator.
Args:
calculator: The calculator to use for the relaxation.
Returns:
None
"""
self.ase_atoms.set_calculator(calculator)
# indices of fixed atoms
fix_indices=self.fix_atoms_index
constraint = FixAtoms(indices=fix_indices)
self.ase_atoms.set_constraint(constraint)
# Perform relaxation
optimizer = BFGS(self.ase_atoms, trajectory='relax.traj', logfile='relax.log')
optimizer.run(fmax=self.tolerance, steps=self.max_steps)
# Retrieve the relaxed atoms
self.atoms_relaxed = self.ase_atoms.copy()
[docs]
def relax_strain(self, calculator):
"""Relax the unit cell until the stress is zero while keeping the scaled positions fixed using the provided calculator.
Args:
calculator: The calculator to use for the relaxation.
Returns:
None
"""
self.ase_atoms.set_calculator(calculator)
sf = StrainFilter(self.ase_atoms)
# Perform relaxation
optimizer = BFGS(sf, trajectory='relax.traj', logfile='relax.log')
optimizer.run(fmax=self.tolerance, steps=self.max_steps)
# Retrieve the relaxed atoms
self.atoms_relaxed = self.ase_atoms.copy()
[docs]
def relax_full(self, calculator):
"""Perform a full relaxation without any constraints using the provided calculator.
Args:
calculator: The calculator to use for the relaxation.
Returns:
None
"""
self.ase_atoms.set_calculator(calculator)
# Perform relaxation
optimizer = BFGS(self.ase_atoms, trajectory='relax.traj', logfile='relax.log')
optimizer.run(fmax=self.tolerance, steps=self.max_steps)
# Retrieve the relaxed atoms
self.atoms_relaxed = self.ase_atoms.copy()
[docs]
def relax(self, use_device="cpu"):
"""Perform relaxation based on the relaxation type.
Args:
use_device (str, optional): The device to use for the calculation. Defaults to "cpu".
Returns:
None
"""
# convert atomcell to ase atoms
self.ase_atoms = self.parse_xyz_file(self.atomcell)
# set up calculator, fix atoms, perturbate the structure, set up force tolerance, max steps
chgnet = CHGNet.load()
calculator = CHGNetCalculator(model=chgnet, use_device=use_device)
relaxation_methods = {
"atom_partial_relax": self.relax_position,
"cell_fractional_relax": self.relax_strain,
"unconstrained_relax": self.relax_full
}
relaxation_method = relaxation_methods.get(self.relaxation_type)
if relaxation_method is not None:
relaxation_method(calculator)
else:
warnings.warn(f"Unrecognized relaxation type: {self.relaxation_type}.\nChoose from 'atom_partial_relax', 'cell_fractional_relax', 'unconstrained_relax'.")
if self.atoms_relaxed:
print("Relaxation finished.")
[docs]
def write(self, filename: Union[str, bytes, os.PathLike, None] = None, output_format: str = "xyz"):
"""Write the relaxed atoms to a file in a specified format.
Args:
filename (os.Pathlike, optional): The filename of the file. Defaults to "relaxed".
output_format (str, optional): The format to write the file in. Defaults to "xyz".
Raises:
ValueError: If `self.atoms_relaxed` is None or if `output_format` is not a valid output format.
Returns:
None
"""
if filename is None:
filename = "relaxed"
if self.atoms_relaxed is None:
raise ValueError("No relaxed atoms to write.")
if output_format in io.formats.ioformats.keys():
filename = f"{str(filename)}.{output_format}"
io.write(filename, self.atoms_relaxed, format=output_format)
print(f"Atoms written to {filename}")
else:
raise ValueError("Invalid output format.")
[docs]
def plot_energy(self, filename: Union[str, bytes, os.PathLike, None] = None, show=False, trajectory_filename: str = 'relax.traj'):
"""Plot the energy during the optimization and save the plot to a file.
Args:
filename (os.Pathlike, optional): The filename of the file to save the plot to. Defaults to "Energy".
show (bool, optional): Whether to display the plot. Defaults to False.
trajectory_filename (str): The filename of the trajectory file to load.
Raises:
FileNotFoundError: If the trajectory file is not found.
Returns:
None
"""
if filename is None:
filename = "Energy"
try:
# Load the trajectory file
trajectory = read(trajectory_filename, index=':')
# plot the energy during the optimization
energies = [atoms.get_potential_energy() for atoms in trajectory]
steps = range(len(energies))
plt.figure(figsize=(8, 6))
plt.plot(steps, energies, 'o-', label='Potential energy')
plt.xlabel('Optimization Step')
plt.ylabel('Potential energy (eV)')
plt.legend()
plt.title('Potential energy during optimization')
plt.savefig(f"{filename}")
print(f"Potential energy during optimization saved to {filename}.png")
if show:
plt.show()
except FileNotFoundError as e:
print(f"Warning: Trajectory file not found: {e}")
sys.exit(1)
[docs]
def visualize_trajectory(self, trajectory_filename: str = 'relax.traj'):
"""Attempts to read a trajectory file and display it.
Args:
trajectory_filename (str): The filename of the trajectory file to load.
Raises:
FileNotFoundError: If the trajectory file is not found, a warning message is printed and the program exits with status code 1.
Returns:
None
"""
try:
trajectory = read(trajectory_filename, index=':')
view(trajectory)
except FileNotFoundError as e:
print(f"Warning: Trajectory file not found: {e}")
sys.exit(1)