|
3 | 3 | import os
|
4 | 4 | import urllib.request
|
5 | 5 | from pathlib import Path
|
| 6 | +from random import random |
6 | 7 | from tempfile import TemporaryDirectory
|
7 | 8 | from typing import Any
|
8 | 9 | from unittest.mock import patch
|
|
18 | 19 | PRED_FILENAMES,
|
19 | 20 | RAW_REPO_URL,
|
20 | 21 | as_dict_handler,
|
21 |
| - chunks, |
22 | 22 | df_wbm,
|
23 | 23 | glob_to_df,
|
24 | 24 | load_df_wbm_with_preds,
|
|
38 | 38 |
|
39 | 39 |
|
40 | 40 | @pytest.mark.parametrize(
|
41 |
| - "data_names, cache_dir, hydrate", |
| 41 | + "data_names, hydrate", |
42 | 42 | [
|
43 |
| - (["wbm-summary"], None, True), |
44 |
| - (["wbm-initial-structures"], TemporaryDirectory().name, True), |
45 |
| - (["wbm-computed-structure-entries"], None, False), |
46 |
| - (["wbm-summary", "wbm-initial-structures"], TemporaryDirectory().name, True), |
47 |
| - (["mp-elemental-ref-energies"], None, True), |
48 |
| - (["mp-energies"], None, True), |
| 43 | + (["wbm-summary"], True), |
| 44 | + (["wbm-initial-structures"], True), |
| 45 | + (["wbm-computed-structure-entries"], False), |
| 46 | + (["wbm-summary", "wbm-initial-structures"], True), |
| 47 | + (["mp-elemental-ref-energies"], True), |
| 48 | + (["mp-energies"], True), |
49 | 49 | ],
|
50 | 50 | )
|
51 | 51 | def test_load_train_test(
|
52 | 52 | data_names: list[str],
|
53 |
| - cache_dir: str | None, |
54 | 53 | hydrate: bool,
|
55 | 54 | dummy_df_with_structures: pd.DataFrame,
|
56 | 55 | capsys: CaptureFixture[str],
|
| 56 | + tmp_path: Path, |
57 | 57 | ) -> None:
|
58 | 58 | # intercept HTTP requests to GitHub raw user content and return dummy df instead
|
59 | 59 | with patch("matbench_discovery.data.pd.read_csv") as read_csv, patch(
|
60 | 60 | "matbench_discovery.data.pd.read_json"
|
61 | 61 | ) as read_json:
|
62 | 62 | read_csv.return_value = read_json.return_value = dummy_df_with_structures
|
63 |
| - out = load_train_test(data_names, cache_dir=cache_dir, hydrate=hydrate) |
| 63 | + out = load_train_test( |
| 64 | + data_names, |
| 65 | + hydrate=hydrate, |
| 66 | + # test both str and Path cache_dir |
| 67 | + cache_dir=TemporaryDirectory().name if random() < 0.5 else tmp_path, |
| 68 | + ) |
64 | 69 |
|
65 | 70 | stdout, stderr = capsys.readouterr()
|
66 | 71 |
|
@@ -152,17 +157,6 @@ def test_load_train_test_no_mock(
|
152 | 157 | )
|
153 | 158 |
|
154 | 159 |
|
155 |
| -def test_chunks() -> None: |
156 |
| - assert list(chunks([], 1)) == [] |
157 |
| - assert list(chunks([1], 1)) == [[1]] |
158 |
| - assert list(chunks([1, 2], 1)) == [[1], [2]] |
159 |
| - assert list(chunks([1, 2, 3], 1)) == [[1], [2], [3]] |
160 |
| - assert list(chunks([1, 2, 3], 2)) == [[1, 2], [3]] |
161 |
| - assert list(chunks(range(1, 4), 2)) == [range(1, 3), range(3, 4)] |
162 |
| - assert list(chunks(range(1, 5), 2)) == [range(1, 3), range(3, 5)] |
163 |
| - assert list(chunks(range(1, 5), 3)) == [range(1, 4), range(4, 5)] |
164 |
| - |
165 |
| - |
166 | 160 | def test_as_dict_handler() -> None:
|
167 | 161 | class C:
|
168 | 162 | def as_dict(self) -> dict[str, Any]:
|
|
0 commit comments