Skip to content

Commit 5084886

Browse files
committed
fix load_train_test() caching all data versions to same directory
improve load_train_test() progress reporting expand load_train_test() test coverage bump flake8 max-complexity = 16 -> 18
1 parent d491439 commit 5084886

File tree

4 files changed

+118
-58
lines changed

4 files changed

+118
-58
lines changed

.pre-commit-config.yaml

+5-5
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@ default_install_hook_types: [pre-commit, commit-msg]
77

88
repos:
99
- repo: https://github.com/PyCQA/isort
10-
rev: 5.10.1
10+
rev: 5.11.4
1111
hooks:
1212
- id: isort
1313

1414
- repo: https://github.com/psf/black
15-
rev: 22.10.0
15+
rev: 22.12.0
1616
hooks:
1717
- id: black
1818

@@ -23,7 +23,7 @@ repos:
2323
additional_dependencies: [flake8-bugbear]
2424

2525
- repo: https://github.com/asottile/pyupgrade
26-
rev: v3.2.2
26+
rev: v3.3.1
2727
hooks:
2828
- id: pyupgrade
2929
args: [--py39-plus]
@@ -63,7 +63,7 @@ repos:
6363
- id: autoflake
6464

6565
- repo: https://github.com/pre-commit/mirrors-prettier
66-
rev: v3.0.0-alpha.0
66+
rev: v3.0.0-alpha.4
6767
hooks:
6868
- id: prettier
6969
args: [--write] # edit files in-place
@@ -74,7 +74,7 @@ repos:
7474
exclude: ^figures/.*$
7575

7676
- repo: https://github.com/pre-commit/mirrors-eslint
77-
rev: v8.24.0
77+
rev: v8.30.0
7878
hooks:
7979
- id: eslint
8080
types: [file]

matbench_discovery/data.py

+34-27
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from __future__ import annotations
22

33
import os
4+
import urllib.error
45
from collections.abc import Generator, Sequence
56
from glob import glob
7+
from pathlib import Path
68
from typing import Any, Callable
79

810
import pandas as pd
@@ -46,11 +48,12 @@ def as_dict_handler(obj: Any) -> dict[str, Any] | None:
4648

4749

4850
def load_train_test(
49-
parts: str | Sequence[str] = ("summary",),
50-
version: int = 1,
51-
cache_dir: str | None = default_cache_dir,
51+
data_names: str | Sequence[str] = ("summary",),
52+
version: str = "1.0.0",
53+
cache_dir: str | Path | None = default_cache_dir,
5254
hydrate: bool = False,
53-
) -> pd.DataFrame | dict[str, pd.DataFrame]:
55+
**kwargs: Any,
56+
) -> pd.DataFrame:
5457
"""Download parts of or the full MP training data and WBM test data as pandas
5558
DataFrames. The full training and test sets are each about ~500 MB as compressed
5659
JSON which will be cached locally to cache_dir for faster re-loading unless
@@ -62,46 +65,50 @@ def load_train_test(
6265
https://matbench-discovery.janosh.dev/how-to-use for brief data descriptions.
6366
6467
Args:
65-
parts (str | list[str], optional): Which parts of the MP/WBM dataset to load.
66-
Can be any subset of the above data names. Defaults to ["summary"].
67-
version (int, optional): Which version of the dataset to load. Defaults to 1
68-
(currently the only available option).
68+
data_names (str | list[str], optional): Which parts of the MP/WBM dataset to load.
69+
Can be any subset of the above data names or 'all'. Defaults to ["summary"].
70+
version (str, optional): Which version of the dataset to load. Defaults to
71+
'1.0.0'. Can be any git tag, branch or commit hash.
6972
cache_dir (str, optional): Where to cache data files on local drive. Defaults to
7073
'~/.cache/matbench-discovery'. Set to None to disable caching.
7174
hydrate (bool, optional): Whether to hydrate pymatgen objects. If False,
7275
Structures and ComputedStructureEntries are returned as dictionaries which
7376
can be hydrated on-demand with df.col.map(Structure.from_dict). Defaults to
7477
False as it noticeably increases load time.
78+
**kwargs: Additional keyword arguments passed to pandas.read_json or read_csv,
79+
depending on which file is loaded.
7580
7681
Raises:
77-
ValueError: On bad version number or bad part names.
82+
ValueError: On bad version number or bad data names.
7883
7984
Returns:
80-
pd.DataFrame | dict[str, pd.DataFrame]: Single dataframe of dictionary of
81-
multiple data parts were requested.
85+
pd.DataFrame: Single dataframe or dictionary of dfs if
86+
multiple data were requested.
8287
"""
83-
if parts == "all":
84-
parts = list(DATA_FILENAMES)
85-
elif isinstance(parts, str):
86-
parts = [parts]
87-
88-
if version != 1:
89-
raise ValueError(f"Only version 1 currently available, got {version=}")
90-
if missing := set(parts) - set(DATA_FILENAMES):
88+
if data_names == "all":
89+
data_names = list(DATA_FILENAMES)
90+
elif isinstance(data_names, str):
91+
data_names = [data_names]
92+
93+
if missing := set(data_names) - set(DATA_FILENAMES):
9194
raise ValueError(f"{missing} must be subset of {set(DATA_FILENAMES)}")
9295

