Skip to content

Commit 0e9e5dc

Browse files
committed
split model pred loading from CSV into new module matbench_discovery/preds.py
1 parent 24f6868 commit 0e9e5dc

14 files changed

+78
-74
lines changed

matbench_discovery/metrics.py

+6-30
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
"""Centralize data-loading and computing metrics for plotting scripts"""
2-
31
from __future__ import annotations
42

53
from collections.abc import Sequence
@@ -8,7 +6,12 @@
86
import pandas as pd
97
from sklearn.metrics import r2_score
108

11-
from matbench_discovery.data import load_df_wbm_preds
9+
"""Functions to classify energy above convex hull predictions as true/false
10+
positive/negative and compute performance metrics.
11+
"""
12+
13+
__author__ = "Janosh Riebesell"
14+
__date__ = "2023-02-01"
1215

1316

1417
def classify_stable(
@@ -98,30 +101,3 @@ def stable_metrics(
98101
RMSE=((true - pred) ** 2).mean() ** 0.5,
99102
R2=r2_score(true, pred),
100103
)
101-
102-
103-
models = sorted(
104-
"Wrenformer, CGCNN, Voronoi Random Forest, MEGNet, M3GNet + MEGNet, "
105-
"BOWSR + MEGNet".split(", ")
106-
)
107-
e_form_col = "e_form_per_atom_mp2020_corrected"
108-
each_true_col = "e_above_hull_mp2020_corrected_ppd_mp"
109-
each_pred_col = "e_above_hull_pred"
110-
111-
df_wbm = load_df_wbm_preds(models).round(3)
112-
113-
for col in [e_form_col, each_true_col]:
114-
assert col in df_wbm, f"{col=} not in {list(df_wbm)=}"
115-
116-
117-
df_metrics = pd.DataFrame()
118-
for model in models:
119-
df_metrics[model] = stable_metrics(
120-
df_wbm[each_true_col],
121-
df_wbm[each_true_col] + df_wbm[e_form_col] - df_wbm[model],
122-
)
123-
124-
assert df_metrics.T.MAE.between(0, 0.2).all(), "MAE not in range"
125-
assert df_metrics.T.R2.between(0.1, 1).all(), "R2 not in range"
126-
assert df_metrics.T.RMSE.between(0, 0.25).all(), "RMSE not in range"
127-
assert df_metrics.isna().sum().sum() == 0, "NaNs in metrics"

matbench_discovery/preds.py

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from __future__ import annotations
2+
3+
import pandas as pd
4+
5+
from matbench_discovery.data import load_df_wbm_preds
6+
from matbench_discovery.metrics import stable_metrics
7+
8+
"""Centralize data-loading and computing metrics for plotting scripts"""
9+
10+
__author__ = "Janosh Riebesell"
11+
__date__ = "2023-02-04"
12+
13+
models = sorted(
14+
"Wrenformer, CGCNN, Voronoi Random Forest, MEGNet, M3GNet + MEGNet, "
15+
"BOWSR + MEGNet".split(", ")
16+
)
17+
e_form_col = "e_form_per_atom_mp2020_corrected"
18+
each_true_col = "e_above_hull_mp2020_corrected_ppd_mp"
19+
each_pred_col = "e_above_hull_pred"
20+
21+
df_wbm = load_df_wbm_preds(models).round(3)
22+
23+
for col in [e_form_col, each_true_col]:
24+
assert col in df_wbm, f"{col=} not in {list(df_wbm)=}"
25+
26+
27+
df_metrics = pd.DataFrame()
28+
for model in models:
29+
df_metrics[model] = stable_metrics(
30+
df_wbm[each_true_col],
31+
df_wbm[each_true_col] + df_wbm[e_form_col] - df_wbm[model],
32+
)
33+
34+
assert df_metrics.T.MAE.between(0, 0.2).all(), "MAE not in range"
35+
assert df_metrics.T.R2.between(0.1, 1).all(), "R2 not in range"
36+
assert df_metrics.T.RMSE.between(0, 0.25).all(), "RMSE not in range"
37+
assert df_metrics.isna().sum().sum() == 0, "NaNs in metrics"

scripts/cumulative_clf_metrics.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
from pymatviz.utils import save_fig
44

55
from matbench_discovery import FIGS, STATIC, today
6-
from matbench_discovery.metrics import df_wbm, e_form_col, each_true_col, models
76
from matbench_discovery.plots import cumulative_precision_recall
7+
from matbench_discovery.preds import df_wbm, e_form_col, each_true_col, models
88

99
__author__ = "Janosh Riebesell, Rhys Goodall"
1010
__date__ = "2022-12-04"

scripts/hist_classified_stable_vs_hull_dist.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,9 @@
22
from pymatviz.utils import save_fig
33

44
from matbench_discovery import FIGS, today
5-
from matbench_discovery.metrics import (
6-
df_wbm,
7-
e_form_col,
8-
each_pred_col,
9-
each_true_col,
10-
stable_metrics,
11-
)
5+
from matbench_discovery.metrics import stable_metrics
126
from matbench_discovery.plots import WhichEnergy, hist_classified_stable_vs_hull_dist
7+
from matbench_discovery.preds import df_wbm, e_form_col, each_pred_col, each_true_col
138

149
__author__ = "Rhys Goodall, Janosh Riebesell"
1510
__date__ = "2022-06-18"

scripts/hist_classified_stable_vs_hull_dist_batches.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,14 @@
22
from pymatviz.utils import save_fig
33

44
from matbench_discovery import FIGS, today
5-
from matbench_discovery.metrics import (
6-
df_wbm,
7-
e_form_col,
8-
each_pred_col,
9-
each_true_col,
10-
stable_metrics,
11-
)
5+
from matbench_discovery.metrics import stable_metrics
126
from matbench_discovery.plots import (
137
Backend,
148
WhichEnergy,
159
hist_classified_stable_vs_hull_dist,
1610
plt,
1711
)
12+
from matbench_discovery.preds import df_wbm, e_form_col, each_pred_col, each_true_col
1813

1914
__author__ = "Rhys Goodall, Janosh Riebesell"
2015
__date__ = "2022-08-25"

scripts/hist_classified_stable_vs_hull_dist_models.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22
from pymatviz.utils import save_fig
33

44
from matbench_discovery import STATIC, today
5-
from matbench_discovery.metrics import (
5+
from matbench_discovery.plots import Backend, hist_classified_stable_vs_hull_dist, plt
6+
from matbench_discovery.preds import (
67
df_metrics,
78
df_wbm,
89
e_form_col,
910
each_true_col,
1011
models,
1112
)
12-
from matbench_discovery.plots import Backend, hist_classified_stable_vs_hull_dist, plt
1313

1414
__author__ = "Janosh Riebesell"
1515
__date__ = "2022-12-01"

scripts/prc_roc_curves_models.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@
55
from tqdm import tqdm
66

77
from matbench_discovery import FIGS, today
8-
from matbench_discovery.metrics import (
8+
from matbench_discovery.metrics import stable_metrics
9+
from matbench_discovery.plots import pio
10+
from matbench_discovery.preds import (
911
df_wbm,
1012
e_form_col,
1113
each_pred_col,
1214
each_true_col,
1315
models,
14-
stable_metrics,
1516
)
16-
from matbench_discovery.plots import pio
1717

1818
__author__ = "Janosh Riebesell"
1919
__date__ = "2023-01-30"

scripts/rolling_mae_vs_hull_dist.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# %%
22
from matbench_discovery import FIGS, today
3-
from matbench_discovery.metrics import df_metrics, df_wbm, e_form_col, each_true_col
43
from matbench_discovery.plots import rolling_mae_vs_hull_dist
4+
from matbench_discovery.preds import df_metrics, df_wbm, e_form_col, each_true_col
55

66
__author__ = "Rhys Goodall, Janosh Riebesell"
77
__date__ = "2022-06-18"

scripts/rolling_mae_vs_hull_dist_all_models.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
from pymatviz.utils import save_fig
33

44
from matbench_discovery import FIGS, STATIC, today
5-
from matbench_discovery.metrics import df_metrics, df_wbm, e_form_col, each_true_col
65
from matbench_discovery.plots import Backend, rolling_mae_vs_hull_dist
6+
from matbench_discovery.preds import df_metrics, df_wbm, e_form_col, each_true_col
77

88
__author__ = "Rhys Goodall, Janosh Riebesell"
99
__date__ = "2022-06-18"

scripts/rolling_mae_vs_hull_dist_wbm_batches.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# %%
22
from matbench_discovery import FIGS, today
3-
from matbench_discovery.metrics import df_wbm, e_form_col, each_true_col
43
from matbench_discovery.plots import plt, rolling_mae_vs_hull_dist
4+
from matbench_discovery.preds import df_wbm, e_form_col, each_true_col
55

66
__author__ = "Rhys Goodall, Janosh Riebesell"
77
__date__ = "2022-06-18"

scripts/scatter_e_above_hull_models.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,15 @@
33
from pymatviz.utils import add_identity_line, save_fig
44

55
from matbench_discovery import FIGS, STATIC, today
6-
from matbench_discovery.metrics import (
7-
classify_stable,
6+
from matbench_discovery.metrics import classify_stable, stable_metrics
7+
from matbench_discovery.plots import clf_color_map, clf_labels, px
8+
from matbench_discovery.preds import (
89
df_wbm,
910
e_form_col,
1011
each_pred_col,
1112
each_true_col,
1213
models,
13-
stable_metrics,
1414
)
15-
from matbench_discovery.plots import clf_color_map, clf_labels, px
1615

1716
__author__ = "Janosh Riebesell"
1817
__date__ = "2022-11-28"

site/src/routes/how-to-contribute/+page.md

+4-4
Original file line numberDiff line numberDiff line change
@@ -147,17 +147,17 @@ and place the above-listed files there. The file structure should look like this
147147
```txt
148148
matbench-discovery-root
149149
└── models
150-
└── <model name>
150+
└── <model_name>
151151
├── metadata.yml
152152
├── <yyyy-mm-dd>-<model_name>-preds.(json|csv).gz
153153
├── test_<model_name>.py
154-
├── readme.md # optional
155-
└── train_<model_name>.py # optional
154+
├── readme.md # optional
155+
└── train_<model_name>.py # optional
156156
```
157157

158158
You can include arbitrary other supporting files like metadata and model features (below 10MB to keep `git clone` time low) if they are needed to run the model or help others reproduce your results. For larger files, please upload to [Figshare](https://figshare.com) or similar and link them somewhere in your files.
159159

160-
### Step 3: Create a PR to the [Matbench Discovery repo](https://github.com/janosh/matbench-discovery)
160+
### Step 3: Open a PR to the [Matbench Discovery repo](https://github.com/janosh/matbench-discovery)
161161

162162
Commit your files to the repo on a branch called `<model_name>` and create a pull request (PR) to the Matbench repository.
163163

tests/test_energy.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,14 @@
66
from pymatgen.analysis.phase_diagram import PDEntry
77
from pymatgen.core import Lattice, Structure
88
from pymatgen.entries.computed_entries import ComputedEntry, Entry
9+
from pytest import approx
910

10-
from matbench_discovery.energy import get_e_form_per_atom, get_elemental_ref_entries
11+
from matbench_discovery.energy import (
12+
get_e_form_per_atom,
13+
get_elemental_ref_entries,
14+
mp_elem_reference_entries,
15+
mp_elemental_ref_energies,
16+
)
1117

1218
dummy_struct = Structure(
1319
lattice=Lattice.cubic(5),
@@ -49,3 +55,11 @@ def test_get_elemental_ref_entries(
4955
expected = {"Fe": constructor(*entries[2]), "O": constructor(*entries[3])}
5056

5157
assert elemental_ref_entries == expected
58+
59+
60+
def test_mp_ref_energies() -> None:
61+
"""Test MP elemental reference energies are in sync with PDEntries saved to disk."""
62+
for key, val in mp_elemental_ref_energies.items():
63+
actual = mp_elem_reference_entries[key].energy_per_atom
64+
assert actual == approx(val, abs=1e-3), f"{key=}"
65+
assert actual == approx(val, abs=1e-3), f"{key=}"

tests/test_metrics.py

-12
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,6 @@
77
import pytest
88
from pytest import approx
99

10-
from matbench_discovery.energy import (
11-
mp_elem_reference_entries,
12-
mp_elemental_ref_energies,
13-
)
1410
from matbench_discovery.metrics import classify_stable, stable_metrics
1511

1612

@@ -83,11 +79,3 @@ def test_stable_metrics() -> None:
8379
# test stable_metrics docstring is up to date, all returned metrics should be listed
8480
assert stable_metrics.__doc__ # for mypy
8581
assert all(key in stable_metrics.__doc__ for key in metrics)
86-
87-
88-
def test_mp_ref_energies() -> None:
89-
"""Test MP elemental reference energies are in sync with PDEntries saved to disk."""
90-
for key, val in mp_elemental_ref_energies.items():
91-
actual = mp_elem_reference_entries[key].energy_per_atom
92-
assert actual == approx(val, abs=1e-3), f"{key=}"
93-
assert actual == approx(val, abs=1e-3), f"{key=}"

0 commit comments

Comments
 (0)