|
17 | 17 | df_hull = pd.read_csv(f"{DATA_DIR}/wbm-e-above-mp-hull.csv").set_index("material_id")
|
18 | 18 |
|
19 | 19 | dfs: dict[str, pd.DataFrame] = {}
|
20 |
| -for model_name in ("Wren", "CGCNN", "Voronoi"): |
21 |
| - df = pd.read_csv( |
22 |
| - f"{DATA_DIR}/{model_name.lower()}-mp-initial-structures.csv" |
23 |
| - ).set_index("material_id") |
| 20 | +for model_name in ("wren", "cgcnn", "voronoi"): |
| 21 | + csv_path = f"{DATA_DIR}/{model_name}-mp-initial-structures.csv" |
| 22 | + df = pd.read_csv(csv_path).set_index("material_id") |
24 | 23 | dfs[model_name] = df
|
25 | 24 |
|
26 |
| -dfs["M3GNet"] = pd.read_json( |
| 25 | +dfs["m3gnet"] = pd.read_json( |
27 | 26 | f"{ROOT}/models/m3gnet/2022-08-16-m3gnet-wbm-IS2RE.json.gz"
|
28 | 27 | ).set_index("material_id")
|
29 | 28 |
|
30 |
| -dfs["Wrenformer"] = pd.read_csv( |
| 29 | +dfs["wrenformer"] = pd.read_csv( |
31 | 30 | f"{ROOT}/models/wrenformer/mp/"
|
32 | 31 | "2022-09-20-wrenformer-e_form-ensemble-1-preds-e_form_per_atom.csv"
|
33 | 32 | ).set_index("material_id")
|
34 | 33 |
|
35 |
| -dfs["BOWSR Megnet"] = pd.read_json( |
| 34 | +dfs["bowsr_megnet"] = pd.read_json( |
36 | 35 | f"{ROOT}/models/bowsr/2022-09-22-bowsr-wbm-megnet-IS2RE.json.gz"
|
37 | 36 | ).set_index("material_id")
|
38 | 37 |
|
|
69 | 68 | std_total = None
|
70 | 69 |
|
71 | 70 | try:
|
72 |
| - if model_name == "M3GNet": |
| 71 | + if model_name == "m3gnet": |
73 | 72 | model_preds = df.e_form_m3gnet
|
74 |
| - elif "Wrenformer" in model_name: |
| 73 | + elif "wrenformer" in model_name: |
75 | 74 | model_preds = df.e_form_per_atom_pred_ens
|
76 | 75 | elif len(pred_cols := df.filter(like="e_form_pred").columns) >= 1:
|
77 | 76 | # Voronoi+RF has single prediction column, Wren and CGCNN each have 10
|
78 | 77 | # other cases are unexpected
|
79 | 78 | assert len(pred_cols) in (1, 10), f"{model_name=} has {len(pred_cols)=}"
|
80 | 79 | model_preds = df[pred_cols].mean(axis=1)
|
81 |
| - elif "BOWSR" in model_name: |
| 80 | + elif "bowsr" in model_name: |
82 | 81 | model_preds = df.e_form_per_atom_bowsr
|
83 | 82 | else:
|
84 | 83 | raise ValueError(f"Unhandled {model_name = }")
|
|
107 | 106 | # keep this outside loop so all model names appear in legend
|
108 | 107 | ax.legend(frameon=False, loc="lower right")
|
109 | 108 |
|
| 109 | +img_name = f"{today}-precision-recall-vs-calc-count-{rare=}" |
| 110 | +ax.set(title=img_name.replace("-", "/", 2).replace("-", " ").title()) |
| 111 | + |
110 | 112 |
|
111 | 113 | # %%
|
112 |
| -img_path = f"{ROOT}/figures/{today}-precision-recall-vs-calc-count-{rare=}.pdf" |
113 |
| -ax.figure.savefig(img_path) |
| 114 | +ax.figure.savefig(f"{ROOT}/figures/{img_name}.pdf") |
0 commit comments