Skip to content

Commit 4ae422d

Browse files
committed
fix scripts/hist_classified_stable_vs_hull_dist_models.py
was overlapping, not stacking histograms in same plot and failed for backend=plotly
1 parent f5c3e37 commit 4ae422d

13 files changed

+120
-108
lines changed

data/wbm/analysis.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33

44
import pandas as pd
55
from pymatviz import count_elements, ptable_heatmap_plotly
6+
from pymatviz.utils import save_fig
67

78
from matbench_discovery import ROOT, today
8-
from matbench_discovery.plots import write_html
99

1010
module_dir = os.path.dirname(__file__)
1111

@@ -47,7 +47,7 @@
4747

4848
# %%
4949
fig.write_image(f"{module_dir}/{today}-wbm-elements.svg", width=1000, height=500)
50-
write_html(fig, f"{module_dir}/{today}-wbm-elements.svelte")
50+
save_fig(fig, f"{module_dir}/{today}-wbm-elements.svelte")
5151

5252

5353
# %% load MP training set
@@ -82,4 +82,4 @@
8282

8383
# %%
8484
fig.write_image(f"{module_dir}/{today}-mp-elements.svg", width=1000, height=500)
85-
write_html(fig, f"{module_dir}/{today}-mp-elements.svelte")
85+
save_fig(fig, f"{module_dir}/{today}-mp-elements.svelte")

data/wbm/fetch_process_wbm_dataset.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
)
1919
from pymatgen.entries.computed_entries import ComputedStructureEntry
2020
from pymatviz import density_scatter
21+
from pymatviz.utils import save_fig
2122
from tqdm import tqdm
2223

2324
from matbench_discovery import ROOT, today
@@ -448,14 +449,10 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
448449
# no need to store all 250k x values in plot, leads to 1.7 MB file, subsample every 10th
449450
# point is enough to see the distribution
450451
fig.data[0].x = fig.data[0].x[::10]
451-
# recommended to upload to vecta.io/nano afterwards for compression
452-
fig.write_image(f"{module_dir}/{today}-hist-e-form-per-atom.svg", width=800, height=300)
453-
fig.write_html(
454-
f"{module_dir}/{today}-hist-e-form-per-atom.svelte",
455-
include_plotlyjs=False,
456-
full_html=False,
457-
config=dict(showTips=False, displayModeBar=False, scrollZoom=True),
458-
)
452+
# recommended to upload SVG to vecta.io/nano afterwards for compression
453+
img_path = f"{module_dir}/{today}-hist-e-form-per-atom"
454+
save_fig(fig, f"{img_path}.svg", width=800, height=300)
455+
save_fig(fig, f"{img_path}.svelte")
459456

460457

461458
# %%

matbench_discovery/plots.py

+2-24
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,8 @@ def hist_classified_stable_vs_hull_dist(
221221
color="clf",
222222
nbins=20000,
223223
range_x=x_lim,
224-
opacity=0.9,
224+
barmode="stack",
225+
color_discrete_map=dict(zip(labels, px.colors.qualitative.Pastel)),
225226
**kwargs,
226227
)
227228
ax.update_layout(
@@ -638,26 +639,3 @@ def wandb_scatter(table: wandb.Table, fields: dict[str, str], **kwargs: Any) ->
638639
)
639640

640641
wandb.log({"true_pred_scatter": scatter_plot})
641-
642-
643-
def write_html(fig: go.Figure, path: str, **kwargs: Any) -> None:
644-
"""Write a plotly figure to an HTML file. If the file is has .svelte extension,
645-
insert `{...$$props}` into the figure's top-level div so it can be styled by
646-
consuming Svelte code
647-
648-
Args:
649-
fig (go.Figure): Plotly figure.
650-
path (str): Path to HTML file that will be created.
651-
**kwargs: Keyword arguments passed to fig.write_html().
652-
"""
653-
config = dict(
654-
showTips=False, displayModeBar=False, scrollZoom=True, responsive=True
655-
)
656-
fig.write_html(
657-
path, include_plotlyjs=False, full_html=False, config=config, **kwargs
658-
)
659-
if path.lower().endswith(".svelte"):
660-
# insert {...$$props} into top-level div to be able to post-process and style
661-
# plotly figures from within Svelte files
662-
text = open(path).read().replace("<div>", "<div {...$$props}>", 1)
663-
open(path, "w").write(text)

scripts/cumulative_clf_metrics.py

