Skip to content

Commit 0773112

Browse files
committed
add new module matbench_discovery/data.py with func load_wbm()
add tests/test_data.py with test_load_wbm()
1 parent 4972c01 commit 0773112

File tree

11 files changed

+240
-62
lines changed

11 files changed

+240
-62
lines changed

matbench_discovery/__init__.py

-15
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,11 @@
22

33
import os
44
import sys
5-
from collections.abc import Generator, Sequence
65
from datetime import datetime
7-
from typing import Any
86

97
ROOT = os.path.dirname(os.path.dirname(__file__))
108
DEBUG = "slurm-submit" not in sys.argv and "SLURM_JOB_ID" not in os.environ
119
CHECKPOINT_DIR = f"{ROOT}/wandb/checkpoints"
1210

1311
timestamp = f"{datetime.now():%Y-%m-%d@%H-%M-%S}"
1412
today = timestamp.split("@")[0]
15-
16-
17-
def chunks(xs: Sequence[Any], n: int) -> Generator[Sequence[Any], None, None]:
18-
return (xs[i : i + n] for i in range(0, len(xs), n))
19-
20-
21-
def as_dict_handler(obj: Any) -> dict[str, Any] | None:
22-
"""Use as default_handler kwarg to json.dump() or pandas.to_json()."""
23-
try:
24-
return obj.as_dict() # all MSONable objects implement as_dict()
25-
except AttributeError:
26-
return None # replace unhandled objects with None in serialized data
27-
# removes e.g. non-serializable AseAtoms from M3GNet relaxation trajectories

matbench_discovery/data.py

+114
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
from __future__ import annotations
2+
3+
import os
4+
from collections.abc import Generator, Sequence
5+
from typing import Any
6+
7+
import pandas as pd
8+
from pymatgen.core import Structure
9+
from pymatgen.entries.computed_entries import ComputedStructureEntry
10+
from tqdm import tqdm
11+
12+
data_files = {
13+
"summary": "2022-10-19-wbm-summary.csv",
14+
"initial-structures": "2022-10-19-wbm-init-structs.json.bz2",
15+
"computed-structure-entries": "2022-10-19-wbm-cses.json.bz2",
16+
}
17+
18+
base_url = "https://raw.githubusercontent.com/janosh/matbench-discovery/main/data/wbm"
19+
default_cache_loc = os.path.expanduser("~/.cache/matbench-discovery")
20+
21+
22+
def chunks(xs: Sequence[Any], n: int) -> Generator[Sequence[Any], None, None]:
23+
return (xs[i : i + n] for i in range(0, len(xs), n))
24+
25+
26+
def as_dict_handler(obj: Any) -> dict[str, Any] | None:
27+
"""Pass this to json.dump(default=) or as pandas.to_json(default_handler=) to
28+
convert Python classes with a as_dict() method to dictionaries on serialization.
29+
Objects without a as_dict() method are replaced with None in the serialized data.
30+
"""
31+
try:
32+
return obj.as_dict() # all MSONable objects implement as_dict()
33+
except AttributeError:
34+
return None # replace unhandled objects with None in serialized data
35+
# removes e.g. non-serializable AseAtoms from M3GNet relaxation trajectories
36+
37+
38+
def load_wbm(
39+
parts: Sequence[str] = ("summary",),
40+
version: int = 1,
41+
cache_dir: str | None = default_cache_loc,
42+
hydrate: bool = False,
43+
) -> pd.DataFrame | dict[str, pd.DataFrame]:
44+
"""_summary_
45+
46+
Args:
47+
parts (str, optional): Which parts of the WBM dataset to load. Can be any subset
48+
of {'summary', 'initial-structures', 'computed-structure-entries'}. Defaults
49+
to ["summary"], a dataframe with columns for material properties like VASP
50+
energy, formation energy, energy above the convex hull (3 columns with old,
51+
new and no Materials Project energy corrections applied for each), volume,
52+
band gap, number of sites per unit cell, and more.
53+
version (int, optional): Which version of the dataset to load. Defaults to 1
54+
(currently the only available option).
55+
cache_dir (str, optional): Where to cache data files on local drive. Defaults to
56+
'~/.cache/matbench-discovery'. Set to None to disable caching.
57+
hydrate (bool, optional): Whether to hydrate pymatgen objects. If False,
58+
Structures and ComputedStructureEntries are returned as dictionaries which
59+
can be hydrated on-demand with df.col.map(Structure.from_dict). Defaults to
60+
False as it noticeably increases load time.
61+
62+
Raises:
63+
ValueError: On bad version or bad keys for which data parts to load.
64+
65+
Returns:
66+
pd.DataFrame | dict[str, pd.DataFrame]: Single dataframe of dictionary of
67+
multiple data parts were requested.
68+
"""
69+
if version != 1:
70+
raise ValueError(f"Only version 1 currently available, got {version=}")
71+
if missing := set(parts) - set(data_files):
72+
raise ValueError(f"{missing} must be subset of {set(data_files)}")
73+
74+
dfs = {}
75+
for key in parts:
76+
file = data_files[key]
77+
reader = pd.read_csv if file.endswith(".csv") else pd.read_json
78+
79+
cache_path = f"{cache_dir}/{file}"
80+
if os.path.isfile(cache_path):
81+
df = reader(cache_path)
82+
else:
83+
url = f"{base_url}/{file}"
84+
print(f"Downloading {url=}")
85+
df = reader(url)
86+
if cache_dir and not os.path.isfile(cache_path):
87+
os.makedirs(cache_dir, exist_ok=True)
88+
if ".csv" in file:
89+
df.to_csv(cache_path)
90+
elif ".json" in file:
91+
df.reset_index().to_json(
92+
cache_path, default_handler=as_dict_handler
93+
)
94+
else:
95+
raise ValueError(f"Unexpected file type {file}")
96+
97+
df = df.set_index("material_id")
98+
if hydrate:
99+
for col in df:
100+
if not isinstance(df[col].iloc[0], dict):
101+
continue
102+
try:
103+
df[col] = [
104+
ComputedStructureEntry.from_dict(d)
105+
for d in tqdm(df[col], desc=col)
106+
]
107+
except Exception:
108+
df[col] = [Structure.from_dict(d) for d in tqdm(df[col], desc=col)]
109+
110+
dfs[key] = df
111+
112+
if len(parts) == 1:
113+
return dfs[parts[0]]
114+
return dfs

