Skip to content

Commit 604cb04

Browse files
committed
make slurm_submit() time and account optional
WBM add 'find large structures that changed symmetry during relaxation' update all slurm_submit(account="LEE-SL...->matgen") ruff unignore PT013 S301 and apply fixes
1 parent 02b2657 commit 604cb04

21 files changed

+106
-70
lines changed

.pre-commit-config.yaml

+4-4
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ default_install_hook_types: [pre-commit, commit-msg]
88

99
repos:
1010
- repo: https://github.com/astral-sh/ruff-pre-commit
11-
rev: v0.4.6
11+
rev: v0.4.8
1212
hooks:
1313
- id: ruff
1414
args: [--fix]
@@ -57,7 +57,7 @@ repos:
5757
exclude: ^(site/src/figs/.+\.svelte|data/wbm/20.+\..+|site/src/(routes|figs).+\.(yaml|json)|changelog.md)$
5858

5959
- repo: https://github.com/pre-commit/mirrors-eslint
60-
rev: v9.3.0
60+
rev: v9.4.0
6161
hooks:
6262
- id: eslint
6363
types: [file]
@@ -71,15 +71,15 @@ repos:
7171
- typescript-eslint
7272

7373
- repo: https://github.com/python-jsonschema/check-jsonschema
74-
rev: 0.28.4
74+
rev: 0.28.5
7575
hooks:
7676
- id: check-jsonschema
7777
files: ^models/(.+)/\1.*\.yml$
7878
args: [--schemafile, tests/model-schema.yml]
7979
- id: check-github-actions
8080

8181
- repo: https://github.com/RobertCraigie/pyright-python
82-
rev: v1.1.365
82+
rev: v1.1.366
8383
hooks:
8484
- id: pyright
8585
args: [--level, error]

data/wbm/compile_wbm_test_set.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -541,7 +541,7 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
541541

542542
# %%
543543
with gzip.open(DATA_FILES.mp_patched_phase_diagram, "rb") as zip_file:
544-
ppd_mp: PatchedPhaseDiagram = pickle.load(zip_file)
544+
ppd_mp: PatchedPhaseDiagram = pickle.load(zip_file) # noqa: S301
545545

546546

547547
# %% calculate e_above_hull for each material

data/wbm/eda_wbm.py

+38-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import pandas as pd
1010
import plotly.express as px
1111
from matplotlib.colors import SymLogNorm
12-
from pymatgen.core import Composition
12+
from pymatgen.core import Composition, Structure
1313
from pymatviz import (
1414
count_elements,
1515
ptable_heatmap,
@@ -18,6 +18,7 @@
1818
spacegroup_sunburst,
1919
)
2020
from pymatviz.io import save_fig
21+
from pymatviz.structure_viz import plot_structure_2d
2122
from pymatviz.utils import si_fmt, si_fmt_int
2223

2324
from matbench_discovery import PDF_FIGS, ROOT, SITE_FIGS, STABILITY_THRESHOLD
@@ -371,3 +372,39 @@
371372
img_name = "mp-vs-wbm-arity-hist"
372373
save_fig(fig, f"{SITE_FIGS}/{img_name}.svelte")
373374
save_fig(fig, f"{PDF_FIGS}/{img_name}.pdf", width=450, height=280)
375+
376+
377+
# %% find large structures that changed symmetry during relaxation
378+
df_sym_change = (
379+
df_wbm.query(f"{Key.init_wyckoff} != {Key.wyckoff}")
380+
.filter(regex="wyckoff|sites")
381+
.nlargest(10, Key.n_sites)
382+
)
383+
384+
385+
# %%
386+
df_wbm_structs = pd.read_json(DATA_FILES.wbm_cses_plus_init_structs).set_index(
387+
Key.mat_id
388+
)
389+
390+
391+
# %%
392+
for wbm_id in df_sym_change.index:
393+
init_struct = Structure.from_dict(df_wbm_structs.loc[wbm_id][Key.init_struct])
394+
final_struct = Structure.from_dict(df_wbm_structs.loc[wbm_id][Key.cse]["structure"])
395+
init_struct.properties[Key.mat_id] = f"{wbm_id}-init"
396+
final_struct.properties[Key.mat_id] = f"{wbm_id}-final"
397+
398+
plot_structure_2d([init_struct, final_struct])
399+
400+
401+
# %% export initial and final structures with symmetry change to CIF
402+
wbm_id = df_sym_change.index[0]
403+
404+
struct = Structure.from_dict(df_wbm_structs.loc[wbm_id][Key.cse]["structure"])
405+
struct.to(f"{module_dir}/{wbm_id}.cif")
406+
struct.to(f"{module_dir}/{wbm_id}.json")
407+
408+
struct = Structure.from_dict(df_wbm_structs.loc[wbm_id][Key.init_struct])
409+
struct.to(f"{module_dir}/{wbm_id}-init.cif")
410+
struct.to(f"{module_dir}/{wbm_id}-init.json")

