Skip to content

Commit 0249327

Browse files
committed
fix and unignore ruff DTZ005 FBT001 FBT002
1 parent 016e71f commit 0249327

20 files changed

+45
-30
lines changed

.pre-commit-config.yaml

+4-4
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ default_install_hook_types: [pre-commit, commit-msg]
88

99
repos:
1010
- repo: https://github.com/astral-sh/ruff-pre-commit
11-
rev: v0.4.2
11+
rev: v0.4.4
1212
hooks:
1313
- id: ruff
1414
args: [--fix]
@@ -57,7 +57,7 @@ repos:
5757
exclude: ^(site/src/figs/.+\.svelte|data/wbm/20.+\..+|site/src/(routes|figs).+\.(yaml|json)|changelog.md)$
5858

5959
- repo: https://github.com/pre-commit/mirrors-eslint
60-
rev: v9.1.1
60+
rev: v9.2.0
6161
hooks:
6262
- id: eslint
6363
types: [file]
@@ -71,15 +71,15 @@ repos:
7171
- typescript-eslint
7272

7373
- repo: https://github.com/python-jsonschema/check-jsonschema
74-
rev: 0.28.2
74+
rev: 0.28.3
7575
hooks:
7676
- id: check-jsonschema
7777
files: ^models/(.+)/\1.*\.yml$
7878
args: [--schemafile, tests/model-schema.yml]
7979
- id: check-github-actions
8080

8181
- repo: https://github.com/RobertCraigie/pyright-python
82-
rev: v1.1.360
82+
rev: v1.1.363
8383
hooks:
8484
- id: pyright
8585
args: [--level, error]

matbench_discovery/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import json
44
import os
55
import warnings
6-
from datetime import datetime
6+
from datetime import UTC, datetime
77
from importlib.metadata import Distribution
88

99
import matplotlib.pyplot as plt
@@ -43,7 +43,7 @@
4343
# threshold on hull distance for a material to be considered stable
4444
STABILITY_THRESHOLD = 0
4545

46-
timestamp = f"{datetime.now():%Y-%m-%d@%H-%M-%S}"
46+
timestamp = f"{datetime.now(tz=UTC):%Y-%m-%d@%H-%M-%S}"
4747
today = timestamp.split("@")[0]
4848

4949
# filter pymatgen warnings that spam the logs when e.g. applying corrections to

matbench_discovery/data.py

+2
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def as_dict_handler(obj: Any) -> dict[str, Any] | None:
5151

