|
4 | 4 | import numpy as np
|
5 | 5 | import pandas as pd
|
6 | 6 | from aviary.wren.utils import get_isopointal_proto_from_aflow
|
| 7 | +from IPython.display import display |
7 | 8 | from pymatviz import spacegroup_hist, spacegroup_sunburst
|
8 | 9 | from pymatviz.io import df_to_html_table, df_to_pdf, save_fig
|
9 |
| -from pymatviz.powerups import add_identity_line, bin_df_cols |
| 10 | +from pymatviz.powerups import add_identity_line |
10 | 11 | from pymatviz.ptable import ptable_heatmap_plotly
|
| 12 | +from pymatviz.utils import bin_df_cols |
11 | 13 |
|
12 | 14 | from matbench_discovery import PDF_FIGS, SITE_FIGS, Model
|
13 | 15 | from matbench_discovery.data import DATA_FILES, df_wbm
|
|
20 | 22 |
|
21 | 23 | # %%
|
22 | 24 | model = Model.wrenformer
|
| 25 | +model_low = model.lower() |
23 | 26 | max_each_true = 1
|
24 | 27 | min_each_pred = 1
|
25 | 28 | df_each_pred[Key.each_true] = df_preds[Key.each_true]
|
|
42 | 45 |
|
43 | 46 |
|
44 | 47 | # %%
|
45 |
| -ax = spacegroup_hist(df_bad[Key.spacegroup]) |
46 |
| -ax.set_title(f"Spacegroup hist for {title}", y=1.15) |
47 |
| -save_fig(ax, f"{PDF_FIGS}/spacegroup-hist-{model.lower()}-failures.pdf") |
| 48 | +fig = spacegroup_hist(df_bad[Key.spacegroup]) |
| 49 | +fig.layout.title.update(text=f"Spacegroup hist for {title}", y=0.96) |
| 50 | +fig.layout.margin.update(l=0, r=0, t=80, b=0) |
| 51 | +save_fig(fig, f"{PDF_FIGS}/spacegroup-hist-{model.lower()}-failures.pdf") |
| 52 | +fig.show() |
48 | 53 |
|
49 | 54 |
|
50 | 55 | # %%
|
|
68 | 73 | df_proto_counts[proto_col] = df_proto_counts[proto_col].str.replace("_", "-")
|
69 | 74 |
|
70 | 75 | styler = df_proto_counts.head(10).style.background_gradient(cmap="viridis")
|
71 |
| - |
72 |
| -df_to_html_table(styler, f"{SITE_FIGS}/proto-counts-{model}-failures.svelte") |
73 |
| -df_to_pdf(styler, f"{PDF_FIGS}/proto-counts-{model}-failures.pdf") |
| 76 | +styler.set_caption(f"Top 10 {proto_col} in {len(df_bad)} {model} failures") |
| 77 | +display(styler) |
| 78 | +df_to_html_table(styler, f"{SITE_FIGS}/proto-counts-{model_low}-failures.svelte") |
| 79 | +df_to_pdf(styler, f"{PDF_FIGS}/proto-counts-{model_low}-failures.pdf") |
74 | 80 |
|
75 | 81 |
|
76 | 82 | # %%
|
77 |
| -fig = spacegroup_sunburst(df_bad[Key.spacegroup], width=350, height=350) |
| 83 | +fig = spacegroup_sunburst( |
| 84 | + df_bad[Key.spacegroup], width=350, height=350, show_counts="percent" |
| 85 | +) |
78 | 86 | # fig.layout.title.update(text=f"Spacegroup sunburst for {title}", x=0.5, font_size=14)
|
79 | 87 | fig.layout.margin.update(l=1, r=1, t=1, b=1)
|
80 | 88 | fig.show()
|
81 | 89 |
|
82 | 90 |
|
83 | 91 | # %%
|
84 |
| -save_fig(fig, f"{PDF_FIGS}/spacegroup-sunburst-{model.lower()}-failures.pdf") |
85 |
| -save_fig(fig, f"{SITE_FIGS}/spacegroup-sunburst-{model}-failures.svelte") |
| 92 | +save_fig(fig, f"{PDF_FIGS}/spacegroup-sunburst-{model_low}-failures.pdf") |
| 93 | +save_fig(fig, f"{SITE_FIGS}/spacegroup-sunburst-{model_low}-failures.svelte") |
86 | 94 |
|
87 | 95 |
|
88 | 96 | # %%
|
89 | 97 | fig = ptable_heatmap_plotly(df_bad[Key.formula])
|
90 | 98 | fig.layout.title = f"Elements in {title}"
|
91 | 99 | fig.layout.margin = dict(l=0, r=0, t=50, b=0)
|
92 | 100 | fig.show()
|
93 |
| -save_fig(fig, f"{PDF_FIGS}/elements-{model.lower()}-failures.pdf") |
| 101 | +save_fig(fig, f"{PDF_FIGS}/elements-{model_low}-failures.pdf") |
94 | 102 |
|
95 | 103 |
|
96 | 104 | # %%
|
|
0 commit comments