Source code for mlipx.nodes.vibrational_analysis

import pathlib
import typing as t

import ase.io
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import zntrack
from ase import units
from ase.constraints import FixAtoms
from ase.thermochemistry import HarmonicThermo, IdealGasThermo
from ase.vibrations import Vibrations

# from copy import deepcopy
from tqdm import tqdm

from mlipx.abc import ComparisonResults, NodeWithCalculator


[docs] class VibrationalAnalysis(zntrack.Node): """ Vibrational Analysis Node This node performs vibrational analysis on the provided images. Parameters ---------- data : list[ase.Atoms] List of images to perform vibrational analysis on. model : NodeWithCalculator Model node with calculator to perform vibrational analysis. displacement : float Displacement for vibrational analysis. nfree : int Number of free atoms. lower_freq_threshold : float Lower frequency threshold. frames_path : pathlib.Path Path to save frames. modes_path : pathlib.Path Path to save vibrational modes. modes_cache : pathlib.Path Path to save modes cache. vib_cache : pathlib.Path Path to save vibrational cache. Attributes ---------- results : pd.DataFrame Results of vibrational analysis. """ data: list[ase.Atoms] = zntrack.deps() # image_ids: list[int] = zntrack.params() model: NodeWithCalculator = zntrack.deps() # adding more parameters # n_images: int = zntrack.params(5) # fmax: float = zntrack.params(0.09) displacement: float = zntrack.params(0.01) nfree: int = zntrack.params(4) temperature: float = zntrack.params(298.15) # in Kelvin lower_freq_threshold: float = zntrack.params(12.0) frames_path: pathlib.Path = zntrack.outs_path(zntrack.nwd / "frames.xyz") modes_path: pathlib.Path = zntrack.outs_path(zntrack.nwd / "modes.xyz") modes_cache: pathlib.Path = zntrack.outs_path(zntrack.nwd / "modes") vib_cache: pathlib.Path = zntrack.outs_path(zntrack.nwd / "vib") results: pd.DataFrame = zntrack.plots(y="ddG", x="Frame") free_indices: list[int] = zntrack.params(None) # by default freeze no index system: ( t.Literal["molecule"] | t.Literal["other"] | t.Literal["linear-molecule"] | t.Literal["isolated-atom"] | None ) = zntrack.params(None) calc_type: t.Literal["ts"] | t.Literal["relax"] | None = zntrack.params(None) def run(self): # noqa C901 # frames = [] # molecules = {} calc = self.model.get_calculator() results = [] # {"Frame": [], "ddG_300k": []} modes = [] for current_frame, atoms in tqdm(enumerate(self.data)): # these type/molecule checks should go into a separate node. if self.system is None: try: system = atoms.info["type"] # raises IndexError if neither is set except KeyError: raise KeyError( "Unable to determine system type from `atoms.info`." " Please set the 'system' parameter." ) else: system = self.system print(system) if self.free_indices is None: if "free_indices" in atoms.info: free_indices = atoms.info["free_indices"] else: free_indices = list(range(len(atoms))) else: free_indices = self.free_indices print(free_indices) if self.calc_type is None: if "calc_type" in atoms.info: calc_type = atoms.info["calc_type"] else: calc_type = "relax" else: calc_type = self.calc_type print(calc_type) # if ( # "type" not in atoms.info # or "calc_type" not in atoms.info # or "free_indices" not in atoms.info # # or atoms.info["type"].lower() not in ["slab+adsorbate", "slab+ads"] # ): # continue cache = self.vib_cache / f"{current_frame}" cache.mkdir(parents=True, exist_ok=True) modes_cache = self.modes_cache / f"{current_frame}" modes_cache.mkdir(parents=True, exist_ok=True) constraints = [i for i, j in enumerate(atoms) if i not in free_indices] c = FixAtoms(constraints) atoms.constraints = c atoms.calc = calc _ = atoms.get_potential_energy() _ = atoms.get_forces() # fmax = np.linalg.norm(f, axis=1).max() vib = Vibrations( atoms, nfree=self.nfree, name=cache, delta=self.displacement, indices=free_indices, ) vib.run() _freq = vib.get_frequencies() freq = [ i if i > self.lower_freq_threshold else complex(self.lower_freq_threshold) for i in _freq ] if calc_type.lower() == "ts": freq = freq[1:] if system.lower() in [ "mol", "molecule", "linear-molecule", "isolated-atom", ]: if system.lower() == "linear-molecule": freq = freq[5:] geometry = "linear" elif system.lower() == "isolated-atom": freq = [] geometry = "monatomic" else: freq = freq[6:] geometry = "nonlinear" vib_energies = [i * 0.0001239843 for i in freq] symm_number = 1 p_pascal = 1e5 spin = 0 if "symmetry_number" in atoms.info: symm_number = atoms.info["symmetry_number"] if "pressure" in atoms.info: p_pascal = atoms.info["pressure"] * 1e5 if "spin" in atoms.info: spin = atoms.info["spin"] thermo = IdealGasThermo( atoms=atoms, vib_energies=vib_energies, geometry=geometry, potentialenergy=0.0, symmetrynumber=symm_number, spin=spin, ) dg_Tk = thermo.get_gibbs_energy( self.temperature, p_pascal, verbose=True ) else: vib_energies = [i * 0.0001239843 for i in freq] thermo = HarmonicThermo(vib_energies=vib_energies, potentialenergy=0.0) dg_Tk = thermo.get_helmholtz_energy(self.temperature, verbose=True) atoms.info[f"dg_{self.temperature}k"] = dg_Tk # results["Frame"].append(current_frame) # results["ddG_300k"].append(dg_300k) results.append({"Frame": current_frame, "ddG": dg_Tk}) for temp in np.linspace(10, 1000, 10): if system.lower() in [ "mol", "molecule", "linear-molecule", "isolated-atom", ]: dg = thermo.get_gibbs_energy(temp, p_pascal, verbose=True) else: dg = thermo.get_helmholtz_energy(temp, verbose=True) atoms.info[f"dg_{temp:.1f}k"] = dg # vibenergies=vib.get_energies() # vib.summary(log='vib.txt') # for mode in range(len(vibindices)*3): # vib.write_mode(mode) # molecule vibrations disabled for now # molecule = atoms.copy()[atoms.info["molecule_indices"]] # if molecule.get_chemical_formula() not in molecules: # molecule.calc = calc # molecules[molecule.get_chemical_formula()] = [] # frames += [atoms] ase.io.write(self.frames_path, atoms, append=True) for mode in range(len(free_indices) * 3): mode_cache = modes_cache / f"mode_{mode}.traj" kT = units.kB * self.temperature with ase.io.Trajectory(mode_cache, "w") as traj: for image in vib.get_vibrations().iter_animated_mode( mode, temperature=kT, frames=30 ): traj.write(image) vib_mode = ase.io.read(mode_cache, index=":") modes += vib_mode # vib.write_mode(mode) ase.io.write(self.modes_path, modes) self.results = pd.DataFrame(results) # ase.io.write(self.frames_path, frames) # run the NEB using self.data, self.image_ids, self.n_images # save the trajectroy to self.frames_path # # ase.io.write(self.frames_path, self.data) @property def frames(self) -> list[ase.Atoms]: with self.state.fs.open(self.frames_path, "r") as f: return list(ase.io.iread(f, format="extxyz")) @property def modes(self) -> list[ase.Atoms]: with self.state.fs.open(self.modes_path, "r") as f: return list(ase.io.iread(f, format="extxyz")) @property def figures(self) -> dict[str, go.Figure]: # plotter = PDPlotter(self.pd) # fig = plotter.get_plot() fig = px.line(self.results, x="Frame", y="ddG", markers=True) fig.update_layout( title=f"Gibbs Free Energy at {self.temperature}K", xaxis_title="Frame", yaxis_title="ddG / eV", ) fig.update_traces(customdata=np.stack([np.arange(len(self.results))], axis=-1)) return {"Gibbs": fig} @staticmethod def compare(*nodes: "VibrationalAnalysis") -> ComparisonResults: frames = sum([node.frames for node in nodes], []) offset = 0 fig = go.Figure() # px.scatter() for i, node in enumerate(nodes): fig.add_trace( go.Scatter( x=node.results["Frame"], y=node.results["ddG"], mode="lines+markers", name=node.name, customdata=np.stack( [np.arange(len(node.results["ddG"])) + offset], axis=1 ), ) ) offset += len(node.results["ddG"]) temperature = node.temperature fig.update_layout( title=f"Comparison of Gibbs Free Energies at {temperature}K", xaxis_title="Frame", yaxis_title="ddG / eV", plot_bgcolor="rgba(0, 0, 0, 0)", paper_bgcolor="rgba(0, 0, 0, 0)", ) fig.update_xaxes( showgrid=True, gridwidth=1, gridcolor="rgba(120, 120, 120, 0.3)", zeroline=False, ) fig.update_yaxes( showgrid=True, gridwidth=1, gridcolor="rgba(120, 120, 120, 0.3)", zeroline=False, ) return ComparisonResults(frames=frames, figures={"Gibbs-Comparison": fig})