models/bowsr/test_bowsr.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
from maml.apps.bowsr.optimizer import BayesianOptimizer
1414
from tqdm import tqdm
1515

16-
from matbench_discovery import DEBUG, ROOT, as_dict_handler, timestamp, today
16+
from matbench_discovery import DEBUG, ROOT, timestamp, today
17+
from matbench_discovery.data import as_dict_handler
1718
from matbench_discovery.slurm import slurm_submit
1819

1920
__author__ = "Janosh Riebesell"

models/cgcnn/test_cgcnn.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,8 @@
109109
data_loader=data_loader,
110110
)
111111

112-
df_preds.to_csv(f"{out_dir}/{job_name}-preds.csv", index=False)
112+
slurm_job_id = os.environ.get("SLURM_JOB_ID", "debug")
113+
df_preds.to_csv(f"{out_dir}/{job_name}-preds-{slurm_job_id}.csv", index=False)
113114
pred_col = f"{target_col}_pred_ens"
114115
assert pred_col in df, f"{pred_col=} not in {list(df)}"
115116
table = wandb.Table(dataframe=df_preds[[target_col, pred_col]].reset_index())

models/m3gnet/join_m3gnet_results.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from pymatgen.analysis.phase_diagram import PDEntry
99
from tqdm import tqdm
1010

11-
from matbench_discovery import ROOT, as_dict_handler, today
11+
from matbench_discovery import ROOT, today
12+
from matbench_discovery.data import as_dict_handler
1213
from matbench_discovery.energy import get_e_form_per_atom
1314

1415
__author__ = "Janosh Riebesell"

models/m3gnet/test_m3gnet.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
from m3gnet.models import Relaxer
1313
from tqdm import tqdm
1414