+6-9
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
# %%
22
import pandas as pd
3+
from pymatviz.utils import save_fig
34

45
from matbench_discovery import FIGS, today
56
from matbench_discovery.data import load_df_wbm_with_preds
6-
from matbench_discovery.plots import cumulative_precision_recall, write_html
7+
from matbench_discovery.plots import cumulative_precision_recall
78

89
__author__ = "Janosh Riebesell, Rhys Goodall"
910
__date__ = "2022-12-04"
@@ -12,7 +13,7 @@
1213
# %%
1314
models = (
1415
# Wren, CGCNN IS2RE, CGCNN RS2RE
15-
"Voronoi RF, Wrenformer, MEGNet, M3GNet, BOWSR MEGNet"
16+
"Voronoi RF, Wrenformer, MEGNet, M3GNet, BOWSR MEGNet, CGCNN, CGCNN debug"
1617
).split(", ")
1718

1819
df_wbm = load_df_wbm_with_preds(models=models).round(3)
@@ -50,16 +51,12 @@
5051

5152

5253
# %%
53-
img_path = f"{FIGS}/{today}-cumulative-clf-metrics"
54-
5554
# file will be served by site
5655
# so we round y floats to reduce file size since
5756
for trace in fig.data:
5857
assert isinstance(trace.y[0], float)
5958
trace.y = [round(y, 3) for y in trace.y]
6059

61-
if hasattr(fig, "write_image"):
62-
fig.write_image(f"{img_path}.pdf")
63-
write_html(fig, f"{img_path}.svelte")
64-
else:
65-
fig.savefig(f"{img_path}.pdf")
60+
img_path = f"{FIGS}/{today}-cumulative-clf-metrics"
61+
# save_fig(fig, f"{img_path}.pdf")
62+
save_fig(fig, f"{img_path}.svelte")

scripts/hist_classified_stable_vs_hull_dist.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# %%
2+
from pymatviz.utils import save_fig
3+
24
from matbench_discovery import FIGS, today
35
from matbench_discovery.data import load_df_wbm_with_preds
46
from matbench_discovery.plots import WhichEnergy, hist_classified_stable_vs_hull_dist
@@ -60,9 +62,6 @@
6062

6163

6264
# %%
63-
img_path = f"{FIGS}/{today}-wren-wbm-hull-dist-hist-{which_energy=}.pdf"
64-
if hasattr(ax, "write_image"):
65-
# fig.write_image(img_path)
66-
ax.write_html(img_path.replace(".pdf", ".html"))
67-
else:
68-
ax.figure.savefig(img_path)
65+
img_path = f"{FIGS}/{today}-wren-wbm-hull-dist-hist-{which_energy=}"
66+
# save_fig(ax, f"{img_path}.pdf")
67+
save_fig(ax, f"{img_path}.html")

scripts/hist_classified_stable_vs_hull_dist_batches.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# %%
2+
from pymatviz.utils import save_fig
3+
24
from matbench_discovery import FIGS, today
35
from matbench_discovery.data import load_df_wbm_with_preds
46
from matbench_discovery.plots import (
@@ -70,4 +72,4 @@
7072

7173
# %%
7274
img_path = f"{FIGS}/{today}-{model_name}-wbm-hull-dist-hist-batches.pdf"
73-
# ax.figure.savefig(img_path)
75+
save_fig(ax, img_path)

scripts/hist_classified_stable_vs_hull_dist_models.py

+29-18
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# %%
22
from plotly.subplots import make_subplots
3+
from pymatviz.utils import save_fig
34

45
from matbench_discovery import FIGS, today
56
from matbench_discovery.data import load_df_wbm_with_preds
@@ -21,10 +22,9 @@
2122

2223

2324
# %%
24-
models = (
25-
"Wren, CGCNN, CGCNN IS2RE, CGCNN RS2RE, Voronoi RF, "
26-
"Wrenformer, MEGNet, M3GNet, BOWSR MEGNet"
27-
).split(", ")
25+
models = sorted(
26+
"CGCNN, Voronoi RF, Wrenformer, MEGNet, M3GNet, BOWSR MEGNet".split(", ")
27+
)
2828
df_wbm = load_df_wbm_with_preds(models=models).round(3)
2929

3030
target_col = "e_form_per_atom_mp2020_corrected"
@@ -35,11 +35,15 @@
3535
which_energy: WhichEnergy = "true"
3636
model_name = "Wrenformer"
3737

38-
backend: Backend = "matplotlib"
38+
backend: Backend = "plotly"
39+
rows, cols = len(models) // 3, 3
3940
if backend == "matplotlib":
40-
fig, axs = plt.subplots(nrows=3, ncols=3, figsize=(18, 12))
41+
fig, axs = plt.subplots(nrows=rows, ncols=cols, figsize=(6 * cols, 5 * rows))
4142
else:
42-
fig = make_subplots(rows=3, cols=3)
43+
x_title = "distance to convex hull (eV/atom)"
44+
fig = make_subplots(
45+
rows=rows, cols=cols, y_title="Count", x_title=x_title, subplot_titles=models
46+
)
4347

4448

4549
for idx, model_name in enumerate(models):
@@ -48,22 +52,27 @@
4852
e_above_hull_pred=df_wbm[e_above_hull_col]
4953
+ (df_wbm[model_name] - df_wbm[target_col]),
5054
which_energy=which_energy,
51-
ax=axs.flat[idx],
55+
ax=axs.flat[idx] if backend == "matplotlib" else None,
5256
backend=backend,
5357
)
5458
title = f"{model_name} ({len(df_wbm[model_name].dropna()):,})"
5559
text = f"Enrichment\nFactor = {metrics['enrichment']:.3}"
60+
row, col = idx % rows + 1, idx // rows + 1
5661

