Skip to content

Commit a9386fb

Browse files
committed
delete outdated 'from matbench_discovery import DEBUG'
import col names from matbench_discovery.preds add license rm .gitmodules
1 parent 50f5821 commit a9386fb

14 files changed

+72
-41
lines changed

.gitmodules

-3
This file was deleted.

data/wbm/eda.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,7 @@
9494

9595

9696
# %% histogram of energy above MP convex hull for WBM
97-
col = "e_above_hull_mp2020_corrected_ppd_mp"
98-
# col = "e_form_per_atom_mp2020_corrected"
97+
col = each_true_col # or e_form_col
9998
mean, std = df_wbm[col].mean(), df_wbm[col].std()
10099

101100
range_x = (mean - 2 * std, mean + 2 * std)

license

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2022 Janosh Riebesell
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
The software is provided "as is", without warranty of any kind, express or
16+
implied, including but not limited to the warranties of merchantability,
17+
fitness for a particular purpose and noninfringement. In no event shall the
18+
authors or copyright holders be liable for any claim, damages or other
19+
liability, whether in an action of contract, tort or otherwise, arising from,
20+
out of or in connection with the software or the use or other dealings in the
21+
software.

models/alignn/test_alignn.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from matbench_discovery import today
2121
from matbench_discovery.data import DATA_FILES, df_wbm
2222
from matbench_discovery.plots import wandb_scatter
23+
from matbench_discovery.preds import e_form_col
2324
from matbench_discovery.slurm import slurm_submit
2425

2526
__author__ = "Janosh Riebesell, Philipp Benner"
@@ -33,7 +34,7 @@
3334
# TODO fix this to load checkpoint from figshare
3435
# model_name = f"{module_dir}/data-train-result/best-model.pth"
3536
task_type = "IS2RE"
36-
target_col = "e_form_per_atom_mp2020_corrected"
37+
target_col = e_form_col
3738
input_col = "initial_structure"
3839
id_col = "material_id"
3940
device = "cuda" if torch.cuda.is_available() else "cpu"

models/alignn_ff/alignn_ff_relax.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99
from pymatgen.io.jarvis import JarvisAtomsAdaptor
1010
from tqdm import tqdm
1111

12-
from matbench_discovery import DEBUG, today
12+
from matbench_discovery import today
1313
from matbench_discovery.data import DATA_FILES, df_wbm
14+
from matbench_discovery.preds import e_form_col as target_col
1415

1516
__author__ = "Janosh Riebesell, Philipp Benner"
1617
__date__ = "2023-07-11"
@@ -28,10 +29,9 @@
2829
# model_name = "mp_e_form_alignn" # pre-trained by NIST
2930
model_name = f"{out_dir}/best-model.pth"
3031
task_type = "IS2RE"
31-
target_col = "e_form_per_atom_mp2020_corrected"
3232
input_col = "initial_structure"
3333
id_col = "material_id"
34-
job_name = f"{model_name}-wbm-{task_type}{'-debug' if DEBUG else ''}"
34+
job_name = f"{model_name}-wbm-{task_type}"
3535
out_path = (
3636
f"{out_dir}/{'alignn-relaxed-structs' if batch == 0 else f'{batch=}'}.json.gz"
3737
)

models/alignn_ff/test_alignn_ff.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@
1818
from sklearn.metrics import r2_score
1919
from tqdm import tqdm
2020

21-
from matbench_discovery import DEBUG, today
21+
from matbench_discovery import today
2222
from matbench_discovery.data import DATA_FILES, df_wbm
2323
from matbench_discovery.plots import wandb_scatter
24+
from matbench_discovery.preds import e_form_col as target_col
2425

