Skip to content

Commit a3b4362

Browse files
committed
fix pymatviz bin_df_cols util imports + ruff fixes + bump site deps
1 parent 202a305 commit a3b4362

22 files changed

+121
-105
lines changed

.github/workflows/gh-pages.yml

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@ jobs:
1111
build:
1212
uses: janosh/workflows/.github/workflows/nodejs-gh-pages.yml@main
1313
with:
14+
install-cmd: npm install --force
1415
python-version: "3.11"
1516
working-directory: site
1617
pre-build: |
1718
pip install lazydocs
1819
# lazydocs needs package deps to be installed
19-
pip install -e ..
20+
pip install --editable ..
2021
python ../scripts/make_api_docs.py

.pre-commit-config.yaml

+9-11
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
ci:
22
autoupdate_schedule: quarterly
3-
skip: [pyright]
3+
skip: [pyright, eslint]
44

55
default_stages: [commit]
66

77
default_install_hook_types: [pre-commit, commit-msg]
88

99
repos:
1010
- repo: https://github.com/astral-sh/ruff-pre-commit
11-
rev: v0.3.4
11+
rev: v0.3.7
1212
hooks:
1313
- id: ruff
1414
args: [--fix]
@@ -20,7 +20,7 @@ repos:
2020
- id: format-ipy-cells
2121

2222
- repo: https://github.com/pre-commit/pre-commit-hooks
23-
rev: v4.5.0
23+
rev: v4.6.0
2424
hooks:
2525
- id: check-case-conflict
2626
- id: check-symlinks
@@ -57,31 +57,29 @@ repos:
5757
exclude: ^(site/src/figs/.+\.svelte|data/wbm/20.+\..+|site/src/(routes|figs).+\.(yaml|json)|changelog.md)$
5858

5959
- repo: https://github.com/pre-commit/mirrors-eslint
60-
rev: v9.0.0-rc.0
60+
rev: v9.0.0
6161
hooks:
6262
- id: eslint
6363
types: [file]
64-
args: [--fix]
64+
args: [--fix, --config, site/eslint.config.js]
6565
files: \.(js|ts|svelte)$
6666
additional_dependencies:
6767
- eslint
68+
- eslint-plugin-svelte
6869
- svelte
6970
- typescript
70-
- eslint-plugin-svelte
71-
- "@typescript-eslint/eslint-plugin"
72-
- "@typescript-eslint/parser"
73-
- svelte-eslint-parser
71+
- typescript-eslint
7472

7573
- repo: https://github.com/python-jsonschema/check-jsonschema
76-
rev: 0.28.0
74+
rev: 0.28.2
7775
hooks:
7876
- id: check-jsonschema
7977
files: ^models/(.+)/\1.*\.yml$
8078
args: [--schemafile, tests/model-schema.yml]
8179
- id: check-github-actions
8280

8381
- repo: https://github.com/RobertCraigie/pyright-python
84-
rev: v1.1.356
82+
rev: v1.1.358
8583
hooks:
8684
- id: pyright
8785
args: [--level, error]

data/mp/build_phase_diagram.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
# drop the structure, just load ComputedEntry, makes the PPD faster to build and load
4646
mp_computed_entries = [ComputedEntry.from_dict(dct) for dct in tqdm(df.entry)]
4747

48-
print(f"{len(mp_computed_entries) = :,} on {today}")
48+
print(f"{len(mp_computed_entries)=:,} on {today}")
4949
# len(mp_computed_entries) = 146,323 on 2022-09-16
5050
# len(mp_computed_entries) = 154,719 on 2023-02-07
5151

data/mp/eda_mp_trj.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,14 @@
7171

7272

7373
# %%
74-
info_to_id = lambda info: f"{info[Key.task_id]}-{info['calc_id']}-{info['ionic_step']}"
74+
def info_dict_to_id(info: dict[str, int | str]) -> str:
75+
"""Construct a unique frame ID from the atoms info dict."""
76+
return f"{info[Key.task_id]}-{info['calc_id']}-{info['ionic_step']}"
77+
7578

