|
| 1 | +# %% |
| 2 | +from __future__ import annotations |
| 3 | + |
| 4 | +from typing import Any |
| 5 | + |
| 6 | +import pandas as pd |
| 7 | +import requests |
| 8 | +import wandb |
| 9 | +from sklearn.metrics import f1_score, r2_score |
| 10 | +from tqdm import tqdm |
| 11 | + |
| 12 | +from matbench_discovery import ROOT, today |
| 13 | +from matbench_discovery.load_preds import load_df_wbm_with_preds |
| 14 | + |
| 15 | +__author__ = "Janosh Riebesell" |
| 16 | +__date__ = "2022-11-28" |
| 17 | + |
| 18 | + |
| 19 | +# %% |
| 20 | +models: dict[str, dict[str, Any]] = { |
| 21 | + "Wren": dict(n_runs=0), |
| 22 | + "CGCNN": dict( |
| 23 | + n_runs=10, |
| 24 | + filters=dict( |
| 25 | + created_at={"$gt": "2022-11-21", "$lt": "2022-11-23"}, |
| 26 | + display_name={"$regex": "cgcnn-robust-formation_energy_per_atom"}, |
| 27 | + ), |
| 28 | + ), |
| 29 | + "Voronoi RF": dict( |
| 30 | + n_runs=70, |
| 31 | + filters=dict( |
| 32 | + created_at={"$gt": "2022-11-17", "$lt": "2022-11-28"}, |
| 33 | + display_name={"$regex": "voronoi-features"}, |
| 34 | + ), |
| 35 | + ), |
| 36 | + "Wrenformer": dict( |
| 37 | + n_runs=10, |
| 38 | + filters=dict( |
| 39 | + created_at={"$gt": "2022-11-14", "$lt": "2022-11-16"}, |
| 40 | + display_name={"$regex": "wrenformer-robust-mp-formation_energy"}, |
| 41 | + ), |
| 42 | + ), |
| 43 | + "MEGNet": dict( |
| 44 | + n_runs=1, |
| 45 | + filters=dict( |
| 46 | + created_at={"$gt": "2022-11-17", "$lt": "2022-11-19"}, |
| 47 | + display_name={"$regex": "megnet-wbm-IS2RE"}, |
| 48 | + ), |
| 49 | + ), |
| 50 | + "M3GNet": dict( |
| 51 | + n_runs=99, |
| 52 | + filters=dict( |
| 53 | + created_at={"$gt": "2022-10-31", "$lt": "2022-11-01"}, |
| 54 | + display_name={"$regex": "m3gnet-wbm-IS2RE"}, |
| 55 | + ), |
| 56 | + ), |
| 57 | + "BOWSR MEGNet": dict( |
| 58 | + n_runs=1000, |
| 59 | + filters=dict( |
| 60 | + created_at={"$gt": "2022-11-22", "$lt": "2022-11-25"}, |
| 61 | + display_name={"$regex": "bowsr-megnet"}, |
| 62 | + ), |
| 63 | + ), |
| 64 | +} |
| 65 | + |
| 66 | +run_times: dict[str, dict[str, str | int | float]] = {} |
| 67 | + |
| 68 | + |
| 69 | +# %% calculate total model run times from wandb logs |
| 70 | +# NOTE these model run times are pretty meaningless since some models were run on GPU |
| 71 | +# (Wrenformer and CGCNN), others on CPU. Also BOWSR MEGNet, M3GNet and MEGNet weren't |
| 72 | +# trained from scratch. Their run times only indicate the time needed to predict the |
| 73 | +# test set. |
| 74 | + |
| 75 | +for model in (pbar := tqdm(models)): |
| 76 | + model_dict = models[model] |
| 77 | + n_runs, filters = (model_dict.get(x) for x in ("n_runs", "filters")) |
| 78 | + if n_runs == 0 or model in run_times: |
| 79 | + continue |
| 80 | + pbar.set_description(model) |
| 81 | + |
| 82 | + runs = wandb.Api().runs("janosh/matbench-discovery", filters=filters) |
| 83 | + |
| 84 | + assert len(runs) == n_runs, f"found {len(runs)=} for {model}, expected {n_runs}" |
| 85 | + |
| 86 | + run_time = sum(run.summary.get("_wandb", {}).get("runtime", 0) for run in runs) |
| 87 | + # NOTE we assume all jobs have the same metadata here |
| 88 | + metadata = requests.get(runs[0].file("wandb-metadata.json").url).json() |
| 89 | + |
| 90 | + n_gpu, n_cpu = metadata.get("gpu_count", 0), metadata.get("cpu_count", 0) |
| 91 | + run_times[model] = {"Run time": run_time, "Hardware": f"GPU: {n_gpu}, CPU: {n_cpu}"} |
| 92 | + |
| 93 | + |
| 94 | +# on 2022-11-28: |
| 95 | +# run_times = {'Voronoi RF': 739608, |
| 96 | +# 'Wrenformer': 208399, |
| 97 | +# 'MEGNet': 12396, |
| 98 | +# 'M3GNet': 301138, |
| 99 | +# 'BOWSR MEGNet': 9105237} |
| 100 | + |
| 101 | + |
| 102 | +# %% |
| 103 | +df_wbm = load_df_wbm_with_preds(models=models).round(3) |
| 104 | + |
| 105 | + |
| 106 | +target_col = "e_form_per_atom_mp2020_corrected" |
| 107 | + |
| 108 | +df_wbm = df_wbm.query(f"{target_col} < 5") |
| 109 | +e_above_hull_col = "e_above_hull_mp2020_corrected_ppd_mp" |
| 110 | +e_above_hull = df_wbm[e_above_hull_col] |
| 111 | + |
| 112 | + |
| 113 | +# %% |
| 114 | +df_metrics = pd.DataFrame(run_times).T |
| 115 | + |
| 116 | +for model in models: |
| 117 | + dct = {} |
| 118 | + e_above_hull_pred = df_wbm[model] - df_wbm[target_col] |
| 119 | + |
| 120 | + dct["F1"] = f1_score(e_above_hull < 0, e_above_hull_pred < 0) |
| 121 | + dct["Precision"] = f1_score(e_above_hull < 0, e_above_hull_pred < 0, pos_label=True) |
| 122 | + dct["Recall"] = f1_score(e_above_hull < 0, e_above_hull_pred < 0, pos_label=False) |
| 123 | + |
| 124 | + dct["MAE"] = (e_above_hull_pred - e_above_hull).abs().mean() |
| 125 | + |
| 126 | + dct["RMSE"] = ((e_above_hull_pred - e_above_hull) ** 2).mean() ** 0.5 |
| 127 | + dct["R2"] = r2_score( |
| 128 | + e_above_hull.loc[e_above_hull_pred.dropna().index], e_above_hull_pred.dropna() |
| 129 | + ) |
| 130 | + |
| 131 | + df_metrics.loc[model, list(dct)] = dct.values() |
| 132 | + |
| 133 | + |
| 134 | +df_styled = df_metrics.style.format(precision=3).background_gradient( |
| 135 | + cmap="viridis", |
| 136 | + # gmap=np.log10(df_table) # for log scaled color map |
| 137 | +) |
| 138 | +df_styled |
| 139 | + |
| 140 | + |
| 141 | +# %% |
| 142 | +styles = { |
| 143 | + "": "font-family: sans-serif; border-collapse: collapse;", |
| 144 | + "td, th": "border: 1px solid #ddd; text-align: left; padding: 8px;", |
| 145 | +} |
| 146 | +df_styled.set_table_styles([dict(selector=sel, props=styles[sel]) for sel in styles]) |
| 147 | + |
| 148 | +df_styled.to_html(f"{ROOT}/figures/{today}-metrics-table.html") |
0 commit comments