Skip to content

Commit 4936079

Browse files
committed
provisionally last set of fixes/refactors to data/wbm/fetch_process_wbm_dataset.py
1 parent 9fdab8c commit 4936079

9 files changed

+129
-74
lines changed

.github/workflows/test.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,4 @@ jobs:
2929

3030
- name: Run tests
3131
id: tests
32-
run: pytest --durations 0 --cov matbench_discovery
32+
run: pytest --durations 0 --cov .

.gitignore

+1-3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
# cache
55
__pycache__
6+
.coverage*
67

78
# datasets
89
*.p
@@ -13,9 +14,6 @@ __pycache__
1314
data/**/raw
1415
data/**/202*
1516

16-
# checkpoint files of trained models
17-
pretrained/
18-
1917
# Weights and Biases logs
2018
wandb/
2119
job-logs/

data/wbm/fetch_process_wbm_dataset.py

+105-54
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88
from glob import glob
99

1010
import pandas as pd
11+
from aviary.wren.utils import get_aflow_label_from_spglib
1112
from pymatgen.analysis.phase_diagram import PatchedPhaseDiagram
12-
from pymatgen.core import Structure
13+
from pymatgen.core import Composition, Structure
1314
from pymatgen.entries.compatibility import (
1415
MaterialsProject2020Compatibility as MP2020Compat,
1516
)
@@ -144,8 +145,10 @@ def increment_wbm_material_id(wbm_id: str) -> str:
144145
print(f"bad {wbm_id=}")
145146
return wbm_id
146147

147-
assert prefix == "step"
148-
assert step_num.isdigit() and material_num.isdigit()
148+
msg = f"bad {wbm_id=}, {prefix=} {step_num=} {material_num=}"
149+
assert prefix == "step", msg
150+
assert step_num.isdigit(), msg
151+
assert material_num.isdigit(), msg
149152

150153
return f"wbm-step-{step_num}-{int(material_num) + 1}"
151154

@@ -266,7 +269,9 @@ def increment_wbm_material_id(wbm_id: str) -> str:
266269

267270

268271
# %%
269-
df_wbm["formula_from_cse"] = [x.formula for x in df_wbm.pop("composition_from_cse")]
272+
df_wbm["formula_from_cse"] = [
273+
x.alphabetical_formula for x in df_wbm.pop("composition_from_cse")
274+
]
270275
df_wbm[["initial_structure", "computed_structure_entry", "formula_from_cse"]].to_json(
271276
f"{module_dir}/{today}-wbm-cses+init-structs.json.bz2"
272277
)
@@ -278,8 +283,8 @@ def increment_wbm_material_id(wbm_id: str) -> str:
278283
"nsites": "n_sites",
279284
"vol": "volume",
280285
"e": "uncorrected_energy",
281-
"e_form": "e_form_per_atom",
282-
"e_hull": "e_hull",
286+
"e_form": "e_form_per_atom_wbm",
287+
"e_hull": "e_hull_wbm",
283288
"gap": "bandgap_pbe",
284289
"id": "material_id",
285290
}
@@ -313,11 +318,32 @@ def increment_wbm_material_id(wbm_id: str) -> str:
313318
df_summary.index = df_summary.index.map(increment_wbm_material_id)
314319
assert sum(df_summary.index != df_wbm.index) == 0
315320

321+
# sort formulas alphabetically
322+
df_summary["alph_formula"] = [
323+
Composition(x).alphabetical_formula for x in df_summary.formula
324+
]
325+
assert sum(df_summary.alph_formula != df_summary.formula) == 219_215
326+
assert df_summary.alph_formula[3] == "Ag2 Au1 Hg1"
327+
assert df_summary.formula[3] == "Ag2 Hg1 Au1"
328+
329+
df_summary["formula"] = df_summary.pop("alph_formula")
330+
331+
332+
# %%
333+
# check summary and CSE formulas agree
334+
assert all(df_summary["formula"] == df_wbm.formula_from_cse)
335+
336+
316337
# fix bad energy which is 0 in df_summary but a more realistic -63.68 in CSE
317338
df_summary.at["wbm-step-2-18689", "uncorrected_energy"] = df_wbm.loc[
318339
"wbm-step-2-18689"
319340
].computed_structure_entry["energy"]
320341

342+
# NOTE careful with ComputedEntries as object vs as dicts, the meaning of keys changes:
343+
# cse.energy == cse.uncorrected_energy + cse.correction
344+
# whereas
345+
# cse.as_dict()["energy"] == cse.uncorrected_energy
346+
321347