2526
__author__ = "Philipp Benner, Janosh Riebesell"
2627
__date__ = "2023-07-11"
@@ -32,12 +33,11 @@
3233
n_splits = 100
3334
# model_name = "mp_e_form_alignnn" # pre-trained by NIST
3435
task_type = "IS2RE"
35-
target_col = "e_form_per_atom_mp2020_corrected"
3636
input_col = "initial_structure"
3737
id_col = "material_id"
3838
device = "cuda" if torch.cuda.is_available() else "cpu"
3939
model_name = f"alignn-ff-wbm-{task_type}"
40-
job_name = f"{model_name}-relaxed-wbm-{task_type}{'-debug' if DEBUG else ''}"
40+
job_name = f"{model_name}-relaxed-wbm-{task_type}"
4141
out_dir = os.getenv("SBATCH_OUTPUT", f"{module_dir}/{today}-{job_name}")
4242
in_dir = os.getenv("SBATCH_OUTPUT", f"{module_dir}/{today}-{job_name}")
4343

models/cgcnn/test_cgcnn.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from matbench_discovery import CHECKPOINT_DIR, ROOT, WANDB_PATH, today
1717
from matbench_discovery.data import DATA_FILES, df_wbm
1818
from matbench_discovery.plots import wandb_scatter
19+
from matbench_discovery.preds import e_form_col as target_col
1920
from matbench_discovery.slurm import slurm_submit
2021

2122
__author__ = "Janosh Riebesell"
@@ -53,8 +54,7 @@
5354

5455
df = pd.read_json(data_path).set_index("material_id")
5556

56-
e_form_col = "e_form_per_atom_mp2020_corrected"
57-
df[e_form_col] = df_wbm[e_form_col]
57+
df[target_col] = df_wbm[target_col]
5858
if task_type == "RS2RE":
5959
df[input_col] = [x["structure"] for x in df.computed_structure_entry]
6060
assert input_col in df, f"{input_col=} not in {list(df)}"
@@ -87,7 +87,7 @@
8787
versions={dep: version(dep) for dep in ("aviary", "numpy", "torch")},
8888
ensemble_size=len(runs),
8989
task_type=task_type,
90-
target_col=e_form_col,
90+
target_col=target_col,
9191
input_col=input_col,
9292
wandb_run_filters=filters,
9393
slurm_vars=slurm_vars,
@@ -97,7 +97,7 @@
9797
wandb.init(project="matbench-discovery", name=job_name, config=run_params)
9898

9999
cg_data = CrystalGraphData(
100-
df, task_dict={e_form_col: "regression"}, structure_col=input_col
100+
df, task_dict={target_col: "regression"}, structure_col=input_col
101101
)
102102
data_loader = DataLoader(
103103
cg_data, batch_size=1024, shuffle=False, collate_fn=collate_batch
@@ -110,16 +110,16 @@
110110
# dropping isolated-atom structs means len(cg_data.df) < len(df)
111111
cache_dir=CHECKPOINT_DIR,
112112
df=cg_data.df.drop(columns=input_col),
113-
target_col=e_form_col,
113+
target_col=target_col,
114114
model_cls=CrystalGraphConvNet,
115115
data_loader=data_loader,
116116
)
117117

118118
slurm_job_id = os.getenv("SLURM_JOB_ID", "debug")
119119
df.round(4).to_csv(f"{out_dir}/{job_name}-preds-{slurm_job_id}.csv.gz")
120-
pred_col = f"{e_form_col}_pred_ens"
120+
pred_col = f"{target_col}_pred_ens"
121121
assert pred_col in df, f"{pred_col=} not in {list(df)}"
122-
table = wandb.Table(dataframe=df[[e_form_col, pred_col]].reset_index())
122+
table = wandb.Table(dataframe=df[[target_col, pred_col]].reset_index())
123123

124124

125125
# %%
@@ -128,4 +128,4 @@
128128

129129
title = f"CGCNN {task_type} ensemble={len(runs)} {MAE=:.4} {R2=:.4}"
130130

131-
wandb_scatter(table, fields=dict(x=e_form_col, y=pred_col), title=title)
131+
wandb_scatter(table, fields=dict(x=target_col, y=pred_col), title=title)

models/mace/analyze_mace.py