9396
dfs = {}
94-
for key in parts:
97+
for key in data_names:
9598
file = DATA_FILENAMES[key]
9699
reader = pd.read_csv if file.endswith(".csv") else pd.read_json
97100

98-
cache_path = f"{cache_dir}/{file}"
101+
cache_path = f"{cache_dir}/{version}/{file}"
99102
if os.path.isfile(cache_path):
100-
df = reader(cache_path)
103+
print(f"Loading '{key}' from cached file at '{cache_path}'")
104+
df = reader(cache_path, **kwargs)
101105
else:
102-
url = f"{RAW_REPO_URL}/{version}.0.0/data/{file}"
103-
print(f"Downloading {key} from {url}")
104-
df = reader(url)
106+
url = f"{RAW_REPO_URL}/{version}/data/{file}"
107+
print(f"Downloading '{key}' from {url}")
108+
try:
109+
df = reader(url)
110+
except urllib.error.HTTPError as exc:
111+
raise ValueError(f"Bad {url=}") from exc
105112
if cache_dir and not os.path.isfile(cache_path):
106113
os.makedirs(os.path.dirname(cache_path), exist_ok=True)
107114
if ".csv" in file:
@@ -128,8 +135,8 @@ def load_train_test(
128135

129136
dfs[key] = df
130137

131-
if len(parts) == 1:
132-
return dfs[parts[0]]
138+
if len(data_names) == 1:
139+
return dfs[data_names[0]]
133140
return dfs
134141

135142

readme.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ Matbench Discovery
88
[![Tests](https://github.com/janosh/matbench-discovery/actions/workflows/test.yml/badge.svg)](https://github.com/janosh/matbench-discovery/actions/workflows/test.yml)
99
[![GitHub Pages](https://github.com/janosh/matbench-discovery/actions/workflows/gh-pages.yml/badge.svg)](https://github.com/janosh/matbench-discovery/actions/workflows/gh-pages.yml)
1010
[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/janosh/matbench-discovery/main.svg?badge_token=Qza33izjRxSbegTqeSyDvA)](https://results.pre-commit.ci/latest/github/janosh/matbench-discovery/main?badge_token=Qza33izjRxSbegTqeSyDvA)
11-
[![Requires Python 3.9+](https://img.shields.io/badge/Python-3.9+-blue.svg?logo=python)](https://python.org/downloads)
12-
[![PyPI](https://img.shields.io/pypi/v/matbench-discovery?logo=pypi)](https://pypi.org/project/matbench-discovery?logo=pypi)
11+
[![Requires Python 3.9+](https://img.shields.io/badge/Python-3.9+-blue.svg?logo=python&logoColor=white)](https://python.org/downloads)
12+
[![PyPI](https://img.shields.io/pypi/v/matbench-discovery?logo=pypi&logoColor=white)](https://pypi.org/project/matbench-discovery?logo=pypi&logoColor=white)
1313

1414
</h4>
1515

tests/test_data.py

+77-24
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
from __future__ import annotations
22

33
import os
4+
import urllib.request
5+
from pathlib import Path
46
from tempfile import TemporaryDirectory
57
from typing import Any
68
from unittest.mock import patch
79

810
import pandas as pd
911
import pytest
1012
from pymatgen.core import Lattice, Structure
13+
from pytest import CaptureFixture
1114

1215
from matbench_discovery import ROOT
1316
from matbench_discovery.data import (
@@ -28,9 +31,14 @@
2831
coords=((0, 0, 0), (0.5, 0.5, 0.5)),
2932
)
3033

34+
try:
35+
website_down = urllib.request.urlopen(RAW_REPO_URL).status != 200
36+
except Exception:
37+
website_down = True
38+
3139

3240
@pytest.mark.parametrize(
33-
"parts, cache_dir, hydrate",
41+
"data_names, cache_dir, hydrate",
3442
[
3543
(["wbm-summary"], None, True),
3644
(["wbm-initial-structures"], TemporaryDirectory().name, True),
@@ -41,64 +49,109 @@
4149
],
4250
)
4351
def test_load_train_test(
44-
parts: list[str],
52+
data_names: list[str],
4553
cache_dir: str | None,
4654
hydrate: bool,
4755
dummy_df_with_structures: pd.DataFrame,
48-
capsys: pytest.CaptureFixture,
56+
capsys: CaptureFixture[str],
4957
) -> None:
5058
# intercept HTTP requests to GitHub raw user content and return dummy df instead
5159
with patch("matbench_discovery.data.pd.read_csv") as read_csv, patch(
5260
"matbench_discovery.data.pd.read_json"
5361
) as read_json:
5462
read_csv.return_value = read_json.return_value = dummy_df_with_structures
55-
out = load_train_test(parts, cache_dir=cache_dir, hydrate=hydrate)
63+
out = load_train_test(data_names, cache_dir=cache_dir, hydrate=hydrate)
5664

5765
stdout, stderr = capsys.readouterr()
5866

59-
assert (
60-
"\n".join(
61-
f"Downloading {part} from {RAW_REPO_URL}/1.0.0/data/{DATA_FILENAMES[part]}"
62-
for part in parts
63-
)
64-
in stdout
67+
expected_out = "\n".join(
68+
f"Downloading '{name}' from {RAW_REPO_URL}/1.0.0/data/{DATA_FILENAMES[name]}"
69+
for name in data_names
6570
)
71+
assert expected_out in stdout
6672
assert "" == stderr
6773

68-
assert read_json.call_count + read_csv.call_count == len(parts)
74+
assert read_json.call_count + read_csv.call_count == len(data_names)
6975

70-
if len(parts) > 1:
76+
if len(data_names) > 1:
7177
assert isinstance(out, dict)
72-
assert list(out) == parts
78+
assert list(out) == data_names
7379
for df in out.values():
7480
assert isinstance(df, pd.DataFrame)
7581
else:
7682
assert isinstance(out, pd.DataFrame)
7783

7884

79-
def test_load_train_test_raises() -> None:
80-
with pytest.raises(
81-
ValueError,
82-
match=f"must be subset of {set(DATA_FILENAMES)}",
83-
):
84-
load_train_test(["invalid-part"])
85+
def test_load_train_test_raises(tmp_path: Path) -> None:
86+
# bad data name
87+
with pytest.raises(ValueError, match=f"must be subset of {set(DATA_FILENAMES)}"):
88+
load_train_test(["bad-data-name"])
89+
90+
# bad_version
91+
version = "not-a-real-branch"
92+
with pytest.raises(ValueError) as exc_info:
93+
load_train_test("wbm-summary", version=version, cache_dir=tmp_path)
8594

86-
with pytest.raises(
87-
ValueError, match="Only version 1 currently available, got version=2"
88-
):
89-
load_train_test(version=2)
95+
assert (
96+
str(exc_info.value)
97+
== "Bad url='https://raw.githubusercontent.com/janosh/matbench-discovery"
98+
f"/{version}/data/wbm/2022-10-19-wbm-summary.csv'"
99+
)
90100

91101

92102
def test_load_train_test_doc_str() -> None:
93103
doc_str = load_train_test.__doc__
94104
assert isinstance(doc_str, str) # mypy type narrowing
95105

96-
assert all(key in doc_str for key in DATA_FILENAMES)
106+
for name in DATA_FILENAMES:
107+
assert name in doc_str, f"Missing data {name=} in load_train_test() docstring"
97108

98109
# TODO refactor to load site URL from site/package.json for SSoT
99110
assert "https://matbench-discovery.janosh.dev" in doc_str
100111

101112

113+
@pytest.mark.skipif(website_down, reason=f"{RAW_REPO_URL} unreachable")
114+
@pytest.mark.parametrize("version", ["main"]) # , "d00d475"
115+
def test_load_train_test_no_mock(
116+
version: str, capsys: CaptureFixture[str], tmp_path: Path
117+
) -> None:
118+
# this function runs the download from GitHub raw user content for real
119+
# hence takes some time and requires being online
120+
df_wbm = load_train_test("wbm-summary", version=version, cache_dir=tmp_path)
121+
assert df_wbm.shape == (256963, 17)
122+
assert set(df_wbm) > {
123+
"bandgap_pbe",
124+
"e_form_per_atom_mp2020_corrected",
125+
"e_form_per_atom_uncorrected",
126+
"e_form_per_atom_wbm",
127+
"e_hull_wbm",
128+
"formula",
129+
"n_sites",
130+
"uncorrected_energy",
131+
"uncorrected_energy_from_cse",
132+
"volume",
133+
"wyckoff_spglib",
134+
}, "Loaded df missing columns"
135+
136+
stdout, stderr = capsys.readouterr()
137+
assert stderr == ""
138+
assert (
139+
stdout
140+
== "Downloading 'wbm-summary' from https://raw.githubusercontent.com/janosh"
141+
f"/matbench-discovery/{version}/data/wbm/2022-10-19-wbm-summary.csv\n"
142+
)
143+
144+
df_wbm = load_train_test("wbm-summary", version=version, cache_dir=tmp_path)
145+
146+
stdout, stderr = capsys.readouterr()
147+
assert stderr == ""
148+
assert (
149+
stdout
150+
== f"Loading 'wbm-summary' from cached file at '{tmp_path}/main/wbm/2022-10-19-"
151+
"wbm-summary.csv'\n"
152+
)
153+
154+
102155
def test_chunks() -> None:
103156
assert list(chunks([], 1)) == []
104157
assert list(chunks([1], 1)) == [[1]]

0 commit comments

Comments
 (0)