|
1 | 1 | # %%
|
2 | 2 | from plotly.subplots import make_subplots
|
| 3 | +from pymatviz.utils import save_fig |
3 | 4 |
|
4 | 5 | from matbench_discovery import FIGS, today
|
5 | 6 | from matbench_discovery.data import load_df_wbm_with_preds
|
|
21 | 22 |
|
22 | 23 |
|
23 | 24 | # %%
|
24 |
| -models = ( |
25 |
| - "Wren, CGCNN, CGCNN IS2RE, CGCNN RS2RE, Voronoi RF, " |
26 |
| - "Wrenformer, MEGNet, M3GNet, BOWSR MEGNet" |
27 |
| -).split(", ") |
| 25 | +models = sorted( |
| 26 | + "CGCNN, Voronoi RF, Wrenformer, MEGNet, M3GNet, BOWSR MEGNet".split(", ") |
| 27 | +) |
28 | 28 | df_wbm = load_df_wbm_with_preds(models=models).round(3)
|
29 | 29 |
|
30 | 30 | target_col = "e_form_per_atom_mp2020_corrected"
|
|
35 | 35 | which_energy: WhichEnergy = "true"
|
36 | 36 | model_name = "Wrenformer"
|
37 | 37 |
|
38 |
| -backend: Backend = "matplotlib" |
| 38 | +backend: Backend = "plotly" |
| 39 | +rows, cols = len(models) // 3, 3 |
39 | 40 | if backend == "matplotlib":
|
40 |
| - fig, axs = plt.subplots(nrows=3, ncols=3, figsize=(18, 12)) |
| 41 | + fig, axs = plt.subplots(nrows=rows, ncols=cols, figsize=(6 * cols, 5 * rows)) |
41 | 42 | else:
|
42 |
| - fig = make_subplots(rows=3, cols=3) |
| 43 | + x_title = "distance to convex hull (eV/atom)" |
| 44 | + fig = make_subplots( |
| 45 | + rows=rows, cols=cols, y_title="Count", x_title=x_title, subplot_titles=models |
| 46 | + ) |
43 | 47 |
|
44 | 48 |
|
45 | 49 | for idx, model_name in enumerate(models):
|
|
48 | 52 | e_above_hull_pred=df_wbm[e_above_hull_col]
|
49 | 53 | + (df_wbm[model_name] - df_wbm[target_col]),
|
50 | 54 | which_energy=which_energy,
|
51 |
| - ax=axs.flat[idx], |
| 55 | + ax=axs.flat[idx] if backend == "matplotlib" else None, |
52 | 56 | backend=backend,
|
53 | 57 | )
|
54 | 58 | title = f"{model_name} ({len(df_wbm[model_name].dropna()):,})"
|
55 | 59 | text = f"Enrichment\nFactor = {metrics['enrichment']:.3}"
|
| 60 | + row, col = idx % rows + 1, idx // rows + 1 |
56 | 61 |
|
57 | 62 | if backend == "matplotlib":
|
58 | 63 | ax.text(0.02, 0.25, text, fontsize=16, transform=ax.transAxes)
|
59 | 64 | ax.set(title=title)
|
60 | 65 |
|
| 66 | + # no need to store all 250k x values in plot, leads to 1.7 MB file, subsample every 10th |
| 67 | + # point is enough to see the distribution |
| 68 | + for trace in ax.data: |
| 69 | + trace.x = trace.x[::10] |
| 70 | + |
61 | 71 | else:
|
62 |
| - ax.add_annotation(text=text, x=0.5, y=0.5, showarrow=False) |
63 |
| - ax.update_xaxes(title_text=title) |
| 72 | + fig.add_annotation(text=text, x=0.5, y=0.5, showarrow=False) |
| 73 | + fig.add_traces(ax.data, rows=row, cols=col) |
| 74 | + # fig.update_xaxes(title_text=title, row=row, col=col) |
64 | 75 |
|
65 |
| - for trace in ax.data: |
66 |
| - fig.append_trace(trace, row=idx % 3 + 1, col=idx // 3 + 1) |
67 | 76 |
|
68 | 77 | if backend == "matplotlib":
|
69 | 78 | fig.suptitle(f"{today} {which_energy=}", y=1.07, fontsize=16)
|
|
74 | 83 | bbox_to_anchor=(0.5, -0.05),
|
75 | 84 | frameon=False,
|
76 | 85 | )
|
| 86 | +else: |
| 87 | + fig.update_xaxes(range=[-0.4, 0.4]) |
| 88 | + fig.update_layout(showlegend=False, barmode="stack") |
| 89 | + |
77 | 90 |
|
78 |
| -fig.show() |
| 91 | +fig.show(config=dict(responsive=True)) |
79 | 92 |
|
80 | 93 |
|
81 | 94 | # %%
|
82 | 95 | img_path = f"{FIGS}/{today}-wbm-hull-dist-hist-models"
|
83 |
| -# if hasattr(fig, "write_image"): |
84 |
| -# fig.write_image(f"{img_path}.pdf") |
85 |
| -# fig.write_html(f"{img_path}.html", include_ploltyjs="cdn") |
86 |
| -# else: |
87 |
| -# fig.savefig(f"{img_path}.pdf") |
| 96 | +save_fig(fig, f"{img_path}.html") |
| 97 | +# save_fig(fig, f"{img_path}.png", scale=3) |
| 98 | +# save_fig(fig, f"{img_path}.pdf") |
0 commit comments