Skip to content

Commit e7f9fe8

Browse files
committed
add marginal hull distribution along rolling_mae_vs_hull_dist_models plot top edge
use different line styles and markers for roc-models-all-in-one.svelte
1 parent eb11ab0 commit e7f9fe8

9 files changed

+78
-43
lines changed

matbench_discovery/plots.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -547,16 +547,16 @@ def rolling_mae_vs_hull_dist(
547547
scatter_kwds = dict(
548548
fill="toself", opacity=0.2, hoverinfo="skip", showlegend=False
549549
)
550-
peril_cone_anno = "MAE > |E<sub>above hull</sub>|"
550+
triangle_anno = "MAE > |E<sub>above hull</sub>|"
551551
fig.add_scatter(
552552
x=(-1, -dft_acc, dft_acc, 1) if show_dft_acc else (-1, 0, 1),
553553
y=(1, dft_acc, dft_acc, 1) if show_dft_acc else (1, 0, 1),
554-
name=peril_cone_anno,
554+
name=triangle_anno,
555555
fillcolor="red",
556556
**scatter_kwds,
557557
)
558558
fig.add_annotation(
559-
x=0, y=0.8, text=peril_cone_anno, showarrow=False, yref="paper"
559+
x=0, y=0.7, text=triangle_anno, showarrow=False, yref="paper"
560560
)
561561

562562
if show_dummy_mae:

scripts/model_figs/roc_prc_curves_models.py

+19-15
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,18 @@
7474
range_y=(0, 1.02),
7575
hover_name=facet_col,
7676
hover_data={facet_col: False},
77-
**kwds if facet_plot else dict(color=facet_col, markers=True, marker_size=3),
77+
**(kwds if facet_plot else dict(color=facet_col, markers=True)),
7878
)
7979

8080
for anno in fig.layout.annotations:
8181
anno.text = anno.text.split("=", 1)[1] # remove Model= from subplot titles
8282

83+
line_styles = "solid dash dot dashdot".split() * 3
84+
markers = "circle square triangle-up triangle-down diamond cross star x".split() * 2
85+
for trace, ls, marker in zip(fig.data, line_styles, markers):
86+
trace.line.dash = ls
87+
trace.marker.symbol = marker
88+
8389
if not facet_plot:
8490
fig.layout.legend.update(x=1, y=0, xanchor="right", title=None)
8591
fig.layout.coloraxis.colorbar.update(thickness=14, title_side="right")
@@ -100,37 +106,35 @@
100106

101107
# %%
102108
save_fig(fig, f"{SITE_FIGS}/{img_name}.svelte")
103-
save_fig(fig, f"{PDF_FIGS}/{img_name}.pdf", width=1000, height=400)
109+
save_fig(fig, f"{PDF_FIGS}/{img_name}.pdf", width=500, height=500)
104110

105111

106112
# %%
107113
df_prc = pd.DataFrame()
114+
prec_col, recall_col = "Precision", "Recall"
108115

109116
for model in (pbar := tqdm(list(df_each_pred), desc="Calculating ROC curves")):
110117
pbar.set_postfix_str(model)
111118
na_mask = df_preds[each_true_col].isna() | df_each_pred[model].isna()
112119
y_true = (df_preds[~na_mask][each_true_col] <= STABILITY_THRESHOLD).astype(int)
113120
y_pred = df_each_pred[model][~na_mask]
114121
prec, recall, thresholds = precision_recall_curve(y_true, y_pred, pos_label=0)
115-
df_tmp = pd.DataFrame(
116-
{
117-
"Precision": prec[:-1],
118-
"Recall": recall[:-1],
119-
color_col: thresholds,
120-
facet_col: model,
121-
}
122-
).round(3)
123-
124-
df_prc = pd.concat([df_prc, df_tmp])
122+
dct = {
123+
prec_col: prec[:-1],
124+
recall_col: recall[:-1],
125+
color_col: thresholds,
126+
facet_col: model,
127+
}
128+
df_prc = pd.concat([df_prc, pd.DataFrame(dct).round(3)])
125129

126130

127131
# %%
128132
n_cols = 3
129133
n_rows = math.ceil(len(models) / n_cols)
130134

