Skip to content

Commit 811f581

Browse files
committed
rename load_wbm() -> load_train_test() and add ability to download MP training files too
increase flake8 max-line-length = 95
1 parent f5057ac commit 811f581

File tree

3 files changed

+68
-32
lines changed

3 files changed

+68
-32
lines changed

matbench_discovery/data.py

+37-23
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,18 @@
99
from pymatgen.entries.computed_entries import ComputedStructureEntry
1010
from tqdm import tqdm
1111

12-
data_files = {
13-
"summary": "2022-10-19-wbm-summary.csv",
14-
"initial-structures": "2022-10-19-wbm-init-structs.json.bz2",
15-
"computed-structure-entries": "2022-10-19-wbm-cses.json.bz2",
12+
DATA_FILENAMES = {
13+
"wbm-summary": "wbm/2022-10-19-wbm-summary.csv",
14+
"wbm-initial-structures": "wbm/2022-10-19-wbm-init-structs.json.bz2",
15+
"wbm-computed-structure-entries": "wbm/2022-10-19-wbm-cses.json.bz2",
16+
"mp-energies": "mp/2022-08-13-mp-energies.json.gz",
17+
"mp-computed-structure-entries": "mp/2022-09-16-mp-computed-structure-entries.json.gz",
18+
"mp-patched-phase-diagram": "mp/2022-09-18-ppd-mp.pkl.gz",
19+
"mp-elemental-ref-energies": "mp/2022-09-19-mp-elemental-ref-energies.json",
1620
}
1721

18-
base_url = "https://raw.githubusercontent.com/janosh/matbench-discovery/main/data/wbm"
19-
default_cache_loc = os.path.expanduser("~/.cache/matbench-discovery")
22+
RAW_REPO_URL = "https://raw.githubusercontent.com/janosh/matbench-discovery"
23+
default_cache_dir = os.path.expanduser("~/.cache/matbench-discovery")
2024

2125

2226
def chunks(xs: Sequence[Any], n: int) -> Generator[Sequence[Any], None, None]:
@@ -35,21 +39,26 @@ def as_dict_handler(obj: Any) -> dict[str, Any] | None:
3539
# removes e.g. non-serializable AseAtoms from M3GNet relaxation trajectories
3640

3741

38-
def load_wbm(
39-
parts: Sequence[str] = ("summary",),
42+
def load_train_test(
43+
parts: str | Sequence[str] = ("summary",),
4044
version: int = 1,
41-
cache_dir: str | None = default_cache_loc,
45+
cache_dir: str | None = default_cache_dir,
4246
hydrate: bool = False,
4347
) -> pd.DataFrame | dict[str, pd.DataFrame]:
44-
"""_summary_
48+
"""Download the MP training data and WBM test data in parts or in full as pandas
49+
DataFrames. The full training and test sets are each about ~500 MB as compressed
50+
JSON will be cached locally for faster re-loading unless cache_dir is set to None.
51+
52+
Hint: Import DATA_FILES from the same module as this function and
53+
print(list(DATA_FILES)) to see permissible data names.
4554
4655
Args:
47-
parts (str, optional): Which parts of the WBM dataset to load. Can be any subset
48-
of {'summary', 'initial-structures', 'computed-structure-entries'}. Defaults
49-
to ["summary"], a dataframe with columns for material properties like VASP
50-
energy, formation energy, energy above the convex hull (3 columns with old,
51-
new and no Materials Project energy corrections applied for each), volume,
52-
band gap, number of sites per unit cell, and more.
56+
parts (str | list[str], optional): Which parts of the MP/WBM dataset to load.
57+
Can be any subset of list(DATA_FILES). Defaults to ["summary"], a dataframe
58+
with columns for material properties like VASP energy, formation energy,
59+
energy above the convex hull (3 columns with old, new and no Materials
60+
Project energy corrections applied for each), volume, band gap, number of
61+
sites per unit cell, and more.
5362
version (int, optional): Which version of the dataset to load. Defaults to 1
5463
(currently the only available option).
5564
cache_dir (str, optional): Where to cache data files on local drive. Defaults to
@@ -60,31 +69,36 @@ def load_wbm(
6069
False as it noticeably increases load time.
6170
6271
Raises:
63-
ValueError: On bad version or bad keys for which data parts to load.
72+
ValueError: On bad version number or bad part names.
6473
6574
Returns:
6675
pd.DataFrame | dict[str, pd.DataFrame]: Single dataframe of dictionary of
6776
multiple data parts were requested.
6877
"""
78+
if parts == "all":
79+
parts = list(DATA_FILENAMES)
80+
elif isinstance(parts, str):
81+
parts = [parts]
82+
6983
if version != 1:
7084
raise ValueError(f"Only version 1 currently available, got {version=}")
71-
if missing := set(parts) - set(data_files):
72-
raise ValueError(f"{missing} must be subset of {set(data_files)}")
85+
if missing := set(parts) - set(DATA_FILENAMES):
86+
raise ValueError(f"{missing} must be subset of {set(DATA_FILENAMES)}")
7387

7488
dfs = {}
7589
for key in parts:
76-
file = data_files[key]
90+
file = DATA_FILENAMES[key]
7791
reader = pd.read_csv if file.endswith(".csv") else pd.read_json
7892

7993
cache_path = f"{cache_dir}/{file}"
8094
if os.path.isfile(cache_path):
8195
df = reader(cache_path)
8296
else:
83-
url = f"{base_url}/{file}"
84-
print(f"Downloading {url=}")
97+
url = f"{RAW_REPO_URL}/{version}.0.0/data/{file}"
98+
print(f"Downloading {key} from {url}")
8599
df = reader(url)
86100
if cache_dir and not os.path.isfile(cache_path):
87-
os.makedirs(cache_dir, exist_ok=True)
101+
os.makedirs(os.path.dirname(cache_path), exist_ok=True)
88102
if ".csv" in file:
89103
df.to_csv(cache_path)
90104
elif ".json" in file:

tests/conftest.py

+2
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
import pandas as pd
44
import pytest
55
from pymatgen.core import Lattice, Structure
6+
from pymatgen.entries.computed_entries import ComputedStructureEntry
67

78

89
@pytest.fixture
910
def dummy_df_with_structures(dummy_struct: Structure) -> pd.DataFrame:
1011
# create a dummy df with a structure column
1112
df = pd.DataFrame(dict(material_id=range(10), structure=[dummy_struct] * 10))
1213
df["volume"] = [x.volume for x in df.structure]
14+
df["computed_structure_entry"] = [ComputedStructureEntry(dummy_struct, 0)] * 10
1315
return df
1416

1517

tests/test_data.py

+29-9
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,13 @@
88
import pytest
99
from pymatgen.core import Lattice, Structure
1010

11-
from matbench_discovery.data import as_dict_handler, chunks, data_files, load_wbm
11+
from matbench_discovery.data import (
12+
DATA_FILENAMES,
13+
RAW_REPO_URL,
14+
as_dict_handler,
15+
chunks,
16+
load_train_test,
17+
)
1218

1319
structure = Structure(
1420
lattice=Lattice.cubic(5),
@@ -20,24 +26,38 @@
2026
@pytest.mark.parametrize(
2127
"parts, cache_dir, hydrate",
2228
[
23-
(["summary"], None, True),
24-
(["initial-structures"], TemporaryDirectory().name, True),
25-
(["computed-structure-entries"], None, False),
26-
(["summary", "initial-structures"], TemporaryDirectory().name, True),
29+
(["wbm-summary"], None, True),
30+
(["wbm-initial-structures"], TemporaryDirectory().name, True),
31+
(["wbm-computed-structure-entries"], None, False),
32+
(["wbm-summary", "wbm-initial-structures"], TemporaryDirectory().name, True),
33+
(["mp-elemental-ref-energies"], None, True),
34+
(["mp-energies"], None, True),
2735
],
2836
)
2937
def test_load_wbm(
3038
parts: list[str],
3139
cache_dir: str | None,
3240
hydrate: bool,
3341
dummy_df_with_structures: pd.DataFrame,
42+
capsys: pytest.CaptureFixture,
3443
) -> None:
3544
# intercept HTTP requests to GitHub raw user content and return dummy df instead
3645
with patch("matbench_discovery.data.pd.read_csv") as read_csv, patch(
3746
"matbench_discovery.data.pd.read_json"
3847
) as read_json:
3948
read_csv.return_value = read_json.return_value = dummy_df_with_structures
40-
out = load_wbm(parts, cache_dir=cache_dir, hydrate=hydrate)
49+
out = load_train_test(parts, cache_dir=cache_dir, hydrate=hydrate)
50+
51+
stdout, stderr = capsys.readouterr()
52+
53+
assert (
54+
"\n".join(
55+
f"Downloading {part} from {RAW_REPO_URL}/1.0.0/data/{DATA_FILENAMES[part]}"
56+
for part in parts
57+
)
58+
in stdout
59+
)
60+
assert "" == stderr
4161

4262
assert read_json.call_count + read_csv.call_count == len(parts)
4363

@@ -53,14 +73,14 @@ def test_load_wbm(
5373
def test_load_wbm_raises() -> None:
5474
with pytest.raises(
5575
ValueError,
56-
match=f"must be subset of {set(data_files)}",
76+
match=f"must be subset of {set(DATA_FILENAMES)}",
5777
):
58-
load_wbm(["invalid-part"])
78+
load_train_test(["invalid-part"])
5979

6080
with pytest.raises(
6181
ValueError, match="Only version 1 currently available, got version=2"
6282
):
63-
load_wbm(version=2)
83+
load_train_test(version=2)
6484

6585

6686
def test_chunks() -> None:

0 commit comments

Comments
 (0)