# skip linting for this file
import itertools
import os
import typing as t
import warnings
import ase.io
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import zntrack
from ase.optimize import BFGS
from mp_api.client import MPRester
from pymatgen.analysis.phase_diagram import PhaseDiagram as pmg_PhaseDiagram
from pymatgen.analysis.pourbaix_diagram import PourbaixDiagram as pmg_PourbaixDiagram
from pymatgen.analysis.pourbaix_diagram import (
PourbaixEntry,
PourbaixPlotter,
)
from pymatgen.core import Element
from pymatgen.core.ion import Ion
from pymatgen.entries.compatibility import (
MaterialsProject2020Compatibility,
MaterialsProjectAqueousCompatibility,
)
from pymatgen.entries.computed_entries import (
ComputedEntry,
GibbsComputedStructureEntry,
)
from mlipx.abc import ComparisonResults, NodeWithCalculator
def create_pourbaix_plot(
self,
limits=None,
title="Pourbaix Diagram",
label_domains=True,
label_fontsize=12,
show_water_lines=True,
show_neutral_axes=True,
) -> go.Figure:
PREFAC = 0.0591 # Prefactor for water stability lines
# Set default limits if not provided
if limits is None:
limits = [[-2, 16], [-3, 3]]
xlim, ylim = limits
# Initialize Plotly figure
fig = go.Figure()
# Add water stability lines
if show_water_lines:
h_line_x = np.linspace(xlim[0], xlim[1], 100)
h_line_y = -h_line_x * PREFAC
o_line_y = -h_line_x * PREFAC + 1.23
fig.add_trace(
go.Scatter(
x=h_line_x,
y=h_line_y,
mode="lines",
line={"color": "red", "dash": "dash"},
name="H2O Reduction",
)
)
fig.add_trace(
go.Scatter(
x=h_line_x,
y=o_line_y,
mode="lines",
line={"color": "red", "dash": "dash"},
name="H2O Oxidation",
)
)
# Add neutral axes
if show_neutral_axes:
fig.add_trace(
go.Scatter(
x=[7, 7],
y=ylim,
mode="lines",
line={"color": "grey", "dash": "dot"},
name="Neutral Axis",
)
)
fig.add_trace(
go.Scatter(
x=xlim,
y=[0, 0],
mode="lines",
line={"color": "grey", "dash": "dot"},
name="V=0 Line",
)
)
# Add stable domain polygons
for entry, vertices in self._pbx._stable_domain_vertices.items():
# Close the polygon by repeating the first vertex
vertices = np.vstack([vertices, vertices[0]])
x, y = vertices[:, 0], vertices[:, 1]
center = np.mean(vertices, axis=0)
# Add the domain polygon
fig.add_trace(
go.Scatter(
x=x,
y=y,
mode="lines",
line={"color": "grey"},
name=f"Domain {entry.name}",
)
)
# Optionally add labels to domains
if label_domains:
fig.add_trace(
go.Scatter(
x=[center[0]],
y=[center[1]],
mode="text",
text=[entry.to_pretty_string()],
textfont={"size": label_fontsize, "color": "blue"},
name="Domain Label",
)
)
# Update layout for the figure
fig.update_layout(
title={
"text": title,
"font": {"size": 20, "family": "Arial", "weight": "bold"},
},
xaxis={"title": "pH", "range": xlim},
yaxis={"title": "E (V)", "range": ylim},
plot_bgcolor="rgba(0, 0, 0, 0)",
paper_bgcolor="rgba(0, 0, 0, 0)",
)
fig.update_xaxes(showgrid=False, zeroline=False)
fig.update_yaxes(showgrid=False, zeroline=False)
return fig
[docs]
class PourbaixDiagram(zntrack.Node):
"""Compute the Pourbaix diagram for a given set of structures.
Parameters
----------
data : list[ase.Atoms]
List of structures to evaluate.
model : NodeWithCalculator
Node providing the calculator object for the energy calculations.
pH : float
pH where the Pourbaix stability is evaluated ,
V : float
Electrode potential where the Pourbaix stability is evaluated.
use_gibbs : bool, default=False
Set to 300 (for 300 Kelvin) to use a machine learning model to
estimate solid free energy from DFT energy (see
GibbsComputedStructureEntry). This can slightly improve the accuracy
of the Pourbaix diagram in some cases. Default: None. Note that
temperatures other than 300K are not permitted here, because
MaterialsProjectAqueousCompatibility corrections, used in Pourbaix
diagram construction, are calculated based on 300 K data.
data_ids : list[int], default=None
Index of the structure to evaluate.
geo_opt: bool, default=False
Whether to perform geometry optimization before calculating the
Pourbaix decomposition energy of each structure.
fmax: float, default=0.05
The maximum force stopping rule for geometry optimizations.
Attributes
----------
results : pd.DataFrame
DataFrame with the data_id, potential energy and Pourbaix
decomposition energy.
plots : dict[str, go.Figure]
Dictionary with the phase diagram (and Pourbaix decomposition
energy plot).
"""
model: NodeWithCalculator = zntrack.deps()
data: list[ase.Atoms] = zntrack.deps()
pH: float = zntrack.params()
V: float = zntrack.params()
use_gibbs: bool = zntrack.params(False)
data_ids: list[int] = zntrack.params(None)
geo_opt: bool = zntrack.params(False)
fmax: float = zntrack.params(0.05)
frames_path: str = zntrack.outs_path(zntrack.nwd / "frames.xyz")
results: pd.DataFrame = zntrack.plots(
x="data_id", y="pourbaix_decomposition_energy"
)
pourbaix_diagram: t.Any = zntrack.outs()
def run(self): # noqa: C901
if self.data_ids is None:
atoms_list = self.data
else:
atoms_list = [self.data[i] for i in self.data_id]
if self.model is not None:
calc = self.model.get_calculator()
try:
api_key = os.environ["MP_API_KEY"]
except KeyError:
raise KeyError("Please set the environment variable `MP_API_KEY`.")
mpr = MPRester(api_key)
U_metal_set = {"Co", "Cr", "Fe", "Mn", "Mo", "Ni", "V", "W"}
U_settings = {
"Co": 3.32,
"Cr": 3.7,
"Fe": 5.3,
"Mn": 3.9,
"Mo": 4.38,
"Ni": 6.2,
"V": 3.25,
"W": 6.2,
}
solid_compat = MaterialsProject2020Compatibility()
chemsys = set(
itertools.chain.from_iterable(atoms.symbols for atoms in atoms_list)
)
# capitalize and sort the elements
chemsys = sorted(e.capitalize() for e in chemsys)
if isinstance(chemsys, str):
chemsys = chemsys.split("-")
# download the ion reference data from MPContribs
ion_data = mpr.get_ion_reference_data_for_chemsys(chemsys)
# build the PhaseDiagram for get_ion_entries
ion_ref_comps = [
Ion.from_formula(d["data"]["RefSolid"]).composition for d in ion_data
]
ion_ref_elts = set(
itertools.chain.from_iterable(i.elements for i in ion_ref_comps)
)
# TODO - would be great if the commented line below would work
# However for some reason you cannot process GibbsComputedStructureEntry with
# MaterialsProjectAqueousCompatibility
ion_ref_entries = mpr.get_entries_in_chemsys(
list([str(e) for e in ion_ref_elts] + ["O", "H"]), use_gibbs=self.use_gibbs
)
epots, new_ion_ref_entries, metal_comp_dicts, metallic_ids = [], [], [], []
for i, atoms in enumerate(atoms_list):
metals = [s for s in set(atoms.symbols) if s not in ["O", "H"]]
hubbards = {}
if set(metals) & U_metal_set:
run_type = "GGA+U"
is_hubbard = True
for m in metals:
hubbards[m] = U_settings.get(m, 0)
else:
run_type = "GGA"
is_hubbard = False
if self.model is not None:
atoms.calc = calc
if self.geo_opt:
dyn = BFGS(atoms)
dyn.run(fmax=self.fmax)
epot = atoms.get_potential_energy()
ase.io.write(self.frames_path, atoms, append=True)
epots.append(epot)
amt_dict = {
m: len([a for a in atoms if a.symbol == m]) for m in set(atoms.symbols)
}
n_metals = len([a for a in atoms if a.symbol not in ["O", "H"]])
if n_metals > 0:
metal_comp_dict = {m: amt_dict[m] / n_metals for m in metals}
metallic_ids.append(i)
metal_comp_dicts.append(metal_comp_dict)
entry = ComputedEntry(
composition=amt_dict,
energy=epot,
parameters={
"run_type": run_type,
"software": "N/A",
"oxide_type": "oxide",
"is_hubbard": is_hubbard,
"hubbards": hubbards,
},
)
new_ion_ref_entries.append(entry)
ion_ref_entries = new_ion_ref_entries + ion_ref_entries
# suppress the warning about supplying the required energies;
# they will be calculated from the
# entries we get from MPRester
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="You did not provide the required O2 and H2O energies.",
)
compat = MaterialsProjectAqueousCompatibility(solid_compat=solid_compat)
# suppress the warning about missing oxidation states
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", message="Failed to guess oxidation states.*"
)
ion_ref_entries = compat.process_entries(ion_ref_entries) # type: ignore
# TODO - if the commented line above would work, this conditional block
# could be removed
if self.use_gibbs:
# replace the entries with GibbsComputedStructureEntry
ion_ref_entries = GibbsComputedStructureEntry.from_entries(
ion_ref_entries, temp=self.use_gibbs
)
ion_ref_pd = pmg_PhaseDiagram(ion_ref_entries)
ion_entries = mpr.get_ion_entries(ion_ref_pd, ion_ref_data=ion_data)
pbx_entries = [PourbaixEntry(e, f"ion-{n}") for n, e in enumerate(ion_entries)]
# Construct the solid pourbaix entries from filtered ion_ref entries
extra_elts = (
set(ion_ref_elts)
- {Element(s) for s in chemsys}
- {Element("H"), Element("O")}
)
new_pbx_entries = []
for entry in ion_ref_entries:
entry_elts = set(entry.composition.elements)
# Ensure no OH chemsys or extraneous elements from ion references
if not (
entry_elts <= {Element("H"), Element("O")}
or extra_elts.intersection(entry_elts)
):
# Create new computed entry
eform = ion_ref_pd.get_form_energy(entry) # type: ignore
new_entry = ComputedEntry(
entry.composition, eform, entry_id=entry.entry_id
)
pbx_entry = PourbaixEntry(new_entry)
new_pbx_entries.append(pbx_entry)
pbx_entries = new_pbx_entries + pbx_entries
row_dicts = []
epbx_min = 10000.0
for i, atoms in enumerate(atoms_list):
if self.data_ids is None:
data_id = i
else:
data_id = self.data_id[i]
if i in metallic_ids:
idx = metallic_ids.index(i)
entry = pbx_entries[idx]
pbx_dia = pmg_PourbaixDiagram(
pbx_entries, comp_dict=metal_comp_dicts[idx]
)
epbx = pbx_dia.get_decomposition_energy(entry, pH=self.pH, V=self.V)
if epbx < epbx_min:
self.pourbaix_diagram = pbx_dia
epbx_min = epbx
else:
epbx = 0.0
row_dicts.append(
{
"data_id": data_id,
"potential_energy": epots[i],
"pourbaix_decomposition_energy": epbx,
},
)
self.results = pd.DataFrame(row_dicts)
@property
def figures(self) -> dict[str, go.Figure]:
# Create the Pourbaix diagram plot using Matplotlib
plotter = PourbaixPlotter(self.pourbaix_diagram)
return {
"pourbaix-diagram": create_pourbaix_plot(plotter),
"pourbaix-decomposition-energy-plot": px.line(
self.results, x="data_id", y="pourbaix_decomposition_energy"
),
}
@staticmethod
def compare(*nodes: "PourbaixDiagram") -> ComparisonResults:
figures = {}
for node in nodes:
# Extract a unique identifier for the node
node_identifier = node.name.replace(f"_{node.__class__.__name__}", "")
# Update and store the figures directly
for key, fig in node.figures.items():
fig.update_layout(title=node_identifier)
figures[f"{node_identifier}-{key}"] = fig
return {
"frames": nodes[0].frames,
"figures": figures,
}
@property
def frames(self) -> list[ase.Atoms]:
with self.state.fs.open(self.frames_path, "r") as f:
return list(ase.io.iread(f, format="extxyz"))