+17-9
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,18 @@
66

77
import pandas as pd
88
from pymatviz import density_scatter, ptable_heatmap_plotly, spacegroup_sunburst
9+
from pymatviz.utils import save_fig
910

1011
from matbench_discovery import plots as plots
1112
from matbench_discovery.data import df_wbm
1213
from matbench_discovery.preds import PRED_FILES
14+
from matbench_discovery.preds import e_form_col as target_col
1315

1416
__author__ = "Janosh Riebesell"
1517
__date__ = "2023-07-23"
1618

1719
module_dir = os.path.dirname(__file__)
1820
id_col = "material_id"
19-
target_col = "e_form_per_atom_mp2020_corrected"
2021
pred_col = "e_form_per_atom_mace"
2122

2223

@@ -29,29 +30,36 @@
2930

3031

3132
# %%
32-
density_scatter(df=df_mace, x=target_col, y=pred_col)
33+
ax = density_scatter(df=df_mace, x=target_col, y=pred_col)
34+
ax.set(title=f"{len(df_mace):,} MACE severe energy underpredictions")
35+
save_fig(ax, "mace-hull-dist-scatter.pdf")
3336

3437

3538
# %%
36-
df_bad = df_mace.query(f"{target_col} - {pred_col} > 2")
39+
df_low = df_mace.query(f"{target_col} - {pred_col} > 2")
3740

38-
ax = density_scatter(df=df_bad, x=target_col, y=pred_col)
39-
ax.set(title=f"{len(df_bad):,} MACE severe energy underpredictions")
41+
ax = density_scatter(df=df_low, x=target_col, y=pred_col)
42+
ax.set(title=f"{len(df_low):,} MACE severe energy underpredictions")
43+
save_fig(ax, "mace-too-low-hull-dist-scatter.pdf")
4044

4145

4246
# %%
43-
fig = ptable_heatmap_plotly(df_bad.formula)
44-
title = f"Elements in {len(df_bad):,} MACE severe energy underpredictions"
47+
fig = ptable_heatmap_plotly(df_low.formula)
48+
title = f"Elements in {len(df_low):,} MACE severe energy underpredictions"
4549
fig.layout.title.update(text=title, x=0.4, y=0.95)
4650
fig.show()
4751

52+
save_fig(fig, "mace-too-low-elements-heatmap.pdf")
53+
4854

4955
# %%
50-
fig = spacegroup_sunburst(df_bad[spg_col], title="MACE spacegroups")
51-
title = f"Spacegroup sunburst of {len(df_bad):,} MACE severe energy underpredictions"
56+
fig = spacegroup_sunburst(df_low[spg_col], title="MACE spacegroups")
57+
title = f"Spacegroup sunburst of {len(df_low):,} MACE severe energy underpredictions"
5258
fig.layout.title.update(text=title, x=0.5)
5359
fig.show()
5460

61+
save_fig(fig, "mace-too-low-spacegroup-sunburst.pdf")
62+
5563

