|
| 1 | +"""Centralize data-loading and computing metrics for plotting scripts""" |
| 2 | + |
| 3 | +from collections.abc import Sequence |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +import pandas as pd |
| 7 | +from sklearn.metrics import r2_score |
| 8 | + |
| 9 | +from matbench_discovery.data import load_df_wbm_preds |
| 10 | + |
| 11 | + |
| 12 | +def classify_stable( |
| 13 | + e_above_hull_true: pd.Series, |
| 14 | + e_above_hull_pred: pd.Series, |
| 15 | + stability_threshold: float | None = 0, |
| 16 | +) -> tuple[pd.Series, pd.Series, pd.Series, pd.Series]: |
| 17 | + """Classify model stability predictions as true/false positive/negatives (usually |
| 18 | + w.r.t DFT-ground truth labels). All energies are assumed to be in eV/atom |
| 19 | + (but shouldn't really matter as long as they're consistent). |
| 20 | +
|
| 21 | + Args: |
| 22 | + e_above_hull_true (pd.Series): Ground truth energy above convex hull values. |
| 23 | + e_above_hull_pred (pd.Series): Model predicted energy above convex hull values. |
| 24 | + stability_threshold (float | None, optional): Maximum energy above convex hull for a |
| 25 | + material to still be considered stable. Usually 0, 0.05 or 0.1. Defaults to |
| 26 | + 0, meaning a material has to be directly on the hull to be called stable. |
| 27 | + Negative values mean a material has to pull the known hull down by that |
| 28 | + amount to count as stable. Few materials lie below the known hull, so only |
| 29 | + negative values very close to 0 make sense. |
| 30 | +
|
| 31 | + Returns: |
| 32 | + tuple[TP, FN, FP, TN]: Indices as pd.Series for true positives, |
| 33 | + false negatives, false positives and true negatives (in this order). |
| 34 | + """ |
| 35 | + actual_pos = e_above_hull_true <= (stability_threshold or 0) # guard against None |
| 36 | + actual_neg = e_above_hull_true > (stability_threshold or 0) |
| 37 | + model_pos = e_above_hull_pred <= (stability_threshold or 0) |
| 38 | + model_neg = e_above_hull_pred > (stability_threshold or 0) |
| 39 | + |
| 40 | + true_pos = actual_pos & model_pos |
| 41 | + false_neg = actual_pos & model_neg |
| 42 | + false_pos = actual_neg & model_pos |
| 43 | + true_neg = actual_neg & model_neg |
| 44 | + |
| 45 | + return true_pos, false_neg, false_pos, true_neg |
| 46 | + |
| 47 | + |
| 48 | +def stable_metrics( |
| 49 | + true: Sequence[float], pred: Sequence[float], stability_threshold: float = 0 |
| 50 | +) -> dict[str, float]: |
| 51 | + """ |
| 52 | + Get a dictionary of stability prediction metrics. Mostly binary classification |
| 53 | + metrics, but also MAE, RMSE and R2. |
| 54 | +
|
| 55 | + Args: |
| 56 | + true (list[float]): true energy values |
| 57 | + pred (list[float]): predicted energy values |
| 58 | + stability_threshold (float): Where to place stability threshold relative to |
| 59 | + convex hull in eV/atom, usually 0 or 0.1 eV. Defaults to 0. |
| 60 | +
|
| 61 | + Note: Could be replaced by sklearn.metrics.classification_report() which takes |
| 62 | + binary labels. I.e. classification_report(true > 0, pred > 0, output_dict=True) |
| 63 | + should give equivalent results. |
| 64 | +
|
| 65 | + Returns: |
| 66 | + dict[str, float]: dictionary of classification metrics with keys DAF, Precision, |
| 67 | + Recall, Accuracy, F1, TPR, FPR, TNR, FNR, MAE, RMSE, R2. |
| 68 | + """ |
| 69 | + true_pos, false_neg, false_pos, true_neg = classify_stable( |
| 70 | + true, pred, stability_threshold |
| 71 | + ) |
| 72 | + |
| 73 | + n_true_pos, n_false_pos, n_true_neg, n_false_neg = map( |
| 74 | + sum, (true_pos, false_pos, true_neg, false_neg) |
| 75 | + ) |
| 76 | + |
| 77 | + n_total_pos = n_true_pos + n_false_neg |
| 78 | + prevalence = n_total_pos / len(true) # null rate |
| 79 | + precision = n_true_pos / (n_true_pos + n_false_pos) |
| 80 | + recall = n_true_pos / n_total_pos |
| 81 | + |
| 82 | + is_nan = np.isnan(true) | np.isnan(pred) |
| 83 | + true, pred = np.array(true)[~is_nan], np.array(pred)[~is_nan] |
| 84 | + |
| 85 | + return dict( |
| 86 | + DAF=precision / prevalence, |
| 87 | + Precision=precision, |
| 88 | + Recall=recall, |
| 89 | + Accuracy=(n_true_pos + n_true_neg) / len(true), |
| 90 | + F1=2 * (precision * recall) / (precision + recall), |
| 91 | + TPR=n_true_pos / n_total_pos, |
| 92 | + FPR=n_false_pos / (n_true_neg + n_false_pos), |
| 93 | + TNR=n_true_neg / (n_true_neg + n_false_pos), |
| 94 | + FNR=n_false_neg / n_total_pos, |
| 95 | + MAE=np.abs(true - pred).mean(), |
| 96 | + RMSE=((true - pred) ** 2).mean() ** 0.5, |
| 97 | + R2=r2_score(true, pred), |
| 98 | + ) |
| 99 | + |
| 100 | + |
| 101 | +models = sorted( |
| 102 | + "Wrenformer, CGCNN, Voronoi Random Forest, MEGNet, M3GNet " |
| 103 | + "MEGNet, BOWSR MEGNet".split(", ") |
| 104 | +) |
| 105 | +e_form_col = "e_form_per_atom_mp2020_corrected" |
| 106 | +each_true_col = "e_above_hull_mp2020_corrected_ppd_mp" |
| 107 | +each_pred_col = "e_above_hull_pred" |
| 108 | + |
| 109 | +df_wbm = load_df_wbm_preds(models).round(3) |
| 110 | + |
| 111 | +for col in [e_form_col, each_true_col]: |
| 112 | + assert col in df_wbm, f"{col=} not in {list(df_wbm)=}" |
| 113 | + |
| 114 | + |
| 115 | +df_metrics = pd.DataFrame() |
| 116 | +for model in models: |
| 117 | + df_metrics[model] = stable_metrics( |
| 118 | + df_wbm[each_true_col], |
| 119 | + df_wbm[each_true_col] + df_wbm[e_form_col] - df_wbm[model], |
| 120 | + ) |
| 121 | + |
| 122 | +assert df_metrics.T.MAE.between(0, 0.2).all(), "MAE not in range" |
| 123 | +assert df_metrics.T.R2.between(0.1, 1).all(), "R2 not in range" |
| 124 | +assert df_metrics.T.RMSE.between(0, 0.25).all(), "RMSE not in range" |
| 125 | +assert df_metrics.isna().sum().sum() == 0, "NaNs in metrics" |
0 commit comments