Skip to content

Commit 89976c4

Browse files
committed
move model run/eval scripts from mb_discovery/ to new dir models/
mv mb_discovery/{plot_scripts/plot_funcs.py -> plots.py}
1 parent 50b5f28 commit 89976c4

14 files changed

+58
-33
lines changed

.gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ job-logs/
2020

2121
# slurm logs
2222
slurm-*out
23-
mb_discovery/**/*.csv
23+
models/**/*.csv
2424

2525
# temporary ignore rule
2626
paper

mb_discovery/__init__.py

-2
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
PKG_DIR = os.path.dirname(__file__)
77
ROOT = os.path.dirname(PKG_DIR)
88

9-
os.makedirs(f"{PKG_DIR}/plots", exist_ok=True)
10-
119

1210
def chunks(xs: Sequence[Any], n: int) -> Generator[Sequence[Any], None, None]:
1311
return (xs[i : i + n] for i in range(0, len(xs), n))

mb_discovery/plot_scripts/hist_classified_stable_as_func_of_hull_dist.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
import pandas as pd
55

66
from mb_discovery import ROOT
7-
from mb_discovery.plot_scripts import plt
8-
from mb_discovery.plot_scripts.plot_funcs import (
7+
from mb_discovery.plots import (
98
StabilityCriterion,
109
WhichEnergy,
1110
hist_classified_stable_as_func_of_hull_dist,
11+
plt,
1212
)
1313

1414
__author__ = "Rhys Goodall, Janosh Riebesell"

mb_discovery/plot_scripts/hist_classified_stable_as_func_of_hull_dist_batches.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
import pandas as pd
55

66
from mb_discovery import ROOT
7-
from mb_discovery.plot_scripts import plt
8-
from mb_discovery.plot_scripts.plot_funcs import (
7+
from mb_discovery.plots import (
98
StabilityCriterion,
109
WhichEnergy,
1110
hist_classified_stable_as_func_of_hull_dist,
11+
plt,
1212
)
1313

1414
__author__ = "Rhys Goodall, Janosh Riebesell"

mb_discovery/plot_scripts/precision_recall_vs_calc_count.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,7 @@
44
import pandas as pd
55

66
from mb_discovery import ROOT
7-
from mb_discovery.plot_scripts import plt
8-
from mb_discovery.plot_scripts.plot_funcs import (
9-
StabilityCriterion,
10-
precision_recall_vs_calc_count,
11-
)
7+
from mb_discovery.plots import StabilityCriterion, plt, precision_recall_vs_calc_count
128

139
__author__ = "Rhys Goodall, Janosh Riebesell"
1410
__date__ = "2022-06-18"
@@ -32,20 +28,22 @@
3228
).set_index("material_id")
3329

3430
dfs["Wrenformer"] = pd.read_csv(
35-
f"{ROOT}/data/2022-08-16-wrenformer-preds.csv.bz2"
31+
f"{ROOT}/models/wrenformer/mp/"
32+
"2022-09-20-wrenformer-e_form-ensemble-1-preds-e_form_per_atom.csv"
3633
).set_index("material_id")
3734

35+
print(f"loaded models: {list(dfs)}")
36+
3837

3938
# %% download wbm-steps-summary.csv (23.31 MB)
40-
df_summary = pd.read_csv(
39+
df_wbm = pd.read_csv(
4140
"https://figshare.com/files/37542841?private_link=ff0ad14505f9624f0c05"
4241
).set_index("material_id")
4342

4443

44+
# %%
4545
stability_crit: StabilityCriterion = "energy"
4646

47-
48-
# %%
4947
for (model_name, df), color in zip(
5048
dfs.items(),
5149
("tab:blue", "tab:orange", "teal", "tab:pink", "black", "red", "turquoise"),

mb_discovery/plot_scripts/rolling_mae_vs_hull_dist.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
import pandas as pd
55

66
from mb_discovery import ROOT
7-
from mb_discovery.plot_scripts import plt
8-
from mb_discovery.plot_scripts.plot_funcs import rolling_mae_vs_hull_dist
7+
from mb_discovery.plots import plt, rolling_mae_vs_hull_dist
98

109
__author__ = "Rhys Goodall, Janosh Riebesell"
1110
__date__ = "2022-06-18"
@@ -17,8 +16,9 @@
1716
markers = ["o", "v", "^", "H", "D", ""]
1817

1918
data_path = (
20-
f"{ROOT}/data/2022-06-11-from-rhys/wren-mp-initial-structures.csv"
21-
# f"{ROOT}/data/2022-08-16-wrenformer-preds.csv.bz2"
19+
# f"{ROOT}/data/2022-06-11-from-rhys/wren-mp-initial-structures.csv"
20+
f"{ROOT}/models/wrenformer/mp/"
21+
"2022-09-20-wrenformer-e_form-ensemble-1-preds-e_form_per_atom.csv"
2222
)
2323
df = pd.read_csv(data_path).set_index("material_id")
2424
legend_label = "Wren"

mb_discovery/plot_scripts/rolling_mae_vs_hull_dist_wbm_batches.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
import pandas as pd
55

66
from mb_discovery import ROOT
7-
from mb_discovery.plot_scripts import plt
8-
from mb_discovery.plot_scripts.plot_funcs import rolling_mae_vs_hull_dist
7+
from mb_discovery.plots import plt, rolling_mae_vs_hull_dist
98

109
__author__ = "Rhys Goodall, Janosh Riebesell"
1110
__date__ = "2022-06-18"

mb_discovery/plot_scripts/plot_funcs.py mb_discovery/plots.py

+33-2
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,52 @@
22

33
from typing import Any, Literal, Sequence, get_args
44

5+
import matplotlib.pyplot as plt
56
import numpy as np
67
import pandas as pd
8+
import plotly.express as px
9+
import plotly.io as pio
710
import scipy.interpolate
811
import scipy.stats
912
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
1013

11-
from mb_discovery.plot_scripts import plt
12-
1314
__author__ = "Janosh Riebesell"
1415
__date__ = "2022-08-05"
1516

1617
StabilityCriterion = Literal["energy", "energy+std", "energy-std"]
1718
WhichEnergy = Literal["true", "pred"]
1819

1920

21+
# --- define global plot settings
22+
px.defaults.labels = {
23+
"n_atoms": "Atom Count",
24+
"n_elems": "Element Count",
25+
"crystal_sys": "Crystal system",
26+
"spg_num": "Space group",
27+
"n_wyckoff": "Number of Wyckoff positions",
28+
"n_sites": "Lattice site count",
29+
"energy_per_atom": "Energy (eV/atom)",
30+
"e_form": "Formation energy (eV/atom)",
31+
"e_above_hull": "Energy above convex hull (eV/atom)",
32+
"e_above_hull_pred": "Predicted energy above convex hull (eV/atom)",
33+
"e_above_mp_hull": "Energy above MP convex hull (eV/atom)",
34+
"e_above_hull_error": "Error in energy above convex hull (eV/atom)",
35+
}
36+
37+
pio.templates.default = "plotly_white"
38+
39+
# https://github.com/plotly/Kaleido/issues/122#issuecomment-994906924
40+
# when seeing MathJax "loading" message in exported PDFs, try:
41+
# pio.kaleido.scope.mathjax = None
42+
43+
44+
plt.rc("font", size=14)
45+
plt.rc("savefig", bbox="tight", dpi=200)
46+
plt.rc("figure", dpi=200, titlesize=16)
47+
plt.rcParams["figure.constrained_layout.use"] = True
48+
# --- end global plot settings
49+
50+
2051
def hist_classified_stable_as_func_of_hull_dist(
2152
e_above_hull_pred: pd.Series,
2253
e_above_hull_true: pd.Series,

mb_discovery/m3gnet/join_and_plot_m3gnet_relax_results.py models/m3gnet/join_and_plot_m3gnet_relax_results.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,7 @@
1414
from tqdm import tqdm
1515

1616
from mb_discovery import ROOT, as_dict_handler
17-
from mb_discovery.plot_scripts.plot_funcs import (
18-
hist_classified_stable_as_func_of_hull_dist,
19-
)
17+
from mb_discovery.plots import hist_classified_stable_as_func_of_hull_dist
2018

2119
today = f"{datetime.now():%Y-%m-%d}"
2220

mb_discovery/m3gnet/slurm_array_m3gnet_relax_wbm.py models/m3gnet/slurm_array_m3gnet_relax_wbm.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
```sh
2121
sbatch --partition icelake-himem --account LEE-SL3-CPU --array 1-101 \
2222
--time 3:0:0 --job-name m3gnet-wbm-relax-RS2RE --mem 12000 \
23-
--output mb_discovery/m3gnet/slurm_logs/slurm-%A-%a.out \
24-
--wrap "python mb_discovery/m3gnet/slurm_array_m3gnet_relax_wbm.py"
23+
--output models/m3gnet/slurm_logs/slurm-%A-%a.out \
24+
--wrap "python models/m3gnet/slurm_array_m3gnet_relax_wbm.py"
2525
```
2626
2727
--time 2h is probably enough but missing indices are annoying so best be safe.

mb_discovery/wrenformer/mp/get_mp_energies.py models/wrenformer/mp/get_mp_energies.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
from datetime import datetime
33

44
import pandas as pd
5-
from aviary import ROOT
65
from aviary.utils import as_dict_handler
76
from aviary.wren.utils import get_aflow_label_from_spglib
87
from mp_api.client import MPRester
98
from tqdm import tqdm
109

10+
from mb_discovery import ROOT
11+
1112
"""
1213
Download all MP formation and above hull energies on 2022-08-13.
1314
@@ -48,7 +49,7 @@
4849
df["wyckoff"] = [get_aflow_label_from_spglib(x) for x in tqdm(df.structure)]
4950

5051
df.to_json(
51-
f"{ROOT}/datasets/{today}-mp-all-energies.json.gz", default_handler=as_dict_handler
52+
f"{ROOT}/data/{today}-mp-all-energies.json.gz", default_handler=as_dict_handler
5253
)
5354

54-
# df = pd.read_json(f"{ROOT}/datasets/2022-08-13-mp-all-energies.json.gz")
55+
# df = pd.read_json(f"{ROOT}/data/2022-08-13-mp-all-energies.json.gz")

tests/test_plot_funcs.py tests/test_plots.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pytest
88

99
from mb_discovery import ROOT
10-
from mb_discovery.plot_scripts.plot_funcs import (
10+
from mb_discovery.plots import (
1111
StabilityCriterion,
1212
hist_classified_stable_as_func_of_hull_dist,
1313
precision_recall_vs_calc_count,

0 commit comments

Comments
 (0)