Skip to content

Commit 90e1b89

Browse files
committed
compare IS2RE and RS2RE energies in eda_wbm_pre_vs_post_m3gnet_relaxation.py
add figures/2022-08-22-m3gnet-energy-per-atom-scatter-is2re-vs-rs2re.png
1 parent 7006822 commit 90e1b89

File tree

2 files changed

+95
-41
lines changed

2 files changed

+95
-41
lines changed

mb_discovery/m3gnet/eda_wbm_pre_vs_post_m3gnet_relaxation.py

+88-35
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from pymatgen.core import Structure
88
from pymatgen.util.coord import pbc_diff
99
from pymatviz.utils import add_identity_line
10+
from sklearn.metrics import r2_score
1011

1112
from mb_discovery import ROOT
1213

@@ -27,37 +28,39 @@
2728

2829

2930
# %%
30-
df_m3gnet = pd.read_json(
31-
f"{ROOT}/data/2022-08-16-m3gnet-wbm-relax-results.json.gz"
31+
df_m3gnet_is2re = pd.read_json(
32+
f"{ROOT}/data/2022-08-16-m3gnet-wbm-relax-results-IS2RE.json.gz"
33+
).set_index("material_id")
34+
df_m3gnet_rs2re = pd.read_json(
35+
f"{ROOT}/data/2022-08-19-m3gnet-wbm-relax-results-RS3RE.json.gz"
3236
).set_index("material_id")
33-
34-
print("Number of WBM crystals for which we have M3GNet results:")
35-
print(f"{len(df_m3gnet):,} / {len(df_wbm):,} = {len(df_m3gnet)/len(df_wbm):.1%}")
3637

3738

3839
# %% spread M3GNet post-pseudo-relaxation lattice params into separate columns
39-
df_m3gnet["final_energy"] = df_m3gnet.trajectory.map(lambda x: x["energies"][-1][0])
40+
df_m3gnet_is2re["final_energy"] = df_m3gnet_is2re.trajectory.map(
41+
lambda x: x["energies"][-1][0]
42+
)
4043

4144
df_m3gnet_lattice = pd.json_normalize(
42-
df_m3gnet.initial_structure.map(lambda x: x["lattice"])
45+
df_m3gnet_is2re.initial_structure.map(lambda x: x["lattice"])
4346
).add_prefix("m3gnet_")
44-
df_m3gnet[df_m3gnet_lattice.columns] = df_m3gnet_lattice.to_numpy()
45-
df_m3gnet
47+
df_m3gnet_is2re[df_m3gnet_lattice.columns] = df_m3gnet_lattice.to_numpy()
48+
df_m3gnet_is2re
4649

4750

4851
# %% spread WBM initial and final lattice params into separate columns
49-
df_m3gnet["final_wbm_structure"] = df_wbm.cse.map(lambda x: x["structure"])
52+
df_m3gnet_is2re["final_wbm_structure"] = df_wbm.cse.map(lambda x: x["structure"])
5053
df_wbm_final_lattice = pd.json_normalize(
51-
df_m3gnet.final_wbm_structure.map(lambda x: x["lattice"])
54+
df_m3gnet_is2re.final_wbm_structure.map(lambda x: x["lattice"])
5255
).add_prefix("final_wbm_")
53-
df_m3gnet[df_wbm_final_lattice.columns] = df_wbm_final_lattice.to_numpy()
56+
df_m3gnet_is2re[df_wbm_final_lattice.columns] = df_wbm_final_lattice.to_numpy()
5457

5558

56-
df_m3gnet["initial_wbm_structure"] = df_wbm.initial_structure
59+
df_m3gnet_is2re["initial_wbm_structure"] = df_wbm.initial_structure
5760
df_wbm_initial_lattice = pd.json_normalize(
58-
df_m3gnet.initial_structure.map(lambda x: x["lattice"])
61+
df_m3gnet_is2re.initial_structure.map(lambda x: x["lattice"])
5962
).add_prefix("initial_wbm_")
60-
df_m3gnet[df_wbm_initial_lattice.columns] = df_wbm_initial_lattice.to_numpy()
63+
df_m3gnet_is2re[df_wbm_initial_lattice.columns] = df_wbm_initial_lattice.to_numpy()
6164

6265

6366
# %%
@@ -78,7 +81,7 @@
7881

7982
# %%
8083
px.histogram(
81-
df_m3gnet.filter(like="volume"),
84+
df_m3gnet_is2re.filter(like="volume"),
8285
nbins=500,
8386
barmode="overlay",
8487
opacity=0.5,
@@ -88,10 +91,10 @@
8891

8992
# %%
9093
fig = px.scatter(
91-
df_m3gnet.round(1),
94+
df_m3gnet_is2re.round(1),
9295
x="final_wbm_volume",
9396
y=["initial_wbm_volume", "m3gnet_volume"],
94-
hover_data=[df_m3gnet.index],
97+
hover_data=[df_m3gnet_is2re.index],
9598
)
9699
add_identity_line(fig)
97100
fig.update_layout(
@@ -102,60 +105,63 @@
102105

103106
# %% histogram of alpha lattice angles (similar results for beta and gamma)
104107
fig = px.histogram(
105-
df_m3gnet.filter(like="alpha"), nbins=1000, barmode="overlay", log_y=True
108+
df_m3gnet_is2re.filter(like="alpha"), nbins=1000, barmode="overlay", log_y=True
106109
)
107110
fig.show()
108111

109112

110113
# %%
111114
px.histogram(
112-
df_m3gnet.filter(regex="_c$"),
115+
df_m3gnet_is2re.filter(regex="_c$"),
113116
nbins=1000,
114117
log_y=True,
115118
barmode="overlay",
116119
opacity=0.5,
117120
)
118121

119122

120-
# %%
121-
df_m3gnet["final_m3gnet_structure"] = df_m3gnet.final_structure.map(Structure.from_dict)
122-
df_m3gnet["initial_wbm_structure"] = df_m3gnet.initial_wbm_structure.map(
123+
# %% compute mean absolute PBC difference between initial and final fractional
124+
# coordinates of crystal sites
125+
df_m3gnet_is2re["m3gnet_structure"] = df_m3gnet_is2re.m3gnet_structure.map(
126+
Structure.from_dict
127+
)
128+
df_m3gnet_is2re["initial_wbm_structure"] = df_m3gnet_is2re.initial_wbm_structure.map(
123129
Structure.from_dict
124130
)
125-
df_m3gnet["final_wbm_structure"] = df_m3gnet.final_wbm_structure.map(
131+
df_m3gnet_is2re["final_wbm_structure"] = df_m3gnet_is2re.final_wbm_structure.map(
126132
Structure.from_dict
127133
)
128134

129135

130-
df_m3gnet["m3gnet_pbc_diffs"] = [
136+
df_m3gnet_is2re["m3gnet_pbc_diffs"] = [
131137
abs(
132138
pbc_diff(
133139
row.initial_wbm_structure.frac_coords,
134-
row.final_m3gnet_structure.frac_coords,
140+
row.m3gnet_structure.frac_coords,
135141
)
136142
).mean()
137-
for row in df_m3gnet.itertuples()
143+
for row in df_m3gnet_is2re.itertuples()
138144
]
139145

140146

141-
df_m3gnet["wbm_pbc_diffs"] = [
147+
df_m3gnet_is2re["wbm_pbc_diffs"] = [
142148
abs(
143149
pbc_diff(
144150
row.initial_wbm_structure.frac_coords,
145151
row.final_wbm_structure.frac_coords,
146152
)
147153
).mean()
148-
for row in df_m3gnet.itertuples()
154+
for row in df_m3gnet_is2re.itertuples()
149155
]
150156

151-
df_m3gnet["m3gnet_to_final_wbm_pbc_diffs"] = [
157+
df_m3gnet_is2re["m3gnet_to_final_wbm_pbc_diffs"] = [
152158
abs(
153159
pbc_diff(
154-
row.final_m3gnet_structure.frac_coords,
160+
row.m3gnet_structure.frac_coords,
155161
row.final_wbm_structure.frac_coords,
156162
)
157163
).mean()
158-
for row in df_m3gnet.itertuples()
164+
for row in df_m3gnet_is2re.itertuples()
159165
]
160166

161167

@@ -164,13 +170,60 @@
164170
"and M3GNet"
165171
)
166172

167-
wbm_pbc_diffs_mean = df_m3gnet.wbm_pbc_diffs.mean()
173+
wbm_pbc_diffs_mean = df_m3gnet_is2re.wbm_pbc_diffs.mean()
168174
print(f"{wbm_pbc_diffs_mean = :.3}")
169175

170-
m3gnet_pbc_diffs_mean = df_m3gnet.m3gnet_pbc_diffs.mean()
176+
m3gnet_pbc_diffs_mean = df_m3gnet_is2re.m3gnet_pbc_diffs.mean()
171177
print(f"{m3gnet_pbc_diffs_mean = :.3}")
172178

173-
m3gnet_to_final_wbm_pbc_diffs_mean = df_m3gnet.m3gnet_to_final_wbm_pbc_diffs.mean()
179+
m3gnet_to_final_wbm_pbc_diffs_mean = (
180+
df_m3gnet_is2re.m3gnet_to_final_wbm_pbc_diffs.mean()
181+
)
174182
print(f"{m3gnet_to_final_wbm_pbc_diffs_mean = :.3}")
175183

176184
print(f"{wbm_pbc_diffs_mean / m3gnet_pbc_diffs_mean = :.3}")
185+
186+
187+
# %%
188+
# plt_fig = df_m3gnet_is2re.plot.scatter(
189+
# x="e_m3gnet_per_atom_rs2re", y="e_m3gnet_per_atom_is2re"
190+
# )
191+
# df_m3gnet_is2re.filter(like="m3gnet_energy").hist(bins=100)
192+
193+
df_m3gnet_is2re["m3gnet_energy_rs2re"] = df_m3gnet_rs2re.m3gnet_energy
194+
195+
for task_type in ["is2re", "rs2re"]:
196+
e_per_atom = df_m3gnet_is2re[f"m3gnet_energy_{task_type}"] / df_m3gnet_is2re.n_sites
197+
198+
df_m3gnet_is2re[f"e_m3gnet_per_atom_{task_type}"] = e_per_atom
199+
200+
fig = px.scatter(
201+
df_m3gnet_is2re,
202+
x="e_m3gnet_per_atom_rs2re",
203+
y="e_m3gnet_per_atom_is2re",
204+
render_mode="webgl",
205+
)
206+
add_identity_line(fig)
207+
208+
len_overlap = df_m3gnet_is2re.filter(like="e_m3gnet_per_atom_").dropna().shape[0]
209+
x_vals, y_vals = df_m3gnet_is2re.filter(like="e_m3gnet_per_atom_").dropna().values.T
210+
211+
MAE = abs(x_vals - y_vals).mean()
212+
R2 = r2_score(x_vals, y_vals)
213+
214+
title = f"data size = {len_overlap:,} \t {MAE = :.2} \t {R2 = :.4}"
215+
fig.update_layout(title=dict(text=title, x=0.5))
216+
217+
# 250k scatter points require exporting to PNG, interactive version freezes the
218+
# notebook server
219+
fig.show(renderer="png", scale=2)
220+
fig.write_image(
221+
f"{ROOT}/figures/{today}-m3gnet-energy-per-atom-scatter-is2re-vs-rs2re.png", scale=2
222+
)
223+
224+
225+
# %% write df back to compressed JSON
226+
# filter out columns containing 'rs2re'
227+
# df_m3gnet_is2re.reset_index().filter(regex="^((?!rs2re).)*$").to_json(
228+
# f"{ROOT}/data/2022-08-16-m3gnet-wbm-relax-results-IS2RE-2.json.gz"
229+
# ).set_index("material_id")

mb_discovery/m3gnet/join_and_plot_m3gnet_relax_results.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@
2323

2424

2525
# %%
26-
# glob_pattern = "2022-08-16-m3gnet-relax-wbm-IS3RE/*.json.gz"
27-
glob_pattern = "2022-08-19-m3gnet-relax-wbm-RS3RE/*.json.gz"
26+
task_type = "RS3RE"
27+
date = "2022-08-19"
28+
glob_pattern = f"{date}-m3gnet-relax-wbm-{task_type}/*.json.gz"
2829
file_paths = sorted(glob(f"{ROOT}/data/{glob_pattern}"))
2930
print(f"Found {len(file_paths):,} files for {glob_pattern = }")
3031

@@ -78,7 +79,7 @@
7879
PDEntry(row.m3gnet_structure.composition, row.m3gnet_energy)
7980
for row in df_m3gnet.itertuples()
8081
]
81-
df_m3gnet["e_form_m3gnet"] = [
82+
df_m3gnet["e_form_m3gnet_from_ppd"] = [
8283
ppd_mp_wbm.get_form_energy_per_atom(x) for x in pd_entries_m3gnet
8384
]
8485

@@ -113,14 +114,14 @@
113114

114115

115116
# %%
116-
out_path = f"{ROOT}/data/{today}-m3gnet-wbm-relax-results.json.gz"
117+
out_path = f"{ROOT}/data/{today}-m3gnet-relax-wbm-{task_type}.json.gz"
117118
df_m3gnet.reset_index().to_json(out_path, default_handler=as_dict_handler)
118119

119120

120121
# %%
121122
ax_hull_dist_hist = hist_classify_stable_as_func_of_hull_dist(
122-
formation_energy_targets=df_m3gnet.e_form_ppd,
123-
formation_energy_preds=df_m3gnet.e_form_m3gnet,
123+
formation_energy_targets=df_m3gnet.e_form_ppd_2022_01_25,
124+
formation_energy_preds=df_m3gnet.e_form_m3gnet_from_ppd,
124125
e_above_hull_vals=df_m3gnet.e_above_mp_hull,
125126
)
126127

0 commit comments

Comments
 (0)