[docs]classCompareCalculatorResults(zntrack.Node):""" CompareCalculatorResults is a node that compares the results of two calculators. It calculates the RMSE between the two calculators and adjusts plots accordingly. It calculates the error between the two calculators and saves the min/max values. Parameters ---------- data : EvaluateCalculatorResults The results of the first calculator. reference : EvaluateCalculatorResults The results of the second calculator. The results of the first calculator will be compared to these results. """data:EvaluateCalculatorResults=zntrack.deps()reference:EvaluateCalculatorResults=zntrack.deps()plots:pd.DataFrame=zntrack.plots(autosave=True)rmse:dict=zntrack.metrics()error:dict=zntrack.metrics()defrun(self):e_rmse=rmse(self.data.plots["energy"],self.reference.plots["energy"])self.rmse={"energy":e_rmse,"energy_per_atom":e_rmse/len(self.data.plots),"fmax":rmse(self.data.plots["fmax"],self.reference.plots["fmax"]),"fnorm":rmse(self.data.plots["fnorm"],self.reference.plots["fnorm"]),}all_plots=[]forrow_idxintqdm.trange(len(self.data.plots)):plots={}plots["adjusted_energy_error"]=(self.data.plots["energy"].iloc[row_idx]-e_rmse)-self.reference.plots["energy"].iloc[row_idx]plots["adjusted_energy"]=self.data.plots["energy"].iloc[row_idx]-e_rmseplots["adjusted_energy_error_per_atom"]=(plots["adjusted_energy_error"]/self.data.plots["n_atoms"].iloc[row_idx])plots["fmax_error"]=(self.data.plots["fmax"].iloc[row_idx]-self.reference.plots["fmax"].iloc[row_idx])plots["fnorm_error"]=(self.data.plots["fnorm"].iloc[row_idx]-self.reference.plots["fnorm"].iloc[row_idx])all_plots.append(plots)self.plots=pd.DataFrame(all_plots)# iterate over plots and save min/maxself.error={}forkeyinself.plots.columns:if"_error"inkey:stripped_key=key.replace("_error","")self.error[f"{stripped_key}_max"]=self.plots[key].max()self.error[f"{stripped_key}_min"]=self.plots[key].min()@propertydefframes(self)->FRAMES:returnself.data.frames@propertydeffigures(self)->FIGURES:figures={}forkeyinself.plots.columns:yaxis_title=key.replace("_error","")if"energy_per_atom"inyaxis_title:yaxis_title+=" / eV/atom"elif"energy"inyaxis_title:yaxis_title+=" / eV"elif("fmax"inyaxis_titleor"fnorm"inyaxis_titleor"force"inyaxis_title):yaxis_title+=" / eV/Å"else:yaxis_title+=""figures[key]=get_figure(key,[self],yaxis_title=yaxis_title)returnfiguresdefcompare(self,*nodes:"CompareCalculatorResults")->ComparisonResults:# noqa C901iflen(nodes)==0:raiseValueError("No nodes to compare provided")figures={}frames_info={}forkeyinnodes[0].plots.columns:ifnotall(keyinnode.plots.columnsfornodeinnodes):raiseValueError(f"Key {key} not found in all nodes")# check frames are the sameyaxis_title=key.replace("_error","")if"energy_per_atom"inyaxis_title:yaxis_title+=" / eV/atom"elif"energy"inyaxis_title:yaxis_title+=" / eV"elif("fmax"inyaxis_titleor"fnorm"inyaxis_titleor"force"inyaxis_title):yaxis_title+=" / eV/Å"else:yaxis_title+=""figures[key]=get_figure(key,nodes,yaxis_title=yaxis_title)fornodeinnodes:forkeyinnode.plots.columns:frames_info[f"{node.name}_{key}"]=node.plots[key].values# TODO: calculate the rmse difference between a fixed# one and all the others and shift them accordingly# and plot as energy_shifted# plot error between curves# mlipx pass additional flags to compare function# have different compare methods also for correlation plotsframes=[shallow_copy_atoms(x)forxinnodes[0].frames]forkey,valuesinframes_info.items():foratoms,valueinzip(frames,values):atoms.info[key]=valuefornodeinnodes:fornode_atoms,atomsinzip(node.frames,frames):iflen(node_atoms)!=len(atoms):raiseValueError("Atoms objects have different lengths")withcontextlib.suppress(RuntimeError,PropertyNotImplementedError):atoms.info[f"{node.name}_energy"]=(node_atoms.get_potential_energy())atoms.arrays[f"{node.name}_forces"]=node_atoms.get_forces()forref_atoms,atomsinzip(nodes[0].reference.frames,frames):withcontextlib.suppress(RuntimeError,PropertyNotImplementedError):atoms.info["ref_energy"]=ref_atoms.get_potential_energy()atoms.arrays["ref_forces"]=ref_atoms.get_forces()return{"frames":frames,"figures":figures,}