Source code for mlipx.nodes.formation_energy
import typing as t
import ase
import pandas as pd
import zntrack
from tqdm import tqdm, trange
from mlipx.abc import ASEKeys, NodeWithCalculator
from mlipx.utils import rmse
[docs]
class CalculateFormationEnergy(zntrack.Node):
"""
Calculate formation energy.
Parameters
----------
data : list[ase.Atoms]
ASE atoms object with appropriate tags in info
"""
data: list[ase.Atoms] = zntrack.deps()
model: t.Optional[NodeWithCalculator] = zntrack.deps(None)
formation_energy: list = zntrack.outs(independent=True)
isolated_energies: dict = zntrack.outs(independent=True)
plots: pd.DataFrame = zntrack.plots(
y=["eform", "n_atoms"], independent=True, autosave=True
)
def get_isolated_energies(self) -> dict[str, float]:
# get all unique elements
isolated_energies = {}
for atoms in tqdm(self.data, desc="Getting isolated energies"):
for element in set(atoms.get_chemical_symbols()):
if self.model is None:
if element not in isolated_energies:
isolated_energies[element] = atoms.info[
ASEKeys.isolated_energies.value
][element]
else:
assert (
isolated_energies[element]
== atoms.info[ASEKeys.isolated_energies.value][element]
)
else:
if element not in isolated_energies:
box = ase.Atoms(
element,
positions=[[50, 50, 50]],
cell=[100, 100, 100],
pbc=True,
)
box.calc = self.model.get_calculator()
isolated_energies[element] = box.get_potential_energy()
return isolated_energies
def run(self):
self.formation_energy = []
self.isolated_energies = self.get_isolated_energies()
plots = []
for atoms in self.data:
chem = atoms.get_chemical_symbols()
reference_energy = 0
for element in chem:
reference_energy += self.isolated_energies[element]
E_form = atoms.get_potential_energy() - reference_energy
self.formation_energy.append(E_form)
plots.append({"eform": E_form, "n_atoms": len(atoms)})
self.plots = pd.DataFrame(plots)
@property
def frames(self):
for atom, energy in zip(self.data, self.formation_energy):
atom.info[ASEKeys.formation_energy.value] = energy
return self.data
[docs]
class CompareFormationEnergy(zntrack.Node):
data: CalculateFormationEnergy = zntrack.deps()
reference: CalculateFormationEnergy = zntrack.deps()
plots: pd.DataFrame = zntrack.plots(autosave=True)
rmse: dict = zntrack.metrics()
error: dict = zntrack.metrics()
def run(self):
eform_rmse = rmse(self.data.plots["eform"], self.reference.plots["eform"])
# e_rmse = rmse(self.data.plots["energy"], self.reference.plots["energy"])
self.rmse = {
"eform": eform_rmse,
"eform_per_atom": eform_rmse / len(self.data.plots),
}
all_plots = []
for row_idx in trange(len(self.data.plots)):
plots = {}
plots["adjusted_eform_error"] = (
self.data.plots["eform"].iloc[row_idx] - eform_rmse
) - self.reference.plots["eform"].iloc[row_idx]
plots["adjusted_eform"] = (
self.data.plots["eform"].iloc[row_idx] - eform_rmse
)
plots["adjusted_eform_error_per_atom"] = (
plots["adjusted_eform_error"] / self.data.plots["n_atoms"].iloc[row_idx]
)
all_plots.append(plots)
self.plots = pd.DataFrame(all_plots)
# iterate over plots and save min/max
self.error = {}
for key in self.plots.columns:
if "_error" in key:
stripped_key = key.replace("_error", "")
self.error[f"{stripped_key}_max"] = self.plots[key].max()
self.error[f"{stripped_key}_min"] = self.plots[key].min()