Skip to content

Commit e203f8f

Browse files
committed
fix trainable params in test_m3gnet.py
update site to sveltekit v2, vite v5 ruff unignore NPY002 and fix violations
1 parent e42d70c commit e203f8f

13 files changed

+57
-55
lines changed

.pre-commit-config.yaml

+4-4
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ default_install_hook_types: [pre-commit, commit-msg]
77

88
repos:
99
- repo: https://github.com/astral-sh/ruff-pre-commit
10-
rev: v0.1.7
10+
rev: v0.1.9
1111
hooks:
1212
- id: ruff
1313
args: [--fix]
@@ -30,7 +30,7 @@ repos:
3030
- id: trailing-whitespace
3131

3232
- repo: https://github.com/pre-commit/mirrors-mypy
33-
rev: v1.7.1
33+
rev: v1.8.0
3434
hooks:
3535
- id: mypy
3636
additional_dependencies: [types-pyyaml, types-requests]
@@ -45,7 +45,7 @@ repos:
4545
args: [--ignore-words-list, "nd,te,fpr", --check-filenames]
4646

4747
- repo: https://github.com/pre-commit/mirrors-prettier
48-
rev: v4.0.0-alpha.3
48+
rev: v4.0.0-alpha.8
4949
hooks:
5050
- id: prettier
5151
args: [--write] # edit files in-place
@@ -56,7 +56,7 @@ repos:
5656
exclude: ^(site/src/figs/.+\.svelte|data/wbm/20.+\..+|site/src/routes/.+\.(yaml|json)|changelog.md)$
5757

5858
- repo: https://github.com/pre-commit/mirrors-eslint
59-
rev: v8.55.0
59+
rev: v8.56.0
6060
hooks:
6161
- id: eslint
6262
types: [file]

matbench_discovery/structure.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
__author__ = "Janosh Riebesell"
1111
__date__ = "2022-12-02"
1212

13-
np.random.seed(0) # ensure reproducible structure perturbations
13+
rng = np.random.default_rng(0) # ensure reproducible structure perturbations
1414

1515

