import contextlib
import ase
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import tqdm
import zntrack
from ase.calculators.calculator import PropertyNotImplementedError
from mlipx.abc import ComparisonResults
from mlipx.utils import shallow_copy_atoms
def get_figure(
key: str, nodes: list["EvaluateCalculatorResults"], yaxis_title: str
) -> go.Figure:
fig = go.Figure()
for node in nodes:
fig.add_trace(
go.Scatter(
x=node.plots.index,
y=node.plots[key],
mode="lines+markers",
name=node.name.replace(f"_{node.__class__.__name__}", ""),
)
)
fig.update_traces(customdata=np.stack([np.arange(len(node.plots.index))], axis=1))
fig.update_layout(
title=key,
plot_bgcolor="rgba(0, 0, 0, 0)",
paper_bgcolor="rgba(0, 0, 0, 0)",
)
fig.update_xaxes(
showgrid=True,
gridwidth=1,
gridcolor="rgba(120, 120, 120, 0.3)",
zeroline=False,
title="Index",
)
fig.update_yaxes(
showgrid=True,
gridwidth=1,
gridcolor="rgba(120, 120, 120, 0.3)",
zeroline=False,
title=yaxis_title,
)
return fig
[docs]
class EvaluateCalculatorResults(zntrack.Node):
"""
Evaluate the results of a calculator.
Parameters
----------
data : list[ase.Atoms]
List of atoms objects.
"""
data: list[ase.Atoms] = zntrack.deps()
plots: pd.DataFrame = zntrack.plots(
y=["fmax", "fnorm", "energy"], independent=True, autosave=True
)
def run(self):
self.plots = pd.DataFrame()
frame_data = []
for idx in tqdm.tqdm(range(len(self.data))):
atoms = self.data[idx]
forces = atoms.get_forces()
fmax = np.max(np.linalg.norm(forces, axis=1))
fnorm = np.linalg.norm(forces)
energy = atoms.get_potential_energy()
# eform = atoms.info.get(ASEKeys.formation_energy.value, -1)
n_atoms = len(atoms)
# have energy and formation energy in the plot
plots = {
"fmax": fmax,
"fnorm": fnorm,
"energy": energy,
# "eform": eform,
"n_atoms": n_atoms,
"energy_per_atom": energy / n_atoms,
# "eform_per_atom": eform / n_atoms,
}
frame_data.append(plots)
self.plots = pd.DataFrame(frame_data)
@property
def frames(self):
return self.data
def __run_note__(self) -> str:
return f"""# {self.name}
Results from {self.state.remote} at {self.state.rev}.
View the trajectory via zndraw:
```bash
zndraw {self.name}.frames --rev {self.state.rev} --remote {self.state.remote} --url https://app-dev.roqs.basf.net/zndraw_app
```
"""
@property
def figures(self) -> dict:
# TODO: remove index column
plots = {}
for key in self.plots.columns:
fig = px.line(
self.plots,
x=self.plots.index,
y=key,
title=key,
)
fig.update_traces(
customdata=np.stack([np.arange(len(self.plots))], axis=1),
)
plots[key] = fig
return plots
@staticmethod
def compare( # noqa: C901
*nodes: "EvaluateCalculatorResults", reference: str | None = None
) -> ComparisonResults:
# TODO: if reference, shift energies by
# rmse(val, reference) and plot as energy_adjusted
figures = {}
frames_info = {}
for key in nodes[0].plots.columns:
if not all(key in node.plots.columns for node in nodes):
raise ValueError(f"Key {key} not found in all nodes")
# check frames are the same
yaxis_title = key.replace("_error", "")
if "energy" in key:
yaxis_title += " / eV"
elif "fmax" in key or "fnorm" in key or "force" in key:
yaxis_title += " / eV/Å"
else:
yaxis_title += ""
figures[key] = get_figure(key, nodes, yaxis_title=yaxis_title)
for node in nodes:
for key in node.plots.columns:
frames_info[f"{node.name}_{key}"] = node.plots[key].values
# TODO: calculate the rmse difference between a fixed one
# and all the others and shift them accordingly
# and plot as energy_shifted
# plot error between curves
# mlipx pass additional flags to compare function
# have different compare methods also for correlation plots
frames = [shallow_copy_atoms(x) for x in nodes[0].frames]
for key, values in frames_info.items():
for atoms, value in zip(frames, values):
atoms.info[key] = value
for node in nodes:
for node_atoms, atoms in zip(node.frames, frames):
if len(node_atoms) != len(atoms):
raise ValueError("Atoms objects have different lengths")
with contextlib.suppress(RuntimeError, PropertyNotImplementedError):
atoms.info[f"{node.name}_energy"] = (
node_atoms.get_potential_energy()
)
atoms.arrays[f"{node.name}_forces"] = node_atoms.get_forces()
return {
"frames": frames,
"figures": figures,
}