|
| 1 | +# %% |
| 2 | +import plotly.express as px |
| 3 | +import plotly.graph_objects as go |
| 4 | +import seaborn as sns |
| 5 | +from pymatviz.utils import save_fig |
| 6 | + |
| 7 | +from matbench_discovery import FIGS, PDF_FIGS, plots |
| 8 | +from matbench_discovery.preds import df_each_err, models |
| 9 | + |
| 10 | +__author__ = "Janosh Riebesell" |
| 11 | +__date__ = "2023-05-25" |
| 12 | + |
| 13 | + |
| 14 | +# %% |
| 15 | +ax = df_each_err[models].plot.box( |
| 16 | + showfliers=False, |
| 17 | + rot=90, |
| 18 | + figsize=(12, 6), |
| 19 | + # color="blue", |
| 20 | + # different fill colors for each box |
| 21 | + # patch_artist=True, |
| 22 | + # notch=True, |
| 23 | + # bootstrap=10000, |
| 24 | + showmeans=True, |
| 25 | + # meanline=True, |
| 26 | +) |
| 27 | +ax.axhline(0, linewidth=1, color="gray", linestyle="--") |
| 28 | + |
| 29 | + |
| 30 | +# %% |
| 31 | +ax = sns.violinplot( |
| 32 | + data=df_each_err[models], inner="quartile", linewidth=0.3, palette="Set2", width=1 |
| 33 | +) |
| 34 | +ax.set(ylim=(-0.9, 0.9)) |
| 35 | + |
| 36 | + |
| 37 | +# %% |
| 38 | +px.box( |
| 39 | + df_each_err[models].melt(), |
| 40 | + x="variable", |
| 41 | + y="value", |
| 42 | + color="variable", |
| 43 | + points=False, |
| 44 | + hover_data={"variable": False}, |
| 45 | +) |
| 46 | + |
| 47 | + |
| 48 | +# %% |
| 49 | +px.violin( |
| 50 | + df_each_err[models].melt(), |
| 51 | + x="variable", |
| 52 | + y="value", |
| 53 | + color="variable", |
| 54 | + violinmode="overlay", |
| 55 | + box=True, |
| 56 | + # points="all", |
| 57 | + hover_data={"variable": False}, |
| 58 | + width=1000, |
| 59 | + height=500, |
| 60 | +) |
| 61 | + |
| 62 | + |
| 63 | +# %% |
| 64 | +fig = go.Figure() |
| 65 | +fig.layout.yaxis.title = plots.quantity_labels["e_above_hull_error"] |
| 66 | +fig.layout.margin = dict(l=0, r=0, b=0, t=0) |
| 67 | + |
| 68 | +for col in models: |
| 69 | + val_min = df_each_err[col].quantile(0.05) |
| 70 | + lower_box = df_each_err[col].quantile(0.25) |
| 71 | + median = df_each_err[col].median() |
| 72 | + upper_box = df_each_err[col].quantile(0.75) |
| 73 | + val_max = df_each_err[col].quantile(0.95) |
| 74 | + |
| 75 | + box_plot = go.Box( |
| 76 | + y=[val_min, lower_box, median, upper_box, val_max], |
| 77 | + name=col, |
| 78 | + width=0.7, |
| 79 | + ) |
| 80 | + fig.add_trace(box_plot) |
| 81 | + |
| 82 | +fig.layout.legend.update(orientation="h", y=1.15) |
| 83 | +fig.show() |
| 84 | +save_fig(fig, f"{FIGS}/box-hull-dist-errors.svelte") |
| 85 | +save_fig(fig, f"{PDF_FIGS}/box-hull-dist-errors.pdf") |
0 commit comments