1616
def perturb_structure(struct: Structure, gamma: float = 1.5) -> Structure:
@@ -29,8 +29,8 @@ def perturb_structure(struct: Structure, gamma: float = 1.5) -> Structure:
2929
"""
3030
perturbed = struct.copy()
3131
for site in perturbed:
32-
magnitude = np.random.weibull(gamma)
33-
vec = np.random.randn(3) # TODO maybe make func recursive to deal with 0-vector
32+
magnitude = rng.weibull(gamma)
33+
vec = rng.normal(3) # TODO maybe make func recursive to deal with 0-vector
3434
vec /= np.linalg.norm(vec) # unit vector
3535
site.coords += vec * magnitude
3636
site.to_unit_cell(in_place=True)
@@ -42,7 +42,7 @@ def perturb_structure(struct: Structure, gamma: float = 1.5) -> Structure:
4242
import matplotlib.pyplot as plt
4343

4444
gamma = 1.5
45-
samples = np.array([np.random.weibull(gamma) for _ in range(10000)])
45+
samples = np.array([rng.weibull(gamma) for _ in range(10_000)])
4646
mean = samples.mean()
4747

4848
# reproduces the dist in https://www.nature.com/articles/s41524-022-00891-8#Fig5

models/cgcnn/plot_structure_perturbation.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@
1010
__author__ = "Janosh Riebesell"
1111
__date__ = "2022-12-02"
1212

13+
rng = np.random.default_rng(0)
14+
1315

1416
# %%
15-
ax = pd.Series(np.random.weibull(1.5, 100000)).hist(bins=100)
17+
ax = pd.Series(rng.weibull(1.5, 100_000)).hist(bins=100)
1618
title = "Distribution of perturbation magnitudes"
1719
ax.set(xlabel="magnitude of perturbation", ylabel="count", title=title)
1820

models/chgnet/join_chgnet_results.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
# %%
3131
module_dir = os.path.dirname(__file__)
3232
task_type = "IS2RE"
33-
date = "2023-10-23"
33+
date = "2023-12-21"
3434
glob_pattern = f"{date}-chgnet-*-wbm-{task_type}*/*.json.gz"
3535
file_paths = sorted(glob(f"{module_dir}/{glob_pattern}"))
3636
print(f"Found {len(file_paths):,} files for {glob_pattern = }")

models/m3gnet/test_m3gnet.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,17 @@
3232
# direct: DIRECT cluster sampling, ms: manual sampling
3333
model_type: Literal["orig", "direct", "manual-sampling"] = "orig"
3434
# set large job array size for smaller data splits and faster testing/debugging
35-
slurm_array_task_count = 100
35+
slurm_array_task_count = 50
3636
job_name = f"m3gnet-{model_type}-wbm-{task_type}"
3737
out_dir = os.getenv("SBATCH_OUTPUT", f"{module_dir}/{today}-{job_name}")
3838

3939
slurm_vars = slurm_submit(
4040
job_name=job_name,
4141
out_dir=out_dir,
42-
partition="icelake-himem",
43-
account="LEE-SL3-CPU",
44-
time="3:0:0",
42+
account="matgen",
43+
time="11:55:0",
4544
array=f"1-{slurm_array_task_count}",
46-
slurm_flags=("--mem", "12G"),
45+
slurm_flags="--qos shared --constraint cpu --mem 16G",
4746
# TF_CPP_MIN_LOG_LEVEL=2 means INFO and WARNING logs are not printed
4847
# https://stackoverflow.com/a/40982782
4948
pre_cmd="TF_CPP_MIN_LOG_LEVEL=2",
@@ -88,7 +87,13 @@
8887
task_type=task_type,
8988
df=dict(shape=str(df_in.shape), columns=", ".join(df_in)),
9089
slurm_vars=slurm_vars,
91-
trainable_params=sum(param.numel() for param in m3gnet.parameters()),
90+
trainable_params=sum(
91+
[np.prod(weight.shape) for weight in m3gnet.potential.model.trainable_weights]
92+
),
93+
checkpoint=checkpoint,
94+
model_type=model_type,
95+
out_path=out_path,
96+
job_name=job_name,
9297
)
9398

9499
run_name = f"{job_name}-{slurm_array_task_id}"
@@ -103,7 +108,7 @@
103108

104109
structures = df_in[input_col].map(Structure.from_dict).to_dict()
105110

106-
for material_id in tqdm(structures, desc="Relaxing", disable=None):
111+
for material_id in tqdm(structures, desc="Relaxing"):
107112
if material_id in relax_results:
108113
continue
109114
try:

models/mace/join_mace_results.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
module_dir = os.path.dirname(__file__)
3030
task_type = "IS2RE"
3131
e_form_mace_col = "e_form_per_atom_mace"
32-
3332
date = "2023-12-11"
3433
glob_pattern = f"{date}-mace-wbm-{task_type}*/*.json.gz"
3534
file_paths = sorted(glob(f"{module_dir}/{glob_pattern}"))
@@ -92,16 +91,15 @@
9291

9392

9493
# %%
95-
bad_mask = (df_wbm[e_form_mace_col] - df_wbm[e_form_col]) < -3
96-
df_wbm[bad_mask].to_csv(f"{module_dir}/mace-underpredictions<-3.csv")
94+
bad_mask = (df_wbm[e_form_mace_col] - df_wbm[e_form_col]) < -5
9795
print(f"{sum(bad_mask)=}")
9896
ax = density_scatter(df=df_wbm[~bad_mask], x=e_form_col, y=e_form_mace_col)
9997

10098

10199
# %%
102100
out_path = file_paths[0].rsplit("/", 1)[0]
103101
df_mace = df_mace.round(4)
104-
df_mace[~bad_mask].select_dtypes("number").to_csv(f"{out_path}.csv.gz")
102+
df_mace.select_dtypes("number").to_csv(f"{out_path}.csv.gz")
105103
df_mace.reset_index().to_json(f"{out_path}.json.gz", default_handler=as_dict_handler)
106104

107105
df_bad = df_mace[bad_mask].drop(columns=[entry_col, struct_col])

models/mace/test_mace.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33

44
import os
55
from importlib.metadata import version
6-
from typing import Any
6+
from typing import Any, Literal
77

88
import numpy as np
99
import pandas as pd
1010
import torch
1111
import wandb
12-
from ase.filters import FrechetCellFilter
12+
from ase.filters import ExpCellFilter, FrechetCellFilter
1313
from ase.optimize import FIRE, LBFGS
1414
from mace.calculators import mace_mp
1515
from mace.tools import count_parameters
@@ -31,7 +31,7 @@
3131
task_type = "IS2RE" # "RS2RE"
3232
module_dir = os.path.dirname(__file__)
3333
# set large job array size for smaller data splits and faster testing/debugging
34-
slurm_array_task_count = 20
34+
slurm_array_task_count = 50
3535
ase_optimizer = "FIRE"
3636
job_name = f"mace-wbm-{task_type}-{ase_optimizer}"
3737
out_dir = os.getenv("SBATCH_OUTPUT", f"{module_dir}/{today}-{job_name}")
@@ -42,15 +42,16 @@
4242
"2023-10-29-mace-16M-pbenner-mptrj-no-conditional-loss",
4343
"https://tinyurl.com/y7uhwpje",
4444
][-1]
45+
ase_filter: Literal["frechet", "exp"] = "frechet"
4546

4647
slurm_vars = slurm_submit(
4748
job_name=job_name,
4849
out_dir=out_dir,
4950
account="matgen",
50-
time="9:55:0",
51+
time="11:55:0",
5152
array=f"1-{slurm_array_task_count}",
52-
slurm_flags="--qos shared --constraint gpu --gpus 1",
53-
# slurm_flags="--qos shared --constraint cpu --mem 16G",
53+
# slurm_flags="--qos shared --constraint gpu --gpus 1",
54+
slurm_flags="--qos shared --constraint cpu --mem 32G",
5455
)
5556

5657

@@ -98,6 +99,7 @@
9899
trainable_params=count_parameters(mace_calc.models[0]),
99100
model_name=model_name,
100101
dtype=dtype,
102+
ase_filter=ase_filter,
101103
)
102104

103105
run_name = f"{job_name}-{slurm_array_task_id}"
@@ -112,6 +114,7 @@
112114
df_in[input_col] = [x["structure"] for x in df_in.computed_structure_entry]
113115

114116
structs = df_in[input_col].map(Structure.from_dict).to_dict()
117+
filter_cls = {"frechet": FrechetCellFilter, "exp": ExpCellFilter}[ase_filter]
115118

116119
for material_id in tqdm(structs, desc="Relaxing"):
117120
if material_id in relax_results:
@@ -121,7 +124,7 @@
121124
atoms = structs[material_id].to_ase_atoms()
122125
atoms.calc = mace_calc
123126
if max_steps > 0:
124-
atoms = FrechetCellFilter(atoms)
127+
atoms = filter_cls(atoms)
125128
optim_cls = {"FIRE": FIRE, "LBFGS": LBFGS}[ase_optimizer]
126129
optimizer = optim_cls(atoms, logfile="/dev/null")
127130

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ ignore = [
9898
"FIX002",
9999
"INP001",
100100
"N806", # non-lowercase-variable-in-function
101-
"NPY002",
102101
"PD901", # pandas-df-variable-name
103102
"PERF203", # try-except-in-loop
104103
"PLC0414", # useless-import-alias
@@ -119,6 +118,7 @@ ignore = [
119118
]
120119
pydocstyle.convention = "google"
121120
isort.known-third-party = ["wandb"]
121+
isort.split-on-trailing-comma = false
122122

123123
[tool.ruff.per-file-ignores]
124124
"tests/*" = ["D", "S101"]

scripts/model_figs/make_hull_dist_box_plot.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
# different fill colors for each box
2121
# patch_artist=True,
2222
# notch=True,
23-
# bootstrap=10000,
23+
# bootstrap=10_000,
2424
showmeans=True,
2525
# meanline=True,
2626
)

scripts/model_figs/model_run_times.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@
155155
).update_traces(
156156
textinfo="percent+label",
157157
textfont_size=14,
158-
marker=dict(line=dict(color="#000000", width=2)),
158+
marker=dict(line=dict(color="black", width=2)),
159159
hoverinfo="label+percent+name",
160160
texttemplate="%{label}<br>%{percent:.1%}",
161161
hovertemplate="%{label} %{percent:.1%} (%{value:.1f} h)",

site/package.json

+16-16
Original file line numberDiff line numberDiff line change
@@ -17,38 +17,38 @@
1717
"changelog": "npx auto-changelog --output ../changelog.md --hide-credit --commit-limit false --latest-version x.y.z"
1818
},
1919
"devDependencies": {
20-
"@iconify/svelte": "^3.1.4",
20+
"@iconify/svelte": "^3.1.6",
2121
"@rollup/plugin-yaml": "^4.1.2",
22-
"@sveltejs/adapter-static": "^2.0.3",
23-
"@sveltejs/kit": "^1.27.4",
24-
"@sveltejs/vite-plugin-svelte": "^2.5.1",
25-
"@typescript-eslint/eslint-plugin": "^6.10.0",
26-
"@typescript-eslint/parser": "^6.10.0",
22+
"@sveltejs/adapter-static": "^3.0.1",
23+
"@sveltejs/kit": "^2.0.6",
24+
"@sveltejs/vite-plugin-svelte": "^3.0.1",
25+
"@typescript-eslint/eslint-plugin": "^6.16.0",
26+
"@typescript-eslint/parser": "^6.16.0",
2727
"d3-scale-chromatic": "^3.0.0",
2828
"elementari": "^0.2.2",
29-
"eslint": "^8.53.0",
30-
"eslint-plugin-svelte": "^2.35.0",
29+
"eslint": "^8.56.0",
30+
"eslint-plugin-svelte": "^2.35.1",
3131
"hastscript": "^8.0.0",
3232
"highlight.js": "^11.9.0",
3333
"js-yaml": "^4.1.0",
3434
"katex": "^0.16.9",
3535
"mdsvex": "^0.11.0",
36-
"prettier": "^3.0.3",
37-
"prettier-plugin-svelte": "^3.0.3",
36+
"prettier": "^3.1.1",
37+
"prettier-plugin-svelte": "^3.1.2",
3838
"rehype-autolink-headings": "^7.1.0",
3939
"rehype-katex-svelte": "^1.2.0",
4040
"rehype-slug": "^6.0.0",
4141
"remark-math": "3.0.0",
42-
"svelte": "^4.2.2",
43-
"svelte-check": "^3.5.2",
42+
"svelte": "^4.2.8",
43+
"svelte-check": "^3.6.2",
4444
"svelte-multiselect": "^10.2.0",
45-
"svelte-preprocess": "^5.0.4",
45+
"svelte-preprocess": "^5.1.3",
4646
"svelte-toc": "^0.5.6",
4747
"svelte-zoo": "^0.4.9",
48-
"svelte2tsx": "^0.6.23",
48+
"svelte2tsx": "^0.6.27",
4949
"tslib": "^2.6.2",
50-
"typescript": "5.2.2",
51-
"vite": "^4.5.0"
50+
"typescript": "5.3.3",
51+
"vite": "^5.0.10"
5252
},
5353
"prettier": {
5454
"semi": false,

tests/test_metrics.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ def test_stable_metrics() -> None:
5555

5656
# test stable_metrics gives the same result as sklearn.metrics.classification_report
5757
# for random numpy data
58-
np.random.seed(0)
59-
y_true, y_pred = np.random.randn(100, 2).T
58+
rng = np.random.default_rng(0)
59+
y_true, y_pred = rng.normal(size=(2, 100))
6060
metrics = stable_metrics(y_true, y_pred)
6161

6262
from sklearn.metrics import classification_report

tests/test_structure.py

-6
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,19 @@
22

33
from typing import TYPE_CHECKING
44

5-
import numpy as np
6-
75
from matbench_discovery.structure import perturb_structure
86

97
if TYPE_CHECKING:
108
from pymatgen.core import Structure
119

1210

1311
def test_perturb_structure(dummy_struct: Structure) -> None:
14-
np.random.seed(0)
1512
perturbed = perturb_structure(dummy_struct)
1613
assert len(perturbed) == len(dummy_struct)
1714

1815
for site, new in zip(dummy_struct, perturbed):
1916
assert site.specie == new.specie
2017
assert tuple(site.coords) != tuple(new.coords)
2118

22-
# test that the perturbation is reproducible
23-
np.random.seed(0)
24-
assert perturbed == perturb_structure(dummy_struct)
2519
# but different on subsequent calls
2620
assert perturb_structure(dummy_struct) != perturb_structure(dummy_struct)

0 commit comments

Comments
 (0)