5762
if backend == "matplotlib":
5863
ax.text(0.02, 0.25, text, fontsize=16, transform=ax.transAxes)
5964
ax.set(title=title)
6065

66+
# no need to store all 250k x values in plot, leads to 1.7 MB file, subsample every 10th
67+
# point is enough to see the distribution
68+
for trace in ax.data:
69+
trace.x = trace.x[::10]
70+
6171
else:
62-
ax.add_annotation(text=text, x=0.5, y=0.5, showarrow=False)
63-
ax.update_xaxes(title_text=title)
72+
fig.add_annotation(text=text, x=0.5, y=0.5, showarrow=False)
73+
fig.add_traces(ax.data, rows=row, cols=col)
74+
# fig.update_xaxes(title_text=title, row=row, col=col)
6475

65-
for trace in ax.data:
66-
fig.append_trace(trace, row=idx % 3 + 1, col=idx // 3 + 1)
6776

6877
if backend == "matplotlib":
6978
fig.suptitle(f"{today} {which_energy=}", y=1.07, fontsize=16)
@@ -74,14 +83,16 @@
7483
bbox_to_anchor=(0.5, -0.05),
7584
frameon=False,
7685
)
86+
else:
87+
fig.update_xaxes(range=[-0.4, 0.4])
88+
fig.update_layout(showlegend=False, barmode="stack")
89+
7790

78-
fig.show()
91+
fig.show(config=dict(responsive=True))
7992

8093

8194
# %%
8295
img_path = f"{FIGS}/{today}-wbm-hull-dist-hist-models"
83-
# if hasattr(fig, "write_image"):
84-
# fig.write_image(f"{img_path}.pdf")
85-
# fig.write_html(f"{img_path}.html", include_ploltyjs="cdn")
86-
# else:
87-
# fig.savefig(f"{img_path}.pdf")
96+
save_fig(fig, f"{img_path}.html")
97+
# save_fig(fig, f"{img_path}.png", scale=3)
98+
# save_fig(fig, f"{img_path}.pdf")

scripts/make_api_docs.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
from glob import glob
44
from subprocess import run
55

6-
# update generated API docs on production builds
6+
# Update auto-generated API docs. Also tweak lazydocs's markdown output for
7+
# - prettier badges linking to source code on GitHub
8+
# - remove bold tags since they break inline code
79

