|
21 | 21 |
|
22 | 22 |
|
23 | 23 | # %%
|
| 24 | +model_stats: dict[str, dict[str, str | int | float]] = {} |
24 | 25 | models: dict[str, dict[str, Any]] = {
|
25 | 26 | "CGCNN": dict(
|
26 | 27 | n_runs=10,
|
|
57 | 58 | display_name={"$regex": "m3gnet-wbm-IS2RE"},
|
58 | 59 | ),
|
59 | 60 | ),
|
60 |
| - "BOWSR MEGNet": dict( |
| 61 | + "BOWSR + MEGNet": dict( |
61 | 62 | n_runs=500,
|
62 | 63 | filters=dict(
|
63 | 64 | created_at={"$gt": "2023-01-20", "$lt": "2023-01-22"},
|
|
66 | 67 | ),
|
67 | 68 | }
|
68 | 69 |
|
69 |
| -assert set(models) == set(PRED_FILENAMES), f"{set(models)=} != {set(PRED_FILENAMES)=}" |
70 |
| - |
71 |
| - |
72 |
| -model_stats: dict[str, dict[str, str | int | float]] = {} |
| 70 | +assert not ( |
| 71 | + unknown_models := set(models) - set(PRED_FILENAMES) |
| 72 | +), f"{unknown_models=} missing predictions file" |
73 | 73 |
|
74 | 74 |
|
75 | 75 | # %% calculate total model run times from wandb logs
|
76 | 76 | # NOTE these model run times are pretty meaningless since some models were run on GPU
|
77 |
| -# (Wrenformer and CGCNN), others on CPU. Also BOWSR MEGNet, M3GNet and MEGNet weren't |
| 77 | +# (Wrenformer and CGCNN), others on CPU. Also BOWSR + MEGNet, M3GNet and MEGNet weren't |
78 | 78 | # trained from scratch. Their run times only indicate the time needed to predict the
|
79 | 79 | # test set.
|
80 | 80 |
|
|
110 | 110 | title=f"Run time distribution for {model}", xlabel="Run time [h]", ylabel="Count"
|
111 | 111 | )
|
112 | 112 |
|
| 113 | +model_stats["M3GNet + MEGNet"] = model_stats["M3GNet"].copy() |
| 114 | +model_stats["M3GNet + MEGNet"][time_col] = ( |
| 115 | + model_stats["MEGNet"][time_col] + model_stats["M3GNet"][time_col] # type: ignore |
| 116 | +) |
| 117 | + |
113 | 118 | df_metrics = pd.DataFrame(model_stats).T
|
114 | 119 | df_metrics.index.name = "Model"
|
115 |
| -# on 2022-11-28: |
116 |
| -# run_times = {'Voronoi Random Forest': 739608, |
117 |
| -# 'Wrenformer': 208399, |
118 |
| -# 'MEGNet': 12396, |
119 |
| -# 'M3GNet': 301138, |
120 |
| -# 'BOWSR MEGNet': 9105237} |
121 | 120 |
|
122 | 121 |
|
123 | 122 | # %%
|
124 |
| -df_wbm = load_df_wbm_preds(list(models)) |
| 123 | +df_wbm = load_df_wbm_preds(list(model_stats)) |
125 | 124 | e_form_col = "e_form_per_atom_mp2020_corrected"
|
126 | 125 | each_true_col = "e_above_hull_mp2020_corrected_ppd_mp"
|
127 | 126 |
|
128 | 127 |
|
129 | 128 | # %%
|
130 |
| -for model in models: |
| 129 | +for model in model_stats: |
131 | 130 | each_pred = df_wbm[each_true_col] + df_wbm[model] - df_wbm[e_form_col]
|
132 | 131 |
|
133 | 132 | metrics = stable_metrics(df_wbm[each_true_col], each_pred)
|
|
165 | 164 | }
|
166 | 165 | df_styled.set_table_styles([dict(selector=sel, props=styles[sel]) for sel in styles])
|
167 | 166 |
|
168 |
| -html_path = f"{FIGS}/{today}-metrics-table.svelte" |
169 |
| -df_styled.to_html(html_path) |
| 167 | +# df_styled.to_html(f"{FIGS}/{today}-metrics-table.svelte") |
170 | 168 |
|
171 | 169 |
|
172 | 170 | # %% write model metrics to json for use by the website
|
173 |
| -df_metrics["missing_preds"] = df_wbm[list(models)].isna().sum() |
| 171 | +df_metrics["missing_preds"] = df_wbm[list(model_stats)].isna().sum() |
174 | 172 | df_metrics["missing_percent"] = [
|
175 | 173 | f"{x / len(df_wbm):.2%}" for x in df_metrics.missing_preds
|
176 | 174 | ]
|
|
0 commit comments