322348
# %% scatter plot summary energies vs CSE energies
323349
df_summary["uncorrected_energy_from_cse"] = [
@@ -355,17 +381,24 @@ def increment_wbm_material_id(wbm_id: str) -> str:
355381
assert n_corrected == 100931, f"{n_corrected=}"
356382

357383
corr_label = "mp2020" if isinstance(mp_compat, MP2020Compat) else "legacy"
358-
df_summary[f"e_correction_{corr_label}"] = [
359-
cse.energy - cse.uncorrected_energy for cse in df_wbm.cse
384+
df_summary[f"e_correction_per_atom_{corr_label}"] = [
385+
cse.correction_per_atom for cse in df_wbm.cse
360386
]
361387

362-
assert df_summary.e_correction_mp2020.mean().round(4) == -0.9979
363-
assert df_summary.e_correction_legacy.mean().round(4) == -0.0643
364-
assert (df_summary.filter(like="corrections").abs() > 1e-4).sum().to_dict() == {
365-
"e_correction_mp2020": 100931,
366-
"e_correction_legacy": 39595,
367-
}
388+
assert df_summary.e_correction_per_atom_mp2020.mean().round(4) == -0.1067
389+
assert df_summary.e_correction_per_atom_legacy.mean().round(4) == -0.0643
390+
assert (df_summary.filter(like="correction").abs() > 1e-4).sum().to_dict() == {
391+
"e_correction_per_atom_mp2020": 100931,
392+
"e_correction_per_atom_legacy": 39595,
393+
}, "unexpected number of materials received non-zero corrections"
368394

395+
ax = density_scatter(
396+
df_summary.e_correction_per_atom_legacy,
397+
df_summary.e_correction_per_atom_mp2020,
398+
xlabel="legacy corrections (eV / atom)",
399+
ylabel="MP2020 corrections (eV / atom)",
400+
)
401+
# ax.figure.savefig(f"{ROOT}/tmp/{today}-legacy-vs-mp2020-corrections.png")
369402

370403
# mp_compat.process_entry(cse) for CSE with id wbm-step-1-24459 causes Jupyter kernel to
371404
# crash reason unknown, still occurs even after updating deps like pymatgen, numpy,
@@ -382,65 +415,64 @@ def increment_wbm_material_id(wbm_id: str) -> str:
382415

383416

384417
# %%
385-
with gzip.open(f"{module_dir}/2022-10-13-rhys/ppd-mp.pkl.gz", "rb") as zip_file:
386-
ppd_rhys: PatchedPhaseDiagram = pickle.load(zip_file)
387-
388-
389418
with gzip.open(f"{ROOT}/data/2022-09-18-ppd-mp.pkl.gz", "rb") as zip_file:
390-
ppd_mp = pickle.load(zip_file)
419+
ppd_mp: PatchedPhaseDiagram = pickle.load(zip_file)
391420

392421

393-
# %%
422+
# %% calculate e_above_hull for each material
394423
# this loop needs the warnings filter above to not crash Jupyter kernel with logs
395424
# takes ~20 min at 200 it/s for 250k entries in WBM
425+
e_above_hull_key = "e_above_hull_uncorrected_ppd_mp"
426+
assert e_above_hull_key not in df_summary
427+
396428
for entry in tqdm(df_wbm.cse):
397429
assert entry.entry_id.startswith("wbm-step-")
398-
corr_label = "mp2020_" if isinstance(mp_compat, MP2020Compat) else "legacy_"
399-
# corr_label = "un"
400-
at_idx = entry.entry_id, f"e_above_hull_{corr_label}corrected_ppd_mp"
401430

402-
if at_idx not in df_summary or pd.isna(df_summary.at[at_idx]):
403-
# use entry.(uncorrected_)energy_per_atom
404-
e_above_hull = (
405-
entry.corrected_energy_per_atom
406-
- ppd_mp.get_hull_energy_per_atom(entry.composition)
407-
)
408-
df_summary.at[at_idx] = e_above_hull
431+
e_per_atom = entry.uncorrected_energy_per_atom
432+
e_hull_per_atom = ppd_mp.get_hull_energy_per_atom(entry.composition)
433+
e_above_hull = e_per_atom - e_hull_per_atom
434+
435+
df_summary.at[entry.entry_id, e_above_hull_key] = e_above_hull
409436

410437

411-
# %% compute formation energies
438+
# add old + new MP energy corrections to above hull energies
439+
for corrections in ("mp2020", "legacy"):
440+
df_summary[e_above_hull_key.replace("un", f"{corrections}_")] = (
441+
df_summary[e_above_hull_key]
442+
+ df_summary[f"e_correction_per_atom_{corrections}"]
443+
)
444+
445+
446+
# %% calculate formation energies from CSEs wrt MP elemental reference energies
412447
# first make sure source and target dfs have matching indices
413448
assert sum(df_wbm.index != df_summary.index) == 0
414449

415-
e_form_key = "e_form_per_atom_uncorrected_ppd_mp_rhys"
416-
for mat_id, cse in tqdm(df_wbm.cse.items(), total=len(df_wbm)):
417-
assert mat_id == cse.entry_id, f"{mat_id=} {cse.entry_id=}"
418-
assert mat_id in df_summary.index, f"{mat_id=} not in df_summary"
419-
df_summary.at[cse.entry_id, e_form_key] = ppd_rhys.get_form_energy_per_atom(cse)
450+
e_form_key = "e_form_per_atom_uncorrected_mp_refs"
451+
assert e_form_key not in df_summary
420452

421-
assert len(df_summary) == sum(step_lens)
453+
for row in tqdm(df_wbm.itertuples(), total=len(df_wbm)):
454+
mat_id, cse, formula = row.Index, row.cse, row.formula_from_cse
455+
assert mat_id == cse.entry_id, f"{mat_id=} != {cse.entry_id=}"
456+
assert mat_id in df_summary.index, f"{mat_id=} not in df_summary"
422457

423-
df_summary["e_form_per_atom_legacy_corrected_ppd_mp_rhys"] = (
424-
df_summary[e_form_key] + df_summary.e_correction_legacy
425-
)
458+
entry_like = dict(composition=formula, energy=cse.uncorrected_energy)
459+
e_form = get_e_form_per_atom(entry_like)
460+
e_form_ppd = ppd_mp.get_form_energy_per_atom(cse)
426461

462+
# make sure the PPD and functional method of calculating formation energy agree
463+
assert abs(e_form - e_form_ppd) < 1e-7, f"{e_form=} != {e_form_ppd=}"
464+
df_summary.at[cse.entry_id, e_form_key] = e_form
427465

428-
# %% calculate formation energies from CSEs wrt MP elemental reference energies
429-
df_summary["e_form_per_atom_uncorrected"] = [
430-
get_e_form_per_atom(dict(composition=row.formula, energy=row.uncorrected_energy))
431-
for row in tqdm(df_summary.itertuples(), total=len(df_summary))
432-
]
466+
assert len(df_summary) == sum(
467+
step_lens
468+
), f"rows were added: {len(df_summary)=} {sum(step_lens)=}"
433469

434470

435-
# %% MP2020 corrections are much larger than legacy corrections
436-
ax = density_scatter(
437-
df_summary.e_correction_legacy / df_summary.n_sites,
438-
df_summary.e_correction_mp2020 / df_summary.n_sites,
439-
xlabel="legacy corrections (eV / atom)",
440-
ylabel="MP2020 corrections (eV / atom)",
441-
)
442-
ax.axis("equal")
443-
# ax.figure.savefig(f"{ROOT}/tmp/{today}-legacy-vs-mp2020-corrections.png")
471+
# add old + new MP energy corrections to formation energies
472+
for corrections in ("mp2020", "legacy"):
473+
df_summary[e_form_key.replace("un", f"{corrections}_")] = (
474+
df_summary[e_form_key] + df_summary[f"e_correction_per_atom_{corrections}"]
475+
)
444476

445477

446478
# %%
@@ -457,3 +489,22 @@ def increment_wbm_material_id(wbm_id: str) -> str:
457489
df_wbm["cse"] = [
458490
ComputedStructureEntry.from_dict(x) for x in tqdm(df_wbm.computed_structure_entry)
459491
]
492+
493+
df_wbm["init_struct"] = df_wbm["wyckoff"] = float("nan")
494+
for idx, dct in tqdm(df_wbm.initial_structure.items(), total=len(df_wbm)):
495+
if not df_wbm[idx, "init_struct"]:
496+
df_wbm.at[idx, "init_struct"] = struct = Structure.from_dict(dct)
497+
if not df_wbm[idx, "wyckoff"]:
498+
df_wbm.at[idx, "wyckoff"] = get_aflow_label_from_spglib(struct)
499+
500+
501+
# %% make sure material IDs within each step are consecutive
502+
for step in range(1, 6):
503+
df = df_summary[df_summary.index.str.startswith(f"wbm-step-{step}-")]
504+
step_len = step_lens[step - 1]
505+
assert len(df) == step_len, f"{step=} has {len(df)=}, expected {step_len=}"
506+
507+
step_counts = list(df.index.str.split("-").str[-1].astype(int))
508+
assert step_counts == list(
509+
range(1, step_len + 1)
510+
), f"{step=} counts not consecutive"

matbench_discovery/plot_scripts/hist_classified_stable_as_func_of_hull_dist_batches.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@
7070
model_name = "m3gnet"
7171
df = dfs[model_name]
7272

73-
df["e_above_hull_mp"] = df_wbm.e_above_hull_mp2020_corrected
73+
df["e_above_hull_mp"] = df_wbm.e_above_hull_mp2020_corrected_ppd_mp
7474
df["e_form_per_atom"] = df_wbm.e_form_per_atom_mp2020_corrected
7575

7676

matbench_discovery/plot_scripts/precision_recall_vs_calc_count.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@
7979
except AttributeError as exc:
8080
raise KeyError(f"{model_name = }") from exc
8181

82-
df["e_above_hull_mp"] = df_wbm.e_above_hull_mp2020_corrected
82+
df["e_above_hull_mp"] = df_wbm.e_above_hull_mp2020_corrected_ppd_mp
8383
df["e_form_per_atom"] = df_wbm.e_form_per_atom_mp2020_corrected
8484
df["e_above_hull_pred"] = model_preds - df.e_form_per_atom
8585
if n_nans := df.isna().values.sum() > 0:
@@ -107,7 +107,7 @@
107107
# optimal recall line finds all stable materials without any false positives
108108
# can be included to confirm all models start out of with near optimal recall
109109
# and to see how much each model overshoots total n_stable
110-
n_below_hull = sum(df_wbm.e_above_hull_mp2020_corrected < 0)
110+
n_below_hull = sum(df_wbm.e_above_hull_mp2020_corrected_ppd_mp < 0)
111111
ax.plot(
112112
[0, n_below_hull],
113113
[0, 100],

matbench_discovery/plot_scripts/rolling_mae_vs_hull_dist.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
# df = df.query("~contains_rare_earths")
3737

3838

39-
df["e_above_hull_mp"] = df_wbm.e_above_hull_mp2020_corrected
39+
df["e_above_hull_mp"] = df_wbm.e_above_hull_mp2020_corrected_ppd_mp
4040

4141
assert all(n_nans := df.isna().sum() == 0), f"Found {n_nans} NaNs"
4242

matbench_discovery/plot_scripts/rolling_mae_vs_hull_dist_wbm_batches.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
).set_index("material_id")
2727

2828

29-
df_wrenformer["e_above_hull_mp"] = df_wbm.e_above_hull_mp2020_corrected
29+
df_wrenformer["e_above_hull_mp"] = df_wbm.e_above_hull_mp2020_corrected_ppd_mp
3030
assert df_wrenformer.e_above_hull_mp.isna().sum() == 0
3131

3232
target_col = "e_form_per_atom"

tests/test_plots.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
f"{DATA_DIR}/{model_name.lower()}-mp-initial-structures.csv", nrows=100
2525
).set_index("material_id")
2626

27-
df["e_above_hull_mp"] = df_wbm.e_above_hull_mp2020_corrected
27+
df["e_above_hull_mp"] = df_wbm.e_above_hull_mp2020_corrected_ppd_mp
2828

2929
model_preds = df.filter(like=r"_pred").mean(axis=1)
3030

tests/test_slurm.py

+16-10
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,23 @@
1010

1111
@pytest.mark.parametrize("py_file_path", [None, "path/to/file.py"])
1212
def test_slurm_submit(capsys: CaptureFixture[str], py_file_path: str | None) -> None:
13-
kwargs = dict(
14-
job_name="test_job",
15-
log_dir="test_log_dir",
16-
partition="fake-partition",
17-
account="fake-account",
13+
job_name = "test_job"
14+
log_dir = "tmp"
15+
time = "0:0:1"
16+
partition = "fake-partition"
17+
account = "fake-account"
18+
19+
func_call = lambda: slurm_submit_python(
20+
job_name=job_name,
21+
log_dir=log_dir,
22+
time=time,
23+
partition=partition,
24+
account=account,
1825
py_file_path=py_file_path,
19-
time="0:0:1",
2026
slurm_flags=("--test-flag",),
2127
)
22-
slurm_submit_python(**kwargs) # type: ignore
2328

29+
func_call()
2430
stdout, stderr = capsys.readouterr()
2531
# check slurm_submit_python() did nothing in normal mode
2632
assert stderr == stderr == ""
@@ -29,13 +35,13 @@ def test_slurm_submit(capsys: CaptureFixture[str], py_file_path: str | None) ->
2935
with pytest.raises(SystemExit), patch("sys.argv", ["slurm-submit"]), patch(
3036
"matbench_discovery.slurm.subprocess.run"
3137
) as mock_subprocess_run:
32-
slurm_submit_python(**kwargs) # type: ignore
38+
func_call()
3339

3440
assert mock_subprocess_run.call_count == 1
3541

3642
sbatch_cmd = (
37-
"sbatch --partition=fake-partition --account=fake-account --time=0:0:1 "
38-
"--job-name test_job --output test_log_dir/slurm-%A-%a.out --test-flag "
43+
f"sbatch --partition={partition} --account={account} --time={time} "
44+
f"--job-name {job_name} --output {log_dir}/slurm-%A-%a.out --test-flag "
3945
f"--wrap python {py_file_path or __file__}"
4046
)
4147
stdout, stderr = capsys.readouterr()

0 commit comments

Comments
 (0)