Skip to content

Commit d9bb043

Browse files
committed
use same colors, line styles and markers for a given model across plots
starting with cumulative-precision-recall, roc-models-all-in-one, rolling-mae-vs-hull-dist-models
1 parent c35ccdc commit d9bb043

13 files changed

+98
-56
lines changed

matbench_discovery/plots.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,10 @@
3636

3737
plotly_markers = SymbolValidator().values[2::3] # noqa: PD011
3838
plotly_line_styles = DashValidator().values[:-1] # noqa: PD011
39-
# repeat line styles as many as times as needed to match number of markers
39+
plotly_colors = px.colors.qualitative.Plotly
40+
# repeat line styles/colors as many as times as needed to match number of markers
4041
plotly_line_styles *= len(plotly_markers) // len(plotly_line_styles)
42+
plotly_colors *= len(plotly_markers) // len(plotly_colors)
4143

4244

4345
def plotly_unit(text: str) -> str:
@@ -614,10 +616,12 @@ def rolling_mae_vs_hull_dist(
614616

615617
line_styles = "solid dash dot dashdot".split()
616618
markers = "circle square triangle-up triangle-down diamond cross star x".split()
617-
combinations = [(ls, mark) for mark in markers for ls in line_styles]
618-
for idx, trace in enumerate(fig.data):
619-
ls, marker = combinations[idx % len(combinations)]
620-
trace.line.dash = ls
619+
from matbench_discovery.preds import model_styles
620+
621+
for trace in fig.data:
622+
if style := model_styles.get(trace.name):
623+
ls, _marker, color = style
624+
trace.line = dict(color=color, dash=ls, width=2)
621625
# marker_spacing = 2
622626
# trace = go.Scatter(
623627
# x=trace.x[::marker_spacing],

matbench_discovery/preds.py

+15-7
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,14 @@
99
from matbench_discovery import ROOT, STABILITY_THRESHOLD
1010
from matbench_discovery.data import Files, df_wbm, glob_to_df
1111
from matbench_discovery.metrics import stable_metrics
12-
from matbench_discovery.plots import ev_per_atom, model_labels, quantity_labels
12+
from matbench_discovery.plots import (
13+
ev_per_atom,
14+
model_labels,
15+
plotly_colors,
16+
plotly_line_styles,
17+
plotly_markers,
18+
quantity_labels,
19+
)
1320

1421
"""Centralize data-loading and computing metrics for plotting scripts"""
1522

@@ -52,11 +59,11 @@ class PredFiles(Files):
5259

5360
# original MEGNet straight from publication, not re-trained
5461
megnet = "megnet/2022-11-18-megnet-wbm-IS2RE.csv.gz"
55-
# CHGNet-relaxed structures fed into MEGNet for formation energy prediction
56-
chgnet_megnet = "chgnet/2023-03-04-chgnet-wbm-IS2RE.csv.gz"
57-
# M3GNet-relaxed structures fed into MEGNet for formation energy prediction
58-
m3gnet_megnet = "m3gnet/2022-10-31-m3gnet-wbm-IS2RE.csv.gz"
59-
megnet_rs2re = "megnet/2023-08-23-megnet-wbm-RS2RE.csv.gz"
62+
# # CHGNet-relaxed structures fed into MEGNet for formation energy prediction
63+
# chgnet_megnet = "chgnet/2023-03-04-chgnet-wbm-IS2RE.csv.gz"
64+
# # M3GNet-relaxed structures fed into MEGNet for formation energy prediction
65+
# m3gnet_megnet = "m3gnet/2022-10-31-m3gnet-wbm-IS2RE.csv.gz"
66+
# megnet_rs2re = "megnet/2023-08-23-megnet-wbm-RS2RE.csv.gz"
6067

6168
# Magpie composition+Voronoi tessellation structure features + sklearn random forest
6269
voronoi_rf = "voronoi/2022-11-27-train-test/e-form-preds-IS2RE.csv.gz"
@@ -172,7 +179,8 @@ def load_df_wbm_with_preds(
172179
df_metrics_10k = df_metrics_10k.round(3).sort_values("F1", axis=1, ascending=False)
173180

174181
models = list(df_metrics.T.MAE.sort_values().index)
175-
182+
# used for consistent markers, line styles and colors for a given model across plots
183+
model_styles = dict(zip(models, zip(plotly_line_styles, plotly_markers, plotly_colors)))
176184

177185
# To avoid confusion for anyone reading this code, we calculate the formation energy MAE
178186
# here and report it as the MAE for the energy above the convex hull prediction. The

scripts/model_figs/cumulative_metrics.py

+15-12
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,12 @@
1313
from pymatviz.utils import save_fig
1414

1515
from matbench_discovery import PDF_FIGS, SITE_FIGS
16-
from matbench_discovery.plots import (
17-
cumulative_metrics,
18-
plotly_line_styles,
19-
plotly_markers,
20-
)
16+
from matbench_discovery.plots import cumulative_metrics
2117
from matbench_discovery.preds import (
2218
df_each_pred,
2319
df_preds,
2420
each_true_col,
21+
model_styles,
2522
models,
2623
)
2724

@@ -30,18 +27,17 @@
3027

3128

3229
# %%
33-
metrics = ("Precision", "Recall")
34-
# metrics = ("MAE", "RMSE")
30+
# metrics = ("Precision", "Recall")
31+
metrics = ("MAE", "RMSE")
3532
range_y = {
36-
("MAE", "RMSE"): (0, 0.5),
33+
("MAE", "RMSE"): (0, 0.7),
3734
("Precision", "Recall"): (0, 1),
3835
}[metrics]
3936
fig, df_metric = cumulative_metrics(
4037
e_above_hull_true=df_preds[each_true_col],
4138
df_preds=df_each_pred[models],
4239
project_end_point="xy",
4340
backend=(backend := "plotly"),
44-
range_y=range_y,
4541
metrics=metrics,
4642
# facet_col_wrap=2,
4743
# increase facet col gap
@@ -54,6 +50,9 @@
5450
# fig.suptitle(title)
5551
fig.text(0.5, -0.08, x_label, ha="center", fontdict={"size": 16})
5652
if backend == "plotly":
53+
for key in filter(lambda key: key.startswith("yaxis"), fig.layout):
54+
fig.layout[key].range = range_y
55+
5756
fig.layout.margin.update(l=0, r=0, t=30, b=50)
5857
fig.add_annotation(
5958
x=0.5,
@@ -71,10 +70,14 @@
7170
# )
7271
# if "MAE" in metrics:
7372
# fig.layout.legend.update(traceorder="reversed")
73+
assert len(metrics) * len(models) == len(
74+
fig.data
75+
), f"expected one trace per model per metric, got {len(fig.data)}"
7476

75-
for trace, ls, marker in zip(fig.data, plotly_line_styles, plotly_markers):
76-
trace.line.dash = ls
77-
trace.marker.symbol = marker
77+
for trace in fig.data:
78+
if line_style := model_styles.get(trace.name):
79+
ls, _marker, color = line_style
80+
trace.line = dict(color=color, dash=ls, width=2)
7881

7982
# show only the N best models by default
8083
# if trace.name in df_metrics.T.sort_values("F1").index[:-6]:

scripts/model_figs/make_hull_dist_box_plot.py

+28-15
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@
3333
)
3434
ax.set(ylim=(-0.9, 0.9))
3535

36+
for idx, label in enumerate(ax.get_xticklabels()):
37+
label.set_va("bottom" if idx % 2 else "top")
38+
# lower all labels
39+
label.set_y(label.get_position()[1] - 0.05)
40+
3641

3742
# %%
3843
px.violin(
@@ -54,29 +59,37 @@
5459
fig.layout.yaxis.title = plots.quantity_labels["e_above_hull_error"]
5560
fig.layout.margin = dict(l=0, r=0, b=0, t=0)
5661

57-
for col in models:
58-
val_min = df_each_err[col].quantile(0.05)
59-
lower_box = df_each_err[col].quantile(0.25)
60-
median = df_each_err[col].median()
61-
upper_box = df_each_err[col].quantile(0.75)
62-
val_max = df_each_err[col].quantile(0.95)
63-
64-
box_plot = go.Box(
65-
y=[val_min, lower_box, median, upper_box, val_max],
66-
name=col,
67-
width=0.7,
68-
)
62+
for idx, model in enumerate(models):
63+
ys = [df_each_err[model].quantile(quant) for quant in (0.05, 0.25, 0.5, 0.75, 0.95)]
64+
65+
box_plot = go.Box(y=ys, name=model, width=0.7)
6966
fig.add_trace(box_plot)
7067

71-
fig.layout.legend.update(orientation="h", y=1.15)
68+
# Add an annotation for the interquartile range
69+
IQR = ys[3] - ys[1]
70+
median = ys[2]
71+
fig.add_annotation(
72+
x=idx, y=1, text=f"{IQR:.2}", showarrow=False, yref="paper", yshift=-10
73+
)
74+
fig.add_annotation(
75+
x=idx,
76+
y=median,
77+
text=f"{median:.2}",
78+
showarrow=False,
79+
yshift=7,
80+
# bgcolor="rgba(0, 0, 0, 0.2)",
81+
# width=50,
82+
)
83+
fig.add_annotation(x=-0.6, y=1, text="IQR", showarrow=False, yref="paper", yshift=-10)
84+
85+
fig.layout.legend.update(orientation="h", y=1.2)
7286
# prevent x-labels from rotating
7387
fig.layout.xaxis.tickangle = 0
7488
# use line breaks to offset every other x-label
7589
x_labels_with_offset = [
76-
label if idx % 2 == 0 else f"<br>{label}" for idx, label in enumerate(models)
90+
f"{'<br>' * (idx % 2)}{label}" for idx, label in enumerate(models)
7791
]
7892
fig.layout.xaxis.update(tickvals=models, ticktext=x_labels_with_offset)
79-
8093
fig.show()
8194

8295

scripts/model_figs/model_compute_cost.py scripts/model_figs/model_run_times.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from tqdm import tqdm
2020

2121
from matbench_discovery import PDF_FIGS, SITE_FIGS, SITE_MODELS, WANDB_PATH
22-
from matbench_discovery.preds import df_metrics, df_metrics_10k, df_preds
22+
from matbench_discovery.preds import df_metrics, df_metrics_10k, df_preds, model_styles
2323

2424
__author__ = "Janosh Riebesell"
2525
__date__ = "2022-11-28"
@@ -194,6 +194,7 @@
194194
text_auto=".0f",
195195
text=time_col,
196196
color=model_col,
197+
color_discrete_sequence=[model_styles[model][2] for model in df_melt[model_col]],
197198
)
198199
# reduce bar width
199200
fig.update_traces(width=0.8)
@@ -202,8 +203,11 @@
202203
fig.layout.legend.update(title=title, orientation="h", xanchor="center", x=0.4, y=1.2)
203204
fig.layout.xaxis.title = ""
204205
fig.layout.margin.update(l=0, r=0, t=0, b=0)
205-
save_fig(fig, f"{SITE_FIGS}/model-run-times-bar.svelte")
206+
fig.show()
207+
206208

209+
# %%
210+
save_fig(fig, f"{SITE_FIGS}/model-run-times-bar.svelte")
207211
pdf_fig = go.Figure(fig)
208212
# replace legend with annotation in PDF
209213
pdf_fig.layout.showlegend = False
@@ -217,4 +221,3 @@
217221
yref="paper",
218222
)
219223
save_fig(pdf_fig, f"{PDF_FIGS}/model-run-times-bar.pdf", height=300, width=800)
220-
fig.show()

scripts/model_figs/roc_prc_curves_models.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,13 @@
1414

1515
from matbench_discovery import PDF_FIGS, SITE_FIGS, STABILITY_THRESHOLD
1616
from matbench_discovery import plots as plots
17-
from matbench_discovery.preds import df_each_pred, df_preds, each_true_col, models
17+
from matbench_discovery.preds import (
18+
df_each_pred,
19+
df_preds,
20+
each_true_col,
21+
model_styles,
22+
models,
23+
)
1824

1925
__author__ = "Janosh Riebesell"
2026
__date__ = "2023-01-30"
@@ -80,9 +86,13 @@
8086
for anno in fig.layout.annotations:
8187
anno.text = anno.text.split("=", 1)[1] # remove Model= from subplot titles
8288

83-
for trace, ls, marker in zip(fig.data, plots.plotly_line_styles, plots.plotly_markers):
84-
trace.line.dash = ls
85-
trace.marker.symbol = marker
89+
90+
for trace in fig.data:
91+
if styles := model_styles.get(trace.name.split(" · ")[0]):
92+
ls, marker, color = styles
93+
trace.line = dict(color=color, dash=ls, width=2)
94+
trace.marker = dict(color=color, symbol=marker, size=4)
95+
8696

8797
if not facet_plot:
8898
fig.layout.legend.update(x=1, y=0, xanchor="right", title=None)

scripts/model_figs/rolling_mae_vs_hull_dist_models.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,10 @@
4646
for line in fig.lines:
4747
line._linewidth *= 2
4848
else:
49+
show_n_best_models = len(models)
4950
for trace in fig.data:
5051
model = trace.name.split(" MAE=")[0]
51-
if model in df_metrics.T.sort_values("MAE").index[8:]:
52+
if model in df_metrics.T.sort_values("MAE").index[show_n_best_models:]:
5253
trace.visible = "legendonly" # show only top models by default
5354

5455
fig.update_traces(line=dict(width=3)) # increase line width

site/package.json

+4-4
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@
2020
"@iconify/svelte": "^3.1.4",
2121
"@rollup/plugin-yaml": "^4.1.1",
2222
"@sveltejs/adapter-static": "^2.0.3",
23-
"@sveltejs/kit": "^1.22.6",
23+
"@sveltejs/kit": "^1.23.0",
2424
"@sveltejs/vite-plugin-svelte": "^2.4.5",
2525
"@typescript-eslint/eslint-plugin": "^6.4.1",
2626
"@typescript-eslint/parser": "^6.4.1",
2727
"d3-scale-chromatic": "^3.0.0",
2828
"elementari": "^0.2.2",
29-
"eslint": "^8.47.0",
30-
"eslint-plugin-svelte": "^2.32.4",
29+
"eslint": "^8.48.0",
30+
"eslint-plugin-svelte": "^2.33.0",
3131
"hastscript": "^8.0.0",
3232
"highlight.js": "^11.8.0",
3333
"js-yaml": "^4.1.0",
@@ -47,7 +47,7 @@
4747
"svelte-zoo": "^0.4.9",
4848
"svelte2tsx": "^0.6.20",
4949
"tslib": "^2.6.2",
50-
"typescript": "5.1.6",
50+
"typescript": "5.2.2",
5151
"vite": "^4.4.9"
5252
},
5353
"prettier": {

0 commit comments

Comments
 (0)