import dataclasses
import pathlib
import ase.io
import ase.units
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import tqdm
import zntrack
from ase.md import Langevin
from mlipx.abc import (
ComparisonResults,
DynamicsModifier,
DynamicsObserver,
NodeWithCalculator,
NodeWithMolecularDynamics,
)
[docs]
@dataclasses.dataclass
class LangevinConfig:
"""Configure a Langevin thermostat for molecular dynamics.
Parameters
----------
timestep : float
Time step for the molecular dynamics simulation in fs.
temperature : float
Temperature of the thermostat.
friction : float
Friction coefficient of the thermostat.
"""
timestep: float
temperature: float
friction: float
def get_molecular_dynamics(self, atoms) -> Langevin:
return Langevin(
atoms,
timestep=self.timestep * ase.units.fs,
temperature_K=self.temperature,
friction=self.friction,
)
[docs]
class MolecularDynamics(zntrack.Node):
"""Run molecular dynamics simulation.
Parameters
----------
model : NodeWithCalculator
Node providing the calculator object for the simulation.
thermostat : LangevinConfig
Node providing the thermostat object for the simulation.
data : list[ase.Atoms]
Initial configurations for the simulation.
data_id : int, default=-1
Index of the initial configuration to use.
steps : int, default=100
Number of steps to run the simulation.
"""
model: NodeWithCalculator = zntrack.deps()
thermostat: NodeWithMolecularDynamics = zntrack.deps()
data: list[ase.Atoms] = zntrack.deps()
data_id: int = zntrack.params(-1)
steps: int = zntrack.params(100)
observers: list[DynamicsObserver] = zntrack.deps(None)
modifiers: list[DynamicsModifier] = zntrack.deps(None)
observer_metrics: dict = zntrack.metrics()
plots: pd.DataFrame = zntrack.plots(y=["energy", "fmax"], autosave=True)
frames_path: pathlib.Path = zntrack.outs_path(zntrack.nwd / "frames.xyz")
def run(self):
if self.observers is None:
self.observers = []
if self.modifiers is None:
self.modifiers = []
atoms = self.data[self.data_id]
atoms.calc = self.model.get_calculator()
dyn = self.thermostat.get_molecular_dynamics(atoms)
for obs in self.observers:
obs.initialize(atoms)
self.observer_metrics = {}
self.plots = pd.DataFrame(columns=["energy", "fmax", "fnorm"])
for idx, _ in enumerate(
tqdm.tqdm(dyn.irun(steps=self.steps), total=self.steps)
):
ase.io.write(self.frames_path, atoms, append=True)
plots = {
"energy": atoms.get_potential_energy(),
"fmax": np.max(np.linalg.norm(atoms.get_forces(), axis=1)),
"fnorm": np.linalg.norm(atoms.get_forces()),
}
self.plots.loc[len(self.plots)] = plots
for obs in self.observers:
if obs.check(atoms):
self.observer_metrics[obs.name] = idx
if len(self.observer_metrics) > 0:
break
for mod in self.modifiers:
mod.modify(dyn, idx)
for obs in self.observers:
# document all attached observers
self.observer_metrics[obs.name] = self.observer_metrics.get(obs.name, -1)
@property
def frames(self) -> list[ase.Atoms]:
with self.state.fs.open(self.frames_path, "r") as f:
return list(ase.io.iread(f, format="extxyz"))
@property
def figures(self) -> dict[str, go.Figure]:
plots = {}
for key in self.plots.columns:
fig = px.line(
self.plots,
x=self.plots.index,
y=key,
title=key,
)
fig.update_traces(
customdata=np.stack([np.arange(len(self.plots))], axis=1),
)
plots[key] = fig
return plots
@staticmethod
def compare(*nodes: "MolecularDynamics") -> ComparisonResults:
frames = sum([node.frames for node in nodes], [])
offset = 0
fig = go.Figure()
for _, node in enumerate(nodes):
energies = [atoms.get_potential_energy() for atoms in node.frames]
fig.add_trace(
go.Scatter(
x=list(range(len(energies))),
y=energies,
mode="lines+markers",
name=node.name.replace(f"_{node.__class__.__name__}", ""),
customdata=np.stack([np.arange(len(energies)) + offset], axis=1),
)
)
offset += len(energies)
fig.update_layout(
title="Energy vs. step",
xaxis_title="Step",
yaxis_title="Energy / eV",
)
fig.update_layout(
plot_bgcolor="rgba(0, 0, 0, 0)",
paper_bgcolor="rgba(0, 0, 0, 0)",
)
fig.update_xaxes(
showgrid=True,
gridwidth=1,
gridcolor="rgba(120, 120, 120, 0.3)",
zeroline=False,
)
fig.update_yaxes(
showgrid=True,
gridwidth=1,
gridcolor="rgba(120, 120, 120, 0.3)",
zeroline=False,
)
# Now we set the first energy to zero for better compareability.
offset = 0
fig_adjusted = go.Figure()
for _, node in enumerate(nodes):
energies = np.array([atoms.get_potential_energy() for atoms in node.frames])
energies -= energies[0]
fig_adjusted.add_trace(
go.Scatter(
x=list(range(len(energies))),
y=energies,
mode="lines+markers",
name=node.name.replace(f"_{node.__class__.__name__}", ""),
customdata=np.stack([np.arange(len(energies)) + offset], axis=1),
)
)
offset += len(energies)
fig_adjusted.update_layout(
title="Adjusted energy vs. step",
xaxis_title="Step",
yaxis_title="Adjusted energy / eV",
)
fig_adjusted.update_layout(
plot_bgcolor="rgba(0, 0, 0, 0)",
paper_bgcolor="rgba(0, 0, 0, 0)",
)
fig_adjusted.update_xaxes(
showgrid=True,
gridwidth=1,
gridcolor="rgba(120, 120, 120, 0.3)",
zeroline=False,
)
fig_adjusted.update_yaxes(
showgrid=True,
gridwidth=1,
gridcolor="rgba(120, 120, 120, 0.3)",
zeroline=False,
)
return ComparisonResults(
frames=frames,
figures={"energy_vs_steps": fig, "energy_vs_steps_adjusted": fig_adjusted},
)