7679
df_mp_trj = pd.DataFrame(
7780
{
78-
info_to_id(atoms.info): atoms.info
81+
info_dict_to_id(atoms.info): atoms.info
7982
| {key: atoms.arrays.get(key) for key in ("forces", "magmoms")}
8083
| {"formula": str(atoms.symbols), Key.site_nums: atoms.symbols}
8184
for atoms_list in tqdm(mp_trj_atoms.values(), total=len(mp_trj_atoms))
@@ -101,8 +104,8 @@
101104
# %%
102105
def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]:
103106
"""Annotate each periodic table tile with the number of values in its histogram."""
104-
facecolor = cmap(norm(np.sum(len(hist_vals)))) if hist_vals else "none"
105-
bbox = dict(facecolor=facecolor, alpha=0.4, pad=2, edgecolor="none")
107+
face_color = cmap(norm(np.sum(len(hist_vals)))) if hist_vals else "none"
108+
bbox = dict(facecolor=face_color, alpha=0.4, pad=2, edgecolor="none")
106109
return dict(text=si_fmt(len(hist_vals), ".0f"), bbox=bbox)
107110

108111

@@ -116,7 +119,7 @@ def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]:
116119
# project magmoms onto symbols in dict
117120
df_mp_trj_elem_magmom = pd.DataFrame(
118121
[
119-
dict(zip(elems, magmoms))
122+
dict(zip(elems, magmoms, strict=False))
120123
for elems, magmoms in df_mp_trj.set_index(Key.site_nums)[Key.magmoms]
121124
.dropna()
122125
.items()
@@ -159,7 +162,7 @@ def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]:
159162
if srs_mp_trj_elem_forces is None:
160163
df_mp_trj_elem_forces = pd.DataFrame(
161164
[
162-
dict(zip(elems, np.abs(forces).mean(axis=1)))
165+
dict(zip(elems, np.abs(forces).mean(axis=1), strict=False))
163166
for elems, forces in df_mp_trj.set_index(Key.site_nums)[Key.forces].items()
164167
]
165168
)

data/mp/get_mp_energies.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
docs = mpr.thermo.search(fields=fields, thermo_types=["GGA_GGA+U"])
4343

4444
assert fields == set(docs[0]), f"missing fields: {fields - set(docs[0])}"
45-
print(f"{today}: {len(docs) = :,}")
45+
print(f"{today}: {len(docs)=:,}")
4646
# 2022-08-13: len(docs) = 146,323
4747
# 2023-01-10: len(docs) = 154,718
4848

data/mp/get_mp_traj.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
key=lambda doc: int(doc[Key.task_id].split("-")[1]),
5555
)
5656

57-
print(f"{today}: {len(task_docs) = :,}")
57+
print(f"{today}: {len(task_docs)=:,}")
5858

5959
df_tasks = pd.DataFrame(task_docs).drop(columns=["_id"]).set_index(Key.task_id)
6060
df_tasks.task_type.value_counts(dropna=False).plot.pie()

data/wbm/eda_wbm.py

+5
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959

6060