5252
def load(
5353
key: str,
54+
*,
5455
version: str = figshare_versions[-1],
5556
cache_dir: str | Path = default_cache_dir,
5657
hydrate: bool = False,
@@ -149,6 +150,7 @@ def load(
149150

150151
def glob_to_df(
151152
pattern: str,
153+
*,
152154
reader: Callable[[Any], pd.DataFrame] | None = None,
153155
pbar: bool = True,
154156
**kwargs: Any,

matbench_discovery/energy.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717

1818
def get_elemental_ref_entries(
19-
entries: Sequence[EntryLike], verbose: bool = True
19+
entries: Sequence[EntryLike], *, verbose: bool = True
2020
) -> dict[str, Entry]:
2121
"""Get the lowest energy pymatgen Entry for each element in a list of entries.
2222

matbench_discovery/enums.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88

99
class LabelEnum(StrEnum):
10-
"""StrEnum with optional label and description attributes plus dict() method."""
10+
"""StrEnum with optional label and description attributes plus dict() methods."""
1111

1212
def __new__(
1313
cls, val: str, label: str | None = None, desc: str | None = None
@@ -153,7 +153,7 @@ class Open(LabelEnum):
153153

154154
@unique
155155
class TestSubset(LabelEnum):
156-
"""Test set subsets."""
156+
"""Which subset of the test data to use for evaluation."""
157157

158158
uniq_protos = "uniq_protos", "Unique Structure Prototypes"
159159
ten_k_most_stable = "10k_most_stable", "10k Most Stable"

matbench_discovery/metrics.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
def classify_stable(
1818
e_above_hull_true: pd.Series,
1919
e_above_hull_pred: pd.Series,
20+
*,
2021
stability_threshold: float | None = 0,
2122
fillna: bool = True,
2223
) -> tuple[pd.Series, pd.Series, pd.Series, pd.Series]:
@@ -69,6 +70,7 @@ def classify_stable(
6970
def stable_metrics(
7071
each_true: Sequence[float],
7172
each_pred: Sequence[float],
73+
*,
7274
stability_threshold: float = STABILITY_THRESHOLD,
7375
fillna: bool = True,
7476
) -> dict[str, float]:
@@ -95,7 +97,10 @@ def stable_metrics(
9597
Recall, Accuracy, F1, TPR, FPR, TNR, FNR, MAE, RMSE, R2.
9698
"""
9799
n_true_pos, n_false_neg, n_false_pos, n_true_neg = map(
98-
sum, classify_stable(each_true, each_pred, stability_threshold, fillna)
100+
sum,
101+
classify_stable(
102+
each_true, each_pred, stability_threshold=stability_threshold, fillna=fillna
103+
),
99104
)
100105

101106
n_total_pos = n_true_pos + n_false_neg

matbench_discovery/plots.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,9 @@ def hist_classified_stable_vs_hull_dist(
108108
df.groupby(kwargs["facet_col"]) if "facet_col" in kwargs else [(None, df)]
109109
):
110110
true_pos, false_neg, false_pos, true_neg = classify_stable(
111-
df_group[each_true_col], df_group[each_pred_col], stability_threshold
111+
df_group[each_true_col],
112+
df_group[each_pred_col],
113+
stability_threshold=stability_threshold,
112114
)
113115

114116
# switch between hist of DFT-computed and model-predicted convex hull distance
@@ -264,6 +266,7 @@ def hist_classified_stable_vs_hull_dist(
264266
def rolling_mae_vs_hull_dist(
265267
e_above_hull_true: pd.Series,
266268
e_above_hull_preds: pd.DataFrame | dict[str, pd.Series],
269+
*,
267270
df_rolling_err: pd.DataFrame | None = None,
268271
df_err_std: pd.DataFrame | None = None,
269272
window: float = 0.04,
@@ -567,6 +570,7 @@ def rolling_mae_vs_hull_dist(
567570
def cumulative_metrics(
568571
e_above_hull_true: pd.Series,
569572
df_preds: pd.DataFrame,
573+
*,
570574
metrics: Sequence[str] = ("Precision", "Recall"),
571575
stability_threshold: float = 0, # set stability threshold as distance to convex
572576
# hull in eV / atom, usually 0 or 0.1 eV
@@ -635,7 +639,10 @@ def cumulative_metrics(
635639
each_true = e_above_hull_true.loc[each_pred.index]
636640

637641
true_pos_cum, false_neg_cum, false_pos_cum, _true_neg_cum = map(
638-
np.cumsum, classify_stable(each_true, each_pred, stability_threshold)
642+
np.cumsum,
643+
classify_stable(
644+
each_true, each_pred, stability_threshold=stability_threshold
645+
),
639646
)
640647

641648
# precision aka positive predictive value (PPV)

matbench_discovery/preds.py

+1
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ class PredFiles(Files):
6666

6767

6868
def load_df_wbm_with_preds(
69+
*,
6970
models: Sequence[str] = (*PRED_FILES,),
7071
pbar: bool = True,
7172
id_col: str = Key.mat_id,

models/alignn/train_alignn.py

+1
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@
9595

9696
def df_to_loader(
9797
df: pd.DataFrame,
98+
*,
9899
batch_size: int = 128,
99100
line_graph: bool = True,
100101
pin_memory: bool = False,

models/cgcnn/plot_structure_perturbation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,6 @@
3333

3434
# %%
3535
fig, axs = plt.subplots(3, 4, figsize=(12, 10))
36-
for idx, ax in enumerate(axs.flat, 1):
36+
for idx, ax in enumerate(axs.flat, start=1):
3737
plot_structure_2d(perturb_structure(struct), ax=ax)
3838
ax.set(title=f"perturbation {idx}")

models/chgnet/analyze_chgnet.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@
7979
struct_col = Key.init_struct
8080

8181
fig.suptitle(f"{n_struct} {struct_col} {title}", fontsize=16, fontweight="bold", y=1.05)
82-
for idx, row in enumerate(df_cse.loc[df_diff.index].itertuples(), 1):
82+
for idx, row in enumerate(df_cse.loc[df_diff.index].itertuples(), start=1):
8383
struct = Structure.from_dict(getattr(row, struct_col))
8484
ax = plot_structure_2d(struct, ax=axs.flat[idx - 1])
8585
_, spg_num = struct.get_space_group_info()

pyproject.toml

+3-7
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,10 @@ ignore = [
9090
"C408", # unnecessary-collection-call
9191
"C901",
9292
"COM812",
93-
"D100", # undocumented-public-module
9493
"D205", # blank-line-after-summary
95-
"DTZ005",
9694
"E731", # lambda-assignment
9795
"EM101",
9896
"EM102",
99-
"FBT001",
100-
"FBT002",
10197
"FIX002",
10298
"INP001",
10399
"ISC001",
@@ -125,9 +121,9 @@ isort.known-third-party = ["wandb"]
125121
isort.split-on-trailing-comma = false
126122

127123
[tool.ruff.lint.per-file-ignores]
128-
"tests/*" = ["D", "S101"]
129-
"matbench_discovery/plots.py" = ["ERA001"] # allow commented out code
130-
"matbench_discovery/preds.py" = ["ERA001"] # allow commented out code
124+
"tests/*" = ["D", "FBT001", "FBT002", "S101"]
125+
"matbench_discovery/plots.py" = ["ERA001"] # allow commented out code
126+
"matbench_discovery/preds.py" = ["ERA001"] # allow commented out code
131127
"scripts/*" = ["D", "ERA001", "S101"]
132128
"models/*" = ["D", "ERA001", "S101"]
133129
"data/*" = ["ERA001", "S101"]

scripts/analyze_model_failure_cases.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
)
5252
fig.suptitle(title, fontsize=20, fontweight="bold", y=1.05)
5353

54-
for idx, (mat_id, error) in enumerate(errors.items(), 1):
54+
for idx, (mat_id, error) in enumerate(errors.items(), start=1):
5555
struct = df_cse[struct_col].loc[mat_id]
5656
if "structure" in struct:
5757
struct = struct["structure"]

scripts/model_figs/parity_energy_models.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@
176176
y_title = fig.layout.yaxis.title.text
177177

178178
# iterate over subplots and set new title
179-
for idx, anno in enumerate(fig.layout.annotations, 1):
179+
for idx, anno in enumerate(fig.layout.annotations, start=1):
180180
traces = [t for t in fig.data if t.xaxis == f"x{idx if idx > 1 else ''}"]
181181
# assert len(traces) in (0, 4), f"Plots must have 0 or 4 traces, got {len(traces)=}"
182182

scripts/model_figs/per_element_errors.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
df_elem_err.index.name = "symbol"
6767

6868

69-
# %%
69+
# %% plot number of structures containing each element in MP and WBM
7070
for label, srs in (
7171
("MP", df_elem_err[train_count_col]),
7272
("WBM", df_frac_comp.where(pd.isna, 1).sum()),

scripts/model_figs/rolling_mae_vs_hull_dist_wbm_batches.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
assert len(markers) == 5 # number of iterations of element substitution in WBM data set
6464
model = Model.chgnet
6565

66-
for idx, marker in enumerate(markers, 1):
66+
for idx, marker in enumerate(markers, start=1):
6767
# select all rows from WBM step=idx
6868
df_step = df_preds[df_preds.index.str.startswith(f"wbm-{idx}-")]
6969
df_each_step = df_each_pred[df_each_pred.index.str.startswith(f"wbm-{idx}-")]

scripts/project_compositions.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# %%
44
import os
5-
from datetime import datetime
5+
from datetime import UTC, datetime
66
from typing import Any, Literal
77

88
import numpy as np
@@ -39,7 +39,7 @@
3939
print(f"{data_path=}")
4040
print(f"{out_dim=}")
4141
print(f"{projection_type=}")
42-
start_time = datetime.now()
42+
start_time = datetime.now(tz=UTC)
4343
print(f"job started at {start_time:%Y-%m-%d %H:%M:%S}")
4444
df_in = pd.read_csv(data_path, na_filter=False).set_index(Key.mat_id)
4545

@@ -92,7 +92,7 @@ def sum_one_hot_elem(formula: str) -> np.ndarray[Any, np.int64]:
9292
df_in[out_cols].to_csv(out_path)
9393

9494
print(f"Wrote projections to {out_path!r}")
95-
end_time = datetime.now()
95+
end_time = datetime.now(tz=UTC)
9696
print(
9797
f"Job finished at {end_time:%Y-%m-%d %H:%M:%S} and took "
9898
f"{(end_time - start_time).seconds} sec"

scripts/update_wandb_runs.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
updated_runs: list[Run] = []
3535
wet_run = input("Wet run or dry run? [w/d] ").lower().startswith("w")
3636

37-
for idx, run in enumerate(runs, 1):
37+
for idx, run in enumerate(runs, start=1):
3838
old_config, new_config = run.config.copy(), run.config.copy()
3939

4040
new_display_name = run.display_name.replace(

scripts/upload_to_figshare.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@
2929
BASE_URL = "https://api.figshare.com/v2"
3030

3131

32-
def make_request(method: str, url: str, data: Any = None, binary: bool = False) -> Any:
32+
def make_request(
33+
method: str, url: str, *, data: Any = None, binary: bool = False
34+
) -> Any:
3335
"""Make a token-authorized HTTP request to the Figshare API."""
3436
headers = {"Authorization": f"token {TOKEN}"}
3537
if data is not None and not binary:

scripts/wbm_umap_projection.py

+1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
# %%
4141
def featurize_dataframe(
4242
df_in: pd.DataFrame | pd.Series,
43+
*,
4344
struct_col: str = "structure",
4445
ignore_errors: bool = True,
4546
chunk_size: int = 30,

0 commit comments

Comments
 (0)