15-
from matbench_discovery import DEBUG, ROOT, as_dict_handler, timestamp, today
15+
from matbench_discovery import DEBUG, ROOT, timestamp, today
16+
from matbench_discovery.data import as_dict_handler
1617
from matbench_discovery.slurm import slurm_submit
1718

1819
"""

models/wrenformer/test_wrenformer.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@
9696
runs, data_loader=data_loader, df=df, model_cls=Wrenformer, target_col=target_col
9797
)
9898

99-
df.to_csv(f"{out_dir}/{job_name}-preds.csv")
99+
slurm_job_id = os.environ.get("SLURM_JOB_ID", "debug")
100+
df.to_csv(f"{out_dir}/{job_name}-preds-{slurm_job_id}.csv")
100101

101102

102103
# %%

tests/conftest.py

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from __future__ import annotations
2+
3+
import pandas as pd
4+
import pytest
5+
from pymatgen.core import Lattice, Structure
6+
7+
8+
@pytest.fixture
9+
def dummy_df_with_structures(dummy_struct: Structure) -> pd.DataFrame:
10+
# create a dummy df with a structure column
11+
df = pd.DataFrame(dict(material_id=range(10), structure=[dummy_struct] * 10))
12+
df["volume"] = [x.volume for x in df.structure]
13+
return df
14+
15+
16+
@pytest.fixture
17+
def dummy_struct() -> Structure:
18+
return Structure(
19+
lattice=Lattice.cubic(5),
20+
species=("Fe", "O"),
21+
coords=((0, 0, 0), (0.5, 0.5, 0.5)),
22+
)

tests/test_data.py

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
from __future__ import annotations
2+
3+
from tempfile import TemporaryDirectory
4+
from typing import Any
5+
from unittest.mock import patch
6+
7+
import pandas as pd
8+
import pytest
9+
from pymatgen.core import Lattice, Structure
10+
11+
from matbench_discovery.data import as_dict_handler, chunks, data_files, load_wbm
12+
13+
structure = Structure(
14+
lattice=Lattice.cubic(5),
15+
species=("Fe", "O"),
16+
coords=((0, 0, 0), (0.5, 0.5, 0.5)),
17+
)
18+
19+
20+
@pytest.mark.parametrize(
21+
"parts, cache_dir, hydrate",
22+
[
23+
(["summary"], None, True),
24+
(["initial-structures"], TemporaryDirectory().name, True),
25+
(["computed-structure-entries"], None, False),
26+
(["summary", "initial-structures"], TemporaryDirectory().name, True),
27+
],
28+
)
29+
def test_load_wbm(
30+
parts: list[str],
31+
cache_dir: str | None,
32+
hydrate: bool,
33+
dummy_df_with_structures: pd.DataFrame,
34+
) -> None:
35+
# intercept HTTP requests to GitHub raw user content and return dummy df instead
36+
with patch("matbench_discovery.data.pd.read_csv") as read_csv, patch(
37+
"matbench_discovery.data.pd.read_json"
38+
) as read_json:
39+
read_csv.return_value = read_json.return_value = dummy_df_with_structures
40+
out = load_wbm(parts, cache_dir=cache_dir, hydrate=hydrate)
41+
42+
assert read_json.call_count + read_csv.call_count == len(parts)
43+
44+
if len(parts) > 1:
45+
assert isinstance(out, dict)
46+
assert list(out) == parts
47+
for df in out.values():
48+
assert isinstance(df, pd.DataFrame)
49+
else:
50+
assert isinstance(out, pd.DataFrame)
51+
52+
53+
def test_load_wbm_raises() -> None:
54+
with pytest.raises(
55+
ValueError,
56+
match=f"must be subset of {set(data_files)}",
57+
):
58+
load_wbm(["invalid-part"])
59+
60+
with pytest.raises(
61+
ValueError, match="Only version 1 currently available, got version=2"
62+
):
63+
load_wbm(version=2)
64+
65+
66+
def test_chunks() -> None:
67+
assert list(chunks([], 1)) == []
68+
assert list(chunks([1], 1)) == [[1]]
69+
assert list(chunks([1, 2], 1)) == [[1], [2]]
70+
assert list(chunks([1, 2, 3], 1)) == [[1], [2], [3]]
71+
assert list(chunks([1, 2, 3], 2)) == [[1, 2], [3]]
72+
assert list(chunks(range(1, 4), 2)) == [range(1, 3), range(3, 4)]
73+
assert list(chunks(range(1, 5), 2)) == [range(1, 3), range(3, 5)]
74+
assert list(chunks(range(1, 5), 3)) == [range(1, 4), range(4, 5)]
75+
76+
77+
def test_as_dict_handler() -> None:
78+
class C:
79+
def as_dict(self) -> dict[str, Any]:
80+
return {"foo": "bar"}
81+
82+
assert as_dict_handler(C()) == {"foo": "bar"}
83+
assert as_dict_handler(1) is None
84+
assert as_dict_handler("foo") is None
85+
assert as_dict_handler([1, 2, 3]) is None
86+
assert as_dict_handler({"foo": "bar"}) is None

tests/test_init.py

+1-25
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,11 @@
11
from __future__ import annotations
22

33
import os
4-
from typing import Any
54

6-
from matbench_discovery import ROOT, as_dict_handler, chunks, timestamp, today
5+
from matbench_discovery import ROOT, timestamp, today
76

87

98
def test_has_globals() -> None:
109
assert os.path.isdir(ROOT)
1110
assert today == timestamp.split("@")[0]
1211
assert len(timestamp) == 19
13-
14-
15-
def test_chunks() -> None:
16-
assert list(chunks([], 1)) == []
17-
assert list(chunks([1], 1)) == [[1]]
18-
assert list(chunks([1, 2], 1)) == [[1], [2]]
19-
assert list(chunks([1, 2, 3], 1)) == [[1], [2], [3]]
20-
assert list(chunks([1, 2, 3], 2)) == [[1, 2], [3]]
21-
assert list(chunks(range(1, 4), 2)) == [range(1, 3), range(3, 4)]
22-
assert list(chunks(range(1, 5), 2)) == [range(1, 3), range(3, 5)]
23-
assert list(chunks(range(1, 5), 3)) == [range(1, 4), range(4, 5)]
24-
25-
26-
def test_as_dict_handler() -> None:
27-
class C:
28-
def as_dict(self) -> dict[str, Any]:
29-
return {"foo": "bar"}
30-
31-
assert as_dict_handler(C()) == {"foo": "bar"}
32-
assert as_dict_handler(1) is None
33-
assert as_dict_handler("foo") is None
34-
assert as_dict_handler([1, 2, 3]) is None
35-
assert as_dict_handler({"foo": "bar"}) is None

tests/test_structure.py

+7-17
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,22 @@
11
from __future__ import annotations
22

33
import numpy as np
4-
import pytest
5-
from pymatgen.core import Lattice, Structure
4+
from pymatgen.core import Structure
65

76
from matbench_discovery.structure import perturb_structure
87

98

10-
@pytest.fixture
11-
def struct() -> Structure:
12-
return Structure(
13-
lattice=Lattice.cubic(5),
14-
species=("Fe", "O"),
15-
coords=((0, 0, 0), (0.5, 0.5, 0.5)),
16-
)
17-
18-
19-
def test_perturb_structure(struct: Structure) -> None:
9+
def test_perturb_structure(dummy_struct: Structure) -> None:
2010
np.random.seed(0)
21-
perturbed = perturb_structure(struct)
22-
assert len(perturbed) == len(struct)
11+
perturbed = perturb_structure(dummy_struct)
12+
assert len(perturbed) == len(dummy_struct)
2313

24-
for site, new in zip(struct, perturbed):
14+
for site, new in zip(dummy_struct, perturbed):
2515
assert site.specie == new.specie
2616
assert tuple(site.coords) != tuple(new.coords)
2717

2818
# test that the perturbation is reproducible
2919
np.random.seed(0)
30-
assert perturbed == perturb_structure(struct)
20+
assert perturbed == perturb_structure(dummy_struct)
3121
# but different on subsequent calls
32-
assert perturb_structure(struct) != perturb_structure(struct)
22+
assert perturb_structure(dummy_struct) != perturb_structure(dummy_struct)

0 commit comments

Comments
 (0)