|
1 | 1 | # %%
|
2 |
| -import numpy as np |
3 | 2 | import pandas as pd
|
4 | 3 | from pymatviz.utils import save_fig
|
| 4 | +from sklearn.metrics import auc, precision_recall_curve, roc_curve |
5 | 5 | from tqdm import tqdm
|
6 | 6 |
|
7 | 7 | from matbench_discovery import FIGS
|
8 |
| -from matbench_discovery.metrics import stable_metrics |
9 | 8 | from matbench_discovery.plots import pio
|
10 |
| -from matbench_discovery.preds import ( |
11 |
| - df_wbm, |
12 |
| - e_form_col, |
13 |
| - each_pred_col, |
14 |
| - each_true_col, |
15 |
| - models, |
16 |
| -) |
| 9 | +from matbench_discovery.preds import df_each_pred, df_wbm, each_true_col |
17 | 10 |
|
18 | 11 | __author__ = "Janosh Riebesell"
|
19 | 12 | __date__ = "2023-01-30"
|
|
34 | 27 | # %%
|
35 | 28 | df_roc = pd.DataFrame()
|
36 | 29 |
|
37 |
| -for model in (pbar := tqdm(models)): |
38 |
| - pbar.set_description(model) |
39 |
| - df_wbm[f"{model}_{each_pred_col}"] = df_wbm[each_true_col] + ( |
40 |
| - df_wbm[model] - df_wbm[e_form_col] |
41 |
| - ) |
42 |
| - for stab_treshold in np.arange(-0.4, 0.4, 0.01): |
43 |
| - metrics = stable_metrics( |
44 |
| - df_wbm[each_true_col], df_wbm[f"{model}_{each_pred_col}"], stab_treshold |
45 |
| - ) |
46 |
| - df_tmp = pd.DataFrame( |
47 |
| - {facet_col: model, color_col: stab_treshold, **metrics}, index=[0] |
48 |
| - ) |
49 |
| - df_roc = pd.concat([df_roc, df_tmp]) |
50 |
| - |
| 30 | +for model in (pbar := tqdm(list(df_each_pred), desc="Calculating ROC curves")): |
| 31 | + pbar.set_postfix_str(model) |
| 32 | + na_mask = df_wbm[each_true_col].isna() | df_each_pred[model].isna() |
| 33 | + y_true = (df_wbm[~na_mask][each_true_col] <= 0).astype(int) |
| 34 | + y_pred = df_each_pred[model][~na_mask] |
| 35 | + fpr, tpr, thresholds = roc_curve(y_true, y_pred, pos_label=0) |
| 36 | + AUC = auc(fpr, tpr) |
| 37 | + title = f"{model} · {AUC=:.2f}" |
| 38 | + df_tmp = pd.DataFrame( |
| 39 | + {"FPR": fpr, "TPR": tpr, color_col: thresholds, "AUC": AUC, facet_col: title} |
| 40 | + ).round(3) |
51 | 41 |
|
52 |
| -df_roc = df_roc.round(3) |
| 42 | + df_roc = pd.concat([df_roc, df_tmp]) |
53 | 43 |
|
54 | 44 |
|
55 | 45 | # %%
|
56 |
| -fig = df_roc.plot.scatter( |
57 |
| - x="FPR", |
58 |
| - y="TPR", |
59 |
| - facet_col=facet_col, |
60 |
| - facet_col_wrap=2, |
61 |
| - backend="plotly", |
62 |
| - height=800, |
63 |
| - color=color_col, |
64 |
| - range_x=(0, 1), |
65 |
| - range_y=(0, 1), |
| 46 | +fig = ( |
| 47 | + df_roc.iloc[:: len(df_roc) // 500 or 1] |
| 48 | + .sort_values(["AUC", "FPR"], ascending=False) |
| 49 | + .plot.scatter( |
| 50 | + x="FPR", |
| 51 | + y="TPR", |
| 52 | + facet_col=facet_col, |
| 53 | + facet_col_wrap=2, |
| 54 | + backend="plotly", |
| 55 | + height=150 * len(df_roc[facet_col].unique()), |
| 56 | + color=color_col, |
| 57 | + range_x=(0, 1), |
| 58 | + range_y=(0, 1), |
| 59 | + range_color=(-0.5, 0.5), |
| 60 | + hover_name=facet_col, |
| 61 | + hover_data={facet_col: False}, |
| 62 | + ) |
66 | 63 | )
|
67 | 64 |
|
68 | 65 | for anno in fig.layout.annotations:
|
69 |
| - anno.text = anno.text.split("=")[1] # remove Model= from subplot titles |
| 66 | + anno.text = anno.text.split("=", 1)[1] # remove Model= from subplot titles |
70 | 67 |
|
71 | 68 | fig.layout.coloraxis.colorbar.update(
|
72 |
| - x=1, y=1, xanchor="right", yanchor="top", thickness=14, len=0.27, title_side="right" |
| 69 | + x=1, y=1, xanchor="right", yanchor="top", thickness=14, len=0.2, title_side="right" |
73 | 70 | )
|
74 | 71 | fig.add_shape(type="line", x0=0, y0=0, x1=1, y1=1, line=line, row="all", col="all")
|
75 |
| -fig.add_annotation( |
76 |
| - text="No skill", x=0.5, y=0.5, showarrow=False, yshift=-10, textangle=-30 |
77 |
| -) |
| 72 | +fig.add_annotation(text="No skill", x=0.5, y=0.5, showarrow=False, yshift=-10) |
78 | 73 | # allow scrolling and zooming each subplot individually
|
79 | 74 | fig.update_xaxes(matches=None)
|
80 | 75 | fig.update_yaxes(matches=None)
|
|
86 | 81 |
|
87 | 82 |
|
88 | 83 | # %%
|
89 |
| -fig = df_roc.plot.scatter( |
| 84 | +df_prc = pd.DataFrame() |
| 85 | + |
| 86 | +for model in (pbar := tqdm(list(df_each_pred), desc="Calculating ROC curves")): |
| 87 | + pbar.set_postfix_str(model) |
| 88 | + na_mask = df_wbm[each_true_col].isna() | df_each_pred[model].isna() |
| 89 | + y_true = (df_wbm[~na_mask][each_true_col] <= 0).astype(int) |
| 90 | + y_pred = df_each_pred[model][~na_mask] |
| 91 | + prec, recall, thresholds = precision_recall_curve(y_true, y_pred, pos_label=0) |
| 92 | + df_tmp = pd.DataFrame( |
| 93 | + { |
| 94 | + "Precision": prec[:-1], |
| 95 | + "Recall": recall[:-1], |
| 96 | + color_col: thresholds, |
| 97 | + facet_col: model, |
| 98 | + } |
| 99 | + ).round(3) |
| 100 | + |
| 101 | + df_prc = pd.concat([df_prc, df_tmp]) |
| 102 | + |
| 103 | + |
| 104 | +# %% |
| 105 | +fig = df_prc.iloc[:: len(df_roc) // 500 or 1].plot.scatter( |
90 | 106 | x="Recall",
|
91 | 107 | y="Precision",
|
92 | 108 | facet_col=facet_col,
|
93 | 109 | facet_col_wrap=2,
|
94 | 110 | backend="plotly",
|
95 |
| - height=800, |
| 111 | + height=150 * len(df_roc[facet_col].unique()), |
96 | 112 | color=color_col,
|
97 | 113 | range_x=(0, 1),
|
98 |
| - range_y=(0, 1), |
| 114 | + range_y=(0.5, 1), |
| 115 | + range_color=(-0.5, 1), |
| 116 | + hover_name=facet_col, |
| 117 | + hover_data={facet_col: False}, |
99 | 118 | )
|
100 | 119 |
|
101 | 120 | for anno in fig.layout.annotations:
|
102 |
| - anno.text = anno.text.split("=")[1] # remove Model= from subplot titles |
| 121 | + anno.text = anno.text.split("=", 1)[1] # remove Model= from subplot titles |
103 | 122 |
|
104 | 123 | fig.layout.coloraxis.colorbar.update(
|
105 | 124 | x=0.5, y=1.1, thickness=14, len=0.4, orientation="h"
|
|
0 commit comments