Source code for mlipx.nodes.invariances

import pathlib

import ase
import ase.io
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import tqdm
import zntrack

from mlipx.abc import ComparisonResults, NodeWithCalculator


class InvarianceNode(zntrack.Node):
    """Base class for testing invariances."""

    model: NodeWithCalculator = zntrack.deps()
    data: list[ase.Atoms] = zntrack.deps()
    data_id: int = zntrack.params(-1)
    n_points: int = zntrack.params(50)

    metrics: dict = zntrack.metrics()
    plots: pd.DataFrame = zntrack.plots()

    frames_path: pathlib.Path = zntrack.outs_path(zntrack.nwd / "frames.xyz")

    def run(self):
        """Common logic for invariance testing."""
        atoms = self.data[self.data_id]
        calc = self.model.get_calculator()
        atoms.calc = calc

        rng = np.random.default_rng()
        energies = []
        for _ in tqdm.tqdm(range(self.n_points)):
            self.apply_transformation(atoms, rng)
            energies.append(atoms.get_potential_energy())
            ase.io.write(self.frames_path, atoms, append=True)

        self.plots = pd.DataFrame(energies, columns=["energy"])

        self.metrics = {
            "mean": float(np.mean(energies)),
            "std": float(np.std(energies)),
        }

    @property
    def frames(self):
        with self.state.fs.open(self.frames_path, "r") as f:
            return list(ase.io.iread(f, ":"))

    def apply_transformation(self, atoms_copy: ase.Atoms, rng: np.random.Generator):
        """To be implemented by child classes."""
        raise NotImplementedError("Subclasses must implement apply_transformation()")

    @staticmethod
    def compare(*nodes: "InvarianceNode") -> ComparisonResults:
        if len(nodes) == 0:
            raise ValueError("No nodes to compare")

        fig = go.Figure()
        for node in nodes:
            fig.add_trace(
                go.Scatter(
                    x=np.arange(node.n_points),
                    y=node.plots["energy"] - node.metrics["mean"],
                    mode="markers",
                    name=node.name.replace(f"_{node.__class__.__name__}", ""),
                )
            )

        fig.update_layout(
            title=f"Energy vs step ({nodes[0].__class__.__name__})",
            xaxis_title="Steps",
            yaxis_title="Adjusted energy / eV",
            plot_bgcolor="rgba(0, 0, 0, 0)",
            paper_bgcolor="rgba(0, 0, 0, 0)",
        )
        fig.update_xaxes(
            showgrid=True,
            gridwidth=1,
            gridcolor="rgba(120, 120, 120, 0.3)",
            zeroline=False,
        )
        fig.update_yaxes(
            showgrid=True,
            gridwidth=1,
            gridcolor="rgba(120, 120, 120, 0.3)",
            zeroline=False,
        )

        return ComparisonResults(
            frames=nodes[0].frames, figures={"energy_vs_steps_adjusted": fig}
        )


[docs] class TranslationalInvariance(InvarianceNode): """Test translational invariance by random box translocation.""" def apply_transformation(self, atoms_copy: ase.Atoms, rng: np.random.Generator): translation = rng.uniform(-1, 1, 3) atoms_copy.positions += translation
[docs] class RotationalInvariance(InvarianceNode): """Test rotational invariance by random rotation of the box.""" def apply_transformation(self, atoms_copy: ase.Atoms, rng: np.random.Generator): angle = rng.uniform(-90, 90) direction = rng.choice(["x", "y", "z"]) atoms_copy.rotate(angle, direction, rotate_cell=any(atoms_copy.pbc))
[docs] class PermutationInvariance(InvarianceNode): """Test permutation invariance by permutation of atoms of the same species.""" def apply_transformation(self, atoms_copy: ase.Atoms, rng: np.random.Generator): species = np.unique(atoms_copy.get_chemical_symbols()) for s in species: indices = np.where(atoms_copy.get_chemical_symbols() == s)[0] rng.shuffle(indices) atoms_copy.positions[indices] = rng.permutation( atoms_copy.positions[indices], axis=0 )