Skip to content

Commit 2fcf6c2

Browse files
committed
change ModelCard heading color scale from PuOr to Cividis
define reusable col names in matbench_discovery/__init__.py
1 parent e84c175 commit 2fcf6c2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+183
-165
lines changed

data/mp/build_phase_diagram.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from pymatgen.ext.matproj import MPRester
1818
from tqdm import tqdm
1919

20-
from matbench_discovery import ROOT, today
20+
from matbench_discovery import ROOT, id_col, today
2121
from matbench_discovery.data import DATA_FILES
2222
from matbench_discovery.energy import get_e_form_per_atom, get_elemental_ref_entries
2323

@@ -30,7 +30,7 @@
3030
# save all ComputedStructureEntries to disk
3131
# mp-15590 appears twice so we drop_duplicates()
3232
df = pd.DataFrame(all_mp_computed_structure_entries, columns=["entry"])
33-
df.index.name = "material_id"
33+
df.index.name = id_col
3434
df.index = [e.entry_id for e in df.entry]
3535
df.reset_index().to_json(
3636
f"{module_dir}/{today}-mp-computed-structure-entries.json.gz",
@@ -40,7 +40,7 @@
4040

4141
# %%
4242
data_path = f"{module_dir}/2023-02-07-mp-computed-structure-entries.json.gz"
43-
df = pd.read_json(data_path).set_index("material_id")
43+
df = pd.read_json(data_path).set_index(id_col)
4444

4545
# drop the structure, just load ComputedEntry, makes the PPD faster to build and load
4646
mp_computed_entries = [ComputedEntry.from_dict(dct) for dct in tqdm(df.entry)]
@@ -63,9 +63,7 @@
6363

6464

6565
# %% build phase diagram with both MP entries + WBM entries
66-
df_wbm = pd.read_json(DATA_FILES.wbm_computed_structure_entries).set_index(
67-
"material_id"
68-
)
66+
df_wbm = pd.read_json(DATA_FILES.wbm_computed_structure_entries).set_index(id_col)
6967

7068
# using ComputedStructureEntry vs ComputedEntry here is important as CSEs receive
7169
# more accurate energy corrections that take into account peroxide/superoxide nature
@@ -104,7 +102,7 @@
104102
json.dump(elemental_ref_entries, file, default=lambda x: x.as_dict())
105103

106104

107-
df_mp = pd.read_csv(DATA_FILES.mp_energies, na_filter=False).set_index("material_id")
105+
df_mp = pd.read_csv(DATA_FILES.mp_energies, na_filter=False).set_index(id_col)
108106

109107

110108
# %%

data/mp/get_mp_energies.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from pymatviz.utils import annotate_metrics
99
from tqdm import tqdm
1010

11-
from matbench_discovery import STABILITY_THRESHOLD, today
11+
from matbench_discovery import STABILITY_THRESHOLD, id_col, today
1212
from matbench_discovery.data import DATA_FILES
1313

1414
"""
@@ -26,7 +26,7 @@
2626

2727
# %%
2828
fields = {
29-
"material_id",
29+
id_col,
3030
"formula_pretty",
3131
"formation_energy_per_atom",
3232
"energy_per_atom",
@@ -46,7 +46,7 @@
4646

4747

4848
# %%
49-
df = pd.DataFrame(docs).set_index("material_id")
49+
df = pd.DataFrame(docs).set_index(id_col)
5050

5151
df_spg = pd.json_normalize(df.pop("symmetry"))[["number", "symbol"]]
5252
df["spacegroup_symbol"] = df_spg.symbol.to_numpy()
@@ -56,7 +56,7 @@
5656

5757

5858
# %%
59-
df_cse = pd.read_json(DATA_FILES.mp_computed_structure_entries).set_index("material_id")
59+
df_cse = pd.read_json(DATA_FILES.mp_computed_structure_entries).set_index(id_col)
6060

6161
struct_col = "structure"
6262
df_cse[struct_col] = [
@@ -76,7 +76,7 @@
7676
assert (spg_nums.sort_index() == df_spg["number"].sort_index()).all()
7777

7878
df.to_csv(DATA_FILES.mp_energies)
79-
# df = pd.read_csv(DATA_FILES.mp_energies, na_filter=False).set_index("material_id")
79+
# df = pd.read_csv(DATA_FILES.mp_energies, na_filter=False).set_index(id_col)
8080

8181

8282
# %% reproduce fig. 1b from https://arxiv.org/abs/2001.10591 (as data consistency check)

data/wbm/compare_cse_vs_ce_mp_2020_corrections.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from pymatgen.entries.computed_entries import ComputedEntry, ComputedStructureEntry
1111
from tqdm import tqdm
1212

13-
from matbench_discovery import ROOT, today
13+
from matbench_discovery import ROOT, id_col, today
1414
from matbench_discovery.data import DATA_FILES, df_wbm
1515
from matbench_discovery.energy import get_e_form_per_atom
1616
from matbench_discovery.plots import plt
@@ -22,9 +22,7 @@
2222
"""
2323

2424

25-
df_cse = pd.read_json(DATA_FILES.wbm_computed_structure_entries).set_index(
26-
"material_id"
27-
)
25+
df_cse = pd.read_json(DATA_FILES.wbm_computed_structure_entries).set_index(id_col)
2826

2927
cses = [
3028
ComputedStructureEntry.from_dict(dct)

data/wbm/eda.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
)
1313
from pymatviz.io import save_fig
1414

15-
from matbench_discovery import PDF_FIGS, ROOT, SITE_FIGS, STABILITY_THRESHOLD
15+
from matbench_discovery import PDF_FIGS, ROOT, SITE_FIGS, STABILITY_THRESHOLD, id_col
1616
from matbench_discovery import plots as plots
1717
from matbench_discovery.data import DATA_FILES, df_wbm
1818
from matbench_discovery.energy import mp_elem_reference_entries
@@ -180,14 +180,14 @@
180180
# many models struggle on the halogens in per-element error periodic table heatmaps
181181
# https://janosh.github.io/matbench-discovery/models
182182
df_2d_tsne = pd.read_csv(f"{module_dir}/tsne/one-hot-112-composition-2d.csv.gz")
183-
df_2d_tsne = df_2d_tsne.set_index("material_id")
183+
df_2d_tsne = df_2d_tsne.set_index(id_col)
184184

185185
df_3d_tsne = pd.read_csv(f"{module_dir}/tsne/one-hot-112-composition-3d.csv.gz")
186186
model = "Wrenformer"
187187
df_3d_tsne = pd.read_csv(
188188
f"{module_dir}/tsne/one-hot-112-composition+{model}-each-err-3d-metric=eucl.csv.gz"
189189
)
190-
df_3d_tsne = df_3d_tsne.set_index("material_id")
190+
df_3d_tsne = df_3d_tsne.set_index(id_col)
191191

192192
df_wbm[list(df_2d_tsne)] = df_2d_tsne
193193
df_wbm[list(df_3d_tsne)] = df_3d_tsne
@@ -205,7 +205,7 @@
205205
x="2d t-SNE 1",
206206
y="2d t-SNE 2",
207207
color=color_col,
208-
hover_name="material_id",
208+
hover_name=id_col,
209209
hover_data=("formula", each_true_col),
210210
range_color=(0, clr_range_max),
211211
)
@@ -219,7 +219,7 @@
219219
y="3d t-SNE 2",
220220
z="3d t-SNE 3",
221221
color=color_col,
222-
custom_data=["material_id", "formula", each_true_col, color_col],
222+
custom_data=[id_col, "formula", each_true_col, color_col],
223223
range_color=(0, clr_range_max),
224224
)
225225
fig.data[0].hovertemplate = (

data/wbm/fetch_process_wbm_dataset.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from pymatviz.io import save_fig
1919
from tqdm import tqdm
2020

21-
from matbench_discovery import SITE_FIGS, today
21+
from matbench_discovery import SITE_FIGS, id_col, today
2222
from matbench_discovery.data import DATA_FILES
2323
from matbench_discovery.energy import get_e_form_per_atom
2424
from matbench_discovery.plots import pio
@@ -156,7 +156,7 @@ def increment_wbm_material_id(wbm_id: str) -> str:
156156

157157

158158
df_wbm.index = df_wbm.index.map(increment_wbm_material_id)
159-
df_wbm.index.name = "material_id"
159+
df_wbm.index.name = id_col
160160
assert df_wbm.index[0] == "wbm-1-1"
161161
assert df_wbm.index[-1] == "wbm-5-23308"
162162

@@ -296,13 +296,13 @@ def increment_wbm_material_id(wbm_id: str) -> str:
296296
"e_form": "e_form_per_atom_wbm",
297297
"e_hull": "e_above_hull_wbm",
298298
"gap": "bandgap_pbe",
299-
"id": "material_id",
299+
"id": id_col,
300300
}
301301
# WBM summary was shared twice, once on google drive, once on materials cloud
302302
# download both and check for consistency
303303
df_summary = pd.read_csv(
304304
f"{module_dir}/raw/wbm-summary.txt", sep="\t", names=col_map.values()
305-
).set_index("material_id")
305+
).set_index(id_col)
306306