5664
"""
5765
Space groups of MACE underpredictions look unremarkable but unusually heavy in Silicon,

models/megnet/test_megnet.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from matbench_discovery import timestamp, today
2424
from matbench_discovery.data import DATA_FILES, df_wbm
2525
from matbench_discovery.plots import wandb_scatter
26-
from matbench_discovery.preds import PRED_FILES
26+
from matbench_discovery.preds import PRED_FILES, e_form_col
2727
from matbench_discovery.slurm import slurm_submit
2828

2929
__author__ = "Janosh Riebesell"
@@ -63,7 +63,6 @@
6363
}[task_type]
6464
print(f"\nJob started running {timestamp}")
6565
print(f"{data_path=}")
66-
e_form_col = "e_form_per_atom_mp2020_corrected"
6766
assert e_form_col in df_wbm, f"{e_form_col=} not in {list(df_wbm)=}"
6867

6968
df_in: pd.DataFrame = np.array_split(

models/voronoi/train_test_voronoi_rf.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from matbench_discovery import today
1616
from matbench_discovery.data import DATA_FILES, df_wbm, glob_to_df
1717
from matbench_discovery.plots import wandb_scatter
18+
from matbench_discovery.preds import e_form_col as test_e_form_col
1819
from matbench_discovery.slurm import slurm_submit
1920
from models.voronoi import featurizer
2021

@@ -55,8 +56,6 @@
5556
df_test = pd.read_csv(test_path).set_index("material_id")
5657
print(f"{df_test.shape=}")
5758

58-
test_e_form_col = "e_form_per_atom_mp2020_corrected"
59-
6059

6160
for df, df_tar, col in (
6261
(df_train, df_mp, train_e_form_col),

models/wrenformer/test_wrenformer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from matbench_discovery import CHECKPOINT_DIR, WANDB_PATH, today
2121
from matbench_discovery.data import DATA_FILES
2222
from matbench_discovery.plots import wandb_scatter
23+
from matbench_discovery.preds import e_form_col
2324
from matbench_discovery.slurm import slurm_submit
2425

2526
__author__ = "Janosh Riebesell"
@@ -44,7 +45,6 @@
4445

4546

4647
# %%
47-
e_form_col = "e_form_per_atom_mp2020_corrected"
4848
input_col = "wyckoff_spglib"
4949
df = pd.read_csv(data_path).dropna(subset=input_col).set_index("material_id")
5050

scripts/model_figs/hist_classified_stable_vs_hull_dist_models.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,13 @@
1212

1313
from matbench_discovery import FIGS, PDF_FIGS, today
1414
from matbench_discovery.plots import hist_classified_stable_vs_hull_dist, plt
15-
from matbench_discovery.preds import df_metrics, df_preds, e_form_col, each_true_col
15+
from matbench_discovery.preds import (
16+
df_metrics,
17+
df_preds,
18+
e_form_col,
19+
each_pred_col,
20+
each_true_col,
21+
)
1622

1723
__author__ = "Janosh Riebesell"
1824
__date__ = "2022-12-01"
@@ -21,7 +27,6 @@
2127
# %%
2228
hover_cols = (df_preds.index.name, e_form_col, each_true_col, "formula")
2329
e_form_preds = "e_form_per_atom_pred"
24-
each_pred_col = "e_above_hull_pred"
2530
facet_col = "Model"
2631
# sort facet plots by model's F1 scores (optionally only show top n=6)
2732
models = list(df_metrics.T.F1.sort_values().index)[::-1]

scripts/model_figs/per_element_errors.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@
147147
# %%
148148
expected_cols = {
149149
*"ALIGNN, BOWSR, CGCNN, CGCNN+P, CHGNet, M3GNet, MEGNet, "
150-
f"{train_count_col}, Mean error all models, {test_set_std_col}, Voronoi RF, "
150+
f"{train_count_col}, {model_mean_err_col}, {test_set_std_col}, Voronoi RF, "
151151
"Wrenformer".split(", ")
152152
}
153153
assert {*df_elem_err} >= expected_cols

tests/test_plots.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,16 @@
1717
hist_classified_stable_vs_hull_dist,
1818
rolling_mae_vs_hull_dist,
1919
)
20-
from matbench_discovery.preds import load_df_wbm_with_preds
20+
from matbench_discovery.preds import (
21+
e_form_col,
22+
each_pred_col,
23+
each_true_col,
24+
load_df_wbm_with_preds,
25+
)
2126

2227
AxLine = Literal["x", "y", "xy", ""]
2328
models = ["MEGNet", "CGCNN", "Voronoi RF"]
2429
df_wbm = load_df_wbm_with_preds(models, nrows=100)
25-
each_true_col = "e_above_hull_mp2020_corrected_ppd_mp"
26-
each_pred_col = "e_above_hull_pred"
27-
e_form_col = "e_form_per_atom_mp2020_corrected"
2830

2931

3032
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)