131135
fig = df_prc.iloc[:: len(df_roc) // 500 or 1].plot.scatter(
132-
x="Recall",
133-
y="Precision",
136+
x=recall_col,
137+
y=prec_col,
134138
facet_col=facet_col,
135139
facet_col_wrap=n_cols,
136140
facet_row_spacing=0.04,
@@ -149,7 +153,7 @@
149153
anno.text = anno.text.split("=", 1)[1] # remove Model= from subplot titles
150154

151155
fig.layout.coloraxis.colorbar.update(
152-
x=0.5, y=1.03, thickness=14, len=0.4, orientation="h"
156+
x=0.5, y=1.03, thickness=11, len=0.8, orientation="h"
153157
)
154158
fig.add_hline(y=0.5, line=line)
155159
fig.add_annotation(

scripts/model_figs/rolling_mae_vs_hull_dist_models.py

+26-3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# %%
55
from typing import Final
66

7+
import numpy as np
8+
import plotly.graph_objects as go
79
from pymatviz.utils import save_fig
810

911
from matbench_discovery import PDF_FIGS, SITE_FIGS
@@ -23,7 +25,6 @@
2325

2426

2527
# %%
26-
# sort df columns by MAE (so that the legend is sorted too)
2728
backend: Final = "plotly"
2829

2930
fig, df_err, df_std = rolling_mae_vs_hull_dist(
@@ -52,13 +53,35 @@
5253

5354
# increase line width
5455
fig.update_traces(line=dict(width=3))
55-
fig.layout.legend.update(bgcolor="rgba(0,0,0,0)")
56+
fig.layout.legend.update(
57+
bgcolor="rgba(0,0,0,0)", title="", x=1.01, y=0, yanchor="bottom"
58+
)
5659
# increase legend handle size and reverse order
5760
fig.layout.margin.update(l=5, r=5, t=5, b=55)
61+
62+
# plot marginal histogram of true hull distances
63+
counts, bins = np.histogram(
64+
df_preds[each_true_col], bins=400, range=fig.layout.xaxis.range
65+
)
66+
marginal_trace = go.Scatter(
67+
x=bins, y=counts, name="Density", fill="tozeroy", showlegend=False, yaxis="y2"
68+
)
69+
marginal_trace.marker.color = "rgba(0, 150, 200, 1)"
70+
# add marginal trace to existing figure
71+
fig.add_trace(marginal_trace)
72+
73+
# update layout to include marginal plot
74+
fig.layout.update(
75+
yaxis1=dict(domain=[0, 0.75]), # main yaxis
76+
yaxis2=dict( # marginal yaxis
77+
domain=[0.8, 1], tickformat="s", tickvals=[*range(0, 100_000, 2000)]
78+
),
79+
)
5880
fig.show()
81+
5982
img_name = "rolling-mae-vs-hull-dist-models"
6083

6184

6285
# %%
6386
save_fig(fig, f"{SITE_FIGS}/{img_name}.svelte")
64-
save_fig(fig, f"{PDF_FIGS}/{img_name}.pdf", width=520, height=350)
87+
save_fig(fig, f"{PDF_FIGS}/{img_name}.pdf", width=600, height=400)

site/package.json

+11-11
Original file line numberDiff line numberDiff line change
@@ -20,35 +20,35 @@
2020
"@iconify/svelte": "^3.1.4",
2121
"@rollup/plugin-yaml": "^4.1.1",
2222
"@sveltejs/adapter-static": "^2.0.3",
23-
"@sveltejs/kit": "^1.22.4",
24-
"@sveltejs/vite-plugin-svelte": "^2.4.3",
25-
"@typescript-eslint/eslint-plugin": "^6.2.1",
26-
"@typescript-eslint/parser": "^6.2.1",
23+
"@sveltejs/kit": "^1.22.6",
24+
"@sveltejs/vite-plugin-svelte": "^2.4.5",
25+
"@typescript-eslint/eslint-plugin": "^6.4.0",
26+
"@typescript-eslint/parser": "^6.4.0",
2727
"d3-scale-chromatic": "^3.0.0",
2828
"elementari": "^0.2.2",
29-
"eslint": "^8.46.0",
29+
"eslint": "^8.47.0",
3030
"eslint-plugin-svelte": "^2.32.4",
3131
"hastscript": "^8.0.0",
3232
"highlight.js": "^11.8.0",
3333
"js-yaml": "^4.1.0",
3434
"katex": "^0.16.8",
3535
"mdsvex": "^0.11.0",
36-
"prettier": "^3.0.1",
36+
"prettier": "^3.0.2",
3737
"prettier-plugin-svelte": "^3.0.3",
3838
"rehype-autolink-headings": "^6.1.1",
3939
"rehype-katex-svelte": "^1.2.0",
4040
"rehype-slug": "^5.1.0",
4141
"remark-math": "3.0.0",
42-
"svelte": "^4.1.2",
43-
"svelte-check": "^3.4.6",
42+
"svelte": "^4.2.0",
43+
"svelte-check": "^3.5.0",
4444
"svelte-multiselect": "^10.1.0",
4545
"svelte-preprocess": "^5.0.4",
4646
"svelte-toc": "^0.5.5",
4747
"svelte-zoo": "^0.4.9",
48-
"svelte2tsx": "^0.6.19",
49-
"tslib": "^2.6.1",
48+
"svelte2tsx": "^0.6.20",
49+
"tslib": "^2.6.2",
5050
"typescript": "5.1.6",
51-
"vite": "^4.4.8"
51+
"vite": "^4.4.9"
5252
},
5353
"prettier": {
5454
"semi": false,

site/src/figs/roc-models-all-in-one.svelte

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

site/src/figs/rolling-mae-vs-hull-dist-models.svelte

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

site/src/lib/ModelCard.svelte

+6-3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
export let stats: ModelStatLabel[] // [key, label, unit][]
1010
export let sort_by: keyof ModelData
1111
export let show_details: boolean = false
12+
export let style: string | null = null
13+
export let metrics_style: string | null = null
1214
1315
$: ({ model_name, missing_preds, missing_percent, hyperparams, notes, training_set } =
1416
data)
@@ -23,7 +25,7 @@
2325
const target = { target: `_blank`, rel: `noopener` }
2426
</script>
2527

26-
<h2 id={model_name.toLowerCase().replaceAll(` `, `-`)}>
28+
<h2 id={model_name.toLowerCase().replaceAll(` `, `-`)} {style}>
2729
{model_name}
2830
<button
2931
on:click={() => (show_details = !show_details)}
@@ -107,7 +109,7 @@
107109
</section>
108110
</div>
109111
{/if}
110-
<section class="metrics">
112+
<section class="metrics" style={metrics_style}>
111113
<h3 class="toc-exclude">Metrics</h3>
112114
<ul>
113115
{#each stats as { key, label, unit }}
@@ -143,7 +145,7 @@
143145
<section>
144146
<h3 class="toc-exclude">Notes</h3>
145147
<ul>
146-
{#each [`description`, `training`].filter((k) => k in (notes ?? {})) as key}
148+
{#each [`description`, `training`].filter((key) => key in (notes ?? {})) as key}
147149
<li>{@html notes[key]}</li>
148150
{/each}
149151
</ul>
@@ -154,6 +156,7 @@
154156
h2 {
155157
margin: 8pt 0 1em;
156158
text-align: center;
159+
border-radius: 5pt;
157160
}
158161
button {
159162
background: none;

site/src/lib/References.svelte

+3-3
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
<ol>
2121
{#each found_on_page as { title, id, author, DOI, URL: href, issued } (id)}
2222
<li>
23-
<strong {id}>{title}</strong>
23+
<p {id}>{title}</p>
2424
<span>
2525
{@html author
2626
.slice(0, n_authors)
@@ -61,8 +61,8 @@
6161
ol > li {
6262
margin: 1ex 0;
6363
}
64-
ol > li > strong {
65-
display: block;
64+
ol > li > p {
65+
margin: 0;
6666
}
6767
ol > li > :is(small, span) {
6868
font-weight: lighter;

site/src/routes/models/+page.svelte

+8-3
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@
8686
</ul>
8787

8888
<legend>
89-
best
89+
heading color: best
9090
<ColorBar color_scale={interpolatePuOr} style="min-width: min(70vw, 400px);" />
9191
worst
9292
</legend>
@@ -97,9 +97,14 @@
9797
animate:flip={{ duration: 400 }}
9898
in:fade={{ delay: 100 }}
9999
out:fade={{ delay: 100 }}
100-
style="background-color: {bg_color(model[sort_by], min_val, max_val)};"
101100
>
102-
<ModelCard data={model} {stats} {sort_by} bind:show_details />
101+
<ModelCard
102+
data={model}
103+
{stats}
104+
{sort_by}
105+
bind:show_details
106+
style="background-color: {bg_color(model[sort_by], min_val, max_val)};"
107+
/>
103108
<!-- maybe show this text in a tooltip: This model was not trained on the
104109
canonical training set. It's results should not be seen as a one-to-one
105110
comparison to the other models but rather proof of concept of what is possible. -->

0 commit comments

Comments
 (0)