|
7 | 7 | from pymatgen.core import Structure
|
8 | 8 | from pymatgen.util.coord import pbc_diff
|
9 | 9 | from pymatviz.utils import add_identity_line
|
| 10 | +from sklearn.metrics import r2_score |
10 | 11 |
|
11 | 12 | from mb_discovery import ROOT
|
12 | 13 |
|
|
27 | 28 |
|
28 | 29 |
|
29 | 30 | # %%
|
30 |
| -df_m3gnet = pd.read_json( |
31 |
| - f"{ROOT}/data/2022-08-16-m3gnet-wbm-relax-results.json.gz" |
| 31 | +df_m3gnet_is2re = pd.read_json( |
| 32 | + f"{ROOT}/data/2022-08-16-m3gnet-wbm-relax-results-IS2RE.json.gz" |
| 33 | +).set_index("material_id") |
| 34 | +df_m3gnet_rs2re = pd.read_json( |
| 35 | + f"{ROOT}/data/2022-08-19-m3gnet-wbm-relax-results-RS3RE.json.gz" |
32 | 36 | ).set_index("material_id")
|
33 |
| - |
34 |
| -print("Number of WBM crystals for which we have M3GNet results:") |
35 |
| -print(f"{len(df_m3gnet):,} / {len(df_wbm):,} = {len(df_m3gnet)/len(df_wbm):.1%}") |
36 | 37 |
|
37 | 38 |
|
38 | 39 | # %% spread M3GNet post-pseudo-relaxation lattice params into separate columns
|
39 |
| -df_m3gnet["final_energy"] = df_m3gnet.trajectory.map(lambda x: x["energies"][-1][0]) |
| 40 | +df_m3gnet_is2re["final_energy"] = df_m3gnet_is2re.trajectory.map( |
| 41 | + lambda x: x["energies"][-1][0] |
| 42 | +) |
40 | 43 |
|
41 | 44 | df_m3gnet_lattice = pd.json_normalize(
|
42 |
| - df_m3gnet.initial_structure.map(lambda x: x["lattice"]) |
| 45 | + df_m3gnet_is2re.initial_structure.map(lambda x: x["lattice"]) |
43 | 46 | ).add_prefix("m3gnet_")
|
44 |
| -df_m3gnet[df_m3gnet_lattice.columns] = df_m3gnet_lattice.to_numpy() |
45 |
| -df_m3gnet |
| 47 | +df_m3gnet_is2re[df_m3gnet_lattice.columns] = df_m3gnet_lattice.to_numpy() |
| 48 | +df_m3gnet_is2re |
46 | 49 |
|
47 | 50 |
|
48 | 51 | # %% spread WBM initial and final lattice params into separate columns
|
49 |
| -df_m3gnet["final_wbm_structure"] = df_wbm.cse.map(lambda x: x["structure"]) |
| 52 | +df_m3gnet_is2re["final_wbm_structure"] = df_wbm.cse.map(lambda x: x["structure"]) |
50 | 53 | df_wbm_final_lattice = pd.json_normalize(
|
51 |
| - df_m3gnet.final_wbm_structure.map(lambda x: x["lattice"]) |
| 54 | + df_m3gnet_is2re.final_wbm_structure.map(lambda x: x["lattice"]) |
52 | 55 | ).add_prefix("final_wbm_")
|
53 |
| -df_m3gnet[df_wbm_final_lattice.columns] = df_wbm_final_lattice.to_numpy() |
| 56 | +df_m3gnet_is2re[df_wbm_final_lattice.columns] = df_wbm_final_lattice.to_numpy() |
54 | 57 |
|
55 | 58 |
|
56 |
| -df_m3gnet["initial_wbm_structure"] = df_wbm.initial_structure |
| 59 | +df_m3gnet_is2re["initial_wbm_structure"] = df_wbm.initial_structure |
57 | 60 | df_wbm_initial_lattice = pd.json_normalize(
|
58 |
| - df_m3gnet.initial_structure.map(lambda x: x["lattice"]) |
| 61 | + df_m3gnet_is2re.initial_structure.map(lambda x: x["lattice"]) |
59 | 62 | ).add_prefix("initial_wbm_")
|
60 |
| -df_m3gnet[df_wbm_initial_lattice.columns] = df_wbm_initial_lattice.to_numpy() |
| 63 | +df_m3gnet_is2re[df_wbm_initial_lattice.columns] = df_wbm_initial_lattice.to_numpy() |
61 | 64 |
|
62 | 65 |
|
63 | 66 | # %%
|
|
78 | 81 |
|
79 | 82 | # %%
|
80 | 83 | px.histogram(
|
81 |
| - df_m3gnet.filter(like="volume"), |
| 84 | + df_m3gnet_is2re.filter(like="volume"), |
82 | 85 | nbins=500,
|
83 | 86 | barmode="overlay",
|
84 | 87 | opacity=0.5,
|
|
88 | 91 |
|
89 | 92 | # %%
|
90 | 93 | fig = px.scatter(
|
91 |
| - df_m3gnet.round(1), |
| 94 | + df_m3gnet_is2re.round(1), |
92 | 95 | x="final_wbm_volume",
|
93 | 96 | y=["initial_wbm_volume", "m3gnet_volume"],
|
94 |
| - hover_data=[df_m3gnet.index], |
| 97 | + hover_data=[df_m3gnet_is2re.index], |
95 | 98 | )
|
96 | 99 | add_identity_line(fig)
|
97 | 100 | fig.update_layout(
|
|
102 | 105 |
|
103 | 106 | # %% histogram of alpha lattice angles (similar results for beta and gamma)
|
104 | 107 | fig = px.histogram(
|
105 |
| - df_m3gnet.filter(like="alpha"), nbins=1000, barmode="overlay", log_y=True |
| 108 | + df_m3gnet_is2re.filter(like="alpha"), nbins=1000, barmode="overlay", log_y=True |
106 | 109 | )
|
107 | 110 | fig.show()
|
108 | 111 |
|
109 | 112 |
|
110 | 113 | # %%
|
111 | 114 | px.histogram(
|
112 |
| - df_m3gnet.filter(regex="_c$"), |
| 115 | + df_m3gnet_is2re.filter(regex="_c$"), |
113 | 116 | nbins=1000,
|
114 | 117 | log_y=True,
|
115 | 118 | barmode="overlay",
|
116 | 119 | opacity=0.5,
|
117 | 120 | )
|
118 | 121 |
|
119 | 122 |
|
120 |
| -# %% |
121 |
| -df_m3gnet["final_m3gnet_structure"] = df_m3gnet.final_structure.map(Structure.from_dict) |
122 |
| -df_m3gnet["initial_wbm_structure"] = df_m3gnet.initial_wbm_structure.map( |
| 123 | +# %% compute mean absolute PBC difference between initial and final fractional |
| 124 | +# coordinates of crystal sites |
| 125 | +df_m3gnet_is2re["m3gnet_structure"] = df_m3gnet_is2re.m3gnet_structure.map( |
| 126 | + Structure.from_dict |
| 127 | +) |
| 128 | +df_m3gnet_is2re["initial_wbm_structure"] = df_m3gnet_is2re.initial_wbm_structure.map( |
123 | 129 | Structure.from_dict
|
124 | 130 | )
|
125 |
| -df_m3gnet["final_wbm_structure"] = df_m3gnet.final_wbm_structure.map( |
| 131 | +df_m3gnet_is2re["final_wbm_structure"] = df_m3gnet_is2re.final_wbm_structure.map( |
126 | 132 | Structure.from_dict
|
127 | 133 | )
|
128 | 134 |
|
129 | 135 |
|
130 |
| -df_m3gnet["m3gnet_pbc_diffs"] = [ |
| 136 | +df_m3gnet_is2re["m3gnet_pbc_diffs"] = [ |
131 | 137 | abs(
|
132 | 138 | pbc_diff(
|
133 | 139 | row.initial_wbm_structure.frac_coords,
|
134 |
| - row.final_m3gnet_structure.frac_coords, |
| 140 | + row.m3gnet_structure.frac_coords, |
135 | 141 | )
|
136 | 142 | ).mean()
|
137 |
| - for row in df_m3gnet.itertuples() |
| 143 | + for row in df_m3gnet_is2re.itertuples() |
138 | 144 | ]
|
139 | 145 |
|
140 | 146 |
|
141 |
| -df_m3gnet["wbm_pbc_diffs"] = [ |
| 147 | +df_m3gnet_is2re["wbm_pbc_diffs"] = [ |
142 | 148 | abs(
|
143 | 149 | pbc_diff(
|
144 | 150 | row.initial_wbm_structure.frac_coords,
|
145 | 151 | row.final_wbm_structure.frac_coords,
|
146 | 152 | )
|
147 | 153 | ).mean()
|
148 |
| - for row in df_m3gnet.itertuples() |
| 154 | + for row in df_m3gnet_is2re.itertuples() |
149 | 155 | ]
|
150 | 156 |
|
151 |
| -df_m3gnet["m3gnet_to_final_wbm_pbc_diffs"] = [ |
| 157 | +df_m3gnet_is2re["m3gnet_to_final_wbm_pbc_diffs"] = [ |
152 | 158 | abs(
|
153 | 159 | pbc_diff(
|
154 |
| - row.final_m3gnet_structure.frac_coords, |
| 160 | + row.m3gnet_structure.frac_coords, |
155 | 161 | row.final_wbm_structure.frac_coords,
|
156 | 162 | )
|
157 | 163 | ).mean()
|
158 |
| - for row in df_m3gnet.itertuples() |
| 164 | + for row in df_m3gnet_is2re.itertuples() |
159 | 165 | ]
|
160 | 166 |
|
161 | 167 |
|
|
164 | 170 | "and M3GNet"
|
165 | 171 | )
|
166 | 172 |
|
167 |
| -wbm_pbc_diffs_mean = df_m3gnet.wbm_pbc_diffs.mean() |
| 173 | +wbm_pbc_diffs_mean = df_m3gnet_is2re.wbm_pbc_diffs.mean() |
168 | 174 | print(f"{wbm_pbc_diffs_mean = :.3}")
|
169 | 175 |
|
170 |
| -m3gnet_pbc_diffs_mean = df_m3gnet.m3gnet_pbc_diffs.mean() |
| 176 | +m3gnet_pbc_diffs_mean = df_m3gnet_is2re.m3gnet_pbc_diffs.mean() |
171 | 177 | print(f"{m3gnet_pbc_diffs_mean = :.3}")
|
172 | 178 |
|
173 |
| -m3gnet_to_final_wbm_pbc_diffs_mean = df_m3gnet.m3gnet_to_final_wbm_pbc_diffs.mean() |
| 179 | +m3gnet_to_final_wbm_pbc_diffs_mean = ( |
| 180 | + df_m3gnet_is2re.m3gnet_to_final_wbm_pbc_diffs.mean() |
| 181 | +) |
174 | 182 | print(f"{m3gnet_to_final_wbm_pbc_diffs_mean = :.3}")
|
175 | 183 |
|
176 | 184 | print(f"{wbm_pbc_diffs_mean / m3gnet_pbc_diffs_mean = :.3}")
|
| 185 | + |
| 186 | + |
| 187 | +# %% |
| 188 | +# plt_fig = df_m3gnet_is2re.plot.scatter( |
| 189 | +# x="e_m3gnet_per_atom_rs2re", y="e_m3gnet_per_atom_is2re" |
| 190 | +# ) |
| 191 | +# df_m3gnet_is2re.filter(like="m3gnet_energy").hist(bins=100) |
| 192 | + |
| 193 | +df_m3gnet_is2re["m3gnet_energy_rs2re"] = df_m3gnet_rs2re.m3gnet_energy |
| 194 | + |
| 195 | +for task_type in ["is2re", "rs2re"]: |
| 196 | + e_per_atom = df_m3gnet_is2re[f"m3gnet_energy_{task_type}"] / df_m3gnet_is2re.n_sites |
| 197 | + |
| 198 | + df_m3gnet_is2re[f"e_m3gnet_per_atom_{task_type}"] = e_per_atom |
| 199 | + |
| 200 | +fig = px.scatter( |
| 201 | + df_m3gnet_is2re, |
| 202 | + x="e_m3gnet_per_atom_rs2re", |
| 203 | + y="e_m3gnet_per_atom_is2re", |
| 204 | + render_mode="webgl", |
| 205 | +) |
| 206 | +add_identity_line(fig) |
| 207 | + |
| 208 | +len_overlap = df_m3gnet_is2re.filter(like="e_m3gnet_per_atom_").dropna().shape[0] |
| 209 | +x_vals, y_vals = df_m3gnet_is2re.filter(like="e_m3gnet_per_atom_").dropna().values.T |
| 210 | + |
| 211 | +MAE = abs(x_vals - y_vals).mean() |
| 212 | +R2 = r2_score(x_vals, y_vals) |
| 213 | + |
| 214 | +title = f"data size = {len_overlap:,} \t {MAE = :.2} \t {R2 = :.4}" |
| 215 | +fig.update_layout(title=dict(text=title, x=0.5)) |
| 216 | + |
| 217 | +# 250k scatter points require exporting to PNG, interactive version freezes the |
| 218 | +# notebook server |
| 219 | +fig.show(renderer="png", scale=2) |
| 220 | +fig.write_image( |
| 221 | + f"{ROOT}/figures/{today}-m3gnet-energy-per-atom-scatter-is2re-vs-rs2re.png", scale=2 |
| 222 | +) |
| 223 | + |
| 224 | + |
| 225 | +# %% write df back to compressed JSON |
| 226 | +# filter out columns containing 'rs2re' |
| 227 | +# df_m3gnet_is2re.reset_index().filter(regex="^((?!rs2re).)*$").to_json( |
| 228 | +# f"{ROOT}/data/2022-08-16-m3gnet-wbm-relax-results-IS2RE-2.json.gz" |
| 229 | +# ).set_index("material_id") |
0 commit comments