|
1 |
| -"""Centralize data-loading and computing metrics for plotting scripts""" |
2 |
| - |
3 | 1 | from __future__ import annotations
|
4 | 2 |
|
5 | 3 | from collections.abc import Sequence
|
|
8 | 6 | import pandas as pd
|
9 | 7 | from sklearn.metrics import r2_score
|
10 | 8 |
|
11 |
| -from matbench_discovery.data import load_df_wbm_preds |
| 9 | +"""Functions to classify energy above convex hull predictions as true/false |
| 10 | +positive/negative and compute performance metrics. |
| 11 | +""" |
| 12 | + |
| 13 | +__author__ = "Janosh Riebesell" |
| 14 | +__date__ = "2023-02-01" |
12 | 15 |
|
13 | 16 |
|
14 | 17 | def classify_stable(
|
@@ -98,30 +101,3 @@ def stable_metrics(
|
98 | 101 | RMSE=((true - pred) ** 2).mean() ** 0.5,
|
99 | 102 | R2=r2_score(true, pred),
|
100 | 103 | )
|
101 |
| - |
102 |
| - |
103 |
| -models = sorted( |
104 |
| - "Wrenformer, CGCNN, Voronoi Random Forest, MEGNet, M3GNet + MEGNet, " |
105 |
| - "BOWSR + MEGNet".split(", ") |
106 |
| -) |
107 |
| -e_form_col = "e_form_per_atom_mp2020_corrected" |
108 |
| -each_true_col = "e_above_hull_mp2020_corrected_ppd_mp" |
109 |
| -each_pred_col = "e_above_hull_pred" |
110 |
| - |
111 |
| -df_wbm = load_df_wbm_preds(models).round(3) |
112 |
| - |
113 |
| -for col in [e_form_col, each_true_col]: |
114 |
| - assert col in df_wbm, f"{col=} not in {list(df_wbm)=}" |
115 |
| - |
116 |
| - |
117 |
| -df_metrics = pd.DataFrame() |
118 |
| -for model in models: |
119 |
| - df_metrics[model] = stable_metrics( |
120 |
| - df_wbm[each_true_col], |
121 |
| - df_wbm[each_true_col] + df_wbm[e_form_col] - df_wbm[model], |
122 |
| - ) |
123 |
| - |
124 |
| -assert df_metrics.T.MAE.between(0, 0.2).all(), "MAE not in range" |
125 |
| -assert df_metrics.T.R2.between(0.1, 1).all(), "R2 not in range" |
126 |
| -assert df_metrics.T.RMSE.between(0, 0.25).all(), "RMSE not in range" |
127 |
| -assert df_metrics.isna().sum().sum() == 0, "NaNs in metrics" |
0 commit comments