Skip to content

Commit 442abf9

Browse files
committed
fix typos, bump min pymatviz==0.10.1, bump ruff and fix errors
1 parent eb6ff66 commit 442abf9

12 files changed

+51
-49
lines changed

.pre-commit-config.yaml

+3-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ default_install_hook_types: [pre-commit, commit-msg]
88

99
repos:
1010
- repo: https://github.com/astral-sh/ruff-pre-commit
11-
rev: v0.5.6
11+
rev: v0.6.1
1212
hooks:
1313
- id: ruff
1414
args: [--fix]
@@ -57,7 +57,7 @@ repos:
5757
exclude: ^(site/src/figs/.+\.svelte|data/wbm/20.+\..+|site/src/(routes|figs).+\.(yaml|json)|changelog.md)$
5858

5959
- repo: https://github.com/pre-commit/mirrors-eslint
60-
rev: v9.8.0
60+
rev: v9.9.0
6161
hooks:
6262
- id: eslint
6363
types: [file]
@@ -79,7 +79,7 @@ repos:
7979
- id: check-github-actions
8080

8181
- repo: https://github.com/RobertCraigie/pyright-python
82-
rev: v1.1.375
82+
rev: v1.1.376
8383
hooks:
8484
- id: pyright
8585
args: [--level, error]

