Skip to content

Commit c3c084b

Browse files
committed
drop trajectories when joining slurm array m3gnet results
also remove n in _n1, _n2, etc. from Rhys wren ensemble CSV column names
1 parent 7c30b09 commit c3c084b

File tree

2 files changed

+49
-25
lines changed

2 files changed

+49
-25
lines changed

mb_discovery/m3gnet/join_and_plot_m3gnet_relax_results.py

+45-21
Original file line numberDiff line numberDiff line change
@@ -24,35 +24,44 @@
2424

2525
# %%
2626
glob_pattern = "2022-08-16-m3gnet-wbm-relax-results/*.json.gz"
27-
file_paths = glob(f"{ROOT}/data/{glob_pattern}")
27+
file_paths = sorted(glob(f"{ROOT}/data/{glob_pattern}"))
2828
print(f"Found {len(file_paths):,} files for {glob_pattern = }")
2929

30-
3130
dfs: dict[str, pd.DataFrame] = {}
31+
32+
33+
# %%
3234
# 2022-08-16 tried multiprocessing.Pool() to load files in parallel but was somehow
3335
# slower than serial loading
3436
for file_path in tqdm(file_paths):
3537
if file_path in dfs:
3638
continue
3739
try:
38-
dfs[file_path] = pd.read_json(file_path)
40+
# keep whole dataframe in memory
41+
df = pd.read_json(file_path)
42+
df.index = df.index.str.replace("_", "-")
43+
df.index.name = "material_id"
44+
col_map = dict(
45+
final_structure="m3gnet_structure", trajectory="m3gnet_trajectory"
46+
)
47+
df = df.rename(columns=col_map)
48+
df.reset_index().to_json(file_path)
49+
df["m3gnet_energy"] = df.m3gnet_trajectory.map(lambda x: x["energies"][-1][0])
50+
df["m3gnet_structure"] = df.m3gnet_structure.map(Structure.from_dict)
51+
df["formula"] = df.m3gnet_structure.map(lambda x: x.formula)
52+
df["volume"] = df.m3gnet_structure.map(lambda x: x.volume)
53+
df["n_sites"] = df.m3gnet_structure.map(len)
54+
dfs[file_path] = df.drop(columns=["m3gnet_trajectory"])
3955
except (ValueError, FileNotFoundError):
4056
# pandas v1.5+ correctly raises FileNotFoundError, below raises ValueError
4157
continue
4258

4359

4460
# %%
4561
df_m3gnet = pd.concat(dfs.values())
46-
df_m3gnet.index.name = "material_id"
4762
if any(df_m3gnet.index.str.contains("_")):
4863
df_m3gnet.index = df_m3gnet.index.str.replace("_", "-")
4964

50-
df_m3gnet = df_m3gnet.rename(
51-
columns=dict(final_structure="m3gnet_structure", trajectory="m3gnet_trajectory")
52-
)
53-
54-
df_m3gnet["m3gnet_energy"] = df_m3gnet.trajectory.map(lambda x: x["energies"][-1][0])
55-
5665

5766
# %%
5867
# 2022-01-25-ppd-mp+wbm.pkl.gz (235 MB)
@@ -64,12 +73,13 @@
6473
)
6574

6675

67-
df_m3gnet["m3gnet_structure"] = df_m3gnet.m3gnet_structure.map(Structure.from_dict)
68-
df_m3gnet["pd_entry"] = [
76+
pd_entries_m3gnet = [
6977
PDEntry(row.m3gnet_structure.composition, row.m3gnet_energy)
7078
for row in df_m3gnet.itertuples()
7179
]
72-
df_m3gnet["e_form_m3gnet"] = df_m3gnet.pd_entry.map(ppd_mp_wbm.get_form_energy_per_atom)
80+
df_m3gnet["e_form_m3gnet"] = [
81+
ppd_mp_wbm.get_form_energy_per_atom(x) for x in pd_entries_m3gnet
82+
]
7383

7484

7585
# %%
@@ -80,11 +90,27 @@
8090
df_m3gnet["e_above_mp_hull"] = df_hull.e_above_mp_hull
8191

8292

83-
df_summary = pd.read_csv(f"{ROOT}/data/wbm-steps-summary.csv", comment="#").set_index(
84-
"material_id"
85-
)
93+
df_wbm = pd.read_csv( # download wbm-steps-summary.csv (23.31 MB)
94+
"https://figshare.com/ndownloader/files/36714216?private_link=ff0ad14505f9624f0c05"
95+
).set_index("material_id")
96+
97+
df_m3gnet["e_form_wbm"] = df_wbm.e_form
98+
df_m3gnet["wbm_energy"] = df_wbm.energy
8699

87-
df_m3gnet["e_form_wbm"] = df_summary.e_form
100+
pd_entries_wbm = [
101+
PDEntry(row.m3gnet_structure.composition, row.wbm_energy)
102+
for row in df_m3gnet.itertuples()
103+
]
104+
df_m3gnet["e_form_ppd_2022_01_25"] = [
105+
ppd_mp_wbm.get_form_energy_per_atom(x) for x in pd_entries_wbm
106+
]
107+
108+
109+
df_m3gnet.filter(like="e_form").plot.scatter(x="e_form_m3gnet", y="e_form_wbm")
110+
df_m3gnet.filter(like="e_form").plot.scatter(
111+
x="e_form_m3gnet", y="e_form_ppd_2022_01_25"
112+
)
113+
df_m3gnet.filter(like="e_form").plot.scatter(x="e_form_wbm", y="e_form_ppd_2022_01_25")
88114

89115

90116
# %%
@@ -94,14 +120,12 @@
94120

95121
# %%
96122
out_path = f"{ROOT}/data/{today}-m3gnet-wbm-relax-results.json.gz"
97-
df_m3gnet.drop(columns=["pd_entry"]).reset_index().to_json(
98-
out_path, default_handler=as_dict_handler
99-
)
123+
df_m3gnet.reset_index().to_json(out_path, default_handler=as_dict_handler)
100124

101125

102126
# %%
103127
ax_hull_dist_hist = hist_classify_stable_as_func_of_hull_dist(
104-
formation_energy_targets=df_m3gnet.e_form_wbm,
128+
formation_energy_targets=df_m3gnet.e_form_ppd,
105129
formation_energy_preds=df_m3gnet.e_form_m3gnet,
106130
e_above_hull_vals=df_m3gnet.e_above_mp_hull,
107131
)

mb_discovery/plot_scripts/precision_recall_as_func_of_calc_count.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,12 @@
5555

5656
e_above_mp_hull = df.e_above_mp_hull
5757

58-
# mean = df.filter(like="pred").mean(axis=1) - e_hull
59-
mean = df.filter(like="pred").mean(axis=1) - df[target_col] + e_above_mp_hull
58+
# mean = df.filter(regex=r"_pred_\d").mean(axis=1) - e_hull
59+
mean = df.filter(regex=r"_pred_\d").mean(axis=1) - df[target_col] + e_above_mp_hull
6060

61-
# epistemic_var = df.filter(like="pred").var(axis=1, ddof=0)
61+
# epistemic_var = df.filter(regex=r"_pred_\d").var(axis=1, ddof=0)
6262

63-
# aleatoric_var = (df.filter(like="ale") ** 2).mean(axis=1)
63+
# aleatoric_var = (df.filter(like="_ale_") ** 2).mean(axis=1)
6464

6565
# full_std = (epistemic_var + aleatoric_var) ** 0.5
6666

0 commit comments

Comments
 (0)