Source code for mlipx.nodes.diatomics

import functools
import pathlib
import typing as t

import ase
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import tqdm
import zntrack
from ase.data import atomic_numbers, covalent_radii

from mlipx.abc import ComparisonResults, NodeWithCalculator
from mlipx.utils import freeze_copy_atoms


[docs] class HomonuclearDiatomics(zntrack.Node): """Compute energy-bondlength curves for homonuclear diatomic molecules. Parameters ---------- elements : list[str] List of elements to consider. For example, ["H", "He", "Li"]. model : NodeWithCalculator Node providing the calculator object for the energy calculations. n_points : int, default=100 Number of points to sample for the bond length between min_distance and max_distance. min_distance : float, default=0.5 Minimum bond length to consider in Angstrom. max_distance : float, default=2.0 Maximum bond length to consider in Angstrom. data : list[ase.Atoms]|None Optional list of ase.Atoms. Diatomics for each element in this list will be added to `elements`. model_outs: Path to store the outputs of the model. Some models, like DFT calculators, generate files that will be stored in this path. Attributes ---------- frames : list[ase.Atoms] List of frames with the bond length varied. results : pd.DataFrame DataFrame with the energy values for each bond length. """ model: NodeWithCalculator = zntrack.deps() elements: list[str] = zntrack.params(("H", "He", "Li")) data: list[ase.Atoms] | None = zntrack.deps(None) n_points: int = zntrack.params(100) min_distance: float = zntrack.params(0.5) max_distance: float = zntrack.params(2.0) eq_distance: t.Union[t.Literal["covalent-radiuis"], float] = zntrack.params( "covalent-radiuis" ) frames: list[ase.Atoms] = zntrack.outs() # TODO: change to h5md out results: pd.DataFrame = zntrack.plots() model_outs: pathlib.Path = zntrack.outs_path(zntrack.nwd / "model_outs") def build_molecule(self, element, distance) -> ase.Atoms: return ase.Atoms([element, element], positions=[(0, 0, 0), (0, 0, distance)]) def run(self): self.frames = [] self.results = pd.DataFrame() self.model_outs.mkdir(exist_ok=True, parents=True) (self.model_outs / "mlipx.txt").write_text("Thank you for using MLIPX!") calc = self.model.get_calculator(directory=self.model_outs) e_v = {} elements = set(self.elements) if self.data is not None: for atoms in self.data: elements.update(set(atoms.symbols)) for element in elements: energies = [] if self.eq_distance == "covalent-radiuis": # convert element to atomic number distances = np.linspace( self.min_distance * covalent_radii[atomic_numbers[element]], self.max_distance * covalent_radii[atomic_numbers[element]], self.n_points, ) else: distances = np.linspace( self.min_distance, self.max_distance, self.n_points ) tbar = tqdm.tqdm( distances, desc=f"{element}-{element} bond ({distances[0]:.2f} Å)" ) for distance in tbar: tbar.set_description(f"{element}-{element} bond ({distance:.2f} Å)") molecule = self.build_molecule(element, distance) molecule.calc = calc energies.append(molecule.get_potential_energy()) self.frames.append(freeze_copy_atoms(molecule)) e_v[element] = pd.DataFrame(energies, index=distances, columns=[element]) self.results = functools.reduce( lambda x, y: pd.merge(x, y, left_index=True, right_index=True, how="outer"), e_v.values(), ) @property def figures(self) -> dict: # return a plot for each element plots = {} for element in self.results.columns: fig = go.Figure() fig.add_trace( go.Scatter( x=self.results[element].dropna().index, y=self.results[element].dropna(), mode="lines", ) ) offset = 0 for prev_element in self.results.columns: if prev_element == element: break offset += self.n_points fig.update_traces( customdata=np.stack([np.arange(self.n_points) + offset], axis=1), ) plots[f"{element}-{element} bond"] = fig return plots @classmethod def compare(cls, *nodes: "HomonuclearDiatomics") -> ComparisonResults: """Compare the energy-bondlength curves for homonuclear diatomic molecules. Parameters ---------- nodes : HomonuclearDiatomics Nodes to compare. Returns ------- ComparisonResults Comparison results. """ figures = {} for node in nodes: for element in node.results.columns: # check if a figure for this element already exists if f"{element}-{element} bond" not in figures: # create a line plot and label it with node.name fig = go.Figure() fig.update_layout(title=f"{element}-{element} bond") fig.update_xaxes(title="Distance / Å") fig.update_yaxes(title="Energy / eV") else: fig = figures[f"{element}-{element} bond"] # add a line plot node.results[element] vs node.results.index fig.add_trace( go.Scatter( x=node.results[element].dropna().index, y=node.results[element].dropna(), mode="lines", name=node.name.replace(f"_{cls.__name__}", ""), ) ) offset = 0 for prev_element in node.results.columns: if prev_element == element: break offset += node.n_points fig.update_traces( customdata=np.stack([np.arange(node.n_points) + offset], axis=1), ) fig.update_layout( plot_bgcolor="rgba(0, 0, 0, 0)", paper_bgcolor="rgba(0, 0, 0, 0)", ) fig.update_xaxes( showgrid=True, gridwidth=1, gridcolor="rgba(120, 120, 120, 0.3)", zeroline=False, ) fig.update_yaxes( showgrid=True, gridwidth=1, gridcolor="rgba(120, 120, 120, 0.3)", zeroline=False, ) figures[f"{element}-{element} bond"] = fig # Now with adjusted # check if a figure for this element already exists if f"{element}-{element} bond (adjusted)" not in figures: # create a line plot and label it with node.name fig = go.Figure() fig.update_layout(title=f"{element}-{element} bond") fig.update_xaxes(title="Distance / Å") fig.update_yaxes(title="Adjusted energy / eV") else: fig = figures[f"{element}-{element} bond (adjusted)"] # find the closest to the cov. dist. index to set the energy to zero one_idx = np.abs( node.results[element].dropna().index - covalent_radii[atomic_numbers[element]] ).argmin() # add a line plot node.results[element] vs node.results.index fig.add_trace( go.Scatter( x=node.results[element].dropna().index, y=node.results[element].dropna() - node.results[element].dropna().iloc[one_idx], mode="lines", name=node.name.replace(f"_{cls.__name__}", ""), ) ) offset = 0 for prev_element in node.results.columns: if prev_element == element: break offset += node.n_points fig.update_traces( customdata=np.stack([np.arange(node.n_points) + offset], axis=1), ) fig.update_layout( plot_bgcolor="rgba(0, 0, 0, 0)", paper_bgcolor="rgba(0, 0, 0, 0)", ) fig.update_xaxes( showgrid=True, gridwidth=1, gridcolor="rgba(120, 120, 120, 0.3)", zeroline=False, ) fig.update_yaxes( showgrid=True, gridwidth=1, gridcolor="rgba(120, 120, 120, 0.3)", zeroline=False, ) figures[f"{element}-{element} bond (adjusted)"] = fig return {"frames": nodes[0].frames, "figures": figures}