|
| 1 | +"""Analyze structures and composition with largest mean error across all models. |
| 2 | +Maybe there's some chemistry/region of materials space that all models struggle with? |
| 3 | +Might point to deficiencies in the data or models architecture. |
| 4 | +""" |
| 5 | + |
| 6 | + |
| 7 | +# %% |
| 8 | +import pandas as pd |
| 9 | +import plotly.express as px |
| 10 | +from pymatgen.core import Composition, Element |
| 11 | +from pymatviz import count_elements, ptable_heatmap_plotly |
| 12 | +from pymatviz.utils import bin_df_cols, save_fig |
| 13 | +from sklearn.metrics import r2_score |
| 14 | +from tqdm import tqdm |
| 15 | + |
| 16 | +from matbench_discovery import FIGS, MODELS, ROOT |
| 17 | +from matbench_discovery.data import DATA_FILES, df_wbm |
| 18 | +from matbench_discovery.preds import ( |
| 19 | + df_each_err, |
| 20 | + df_metrics, |
| 21 | + df_preds, |
| 22 | + each_pred_col, |
| 23 | + each_true_col, |
| 24 | + model_mean_err_col, |
| 25 | +) |
| 26 | + |
| 27 | +__author__ = "Janosh Riebesell" |
| 28 | +__date__ = "2023-02-15" |
| 29 | + |
| 30 | +df_each_err[model_mean_err_col] = df_preds[model_mean_err_col] = df_each_err.abs().mean( |
| 31 | + axis=1 |
| 32 | +) |
| 33 | + |
| 34 | + |
| 35 | +# %% |
| 36 | +df_mp = pd.read_csv(DATA_FILES.mp_energies, na_filter=False).set_index("material_id") |
| 37 | +# compute number of samples per element in training set |
| 38 | +# counting element occurrences not weighted by composition, assuming model don't learn |
| 39 | +# much more about iron and oxygen from Fe2O3 than from FeO |
| 40 | + |
| 41 | +train_count_col = "MP Occurrences" |
| 42 | +df_elem_err = count_elements(df_mp.formula_pretty, count_mode="occurrence").to_frame( |
| 43 | + name=train_count_col |
| 44 | +) |
| 45 | + |
| 46 | + |
| 47 | +# %% |
| 48 | +fig = ptable_heatmap_plotly(df_elem_err[train_count_col], font_size=10) |
| 49 | +title = "Number of MP structures containing each element" |
| 50 | +fig.layout.title.update(text=title, x=0.4, y=0.9) |
| 51 | +fig.show() |
| 52 | + |
| 53 | + |
| 54 | +# %% map average model error onto elements |
| 55 | +frac_comp_col = "fractional composition" |
| 56 | +df_wbm[frac_comp_col] = [ |
| 57 | + Composition(comp).fractional_composition for comp in tqdm(df_wbm.formula) |
| 58 | +] |
| 59 | + |
| 60 | +df_frac_comp = pd.DataFrame(comp.as_dict() for comp in df_wbm[frac_comp_col]).set_index( |
| 61 | + df_wbm.index |
| 62 | +) |
| 63 | +assert all( |
| 64 | + df_frac_comp.sum(axis=1).round(6) == 1 |
| 65 | +), "composition fractions don't sum to 1" |
| 66 | + |
| 67 | +# df_frac_comp = df_frac_comp.dropna(axis=1, thresh=100) # remove Xe with only 1 entry |
| 68 | + |
| 69 | + |
| 70 | +# %% |
| 71 | +for label, srs in ( |
| 72 | + ("MP", df_elem_err[train_count_col]), |
| 73 | + ("WBM", df_frac_comp.where(pd.isna, 1).sum()), |
| 74 | +): |
| 75 | + title = f"Number of {label} structures containing each element" |
| 76 | + srs = srs.sort_values().copy() |
| 77 | + srs.index = [f"{len(srs) - idx} {el}" for idx, el in enumerate(srs.index)] |
| 78 | + fig = srs.plot.bar(backend="plotly", title=title) |
| 79 | + fig.layout.update(showlegend=False) |
| 80 | + fig.show() |
| 81 | + |
| 82 | + |
| 83 | +# %% plot structure counts for each element in MP and WBM in a grouped bar chart |
| 84 | +df_struct_counts = pd.DataFrame(index=df_elem_err.index) |
| 85 | +df_struct_counts["MP"] = df_elem_err[train_count_col] |
| 86 | +df_struct_counts["WBM"] = df_frac_comp.where(pd.isna, 1).sum() |
| 87 | +min_count = 10 # only show elements with at least 10 structures |
| 88 | +df_struct_counts = df_struct_counts[df_struct_counts.sum(axis=1) > min_count] |
| 89 | +normalized = False |
| 90 | +if normalized: |
| 91 | + df_struct_counts["MP"] /= len(df_mp) / 100 |
| 92 | + df_struct_counts["WBM"] /= len(df_wbm) / 100 |
| 93 | +y_col = "percent" if normalized else "count" |
| 94 | +fig = ( |
| 95 | + df_struct_counts.reset_index() |
| 96 | + .melt(var_name="dataset", value_name=y_col, id_vars="symbol") |
| 97 | + .sort_values([y_col, "symbol"]) |
| 98 | + .plot.bar( |
| 99 | + x="symbol", |
| 100 | + y=y_col, |
| 101 | + backend="plotly", |
| 102 | + title="Number of structures containing each element", |
| 103 | + color="dataset", |
| 104 | + barmode="group", |
| 105 | + ) |
| 106 | +) |
| 107 | + |
| 108 | +fig.layout.update(bargap=0.1) |
| 109 | +fig.layout.legend.update(x=0.02, y=0.98, font_size=16) |
| 110 | +fig.show() |
| 111 | +save_fig(fig, f"{FIGS}/bar-element-counts-mp+wbm-{normalized=}.svelte") |
| 112 | + |
| 113 | + |
| 114 | +# %% |
| 115 | +test_set_std_col = "Test set standard deviation (eV/atom)" |
| 116 | +df_elem_err[test_set_std_col] = ( |
| 117 | + df_frac_comp.where(pd.isna, 1) * df_wbm[each_true_col].values[:, None] |
| 118 | +).std() |
| 119 | + |
| 120 | + |
| 121 | +# %% |
| 122 | +fig = ptable_heatmap_plotly( |
| 123 | + df_elem_err[test_set_std_col], precision=".2f", colorscale="Inferno" |
| 124 | +) |
| 125 | +fig.show() |
| 126 | + |
| 127 | + |
| 128 | +# %% |
| 129 | +normalized = True |
| 130 | +cs_range = (0, 0.5) # same range for all plots |
| 131 | +# cs_range = (None, None) # different range for each plot |
| 132 | +for model in (*df_metrics, model_mean_err_col): |
| 133 | + df_elem_err[model] = ( |
| 134 | + df_frac_comp * df_each_err[model].abs().values[:, None] |
| 135 | + ).mean() |
| 136 | + # don't change series values in place, would change the df |
| 137 | + per_elem_err = df_elem_err[model].copy(deep=True) |
| 138 | + per_elem_err.name = f"{model} (eV/atom)" |
| 139 | + if normalized: |
| 140 | + per_elem_err /= df_elem_err[test_set_std_col] |
| 141 | + per_elem_err.name = f"{model} (normalized by test set std)" |
| 142 | + fig = ptable_heatmap_plotly( |
| 143 | + per_elem_err, precision=".2f", colorscale="Inferno", cscale_range=cs_range |
| 144 | + ) |
| 145 | + fig.show() |
| 146 | + |
| 147 | + |
| 148 | +# %% |
| 149 | +assert (df_elem_err.isna().sum() < 35).all() |
| 150 | +df_elem_err.round(4).to_json(f"{MODELS}/per-element-model-each-errors.json") |
| 151 | + |
| 152 | + |
| 153 | +# %% scatter plot error by element against prevalence in training set |
| 154 | +# for checking correlation and R2 of elemental prevalence in MP training data vs. |
| 155 | +# model error |
| 156 | +df_elem_err["elem_name"] = [Element(el).long_name for el in df_elem_err.index] |
| 157 | +R2 = r2_score(*df_elem_err[[train_count_col, model_mean_err_col]].dropna().values.T) |
| 158 | +r_P = df_elem_err[model_mean_err_col].corr(df_elem_err[train_count_col]) |
| 159 | + |
| 160 | +fig = df_elem_err.plot.scatter( |
| 161 | + x=train_count_col, |
| 162 | + y=model_mean_err_col, |
| 163 | + backend="plotly", |
| 164 | + hover_name="elem_name", |
| 165 | + text=df_elem_err.index.where( |
| 166 | + (df_elem_err[model_mean_err_col] > 0.04) |
| 167 | + | (df_elem_err[train_count_col] > 6_000) |
| 168 | + ), |
| 169 | + title="Per-element error vs element-occurrence in MP training " |
| 170 | + f"set: r<sub>Pearson</sub>={r_P:.2f}, R<sup>2</sup>={R2:.2f}", |
| 171 | + hover_data={model_mean_err_col: ":.2f", train_count_col: ":,.0f"}, |
| 172 | +) |
| 173 | +fig.update_traces(textposition="top center") # place text above scatter points |
| 174 | +fig.layout.title.update(xanchor="center", x=0.5) |
| 175 | +fig.show() |
| 176 | + |
| 177 | +# save_fig(fig, f"{FIGS}/element-prevalence-vs-error.svelte") |
| 178 | +save_fig(fig, f"{ROOT}/tmp/figures/element-prevalence-vs-error.pdf") |
| 179 | + |
| 180 | + |
| 181 | +# %% plot EACH errors against least prevalent element in structure (by occurrence in |
| 182 | +# MP training set). this seems to correlate more with model error |
| 183 | +n_examp_for_rarest_elem_col = "Examples for rarest element in structure" |
| 184 | +df_wbm["composition"] = df_wbm.get("composition", df_wbm.formula.map(Composition)) |
| 185 | +df_elem_err.loc[list(map(str, df_wbm.composition[0]))][train_count_col].min() |
| 186 | +df_wbm[n_examp_for_rarest_elem_col] = [ |
| 187 | + df_elem_err.loc[list(map(str, Composition(formula)))][train_count_col].min() |
| 188 | + for formula in tqdm(df_wbm.formula) |
| 189 | +] |
| 190 | + |
| 191 | + |
| 192 | +# %% |
| 193 | +df_melt = ( |
| 194 | + df_each_err.abs() |
| 195 | + .reset_index() |
| 196 | + .melt(var_name="Model", value_name=each_pred_col, id_vars="material_id") |
| 197 | + .set_index("material_id") |
| 198 | +) |
| 199 | +df_melt[n_examp_for_rarest_elem_col] = df_wbm[n_examp_for_rarest_elem_col] |
| 200 | + |
| 201 | +df_bin = bin_df_cols(df_melt, [n_examp_for_rarest_elem_col, each_pred_col], ["Model"]) |
| 202 | +df_bin = df_bin.reset_index().set_index("material_id") |
| 203 | +df_bin["formula"] = df_wbm.formula |
| 204 | + |
| 205 | + |
| 206 | +# %% |
| 207 | +fig = px.scatter( |
| 208 | + df_bin.reset_index(), |
| 209 | + x=n_examp_for_rarest_elem_col, |
| 210 | + y=each_pred_col, |
| 211 | + color="Model", |
| 212 | + facet_col="Model", |
| 213 | + facet_col_wrap=3, |
| 214 | + hover_data=dict(material_id=True, formula=True, Model=False), |
| 215 | + title="Absolute errors in model-predicted E<sub>above hull</sub> vs. occurrence " |
| 216 | + "count in MP training set<br>of least prevalent element in structure", |
| 217 | +) |
| 218 | +fig.layout.update(showlegend=False) |
| 219 | +fig.layout.title.update(x=0.5, xanchor="center", y=0.95) |
| 220 | +fig.layout.margin.update(t=100) |
| 221 | +# remove axis labels |
| 222 | +fig.update_xaxes(title="") |
| 223 | +fig.update_yaxes(title="") |
| 224 | +for anno in fig.layout.annotations: |
| 225 | + anno.text = anno.text.split("=")[1] |
| 226 | + |
| 227 | +fig.add_annotation( |
| 228 | + text="MP occurrence count of least prevalent element in structure", |
| 229 | + x=0.5, |
| 230 | + y=-0.18, |
| 231 | + xref="paper", |
| 232 | + yref="paper", |
| 233 | + showarrow=False, |
| 234 | +) |
| 235 | +fig.add_annotation( |
| 236 | + text="Absolute error in E<sub>above hull</sub>", |
| 237 | + x=-0.07, |
| 238 | + y=0.5, |
| 239 | + xref="paper", |
| 240 | + yref="paper", |
| 241 | + showarrow=False, |
| 242 | + textangle=-90, |
| 243 | +) |
| 244 | + |
| 245 | +fig.show() |
| 246 | +save_fig(fig, f"{FIGS}/each-error-vs-least-prevalent-element-in-struct.svelte") |
0 commit comments