307307
df_summary_bz2 = pd.read_csv(
308308
f"{mat_cloud_url}&filename=summary.txt.bz2", sep="\t"
@@ -618,7 +618,7 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
618618
suggest = "not found, run scripts/compute_struct_fingerprints.py to generate"
619619
fp_diff_col = "site_stats_fingerprint_init_final_norm_diff"
620620
try:
621-
df_fp = pd.read_json(fingerprints_path).set_index("material_id")
621+
df_fp = pd.read_json(fingerprints_path).set_index(id_col)
622622
df_summary[fp_diff_col] = df_fp[fp_diff_col]
623623
except FileNotFoundError:
624624
print(f"{fingerprints_path=} {suggest}")
@@ -633,11 +633,11 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
633633
# %% only here to load data for later inspection
634634
if False:
635635
df_summary = pd.read_csv(f"{module_dir}/2022-10-19-wbm-summary.csv.gz").set_index(
636-
"material_id"
636+
id_col
637637
)
638638
df_wbm = pd.read_json(
639639
f"{module_dir}/2022-10-19-wbm-computed-structure-entries+init-structs.json.bz2"
640-
).set_index("material_id")
640+
).set_index(id_col)
641641

642642
df_wbm["cse"] = [
643643
ComputedStructureEntry.from_dict(dct)

matbench_discovery/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
SITE_FIGS = f"{ROOT}/site/src/figs" # directory for interactive figures
88
SITE_MODELS = f"{ROOT}/site/src/routes/models" # directory to write model analysis
99
FIGSHARE = f"{ROOT}/data/figshare"
10+
SCRIPTS = f"{ROOT}/scripts"
1011
PDF_FIGS = f"{ROOT}/paper/figs" # directory for light-themed PDF figures
1112

1213
for directory in [SITE_FIGS, SITE_MODELS, FIGSHARE, PDF_FIGS]:
@@ -31,3 +32,8 @@
3132
warnings.filterwarnings(
3233
action="ignore", category=UserWarning, module="pymatgen", lineno=lineno
3334
)
35+
36+
id_col = "material_id"
37+
init_struct_col = "initial_structure"
38+
struct_col = "structure"
39+
e_form_col = "formation_energy_per_atom"

matbench_discovery/data.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from pymatgen.analysis.phase_diagram import PatchedPhaseDiagram
1717
from tqdm import tqdm
1818

19-
from matbench_discovery import FIGSHARE
19+
from matbench_discovery import FIGSHARE, id_col
2020

2121
# repo URL to raw files on GitHub
2222
RAW_REPO_URL = "https://github.com/janosh/matbench-discovery/raw"
@@ -119,8 +119,8 @@ def load(
119119
print(f"\n\nvariable dump:\n{file=},\n{reader=}\n{kwargs=}")
120120
raise
121121

122-
if "material_id" in df:
123-
df = df.set_index("material_id")
122+
if id_col in df:
123+
df = df.set_index(id_col)
124124
if hydrate:
125125
for col in df:
126126
if not isinstance(df[col].iloc[0], dict):
@@ -256,4 +256,4 @@ def _on_not_found(self, key: str, msg: str) -> None: # type: ignore[override]
256256

257257

258258
df_wbm = load("wbm_summary")
259-
df_wbm["material_id"] = df_wbm.index
259+
df_wbm[id_col] = df_wbm.index

matbench_discovery/preds.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from tqdm import tqdm
88

99
from matbench_discovery import ROOT, STABILITY_THRESHOLD
10+
from matbench_discovery import id_col as default_id_col
1011
from matbench_discovery.data import Files, df_wbm, glob_to_df
1112
from matbench_discovery.metrics import stable_metrics
1213
from matbench_discovery.plots import (
@@ -83,7 +84,7 @@ class PredFiles(Files):
8384
def load_df_wbm_with_preds(
8485
models: Sequence[str] = (*PRED_FILES,),
8586
pbar: bool = True,
86-
id_col: str = "material_id",
87+
id_col: str = default_id_col,
8788
**kwargs: Any,
8889
) -> pd.DataFrame:
8990
"""Load WBM summary dataframe with model predictions from disk.

models/alignn/test_alignn.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from sklearn.metrics import r2_score
1818
from tqdm import tqdm
1919

20-
from matbench_discovery import today
20+
from matbench_discovery import id_col, today
2121
from matbench_discovery.data import DATA_FILES, df_wbm
2222
from matbench_discovery.plots import wandb_scatter
2323
from matbench_discovery.preds import e_form_col
@@ -36,7 +36,6 @@
3636
task_type = "IS2RE"
3737
target_col = e_form_col
3838
input_col = "initial_structure"
39-
id_col = "material_id"
4039
device = "cuda" if torch.cuda.is_available() else "cpu"
4140
job_name = f"{model_name}-wbm-{task_type}"
4241
out_dir = os.getenv("SBATCH_OUTPUT", f"{module_dir}/{today}-{job_name}")

models/alignn/train_alignn.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from torch.utils.data import DataLoader
1919
from tqdm import tqdm
2020

21-
from matbench_discovery import today
21+
from matbench_discovery import e_form_col, id_col, struct_col, today
2222
from matbench_discovery.data import DATA_FILES
2323
from matbench_discovery.slurm import slurm_submit
2424

@@ -30,10 +30,8 @@
3030

3131
# %%
3232
model_name = "alignn-mp-e_form"
33-
target_col = "formation_energy_per_atom"
34-
struct_col = "structure"
33+
target_col = e_form_col
3534
input_col = "atoms"
36-
id_col = "material_id"
3735
device = "cuda" if torch.cuda.is_available() else "cpu"
3836
job_name = f"train-{model_name}"
3937

@@ -48,7 +46,6 @@
4846

4947
slurm_vars = slurm_submit(
5048
job_name=job_name,
51-
# partition="perlmuttter",
5249
account="matgen",
5350
time="4:0:0",
5451
out_dir=out_dir,

models/alignn_ff/alignn_ff_relax.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from pymatgen.io.jarvis import JarvisAtomsAdaptor
1010
from tqdm import tqdm
1111

12-
from matbench_discovery import today
12+
from matbench_discovery import id_col, init_struct_col, today
1313
from matbench_discovery.data import DATA_FILES, df_wbm
1414
from matbench_discovery.preds import e_form_col as target_col
1515

@@ -29,8 +29,7 @@
2929
# model_name = "mp_e_form_alignn" # pre-trained by NIST
3030
model_name = f"{out_dir}/best-model.pth"
3131
task_type = "IS2RE"
32-
input_col = "initial_structure"
33-
id_col = "material_id"
32+
input_col = init_struct_col
3433
job_name = f"{model_name}-wbm-{task_type}"
3534
out_path = (
3635
f"{out_dir}/{'alignn-relaxed-structs' if batch == 0 else f'{batch=}'}.json.gz"

models/alignn_ff/test_alignn_ff.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from sklearn.metrics import r2_score
1919
from tqdm import tqdm
2020

21-
from matbench_discovery import today
21+
from matbench_discovery import init_struct_col, today
2222
from matbench_discovery.data import DATA_FILES, df_wbm
2323
from matbench_discovery.plots import wandb_scatter
2424
from matbench_discovery.preds import e_form_col as target_col
@@ -33,8 +33,7 @@
3333
n_splits = 100
3434
# model_name = "mp_e_form_alignnn" # pre-trained by NIST
3535
task_type = "IS2RE"
36-
input_col = "initial_structure"
37-
id_col = "material_id"
36+
input_col = init_struct_col
3837
device = "cuda" if torch.cuda.is_available() else "cpu"
3938
model_name = f"alignn-ff-wbm-{task_type}"
4039
job_name = f"{model_name}-relaxed-wbm-{task_type}"

models/bowsr/join_bowsr_results.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import pymatviz
99
from tqdm import tqdm
1010

11-
from matbench_discovery.data import DATA_FILES
11+
from matbench_discovery.data import DATA_FILES, id_col
1212

1313
__author__ = "Janosh Riebesell"
1414
__date__ = "2022-09-22"
@@ -30,14 +30,14 @@
3030
for file_path in tqdm(file_paths):
3131
if file_path in dfs:
3232
continue
33-
dfs[file_path] = pd.read_json(file_path).set_index("material_id")
33+
dfs[file_path] = pd.read_json(file_path).set_index(id_col)
3434

3535

3636
df_bowsr = pd.concat(dfs.values()).round(4)
3737

3838

3939
# %% compare against WBM formation energy targets to make sure we got sensible results
40-
df_wbm = pd.read_csv(DATA_FILES.wbm_summary).set_index("material_id")
40+
df_wbm = pd.read_csv(DATA_FILES.wbm_summary).set_index(id_col)
4141

4242

4343
print(
@@ -75,4 +75,4 @@
7575

7676

7777
# in_path = f"{module_dir}/2023-01-23-bowsr-megnet-wbm-IS2RE.json.gz"
78-
# df_bowsr = pd.read_json(in_path).set_index("material_id")
78+
# df_bowsr = pd.read_json(in_path).set_index(id_col)

0 commit comments

Comments
 (0)