Skip to content

Commit 6450ebb

Browse files
committed
add aflow wyckoff labels to 2022-10-19-wbm-summary.csv
1 parent c1e55e1 commit 6450ebb

File tree

6 files changed

+44
-29
lines changed

6 files changed

+44
-29
lines changed

data/mp/get_mp_energies.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646

4747
df["spacegroup_number"] = df.pop("symmetry").map(lambda x: x.number)
4848

49-
df["wyckoff"] = [get_aflow_label_from_spglib(x) for x in tqdm(df.structure)]
49+
df["wyckoff_spglib"] = [get_aflow_label_from_spglib(x) for x in tqdm(df.structure)]
5050

5151
df.to_json(f"{module_dir}/{today}-mp-energies.json.gz", default_handler=as_dict_handler)
5252

data/wbm/fetch_process_wbm_dataset.py

+33-17
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,10 @@
7979
# %%
8080
json_paths = sorted(glob(f"{module_dir}/raw/wbm-structures-step-*.json.bz2"))
8181
step_lens = (61848, 52800, 79205, 40328, 23308)
82-
# step 3 has 79,211 structures but only 79,205 ComputedStructureEntries
82+
# step 3 has 79,211 initial structures but only 79,205 ComputedStructureEntries
8383
# i.e. 6 extra structures which have missing energy, volume, etc. in the summary file
8484
bad_struct_ids = (70802, 70803, 70825, 70826, 70828, 70829)
85+
# step 5 has 2 missing initial structures: 23166, 23294
8586

8687

8788
assert len(json_paths) == len(step_lens), "Mismatch in WBM steps and JSON files"
@@ -229,6 +230,14 @@ def increment_wbm_material_id(wbm_id: str) -> str:
229230
).value_counts().to_dict() == {"GGA": 248481, "GGA+U": 9008}
230231

231232

233+
# drop two materials with missing initial structures
234+
assert list(df_wbm.query("initial_structure.isna()").index) == [
235+
"wbm-step-5-23166",
236+
"wbm-step-5-23294",
237+
]
238+
df_wbm = df_wbm.dropna(subset=["initial_structure"])
239+
240+
232241
# %% get composition from CSEs
233242
df_wbm["composition_from_cse"] = [
234243
ComputedStructureEntry.from_dict(cse).composition
@@ -273,12 +282,13 @@ def increment_wbm_material_id(wbm_id: str) -> str:
273282
x.alphabetical_formula for x in df_wbm.pop("composition_from_cse")
274283
]
275284

276-
for key, col_name in (
277-
("cses", "computed_structure_entry"),
278-
("init-structs", "initial_structure"),
285+
for fname, cols in (
286+
("cses", ["computed_structure_entry"]),
287+
("init-structs", ["initial_structure"]),
288+
("cses+init-structs", ["initial_structure", "computed_structure_entry"]),
279289
):
280-
cols = ["initial_structure", "formula_from_cse", col_name]
281-
df_wbm[cols].reset_index().to_json(f"{module_dir}/{today}-wbm-{key}.json.bz2")
290+
cols = ["formula_from_cse", *cols]
291+
df_wbm[cols].reset_index().to_json(f"{module_dir}/{today}-wbm-{fname}.json.bz2")
282292

283293

284294
# %%
@@ -486,26 +496,32 @@ def increment_wbm_material_id(wbm_id: str) -> str:
486496
f"{module_dir}/2022-10-19-wbm-cses+init-structs.json.bz2"
487497
).set_index("material_id")
488498

489-
df_init_struct = pd.read_json(
490-
f"{module_dir}/2022-10-19-wbm-init-structs.json.bz2"
491-
).set_index("material_id")
492-
493499
df_wbm["cse"] = [
494500
ComputedStructureEntry.from_dict(x) for x in tqdm(df_wbm.computed_structure_entry)
495501
]
496502

497503

498504
# %%
499-
df_wbm["init_struct"] = [
500-
Structure.from_dict(x) if x else None for x in tqdm(df_wbm.initial_structure)
501-
]
505+
df_init_struct = pd.read_json(
506+
f"{module_dir}/2022-10-19-wbm-init-structs.json.bz2"
507+
).set_index("material_id")
502508

