Skip to content

Commit 8a0c3bb

Browse files
committed
move /paper/preprint to /paper
add CHGNet to /si rolling MAE batches model comparison
1 parent 0fad3bd commit 8a0c3bb

10 files changed

+193
-101
lines changed

.pre-commit-config.yaml

+6-6
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@ default_install_hook_types: [pre-commit, commit-msg]
77

88
repos:
99
- repo: https://github.com/charliermarsh/ruff-pre-commit
10-
rev: v0.0.258
10+
rev: v0.0.260
1111
hooks:
1212
- id: ruff
1313
args: [--fix]
1414

1515
- repo: https://github.com/psf/black
16-
rev: 23.1.0
16+
rev: 23.3.0
1717
hooks:
1818
- id: black
1919

@@ -34,13 +34,13 @@ repos:
3434
- id: trailing-whitespace
3535

3636
- repo: https://github.com/pre-commit/mirrors-mypy
37-
rev: v1.0.1
37+
rev: v1.1.1
3838
hooks:
3939
- id: mypy
4040
additional_dependencies: [types-pyyaml, types-requests]
4141

4242
- repo: https://github.com/codespell-project/codespell
43-
rev: v2.2.2
43+
rev: v2.2.4
4444
hooks:
4545
- id: codespell
4646
stages: [commit, commit-msg]
@@ -49,7 +49,7 @@ repos:
4949
args: [--ignore-words-list, "nd,te,fpr"]
5050

5151
- repo: https://github.com/pre-commit/mirrors-prettier
52-
rev: v3.0.0-alpha.4
52+
rev: v3.0.0-alpha.6
5353
hooks:
5454
- id: prettier
5555
args: [--write] # edit files in-place
@@ -60,7 +60,7 @@ repos:
6060
exclude: ^(site/src/figs/.+\.svelte|data/wbm/20.+\..+|site/src/routes/.+\.(yml|yaml|json))$
6161

6262
- repo: https://github.com/pre-commit/mirrors-eslint
63-
rev: v8.34.0
63+
rev: v8.37.0
6464
hooks:
6565
- id: eslint
6666
types: [file]

matbench_discovery/plots.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,7 @@ def rolling_mae_vs_hull_dist(
339339
with_sem: bool = True,
340340
show_dft_acc: bool = False,
341341
show_dummy_mae: bool = False,
342+
pbar: bool = True,
342343
**kwargs: Any,
343344
) -> plt.Axes | go.Figure:
344345
"""Rolling mean absolute error as the energy to the convex hull is varied. A scale
@@ -380,6 +381,8 @@ def rolling_mae_vs_hull_dist(
380381
meV/atom. Defaults to False.
381382
show_dummy_mae (bool, optional): If True, plot a line at the dummy MAE of always
382383
predicting the target mean.
384+
pbar (bool, optional): If True, show a progress bar during rolling MAE
385+
calculation. Defaults to True.
383386
**kwargs: Additional keyword arguments to pass to df.plot().
384387
385388
Returns:
@@ -396,8 +399,10 @@ def rolling_mae_vs_hull_dist(
396399
df_rolling_err = pd.DataFrame(columns=models, index=bins)
397400
df_err_std = df_rolling_err.copy()
398401

399-
for model in (pbar := tqdm(models, desc="Calculating rolling MAE")):
400-
pbar.set_postfix_str(model)
402+
for model in (
403+
prog_bar := tqdm(models, desc="Calculating rolling MAE", disable=not pbar)
404+
):
405+
prog_bar.set_postfix_str(model)
401406
for idx, bin_center in enumerate(bins):
402407
low = bin_center - window
403408
high = bin_center + window

scripts/hist_classified_stable_vs_hull_dist_batches.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
df_preds[each_pred_col] = (
2828
df_preds[each_true_col] + df_preds[model_name] - df_preds[e_form_col]
2929
)
30-
df_preds[(batch_col := "batch_idx")] = df_preds.index.str.split("-").str[-2].astype(int)
30+
df_preds[(batch_col := "batch_idx")] = df_preds.index.str.split("-").str[1].astype(int)
3131

3232

3333
# %% matplotlib

scripts/hist_classified_stable_vs_hull_dist_models.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,7 @@
1010
from pymatviz.utils import save_fig
1111

1212
from matbench_discovery import FIGS, ROOT, today
13-
from matbench_discovery.plots import (
14-
hist_classified_stable_vs_hull_dist,
15-
plt,
16-
)
13+
from matbench_discovery.plots import hist_classified_stable_vs_hull_dist, plt
1714
from matbench_discovery.preds import df_metrics, df_preds, e_form_col, each_true_col
1815

1916
__author__ = "Janosh Riebesell"
@@ -25,8 +22,8 @@
2522
e_form_preds = "e_form_per_atom_pred"
2623
each_pred_col = "e_above_hull_pred"
2724
facet_col = "Model"
28-
models = list(df_metrics)
29-
# models = df_metrics.T.MAE.nsmallest(6).index # top 6 models by MAE
25+
# sort facet plots by model's F1 scores (optionally only show top n=6)
26+
models = list(df_metrics.T.F1.sort_values().index)[::-1]
3027

3128
df_melt = df_preds.melt(
3229
id_vars=hover_cols,
@@ -45,7 +42,7 @@
4542
rows, cols = len(models) // 2, 2
4643
which_energy: Final = "true"
4744
kwds = (
48-
dict(facet_col=facet_col, facet_col_wrap=cols)
45+
dict(facet_col=facet_col, facet_col_wrap=cols, category_orders={facet_col: models})
4946
if backend == "plotly"
5047
else dict(by=facet_col, figsize=(20, 20), layout=(rows, cols), bins=500)
5148
)

scripts/prc_roc_curves_models.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,13 @@
1111
from tqdm import tqdm
1212

1313
from matbench_discovery import FIGS
14-
from matbench_discovery.plots import pio
14+
from matbench_discovery import plots as plots
1515
from matbench_discovery.preds import df_each_pred, df_preds, each_true_col
1616

1717
__author__ = "Janosh Riebesell"
1818
__date__ = "2023-01-30"
1919

2020

21-
pio.templates.default
2221
line = dict(dash="dash", width=0.5)
2322

2423
facet_col = "Model"

scripts/rolling_mae_vs_hull_dist_wbm_batches.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,8 @@
1313
__author__ = "Rhys Goodall, Janosh Riebesell"
1414
__date__ = "2022-06-18"
1515

16-
df_each_pred[(batch_col := "batch_idx")] = (
17-
"Batch " + df_each_pred.index.str.split("-").str[1]
18-
)
16+
batch_col = "batch_idx"
17+
df_each_pred[batch_col] = "Batch " + df_each_pred.index.str.split("-").str[1]
1918
df_err, df_std = None, None # variables to cache rolling MAE and std
2019

2120

@@ -43,6 +42,7 @@
4342
backend="matplotlib",
4443
ax=ax,
4544
just_plot_lines=idx > 1,
45+
pbar=False,
4646
)
4747

4848

@@ -54,7 +54,7 @@
5454

5555

5656
# %% plotly
57-
model = "Wrenformer" # ["M3GNet", "Wrenformer", "MEGNet", "Voronoi RF"]
57+
model = "CHGNet"
5858
df_pivot = df_each_pred.pivot(columns=batch_col, values=model)
5959

6060
# unstack two-level column index into new model column

site/src/figs/chgnet-rolling-mae-vs-hull-dist-wbm-batches.svelte

+1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

site/src/figs/hist-clf-true-hull-dist-models.svelte

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)