810
pkg = json.load(open("site/package.json"))
911
route = "site/src/routes/api"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# %%
2+
from plotly.subplots import make_subplots
3+
from pymatviz.utils import save_fig
4+
5+
from matbench_discovery import FIGS, today
6+
from matbench_discovery.data import load_df_wbm_with_preds
7+
from matbench_discovery.plots import Backend, rolling_mae_vs_hull_dist
8+
9+
__author__ = "Rhys Goodall, Janosh Riebesell"
10+
__date__ = "2022-06-18"
11+
12+
13+
# %%
14+
models = sorted(
15+
"Wrenformer, CGCNN, Voronoi RF, MEGNet, M3GNet, BOWSR MEGNet".split(", ")
16+
)
17+
18+
df_wbm = load_df_wbm_with_preds(models=models).round(3)
19+
20+
21+
# %%
22+
target_col = "e_form_per_atom_mp2020_corrected"
23+
e_above_hull_col = "e_above_hull_mp2020_corrected_ppd_mp"
24+
backend: Backend = "plotly"
25+
26+
rows, cols = len(models) // 3, 3
27+
if backend == "plotly":
28+
fig = make_subplots(rows=rows, cols=cols)
29+
30+
31+
for idx, model_name in enumerate(models):
32+
row, col = idx % rows + 1, idx // rows + 1
33+
34+
# assert df_wbm[model_name].isna().sum() < 100
35+
preds = df_wbm[target_col] - df_wbm[model_name]
36+
MAE = (df_wbm[e_above_hull_col] - preds).abs().mean()
37+
38+
ax = rolling_mae_vs_hull_dist(
39+
e_above_hull_true=df_wbm[e_above_hull_col],
40+
e_above_hull_error=preds,
41+
label=f"{model_name} · {MAE=:.2f}",
42+
backend=backend,
43+
)
44+
if backend == "plotly":
45+
fig.add_traces(ax.data, row=row, col=col)
46+
47+
if hasattr(ax, "legend"):
48+
# increase line width in legend
49+
legend = ax.legend(frameon=False, loc="lower right")
50+
ax.figure.set_size_inches(10, 9)
51+
for line in legend.get_lines():
52+
line._linewidth *= 3
53+
54+
55+
fig.show()
56+
57+
58+
# %%
59+
img_path = f"{FIGS}/{today}-rolling-mae-vs-hull-dist-compare-models"
60+
save_fig(fig, f"{img_path}.pdf")

site/package.json

+2-5
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"preview": "vite preview",
1414
"serve": "vite build && vite preview",
1515
"check": "svelte-check",
16-
"make-api-docs": "python ../scripts/make_api_docs.py"
16+
"make-api-docs": "cd .. && python ../scripts/make_api_docs.py"
1717
},
1818
"devDependencies": {
1919
"@iconify/svelte": "^3.0.1",
@@ -25,16 +25,13 @@
2525
"@typescript-eslint/parser": "^5.48.1",
2626
"eslint": "^8.31.0",
2727
"eslint-plugin-svelte3": "^4.0.0",
28-
"hast-util-from-string": "^2.0.0",
29-
"hast-util-select": "^5.0.3",
30-
"hast-util-to-string": "^2.0.0",
3128
"hastscript": "^7.2.0",
32-
"highlight.js": "^11.7.0",
3329
"katex": "^0.16.4",
3430
"mdsvex": "^0.10.6",
3531
"prettier": "^2.8.2",
3632
"prettier-plugin-svelte": "^2.9.0",
3733
"rehype-autolink-headings": "^6.1.1",
34+
"rehype-katex-svelte": "^1.1.2",
3835
"rehype-slug": "^5.1.0",
3936
"remark-math": "3.0.0",
4037
"svelte": "^3.55.1",

site/src/app.html

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
<!-- math display -->
3131
<link
3232
rel="stylesheet"
33-
href="https://cdn.jsdelivr.net/npm/katex@0.15.0/dist/katex.min.css"
33+
href="https://cdn.jsdelivr.net/npm/katex@latest/dist/katex.min.css"
3434
/>
3535

3636
%sveltekit.head%

site/svelte.config.js

+1-15
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
11
import adapter from '@sveltejs/adapter-static'
2-
import { fromString } from 'hast-util-from-string'
3-
import { selectAll } from 'hast-util-select'
4-
import { toString } from 'hast-util-to-string'
52
import { s } from 'hastscript'
63
import katex from 'katex'
74
import { mdsvex } from 'mdsvex'
@@ -11,18 +8,7 @@ import math from 'remark-math'
118
import preprocess from 'svelte-preprocess'
129

1310
const rehypePlugins = [
14-
// from https://github.com/kwshi/rehype-katex-svelte
15-
(options = {}) =>
16-
(tree) => {
17-
for (const node of selectAll(`.math-inline,.math-display`, tree)) {
18-
const displayMode = node.properties?.className?.includes(`math-display`)
19-
const rendered = katex.renderToString(toString(node), {
20-
...options,
21-
displayMode,
22-
})
23-
fromString(node, `{@html ${JSON.stringify(rendered)}}`)
24-
}
25-
},
11+
katex,
2612
heading_slugs,
2713
[
2814
link_headings,

0 commit comments

Comments
 (0)