6161
# %% print prevalence of stable structures in full WBM and uniq-prototypes only
62+
print(f"{STABILITY_THRESHOLD=}")
6263
for df, label in (
6364
(df_wbm, "full WBM"),
6465
(df_wbm.query(Key.uniq_proto), "WBM unique prototypes"),
@@ -67,6 +68,10 @@
6768
stable_rate = n_stable / len(df)
6869
print(f"{label}: {stable_rate=:.1%} ({n_stable:,} out of {len(df):,})")
6970

71+
# on 2024-04-15: STABILITY_THRESHOLD=0
72+
# full WBM: stable_rate=16.7% (42,825 out of 256,963)
73+
# WBM unique prototypes: stable_rate=15.3% (32,942 out of 215,488)
74+
7075

7176
# %%
7277
for dataset, count_mode, elem_counts in all_counts:

matbench_discovery/plots.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
colorway = ("lightseagreen", "orange", "lightsalmon", "dodgerblue")
4040
clf_labels = ("True Positive", "False Negative", "False Positive", "True Negative")
4141
clf_colors = ("lightseagreen", "orange", "lightsalmon", "dodgerblue")
42-
clf_color_map = dict(zip(clf_labels, clf_colors))
42+
clf_color_map = dict(zip(clf_labels, clf_colors, strict=True))
4343

4444

4545
def hist_classified_stable_vs_hull_dist(

matbench_discovery/structure.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def perturb_structure(struct: Structure, gamma: float = 1.5) -> Structure:
4545
plt.axvline(mean, color="gray", linestyle="dashed", linewidth=1)
4646
# annotate the mean line
4747
plt.annotate(
48-
f"{mean = :.2f}",
48+
f"{mean=:.2f}",
4949
xy=(mean, 1),
5050
# use ax coords for y
5151
xycoords=("data", "axes fraction"),

models/bowsr/join_bowsr_results.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,7 @@
3838
df_wbm = pd.read_csv(DATA_FILES.wbm_summary).set_index(Key.mat_id)
3939

4040

41-
print(
42-
f"{len(df_bowsr) - len(df_wbm) = :,} missing ({len(df_bowsr):,} - {len(df_wbm):,})"
43-
)
41+
print(f"{len(df_bowsr) - len(df_wbm)=:,} missing ({len(df_bowsr):,} - {len(df_wbm):,})")
4442

4543

4644
# %% sanity check: since Bowsr uses MEGNet as energy model final BOWSR energy and Megnet

models/chgnet/analyze_chgnet.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,9 @@
7979
struct_col = Key.init_struct
8080

8181
fig.suptitle(f"{n_struct} {struct_col} {title}", fontsize=16, fontweight="bold", y=1.05)
82-
for idx, (ax, row) in enumerate(
83-
zip(axs.flat, df_cse.loc[df_diff.index].itertuples()), 1
84-
):
82+
for idx, row in enumerate(df_cse.loc[df_diff.index].itertuples(), 1):
8583
struct = Structure.from_dict(getattr(row, struct_col))
86-
plot_structure_2d(struct, ax=ax)
84+
ax = plot_structure_2d(struct, ax=axs.flat[idx - 1])
8785
_, spg_num = struct.get_space_group_info()
8886
formula = struct.composition.reduced_formula
8987
ax.set_title(f"{idx}. {formula} (spg={spg_num})\n{row.Index}", fontweight="bold")

models/chgnet/ctk_trajectory_viewer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,8 @@ def update_structure(step: int) -> tuple[Structure, go.Figure]:
187187
init_struct.lattice = lattice
188188
if len(init_struct) != len(coords):
189189
raise ValueError(f"{len(init_struct)} != {len(coords)}")
190-
for site, coord in zip(init_struct, coords):
191-
site.coords = coord
190+
for idx, site in enumerate(init_struct):
191+
site.coords = coords[idx]
192192

193193
spg = init_struct.get_space_group_info()
194194
title = f"{material_id} - Spacegroup = {spg}"

models/m3gnet/pre_vs_post_m3gnet_relaxation.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -162,17 +162,17 @@
162162
)
163163

164164
wbm_pbc_diffs_mean = df_m3gnet_is2re.wbm_pbc_diffs.mean()
165-
print(f"{wbm_pbc_diffs_mean = :.3}")
165+
print(f"{wbm_pbc_diffs_mean=:.3}")
166166

167167
m3gnet_pbc_diffs_mean = df_m3gnet_is2re.m3gnet_pbc_diffs.mean()
168-
print(f"{m3gnet_pbc_diffs_mean = :.3}")
168+
print(f"{m3gnet_pbc_diffs_mean=:.3}")
169169

170170
m3gnet_to_final_wbm_pbc_diffs_mean = (
171171
df_m3gnet_is2re.m3gnet_to_final_wbm_pbc_diffs.mean()
172172
)
173-
print(f"{m3gnet_to_final_wbm_pbc_diffs_mean = :.3}")
173+
print(f"{m3gnet_to_final_wbm_pbc_diffs_mean=:.3}")
174174

175-
print(f"{wbm_pbc_diffs_mean / m3gnet_pbc_diffs_mean = :.3}")
175+
print(f"{wbm_pbc_diffs_mean / m3gnet_pbc_diffs_mean=:.3}")
176176

177177

178178
# %%

models/wrenformer/analyze_wrenformer.py

+19-11
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
import numpy as np
55
import pandas as pd
66
from aviary.wren.utils import get_isopointal_proto_from_aflow
7+
from IPython.display import display
78
from pymatviz import spacegroup_hist, spacegroup_sunburst
89
from pymatviz.io import df_to_html_table, df_to_pdf, save_fig
9-
from pymatviz.powerups import add_identity_line, bin_df_cols
10+
from pymatviz.powerups import add_identity_line
1011
from pymatviz.ptable import ptable_heatmap_plotly
12+
from pymatviz.utils import bin_df_cols
1113

1214
from matbench_discovery import PDF_FIGS, SITE_FIGS, Model
1315
from matbench_discovery.data import DATA_FILES, df_wbm
@@ -20,6 +22,7 @@
2022

2123
# %%
2224
model = Model.wrenformer
25+
model_low = model.lower()
2326
max_each_true = 1
2427
min_each_pred = 1
2528
df_each_pred[Key.each_true] = df_preds[Key.each_true]
@@ -42,9 +45,11 @@
4245

4346

4447
# %%
45-
ax = spacegroup_hist(df_bad[Key.spacegroup])
46-
ax.set_title(f"Spacegroup hist for {title}", y=1.15)
47-
save_fig(ax, f"{PDF_FIGS}/spacegroup-hist-{model.lower()}-failures.pdf")
48+
fig = spacegroup_hist(df_bad[Key.spacegroup])
49+
fig.layout.title.update(text=f"Spacegroup hist for {title}", y=0.96)
50+
fig.layout.margin.update(l=0, r=0, t=80, b=0)
51+
save_fig(fig, f"{PDF_FIGS}/spacegroup-hist-{model.lower()}-failures.pdf")
52+
fig.show()
4853

4954

5055
# %%
@@ -68,29 +73,32 @@
6873
df_proto_counts[proto_col] = df_proto_counts[proto_col].str.replace("_", "-")
6974

7075
styler = df_proto_counts.head(10).style.background_gradient(cmap="viridis")
71-
72-
df_to_html_table(styler, f"{SITE_FIGS}/proto-counts-{model}-failures.svelte")
73-
df_to_pdf(styler, f"{PDF_FIGS}/proto-counts-{model}-failures.pdf")
76+
styler.set_caption(f"Top 10 {proto_col} in {len(df_bad)} {model} failures")
77+
display(styler)
78+
df_to_html_table(styler, f"{SITE_FIGS}/proto-counts-{model_low}-failures.svelte")
79+
df_to_pdf(styler, f"{PDF_FIGS}/proto-counts-{model_low}-failures.pdf")
7480

7581

7682
# %%
77-
fig = spacegroup_sunburst(df_bad[Key.spacegroup], width=350, height=350)
83+
fig = spacegroup_sunburst(
84+
df_bad[Key.spacegroup], width=350, height=350, show_counts="percent"
85+
)
7886
# fig.layout.title.update(text=f"Spacegroup sunburst for {title}", x=0.5, font_size=14)
7987
fig.layout.margin.update(l=1, r=1, t=1, b=1)
8088
fig.show()
8189

8290

8391
# %%
84-
save_fig(fig, f"{PDF_FIGS}/spacegroup-sunburst-{model.lower()}-failures.pdf")
85-
save_fig(fig, f"{SITE_FIGS}/spacegroup-sunburst-{model}-failures.svelte")
92+
save_fig(fig, f"{PDF_FIGS}/spacegroup-sunburst-{model_low}-failures.pdf")
93+
save_fig(fig, f"{SITE_FIGS}/spacegroup-sunburst-{model_low}-failures.svelte")
8694

8795

8896
# %%
8997
fig = ptable_heatmap_plotly(df_bad[Key.formula])
9098
fig.layout.title = f"Elements in {title}"
9199
fig.layout.margin = dict(l=0, r=0, t=50, b=0)
92100
fig.show()
93-
save_fig(fig, f"{PDF_FIGS}/elements-{model.lower()}-failures.pdf")
101+
save_fig(fig, f"{PDF_FIGS}/elements-{model_low}-failures.pdf")
94102

95103

96104
# %%

scripts/analyze_model_failure_cases.py

+9-11
Original file line numberDiff line numberDiff line change
@@ -41,29 +41,27 @@
4141
n_structs = len(axs.flat)
4242
struct_col = {"initial": Key.init_struct, "final": Key.cse}[init_or_final]
4343

44-
errs = {
44+
errors = {
4545
"best": df_each_err[Key.each_err_models].nsmallest(n_structs),
4646
"worst": df_each_err[Key.each_err_models].nlargest(n_structs),
4747
}[good_or_bad]
4848
title = (
49-
f"{good_or_bad.title()} {len(errs)} {init_or_final} structures (across "
49+
f"{good_or_bad.title()} {len(errors)} {init_or_final} structures (across "
5050
f"{len(list(df_each_pred))} models)\nErrors in (ev/atom)"
5151
)
5252
fig.suptitle(title, fontsize=20, fontweight="bold", y=1.05)
5353

54-
for idx, (ax, (mat_id, error)) in enumerate(zip(axs.flat, errs.items()), 1):
54+
for idx, (mat_id, error) in enumerate(errors.items(), 1):
5555
struct = df_cse[struct_col].loc[mat_id]
5656
if "structure" in struct:
5757
struct = struct["structure"]
5858
struct = Structure.from_dict(struct)
59-
plot_structure_2d(struct, ax=ax)
59+
ax = plot_structure_2d(struct, ax=axs.flat[idx - 1])
6060
_, spg_num = struct.get_space_group_info()
6161
formula = struct.composition.reduced_formula
62-
ax.set_title(
63-
f"{idx}. {formula} (spg={spg_num})\n{mat_id} {error=:.2f}",
64-
fontweight="bold",
65-
)
66-
out_path = f"{PDF_FIGS}/{good_or_bad}-{len(errs)}-structures-{init_or_final}.webp"
62+
ax_title = f"{idx}. {formula} (spg={spg_num})\n{mat_id} {error=:.2f}"
63+
ax.set_title(ax_title, fontweight="bold")
64+
out_path = f"{PDF_FIGS}/{good_or_bad}-{len(errors)}-structures-{init_or_final}.webp"
6765
# fig.savefig(out_path, dpi=300)
6866

6967

@@ -73,7 +71,7 @@
7371
for idx, model in enumerate((Key.each_err_models, *df_metrics)):
7472
large_errors = df_each_err[model].abs().nlargest(n_structs)
7573
small_errors = df_each_err[model].abs().nsmallest(n_structs)
76-
for label, errors in zip(("min", "max"), (large_errors, small_errors)):
74+
for label, errors in (("min", large_errors), ("max", small_errors)):
7775
fig.add_histogram(
7876
x=df_wbm.loc[errors.index][fp_diff_col].values,
7977
name=f"{model} err<sub>{label}</sub>",
@@ -339,7 +337,7 @@
339337
y_label = "E<sub>above hull</sub> error (eV/atom)"
340338
n_structs = 1000
341339

342-
for label, which in zip(("min", "max"), ("nlargest", "nsmallest")):
340+
for label, which in (("min", "nlargest"), ("max", "nsmallest")):
343341
fig = go.Figure()
344342
for model in df_metrics:
345343
errors = getattr(df_each_err[model].abs(), which)(n_structs)

scripts/model_figs/parity_energy_models.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
import numpy as np
1111
import plotly.express as px
1212
from pymatviz.io import save_fig
13-
from pymatviz.powerups import add_identity_line, bin_df_cols
13+
from pymatviz.powerups import add_identity_line
14+
from pymatviz.utils import bin_df_cols
1415

1516
from matbench_discovery import PDF_FIGS, SITE_FIGS
1617
from matbench_discovery.enums import Key, TestSubset

scripts/model_figs/rolling_mae_vs_hull_dist_wbm_batches.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969
df_each_step = df_each_pred[df_each_pred.index.str.startswith(f"wbm-{idx}-")]
7070

7171
title = f"Batch {idx} ({len(df_step.filter(like='e_').dropna()):,})"
72-
assert 1e4 < len(df_step) < 1e5, print(f"{len(df_step) = :,}")
72+
assert 1e4 < len(df_step) < 1e5, print(f"{len(df_step)=:,}")
7373
assert (df_step.index == df_each_step.index).all()
7474

7575
ax, df_err, df_std = rolling_mae_vs_hull_dist(

0 commit comments

Comments
 (0)