Skip to content

Commit 1f69374

Browse files
committed
clarify element-errors-ptable-heatmap.svelte normalized checkbox purpose
with more detailed label also remove abstract from preprint/references.yaml combine 2023-03-(04|06)-chgnet-0.2.0 preds with 500 and 2000 relax steps into one file (2023-03-06-chgnet-0.2.0-wbm-IS2RE.csv.gz)
1 parent 2f4c33a commit 1f69374

11 files changed

+31
-951
lines changed

data/wbm/eda.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import numpy as np
55
import pandas as pd
66
import plotly.express as px
7-
import plotly.graph_objects as go
87
from pymatgen.core import Composition
98
from pymatviz import (
109
count_elements,
@@ -107,7 +106,7 @@
107106
fig = px.bar(
108107
x=bins[bins < 0], y=left_counts, opacity=0.7, labels={"x": x_label, "y": "Counts"}
109108
)
110-
fig.add_trace(go.Bar(x=bins[bins >= 0], y=right_counts, opacity=0.7))
109+
fig.add_bar(x=bins[bins >= 0], y=right_counts, opacity=0.7)
111110

112111
if col.startswith("e_above_hull"):
113112
n_stable = sum(df_wbm[col] <= STABILITY_THRESHOLD)

matbench_discovery/plots.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -625,15 +625,14 @@ def rolling_mae_vs_hull_dist(
625625
color=plotly_colors[idx], dash=plotly_line_styles[idx], width=3
626626
)
627627
# marker_spacing = 2
628-
# trace = go.Scatter(
628+
# fig.add_scatter(
629629
# x=trace.x[::marker_spacing],
630630
# y=trace.y[::marker_spacing],
631631
# mode="markers",
632632
# marker=dict(symbol=marker, color=trace.line.color, size=8),
633633
# showlegend=False,
634634
# legendgroup=getattr(trace, "legendgroup", None),
635635
# )
636-
# fig.add_trace(trace)
637636
return fig, df_rolling_err, df_err_std
638637

639638

Binary file not shown.

models/chgnet/ctk_trajectory_viewer.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -89,19 +89,17 @@ def plot_energy_and_forces(
8989
"""Plot energy and forces as a function of relaxation step."""
9090
fig = go.Figure()
9191
# energy trace = primary y-axis
92-
fig.add_trace(go.Scatter(x=df.index, y=df[e_col], mode="lines", name="Energy"))
92+
fig.add_scatter(x=df.index, y=df[e_col], mode="lines", name="Energy")
9393
# get energy line color
9494
line_color = fig.data[0].line.color
9595

9696
# forces trace = secondary y-axis
97-
fig.add_trace(
98-
go.Scatter(
99-
x=df.index,
100-
y=df[force_col],
101-
mode="lines",
102-
name="Forces",
103-
yaxis="y2",
104-
),
97+
fig.add_scatter(
98+
x=df.index,
99+
y=df[force_col],
100+
mode="lines",
101+
name="Forces",
102+
yaxis="y2",
105103
)
106104

107105
fig.update_layout(

scripts/analyze_model_failure_cases.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -84,14 +84,13 @@
8484
large_errors = df_each_err[model].abs().nlargest(n_structs)
8585
small_errors = df_each_err[model].abs().nsmallest(n_structs)
8686
for label, errors in zip(("min", "max"), (large_errors, small_errors)):
87-
scatter = go.Histogram(
87+
fig.add_histogram(
8888
x=df_wbm.loc[errors.index][fp_diff_col].values,
8989
name=f"{model} err<sub>{label}</sub>",
9090
visible="legendonly" if idx else True,
9191
legendgroup=model,
9292
hovertemplate=("SSFP diff: %{x:.2f}<br>Count: %{y}"),
9393
)
94-
fig.add_trace(scatter)
9594

9695
title = (
9796
f"Norm-diff between initial/final SiteStatsFingerprint<br>"
@@ -121,7 +120,7 @@
121120
for idx, model in enumerate(df_metrics):
122121
errors = df_each_err[model].abs().nlargest(n_structs)
123122
model_mae = errors.mean().round(3)
124-
scatter = go.Scatter(
123+
fig.add_scatter(
125124
x=df_wbm.loc[errors.index][fp_diff_col].values,
126125
y=errors.values,
127126
mode="markers",
@@ -136,7 +135,6 @@
136135
customdata=df_wbm.loc[errors.index][["material_id", "formula"]].values,
137136
legendrank=model_mae,
138137
)
139-
fig.add_trace(scatter)
140138

141139
title = (
142140
f"Norm-diff between initial/final SiteStatsFingerprint<br>"
@@ -246,7 +244,7 @@
246244
model_mae = df_each_err[model].loc[df_largest_fp_diff.index].abs().mean()
247245

248246
visible = "legendonly" if idx else True
249-
scatter = go.Scatter(
247+
fig.add_scatter(
250248
x=df_largest_fp_diff.values,
251249
y=df_each_err[model].loc[df_largest_fp_diff.index].abs(),
252250
mode="markers",
@@ -265,7 +263,6 @@
265263
marker=dict(color=color),
266264
legendrank=model_mae,
267265
)
268-
fig.add_trace(scatter)
269266
# add dashed mean line for each model that toggles with the scatter plot
270267
# fig.add_hline(
271268
# y=model_mae,
@@ -357,14 +354,13 @@
357354
fig = go.Figure()
358355
for model in df_metrics:
359356
errors = getattr(df_each_err[model].abs(), which)(n_structs)
360-
violin = go.Violin(
357+
fig.add_violin(
361358
x=df_wbm.loc[errors.index][fp_diff_col].values,
362359
name=f"{model} err<sub>{label}</sub>",
363360
legendgroup=model,
364361
hovertemplate=("SSFP diff: %{x:.2f}<br>Count: %{y}"),
365362
spanmode="hard",
366363
)
367-
fig.add_trace(violin)
368364
fig.layout.update(showlegend=False)
369365
fig.layout.xaxis.title = "SSFP norm-diff before/after relaxation"
370366
fig.show()

scripts/model_figs/make_hull_dist_box_plot.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,7 @@
6262
for idx, model in enumerate(models):
6363
ys = [df_each_err[model].quantile(quant) for quant in (0.05, 0.25, 0.5, 0.75, 0.95)]
6464

65-
box_plot = go.Box(y=ys, name=model, width=0.7)
66-
fig.add_trace(box_plot)
65+
fig.add_box(y=ys, name=model, width=0.7)
6766

6867
# Add an annotation for the interquartile range
6968
IQR = ys[3] - ys[1]

scripts/model_figs/rolling_mae_vs_hull_dist_models.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import Final
66

77
import numpy as np
8-
import plotly.graph_objects as go
98
from pymatviz.io import save_fig
109

1110
from matbench_discovery import PDF_FIGS, SITE_FIGS
@@ -65,11 +64,10 @@
6564
counts, bins = np.histogram(
6665
df_preds[each_true_col] + noise, bins=100, range=fig.layout.xaxis.range
6766
)
68-
marginal_trace = go.Scatter(
67+
fig.add_scatter(
6968
x=bins, y=counts, name="Density", fill="tozeroy", showlegend=False, yaxis="y2"
7069
)
71-
marginal_trace.marker.color = "rgba(0, 150, 200, 1)"
72-
fig.add_trace(marginal_trace)
70+
fig.data[-1].marker.color = "rgba(0, 150, 200, 1)"
7371

7472
# update layout to include marginal plot
7573
fig.layout.update(

site/src/routes/data/+page.svelte

+2
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
color_scale={color_scale[0]}
4747
{log}
4848
bind:active_element={active_wbm_elem}
49+
show_photo={false}
4950
>
5051
<TableInset slot="inset">
5152
<label for="log">Log color scale<Toggle id="log" bind:checked={log} /></label>
@@ -81,6 +82,7 @@
8182
color_scale={color_scale[0]}
8283
{log}
8384
bind:active_element={active_mp_elem}
85+
show_photo={false}
8486
>
8587
<TableInset slot="inset">
8688
<label for="log">Log color scale<Toggle id="log" bind:checked={log} /></label>

site/src/routes/data/tmi/+page.svelte

+1
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ structure was generated in).
7171
{log}
7272
color_scale={color_scale[0]}
7373
bind:active_element
74+
show_photo={false}
7475
>
7576
<TableInset slot="inset">
7677
<PtableInset element={active_element} elem_counts={active_counts} />

site/src/routes/models/element-errors-ptable-heatmap.svelte

+12-1
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,14 @@
6969
{cbar_max}
7070
</label>
7171
<label>
72-
Divide errors by test set energies std. dev. over structures containing each element
72+
Divide each element value by its std. dev. of target energies over all test structures
73+
containing a given element
7374
<input type="checkbox" bind:checked={normalized} />
7475
</label>
76+
<small>
77+
This is meant to correct for the fact that some elements are inherently more difficult
78+
to predict since some have a wider distribution of convex hull distances.
79+
</small>
7580
</form>
7681

7782
<PeriodicTable
@@ -82,6 +87,7 @@
8287
tile_props={{
8388
precision: `0.2`,
8489
}}
90+
show_photo={false}
8591
>
8692
<TableInset slot="inset" style="align-content: center;">
8793
<PtableInset
@@ -113,4 +119,9 @@
113119
place-content: center;
114120
gap: 1ex;
115121
}
122+
form label + small {
123+
max-width: 60em;
124+
margin: 0 auto;
125+
text-align: center;
126+
}
116127
</style>

0 commit comments

Comments
 (0)