matbench_discovery/data.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def __new__(
214214

215215
obj._rel_path = file_path # type: ignore[attr-defined] # noqa: SLF001
216216
obj._url = url # type: ignore[attr-defined] # noqa: SLF001
217-
obj._label = label # type: ignore[attr-defined] # noqa: SLF001
217+
obj._label = label # type: ignore[attr-defined] # noqa: SLF001
218218

219219
return obj
220220

matbench_discovery/enums.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ class Task(LabelEnum):
8080
S2RE = "S2RE", "structure to relaxed energy"
8181
S2EF = "S2EF", "structure to energy, force"
8282
S2EFS = "S2EFS", "structure to energy, force, stress"
83-
S2EFSM = "S2EFSM" "structure to energy, force, stress, magmoms"
83+
S2EFSM = "S2EFSM", "structure to energy, force, stress, magmoms"
8484
IP2E = "IP2E", "initial prototype to energy"
8585
IS2E = "IS2E", "initial structure to energy"
8686
# IS2RE is for models that learned a discrete version of PES like CGCNN+P

matbench_discovery/plots.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -190,10 +190,10 @@ def hist_classified_stable_vs_hull_dist(
190190
)[which_energy]
191191

192192
if stability_threshold is not None:
193-
for ax in [fig] if isinstance(fig, plt.Axes) else fig.flat:
194-
ax.set(xlabel=xlabel, ylabel=y_label, xlim=x_lim)
193+
for ax_i in [fig] if isinstance(fig, plt.Axes) else fig.flat:
194+
ax_i.set(xlabel=xlabel, ylabel=y_label, xlim=x_lim)
195195
label = "Stability Threshold"
196-
ax.axvline(
196+
ax_i.axvline(
197197
stability_threshold, color="black", linestyle="--", label=label
198198
)
199199

@@ -228,8 +228,8 @@ def hist_classified_stable_vs_hull_dist(
228228
)
229229

230230
if backend == MATPLOTLIB:
231-
for ax in fig.flat if isinstance(fig, np.ndarray) else [fig]:
232-
ax_acc = ax.twinx()
231+
for ax_i in fig.flat if isinstance(fig, np.ndarray) else [fig]:
232+
ax_acc = ax_i.twinx()
233233
ax_acc.set_ylabel("Rolling Accuracy", color="darkblue")
234234
ax_acc.tick_params(labelcolor="darkblue")
235235
ax_acc.set(ylim=(0, 1.1))
@@ -681,14 +681,14 @@ def cumulative_metrics(
681681
rmse_interp = cubic_interpolate(model_range, rmse_cum[:n_pred_stable])
682682
dfs["RMSE"][model_name] = dict(zip(xs_model, rmse_interp(xs_model)))
683683

684-
for key in dfs:
684+
for key, df_i in dfs.items():
685+
# will be used as facet_col in plotly to split different metrics into subplots
686+
df_i["metric"] = key
685687
# drop all-NaN rows so plotly plot x-axis only extends to largest number of
686688
# predicted materials by any model
687-
dfs[key] = dfs[key].dropna(how="all")
688-
# will be used as facet_col in plotly to split different metrics into subplots
689-
dfs[key]["metric"] = key
689+
dfs[key] = df_i.dropna(how="all")
690690

691-
df_cum = pd.concat(dfs.values())
691+
df_cumu_metrics = pd.concat(dfs.values())
692692
# subselect rows for speed, plot has sufficient precision with 1k rows
693693
n_stable = sum(e_above_hull_true <= STABILITY_THRESHOLD)
694694

@@ -752,7 +752,7 @@ def cumulative_metrics(
752752
elif backend == PLOTLY:
753753
n_cols = kwargs.pop("facet_col_wrap", 2)
754754
kwargs.setdefault("facet_col_spacing", 0.03)
755-
fig = df_cum.plot(
755+
fig = df_cumu_metrics.plot(
756756
backend=backend,
757757
facet_col="metric",
758758
facet_col_wrap=n_cols,
@@ -802,7 +802,7 @@ def cumulative_metrics(
802802
else:
803803
raise ValueError(f"Unknown {backend=}")
804804

805-
return fig, df_cum
805+
return fig, df_cumu_metrics
806806

807807

808808
def wandb_scatter(table: wandb.Table, fields: dict[str, str], **kwargs: Any) -> None:

matbench_discovery/preds.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -72,14 +72,14 @@ class Model(Files, base_dir=f"{ROOT}/models"):
7272
"Wrenformer",
7373
)
7474

75-
## Proprietary Models
75+
# --- Proprietary Models
7676
# GNoMe
7777
gnome = "gnome/2023-11-01-gnome-preds-50076332.csv.gz", None, "GNoME"
7878

7979
# MatterSim
8080
mattersim = "mattersim/mattersim-wbm-IS2RE.csv.gz", None, "MatterSim"
8181

82-
## Miscellaneous
82+
# --- Model Combos
8383
# # CHGNet-relaxed structures fed into MEGNet for formation energy prediction
8484
# chgnet_megnet = "chgnet/2023-03-06-chgnet-0.2.0-wbm-IS2RE.csv.gz", None, "CHGNet→MEGNet"
8585
# # M3GNet-relaxed structures fed into MEGNet for formation energy prediction

models/mattersim/test_mattersim.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535

3636

3737
def dummy_mattersim_calculator(
38-
backbone: Literal["m3gnet", "graphomer"] = "m3gnet",
38+
backbone: Literal["m3gnet", "graphormer"] = "m3gnet",
3939
) -> SinglePointCalculator:
4040
"""
4141
This is a dummy function that makes a MatterSim calculator
@@ -69,7 +69,7 @@ def relax_atoms_list(
6969
atoms_list: list[ase.Atoms],
7070
fmax: float = 0.01,
7171
steps: int = 500,
72-
backbone: Literal["m3gnet", "graphomer"] = "m3gnet",
72+
backbone: Literal["m3gnet", "graphormer"] = "m3gnet",
7373
) -> list[ase.Atoms]:
7474
"""
7575
This function relax the atoms.
@@ -88,11 +88,11 @@ def relax_atoms_list(
8888
else:
8989
atoms.info["converged"] = True
9090

91-
if backbone == "graphomer":
91+
if backbone == "graphormer":
9292
# Please note that we only re-calculate the
93-
# energy in the case of MatterSim(graphomer).
93+
# energy in the case of MatterSim(graphormer).
9494
# The structure relaxation is always done with MatterSim(m3gnet).
95-
calc = dummy_mattersim_calculator(backbone="graphomer")
95+
calc = dummy_mattersim_calculator(backbone="graphormer")
9696
atoms.set_calculator(calc)
9797

9898
relaxed_atoms_list.append(atoms)
@@ -106,7 +106,6 @@ def parse_relaxed_atoms_list_as_df(
106106
) -> pd.DataFrame:
107107
e_form_col = "e_form_per_atom_mattersim"
108108

109-
## Read pre-computed CSEs by WBM
110109
wbm_cse_paths = DataFiles.wbm_computed_structure_entries.path
111110
df_cse = pd.read_json(wbm_cse_paths).set_index(Key.mat_id)
112111

@@ -167,7 +166,9 @@ def parse_single_atoms(atoms: ase.Atoms) -> tuple[str, bool, float, float, float
167166

168167
if __name__ == "__main__":
169168
init_wbm_atoms_list = convert_wbm_to_atoms_list()
170-
relaxed_wbm_atoms_list = relax_atoms_list(init_wbm_atoms_list, backbone="graphomer")
169+
relaxed_wbm_atoms_list = relax_atoms_list(
170+
init_wbm_atoms_list, backbone="graphormer"
171+
)
171172
parse_relaxed_atoms_list_as_df(relaxed_wbm_atoms_list).to_csv(
172173
"mattersim-wbm-IS2RE.csv.gz"
173174
)

pyproject.toml

+1-4
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,10 @@ dependencies = [
3131
"pandas>=2.0.0",
3232
"plotly",
3333
"pymatgen",
34-
"pymatviz[export-figs,df-pdf-export,df-svg-export]>=0.10.0",
34+
"pymatviz[export-figs,df-pdf-export,df-svg-export]>=0.10.1",
3535
"scikit-learn",
3636
"scipy",
3737
"seaborn",
38-
# TODO remove svgutils after next pymatviz release 0.10.1
39-
"svgutils",
40-
"svgwrite",
4138
"tqdm",
4239
"wandb",
4340

tests/conftest.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pymatviz.enums import Key
77

88

9-
@pytest.fixture()
9+
@pytest.fixture
1010
def dummy_struct() -> Structure:
1111
return Structure(
1212
lattice=Lattice.cubic(4.2),
@@ -15,14 +15,14 @@ def dummy_struct() -> Structure:
1515
)
1616

1717

18-
@pytest.fixture()
18+
@pytest.fixture
1919
def df_float() -> pd.DataFrame:
2020
rng = np.random.default_rng(0)
2121

2222
return pd.DataFrame(rng.normal(size=(10, 5)), columns=[*"ABCDE"])
2323

2424

25-
@pytest.fixture()
25+
@pytest.fixture
2626
def df_mixed() -> pd.DataFrame:
2727
rng = np.random.default_rng(0)
2828

@@ -32,7 +32,7 @@ def df_mixed() -> pd.DataFrame:
3232
return pd.DataFrame(dict(floats=floats, bools=bools, strings=strings))
3333

3434

35-
@pytest.fixture()
35+
@pytest.fixture
3636
def df_with_pmg_objects(dummy_struct: Structure) -> pd.DataFrame:
3737
# create a dummy df with a structure column on which to test (de-)serialization
3838
df_dummy = pd.DataFrame(dict(material_id=range(5), structure=[dummy_struct] * 5))

tests/test_data.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@
4747

4848
def test_as_dict_handler() -> None:
4949
class C:
50-
def as_dict(self) -> dict[str, Any]:
50+
@staticmethod
51+
def as_dict() -> dict[str, Any]:
5152
return {"foo": "bar"}
5253

5354
assert as_dict_handler(C()) == {"foo": "bar"}

tests/test_models.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,17 @@ def test_model_dirs_have_metadata() -> None:
2727

2828
for model_name, metadata in MODEL_METADATA.items():
2929
model_dir = metadata["model_dir"]
30-
for key in required:
30+
for key, val in required.items():
3131
assert key in metadata, f"Required {key=} missing in {model_dir}"
32-
if isinstance(required[key], dict):
33-
missing_keys = {*required[key]} - {*metadata[key]} # type: ignore[misc]
32+
if isinstance(val, dict):
33+
missing_keys = {*val} - {*metadata[key]}
3434
assert (
3535
not missing_keys
3636
), f"Missing sub-keys {missing_keys} of {key=} in {model_dir}"
3737
continue
3838

39-
err_msg = f"Invalid {key=}, expected {required[key]} in {model_dir}"
40-
assert isinstance(metadata[key], required[key]), err_msg # type: ignore[arg-type]
39+
err_msg = f"Invalid {key=}, expected {val} in {model_dir}"
40+
assert isinstance(metadata[key], val), err_msg # type: ignore[arg-type]
4141

4242
authors, date_added, mbd_version, yml_model_name, model_version, repo = (
4343
metadata[key] for key in list(required)[:-1]

tests/test_plots.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def test_cumulative_metrics(
3838
backend: Backend,
3939
metrics: tuple[str, ...],
4040
) -> None:
41-
fig, df_metrics = cumulative_metrics(
41+
fig, df_cumu_metrics = cumulative_metrics(
4242
e_above_hull_true=df_wbm[MbdKey.each_true],
4343
df_preds=df_wbm[models],
4444
backend=backend,
@@ -47,8 +47,8 @@ def test_cumulative_metrics(
4747
metrics=metrics,
4848
)
4949

50-
assert isinstance(df_metrics, pd.DataFrame)
51-
assert list(df_metrics) == [*models, "metric"]
50+
assert isinstance(df_cumu_metrics, pd.DataFrame)
51+
assert list(df_cumu_metrics) == [*models, "metric"]
5252

5353
if backend == MATPLOTLIB:
5454
assert isinstance(fig, plt.Figure)

tests/test_slurm.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from matbench_discovery.slurm import _get_calling_file_path, slurm_submit
77

88

9-
@patch.dict(os.environ, {"SLURM_JOB_ID": "1234"}, clear=True)
109
@pytest.mark.parametrize("py_file_path", [None, "path/to/file.py"])
1110
@pytest.mark.parametrize("partition", [None, "fake-partition"])
1211
@pytest.mark.parametrize("time", [None, "0:0:1"])
@@ -34,7 +33,14 @@ def test_slurm_submit(
3433
pre_cmd=pre_cmd,
3534
)
3635

37-
slurm_vars = slurm_submit(**kwargs) # type: ignore[arg-type]
36+
slurm_submit(**kwargs) # type: ignore[arg-type]
37+
38+
stdout, stderr = capsys.readouterr()
39+
# check slurm_submit() did nothing in normal mode
40+
assert stdout == stderr == ""
41+
42+
with patch.dict(os.environ, {"SLURM_JOB_ID": "1234"}, clear=True):
43+
slurm_vars = slurm_submit(**kwargs) # type: ignore[arg-type]
3844
expected_slurm_vars = dict(slurm_job_id="1234", slurm_flags="--foo")
3945
if time is not None:
4046
expected_slurm_vars["slurm_timelimit"] = time
@@ -44,15 +50,12 @@ def test_slurm_submit(
4450
expected_slurm_vars["pre_cmd"] = pre_cmd
4551
assert slurm_vars == expected_slurm_vars
4652

47-
stdout, stderr = capsys.readouterr()
48-
# check slurm_submit() did nothing in normal mode
49-
assert stderr == stderr == ""
50-
5153
# check slurm_submit() prints cmd and calls subprocess.run() in submit mode
5254
with (
5355
pytest.raises(SystemExit),
5456
patch("sys.argv", ["slurm-submit"]),
5557
patch("matbench_discovery.slurm.subprocess.run") as mock_subprocess_run,
58+
patch.dict(os.environ, {"SLURM_JOB_ID": "1234"}, clear=True),
5659
):
5760
slurm_submit(**kwargs) # type: ignore[arg-type]
5861

0 commit comments

Comments
 (0)