|
26 | 26 | df_preds,
|
27 | 27 | each_true_col,
|
28 | 28 | model_mean_err_col,
|
| 29 | + model_std_col, |
29 | 30 | )
|
30 | 31 |
|
31 | 32 | __author__ = "Janosh Riebesell"
|
32 | 33 | __date__ = "2023-02-15"
|
33 | 34 |
|
| 35 | +models = list(df_each_pred) |
| 36 | +df_preds[model_std_col] = df_preds[models].std(axis=1) |
34 | 37 | df_each_err[model_mean_err_col] = df_preds[model_mean_err_col] = df_each_err.abs().mean(
|
35 | 38 | axis=1
|
36 | 39 | )
|
|
158 | 161 | # save_fig(fig, f"{FIGS}/scatter-largest-each-errors-fp-diff-models.svelte")
|
159 | 162 |
|
160 | 163 |
|
161 |
| -# %% plotly scatter plot of largest model errors with points sized by mean error and |
162 |
| -# colored by true stability. |
163 |
| -# while some points lie on a horizontal line of constant error, more follow the identity |
164 |
| -# line suggesting the models failed to learn the true physics in these materials |
165 |
| -fig = df_preds.nlargest(200, model_mean_err_col).plot.scatter( |
166 |
| - x=each_true_col, |
167 |
| - y=model_mean_err_col, |
168 |
| - color=each_true_col, |
169 |
| - size=model_mean_err_col, |
170 |
| - backend="plotly", |
| 164 | +# %% |
| 165 | +df_mp = pd.read_csv(DATA_FILES.mp_energies, na_filter=False).set_index("material_id") |
| 166 | +train_count_col = "MP Occurrences" |
| 167 | +df_elem_counts = count_elements(df_mp.formula_pretty, count_mode="occurrence").to_frame( |
| 168 | + name=train_count_col |
171 | 169 | )
|
172 |
| -fig.layout.coloraxis.colorbar.update( |
173 |
| - title="DFT distance to convex hull (eV/atom)", |
174 |
| - title_side="top", |
175 |
| - yanchor="bottom", |
176 |
| - y=1, |
177 |
| - xanchor="center", |
178 |
| - x=0.5, |
179 |
| - orientation="h", |
180 |
| - thickness=12, |
| 170 | +n_examp_for_rarest_elem_col = "Examples for rarest element in structure" |
| 171 | +df_wbm[n_examp_for_rarest_elem_col] = [ |
| 172 | + df_elem_counts[train_count_col].loc[list(map(str, Composition(formula)))].min() |
| 173 | + for formula in tqdm(df_wbm.formula) |
| 174 | +] |
| 175 | +df_preds[n_examp_for_rarest_elem_col] = df_wbm[n_examp_for_rarest_elem_col] |
| 176 | + |
| 177 | + |
| 178 | +# %% scatter plot of largest model errors vs. DFT hull distance |
| 179 | +# while some points lie on a horizontal line of constant error, more follow the identity |
| 180 | +# line showing models are biased to predict low energies likely as a result of training |
| 181 | +# on MP which is highly low-energy enriched. |
| 182 | +# also possible models failed to learn whatever physics makes these materials highly |
| 183 | +# unstable |
| 184 | +fig = ( |
| 185 | + df_preds.nlargest(200, model_mean_err_col) |
| 186 | + .round(2) |
| 187 | + .plot.scatter( |
| 188 | + x=each_true_col, |
| 189 | + y=model_mean_err_col, |
| 190 | + color=model_std_col, |
| 191 | + size=n_examp_for_rarest_elem_col, |
| 192 | + backend="plotly", |
| 193 | + hover_name="material_id", |
| 194 | + hover_data=["formula"], |
| 195 | + color_continuous_scale="Turbo", |
| 196 | + ) |
181 | 197 | )
|
| 198 | +# yanchor="bottom", y=1, xanchor="center", x=0.5, orientation="h", thickness=12 |
| 199 | +fig.layout.coloraxis.colorbar.update(title_side="right", thickness=14) |
182 | 200 | add_identity_line(fig)
|
| 201 | +fig.layout.title = ( |
| 202 | + "Largest model errors vs. DFT hull distance colored by model disagreement" |
| 203 | +) |
| 204 | +# tried setting error_y=model_std_col but looks bad |
| 205 | +# fig.update_traces(error_y=dict(color="rgba(255,255,255,0.2)", width=3, thickness=2)) |
183 | 206 | fig.show()
|
| 207 | +# save_fig(fig, f"{FIGS}/scatter-largest-errors-models-mean-vs-each-true.svelte") |
| 208 | +# save_fig( |
| 209 | +# fig, f"{ROOT}/tmp/figures/scatter-largest-errors-models-mean-vs-each-true.pdf" |
| 210 | +# ) |
184 | 211 |
|
185 | 212 |
|
186 | 213 | # %% find materials that were misclassified by all models
|
|
203 | 230 |
|
204 | 231 |
|
205 | 232 | # %%
|
| 233 | +normalized = True |
206 | 234 | elem_counts: dict[str, pd.Series] = {}
|
207 | 235 | for col in ("All models false neg", "All models false pos"):
|
208 | 236 | elem_counts[col] = elem_counts.get(
|
209 | 237 | col, count_elements(df_preds[df_preds[col]].formula)
|
210 | 238 | )
|
211 |
| - fig = ptable_heatmap_plotly(elem_counts[col], font_size=10) |
212 |
| - fig.layout.title = col |
213 |
| - fig.layout.margin.update(l=0, r=0, t=50, b=0) |
| 239 | + fig = ptable_heatmap_plotly( |
| 240 | + elem_counts[col] / df_elem_counts[train_count_col] |
| 241 | + if normalized |
| 242 | + else elem_counts[col], |
| 243 | + color_bar=dict(title=col), |
| 244 | + precision=".3f", |
| 245 | + cscale_range=[0, 0.1], |
| 246 | + ) |
214 | 247 | fig.show()
|
215 | 248 |
|
| 249 | +# TODO plot these for each model individually |
| 250 | + |
216 | 251 |
|
217 | 252 | # %% map abs EACH model errors onto elements in structure weighted by composition
|
218 | 253 | # fraction and average over all test set structures
|
|
234 | 269 | # df_frac_comp = df_frac_comp.dropna(axis=1, thresh=100) # remove Xe with only 1 entry
|
235 | 270 |
|
236 | 271 |
|
237 |
| -# %% TODO investigate if structures with largest mean over models error can be |
238 |
| -# attributed to DFT gone wrong. would be cool if models can be run across large |
| 272 | +# %% TODO investigate if structures with largest mean error across all models error can |
| 273 | +# be attributed to DFT gone wrong. would be cool if models can be run across large |
239 | 274 | # databases as correctness checkers
|
240 | 275 | df_each_err.abs().mean().sort_values()
|
241 | 276 | df_each_err.abs().mean(axis=1).nlargest(25)
|
|
0 commit comments