503509
wyckoff_col = "wyckoff_spglib"
504-
for idx, struct in tqdm(df_wbm.init_struct.items(), total=len(df_wbm)):
505-
if struct is None:
510+
if wyckoff_col not in df_init_struct:
511+
df_init_struct[wyckoff_col] = None
512+
513+
for idx, struct in tqdm(
514+
df_init_struct.initial_structure.items(), total=len(df_init_struct)
515+
):
516+
if not pd.isna(df_summary.at[idx, wyckoff_col]):
506517
continue
507-
if not df_wbm.at[idx, wyckoff_col]:
508-
df_wbm.at[idx, wyckoff_col] = get_aflow_label_from_spglib(struct)
518+
try:
519+
struct = Structure.from_dict(struct)
520+
df_summary.at[idx, wyckoff_col] = get_aflow_label_from_spglib(struct)
521+
except Exception as exc:
522+
print(f"{idx=} {exc=}")
523+
524+
assert df_summary[wyckoff_col].isna().sum() == 0
509525

510526

511527
# %% make sure material IDs within each step are consecutive

matbench_discovery/plot_scripts/hist_classified_stable_as_func_of_hull_dist.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -72,5 +72,4 @@
7272
ax.legend(loc="center left", frameon=False)
7373

7474
fig_name = f"wren-wbm-hull-dist-hist-{which_energy=}-{stability_crit=}"
75-
img_path = f"{ROOT}/figures/{today}-{fig_name}.pdf"
76-
# fig.savefig(img_path)
75+
# fig.savefig(f"{ROOT}/figures/{today}-{fig_name}.pdf")

matbench_discovery/plot_scripts/precision_recall.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515

1616

1717
# %%
18-
rare = "all"
19-
2018
dfs: dict[str, pd.DataFrame] = {}
2119
for model_name in ("wren", "cgcnn", "voronoi"):
2220
csv_path = (
@@ -118,7 +116,6 @@
118116
ax.set(xlim=(0, None))
119117

120118

121-
img_name = f"{today}-precision-recall-vs-calc-count-{rare=}"
122119
# x-ticks every 10k materials
123120
# ax.set(xticks=range(0, int(ax.get_xlim()[1]), 10_000))
124121

@@ -128,4 +125,4 @@
128125

129126

130127
# %%
131-
fig.savefig(f"{ROOT}/figures/{img_name}.pdf")
128+
# fig.savefig(f"{ROOT}/figures/{today}-precision-recall-curves.pdf")

models/wrenformer/mp/use_wrenformer_ensemble.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,10 @@
3939

4040
# %%
4141
data_path = f"{ROOT}/data/wbm/2022-10-19-wbm-summary.csv"
42-
df = pd.read_csv(data_path).set_index("material_id")
42+
target_col = "e_form_per_atom_mp2020_corrected"
43+
input_col = "wyckoff_spglib"
44+
df = pd.read_csv(data_path).dropna(subset=input_col).set_index("material_id")
4345

44-
target_col = "e_form_per_atom"
45-
input_col = "wyckoff"
4646
assert target_col in df, f"{target_col=} not in {list(df)}"
4747
assert input_col in df, f"{input_col=} not in {list(df)}"
4848

models/wrenformer/slurm_train_wrenformer_ensemble.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
from datetime import datetime
44

55
import pandas as pd
6-
from aviary import ROOT
76
from aviary.train import df_train_test_split, train_wrenformer
87

8+
from matbench_discovery import ROOT
99
from matbench_discovery.slurm import slurm_submit_python
1010

1111
"""
@@ -45,13 +45,15 @@
4545
batch_size = 128
4646
slurm_array_task_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", 0))
4747
timestamp = f"{datetime.now():%Y-%m-%d@%H-%M-%S}"
48+
input_col = "wyckoff_spglib"
4849

4950
print(f"Job started running {timestamp}")
5051
print(f"{run_name=}")
5152
print(f"{data_path=}")
5253

5354
df = pd.read_json(data_path).set_index("material_id", drop=False)
54-
assert target_col in df
55+
assert target_col in df, f"{target_col=} not in {list(df)}"
56+
assert input_col in df, f"{input_col=} not in {list(df)}"
5557
train_df, test_df = df_train_test_split(df, test_size=0.3)
5658

5759
run_params = dict(
@@ -70,6 +72,7 @@
7072
# folds=(n_folds, slurm_array_task_id),
7173
epochs=epochs,
7274
checkpoint="wandb", # None | 'local' | 'wandb',
75+
input_col=input_col,
7376
learning_rate=learning_rate,
7477
batch_size=batch_size,
7578
wandb_path="janosh/matbench-discovery",

0 commit comments

Comments
 (0)