matbench_discovery/data.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def load(
117117
print(f"Loading {key!r} from cached file at {cache_path!r}")
118118
if ".pkl" in file_path: # handle key='mp_patched_phase_diagram' separately
119119
with gzip.open(cache_path, "rb") as zip_file:
120-
return pickle.load(zip_file)
120+
return pickle.load(zip_file) # noqa: S301
121121
if ".pth" in file_path: # handle model checkpoints (e.g. key='alignn_checkpoint')
122122
return cache_path
123123

matbench_discovery/slurm.py

+19-14
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,9 @@ def _get_calling_file_path(frame: int = 1) -> str:
2929
def slurm_submit(
3030
job_name: str,
3131
out_dir: str,
32-
time: str,
33-
account: str,
32+
*,
33+
time: str | None = None,
34+
account: str | None = None,
3435
partition: str | None = None,
3536
py_file_path: str | None = None,
3637
slurm_flags: str | Sequence[str] = (),
@@ -72,30 +73,34 @@ def slurm_submit(
7273

7374
os.makedirs(out_dir, exist_ok=True) # slurm fails if out_dir is missing
7475

76+
# ensure pre_cmd ends with a semicolon
77+
if pre_cmd and not pre_cmd.strip().endswith(";"):
78+
pre_cmd += ";"
79+
7580
cmd = [
76-
*f"sbatch --{account=} --{time=}".replace("'", "").split(),
77-
*("--job-name", job_name),
81+
*("sbatch", "--job-name", job_name),
7882
*("--output", f"{out_dir}/slurm-%A{'-%a' if array else ''}.log"),
7983
*(slurm_flags.split() if isinstance(slurm_flags, str) else slurm_flags),
80-
*("--wrap", f"{pre_cmd} python {py_file_path}".strip()),
84+
*("--wrap", f"{pre_cmd or ''} python {py_file_path}".strip()),
8185
]
82-
if partition:
83-
cmd += ["--partition", partition]
84-
if array:
85-
cmd += ["--array", array]
86+
for flag in (f"{time=}", f"{account=}", f"{partition=}", f"{array=}"):
87+
key, val = flag.split("=")
88+
if val != "None":
89+
cmd += (f"--{key}", val)
8690

8791
is_log_file = not sys.stdout.isatty()
8892
is_slurm_job = "SLURM_JOB_ID" in os.environ
8993

9094
slurm_vars = {
91-
f"slurm_{key}": val
95+
f"slurm_{key}": os.environ[f"SLURM_{key}".upper()]
9296
for key in SLURM_KEYS
93-
if (val := os.getenv(f"SLURM_{key}".upper()))
97+
if f"SLURM_{key}".upper() in os.environ
9498
}
95-
slurm_vars["slurm_timelimit"] = time
96-
if slurm_flags:
99+
if time is not None:
100+
slurm_vars["slurm_timelimit"] = time
101+
if slurm_flags != ():
97102
slurm_vars["slurm_flags"] = str(slurm_flags)
98-
if pre_cmd:
103+
if pre_cmd not in ("", None):
99104
slurm_vars["pre_cmd"] = pre_cmd
100105

101106
# print sbatch command into slurm log file and at job submission time

models/alignn/test_alignn.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,13 @@
5858

5959
slurm_vars = slurm_submit(
6060
job_name=job_name,
61-
partition="ampere",
62-
account="LEE-SL3-GPU",
61+
account="matgen",
6362
time="11:55:0",
6463
out_dir=out_dir,
6564
slurm_flags="--nodes 1 --gpus-per-node 1",
6665
# pre_cmd is platform specific, remove when running on other systems
6766
# just left here for reference
68-
pre_cmd=". /etc/profile.d/modules.sh; module load rhel8/default-amp;"
69-
"module load cuda/11.8",
67+
pre_cmd="module load cuda/11.8",
7068
)
7169

7270

models/bowsr/test_bowsr.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,7 @@
4646
slurm_vars = slurm_submit(
4747
job_name=job_name,
4848
out_dir=out_dir,
49-
partition="skylake",
50-
account="LEE-SL3-CPU",
49+
account="matgen",
5150
time="11:55:0",
5251
# --time=2:0:0 is probably enough but best be safe.
5352
array=f"1-{slurm_array_task_count}%{slurm_max_parallel}",

models/cgcnn/test_cgcnn.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,7 @@
3535

3636
slurm_vars = slurm_submit(
3737
job_name=job_name,
38-
partition="ampere",
39-
account="LEE-SL3-GPU",
38+
account="matgen",
4039
time="2:0:0",
4140
out_dir=out_dir,
4241
slurm_flags="--nodes 1 --gpus-per-node 1",

models/cgcnn/train_cgcnn.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,7 @@
4141

4242
slurm_vars = slurm_submit(
4343
job_name=job_name,
44-
partition="ampere",
45-
account="LEE-SL3-GPU",
44+
account="matgen",
4645
time="11:55:0",
4746
array=f"1-{ensemble_size}",
4847
out_dir=out_dir,

models/megnet/test_megnet.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,9 @@
3939
slurm_vars = slurm_submit(
4040
job_name=job_name,
4141
out_dir=module_dir,
42-
partition="icelake-himem",
43-
account="LEE-SL3-CPU",
42+
account="matgen",
4443
time="11:55:0",
45-
slurm_flags=("--mem", "30G"),
44+
slurm_flags="--mem 30G",
4645
array=f"1-{slurm_array_task_count}",
4746
# TF_CPP_MIN_LOG_LEVEL=2 means INFO and WARNING logs are not printed
4847
# https://stackoverflow.com/a/40982782

models/voronoi_rf/train_test_voronoi_rf.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,7 @@
4141
slurm_vars = slurm_submit(
4242
job_name=job_name,
4343
out_dir=out_dir,
44-
partition="icelake-himem",
45-
account="LEE-SL3-CPU",
44+
account="matgen",
4645
time="6:0:0",
4746
)
4847

models/voronoi_rf/voronoi_featurize_dataset.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,7 @@
4343

4444
slurm_vars = slurm_submit(
4545
job_name=job_name,
46-
partition="icelake-himem",
47-
account="LEE-SL3-CPU",
46+
account="matgen",
4847
time="11:55:0",
4948
array=f"1-{slurm_array_task_count}",
5049
slurm_flags=("--mem", "15G") if data_name == "mp" else (),

models/wrenformer/test_wrenformer.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,7 @@
3232

3333
slurm_vars = slurm_submit(
3434
job_name=job_name,
35-
partition="ampere",
36-
account="LEE-SL3-GPU",
35+
account="matgen",
3736
time="2:0:0",
3837
out_dir=out_dir,
3938
slurm_flags="--nodes 1 --gpus-per-node 1",

models/wrenformer/train_wrenformer.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,7 @@
3232

3333
slurm_vars = slurm_submit(
3434
job_name=job_name,
35-
partition="ampere",
36-
account="LEE-SL3-GPU",
35+
account="matgen",
3736
time="8:0:0",
3837
array=f"1-{ensemble_size}",
3938
out_dir=out_dir,

pyproject.toml

-2
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,8 @@ ignore = [
104104
"PLR", # pylint refactor
105105
"PLW2901", # redefined-loop-name
106106
"PT006", # pytest-parametrize-names-wrong-type
107-
"PT013", # pytest-incorrect-pytest-import
108107
"PTH",
109108
"S108",
110-
"S301",
111109
"S310",
112110
"S311",
113111
"S603",

scripts/compute_struct_fingerprints.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@
3939
slurm_vars = slurm_submit(
4040
job_name=f"{data_name}-struct-fingerprints",
4141
out_dir=out_dir,
42-
partition="icelake-himem",
43-
account="LEE-SL3-CPU",
42+
account="matgen",
4443
time="6:0:0",
4544
array=f"1-{slurm_array_task_count}",
4645
slurm_flags=("--mem", "30G"),

scripts/project_compositions.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@
3030
slurm_vars = slurm_submit(
3131
job_name=f"{data_name}-{projection_type}-{out_dim}d",
3232
out_dir=out_dir,
33-
partition="icelake-himem",
34-
account="LEE-SL3-CPU",
33+
account="matgen",
3534
time="6:0:0",
3635
)
3736

tests/test_data.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import pandas as pd
99
import pytest
1010
from pymatgen.core import Lattice, Structure
11-
from pytest import CaptureFixture
1211

1312
from matbench_discovery import FIGSHARE_DIR, ROOT
1413
from matbench_discovery.data import (
@@ -45,7 +44,7 @@ def test_load(
4544
df_float: pd.DataFrame,
4645
# df with Structures and ComputedStructureEntries as dicts
4746
df_with_pmg_objects: pd.DataFrame,
48-
capsys: CaptureFixture[str],
47+
capsys: pytest.CaptureFixture[str],
4948
tmp_path: Path,
5049
key: str,
5150
hydrate: bool,
@@ -148,7 +147,7 @@ def test_load_no_mock(
148147
version: str,
149148
expected_shape: tuple[int, int],
150149
expected_cols: set[str],
151-
capsys: CaptureFixture[str],
150+
capsys: pytest.CaptureFixture[str],
152151
tmp_path: Path,
153152
) -> None:
154153
assert os.listdir(tmp_path) == [], "cache_dir should be empty"

tests/test_energy.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from pymatgen.analysis.phase_diagram import PDEntry
66
from pymatgen.core import Lattice, Structure
77
from pymatgen.entries.computed_entries import ComputedEntry, Entry
8-
from pytest import approx
98

109
from matbench_discovery.energy import (
1110
get_e_form_per_atom,
@@ -59,5 +58,5 @@ def test_mp_ref_energies() -> None:
5958
"""Test MP elemental reference energies are in sync with PDEntries saved to disk."""
6059
for key, val in mp_elemental_ref_energies.items():
6160
actual = mp_elem_reference_entries[key].energy_per_atom
62-
assert actual == approx(val, abs=1e-3), f"{key=}"
63-
assert actual == approx(val, abs=1e-3), f"{key=}"
61+
assert actual == pytest.approx(val, abs=1e-3), f"{key=}"
62+
assert actual == pytest.approx(val, abs=1e-3), f"{key=}"

tests/test_metrics.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import numpy as np
44
import pandas as pd
55
import pytest
6-
from pytest import approx
76

87
from matbench_discovery.enums import Key
98
from matbench_discovery.metrics import classify_stable, stable_metrics
@@ -48,7 +47,7 @@ def test_stable_metrics() -> None:
4847
RMSE=1.157,
4948
R2=-3.030,
5049
).items():
51-
assert metrics[key] == approx(val, abs=1e-3), f"{key=}"
50+
assert metrics[key] == pytest.approx(val, abs=1e-3), f"{key=}"
5251

5352
assert math.isnan(metrics["F1"])
5453

0 commit comments

Comments
 (0)