Source code for mlipx.nodes.phase_diagram

# skip linting for this file

import itertools
import os
import typing as t

import ase.io
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import zntrack
from ase.optimize import BFGS
from mp_api.client import MPRester
from pymatgen.analysis.phase_diagram import PDPlotter
from pymatgen.analysis.phase_diagram import PhaseDiagram as pmg_PhaseDiagram
from pymatgen.entries.compatibility import (
    MaterialsProject2020Compatibility,
)
from pymatgen.entries.computed_entries import (
    ComputedEntry,
)

from mlipx.abc import ComparisonResults, NodeWithCalculator


[docs] class PhaseDiagram(zntrack.Node): """Compute the phase diagram for a given set of structures. Parameters ---------- data : list[ase.Atoms] List of structures to evaluate. model : NodeWithCalculator Node providing the calculator object for the energy calculations. chemsys: list[str], defaeult=None The set of chemical symbols to construct phase diagram. data_ids : list[int], default=None Index of the structure to evaluate. geo_opt: bool, default=False Whether to perform geometry optimization before calculating the formation energy of each structure. fmax: float, default=0.05 The maximum force stopping rule for geometry optimizations. Attributes ---------- results : pd.DataFrame DataFrame with the data_id, potential energy and formation energy. plots : dict[str, go.Figure] Dictionary with the phase diagram (and formation energy plot). """ model: NodeWithCalculator = zntrack.deps() data: list[ase.Atoms] = zntrack.deps() chemsys: list[str] = zntrack.params(None) data_ids: list[int] = zntrack.params(None) geo_opt: bool = zntrack.params(False) fmax: float = zntrack.params(0.05) frames_path: str = zntrack.outs_path(zntrack.nwd / "frames.xyz") results: pd.DataFrame = zntrack.plots(x="data_id", y="formation_energy") phase_diagram: t.Any = zntrack.outs() def run(self): # noqa C901 if self.data_ids is None: atoms_list = self.data else: atoms_list = [self.data[i] for i in self.data_id] if self.model is not None: calc = self.model.get_calculator() U_metal_set = {"Co", "Cr", "Fe", "Mn", "Mo", "Ni", "V", "W"} U_settings = { "Co": 3.32, "Cr": 3.7, "Fe": 5.3, "Mn": 3.9, "Mo": 4.38, "Ni": 6.2, "V": 3.25, "W": 6.2, } try: api_key = os.environ["MP_API_KEY"] except KeyError: api_key = None entries, epots = [], [] for atoms in atoms_list: metals = [s for s in set(atoms.symbols) if s not in ["O", "H"]] hubbards = {} if set(metals) & U_metal_set: run_type = "GGA+U" is_hubbard = True for m in metals: hubbards[m] = U_settings.get(m, 0) else: run_type = "GGA" is_hubbard = False if self.model is not None: atoms.calc = calc if self.geo_opt: dyn = BFGS(atoms) dyn.run(fmax=self.fmax) epot = atoms.get_potential_energy() ase.io.write(self.frames_path, atoms, append=True) epots.append(epot) amt_dict = { m: len([a for a in atoms if a.symbol == m]) for m in set(atoms.symbols) } entry = ComputedEntry( composition=amt_dict, energy=epot, parameters={ "run_type": run_type, "software": "N/A", "oxide_type": "oxide", "is_hubbard": is_hubbard, "hubbards": hubbards, }, ) entries.append(entry) compat = MaterialsProject2020Compatibility() computed_entries = compat.process_entries(entries) if api_key is None: mp_entries = [] else: mpr = MPRester(api_key) if self.chemsys is None: chemsys = set( itertools.chain.from_iterable(atoms.symbols for atoms in atoms_list) ) else: chemsys = self.chemsys mp_entries = mpr.get_entries_in_chemsys(chemsys) all_entries = computed_entries + mp_entries self.phase_diagram = pmg_PhaseDiagram(all_entries) row_dicts = [] for i, entry in enumerate(computed_entries): if self.data_ids is None: data_id = i else: data_id = self.data_id[i] eform = self.phase_diagram.get_form_energy_per_atom(entry) row_dicts.append( { "data_id": data_id, "potential_energy": epots[i], "formation_energy": eform, }, ) self.results = pd.DataFrame(row_dicts) @property def figures(self) -> dict[str, go.Figure]: plotter = PDPlotter(self.phase_diagram) fig1 = plotter.get_plot() fig2 = px.line(self.results, x="data_id", y="formation_energy") fig2.update_layout(title="Formation Energy Plot") pd_df = pd.DataFrame( [len(self.phase_diagram.stable_entries)], columns=["Stable_phases"] ) fig3 = px.bar(pd_df, y="Stable_phases") return { "phase-diagram": fig1, "formation-energy-plot": fig2, "stable_phases": fig3, } @staticmethod def compare(*nodes: "PhaseDiagram") -> ComparisonResults: figures = {} for node in nodes: # Extract a unique identifier for the node node_identifier = node.name.replace(f"_{node.__class__.__name__}", "") # Update and store the figures directly for key, fig in node.figures.items(): fig.update_layout( title=node_identifier, plot_bgcolor="rgba(0, 0, 0, 0)", paper_bgcolor="rgba(0, 0, 0, 0)", ) fig.update_xaxes( showgrid=True, gridwidth=1, gridcolor="rgba(120, 120, 120, 0.3)", zeroline=False, ) fig.update_yaxes( showgrid=True, gridwidth=1, gridcolor="rgba(120, 120, 120, 0.3)", zeroline=False, ) figures[f"{node_identifier}-{key}"] = fig return { "frames": nodes[0].frames, "figures": figures, } @property def frames(self) -> list[ase.Atoms]: with self.state.fs.open(self.frames_path, "r") as f: return list(ase.io.iread(f, format="extxyz"))