|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
3 | 3 | import os
|
| 4 | +import urllib.request |
| 5 | +from pathlib import Path |
4 | 6 | from tempfile import TemporaryDirectory
|
5 | 7 | from typing import Any
|
6 | 8 | from unittest.mock import patch
|
7 | 9 |
|
8 | 10 | import pandas as pd
|
9 | 11 | import pytest
|
10 | 12 | from pymatgen.core import Lattice, Structure
|
| 13 | +from pytest import CaptureFixture |
11 | 14 |
|
12 | 15 | from matbench_discovery import ROOT
|
13 | 16 | from matbench_discovery.data import (
|
|
28 | 31 | coords=((0, 0, 0), (0.5, 0.5, 0.5)),
|
29 | 32 | )
|
30 | 33 |
|
| 34 | +try: |
| 35 | + website_down = urllib.request.urlopen(RAW_REPO_URL).status != 200 |
| 36 | +except Exception: |
| 37 | + website_down = True |
| 38 | + |
31 | 39 |
|
32 | 40 | @pytest.mark.parametrize(
|
33 |
| - "parts, cache_dir, hydrate", |
| 41 | + "data_names, cache_dir, hydrate", |
34 | 42 | [
|
35 | 43 | (["wbm-summary"], None, True),
|
36 | 44 | (["wbm-initial-structures"], TemporaryDirectory().name, True),
|
|
41 | 49 | ],
|
42 | 50 | )
|
43 | 51 | def test_load_train_test(
|
44 |
| - parts: list[str], |
| 52 | + data_names: list[str], |
45 | 53 | cache_dir: str | None,
|
46 | 54 | hydrate: bool,
|
47 | 55 | dummy_df_with_structures: pd.DataFrame,
|
48 |
| - capsys: pytest.CaptureFixture, |
| 56 | + capsys: CaptureFixture[str], |
49 | 57 | ) -> None:
|
50 | 58 | # intercept HTTP requests to GitHub raw user content and return dummy df instead
|
51 | 59 | with patch("matbench_discovery.data.pd.read_csv") as read_csv, patch(
|
52 | 60 | "matbench_discovery.data.pd.read_json"
|
53 | 61 | ) as read_json:
|
54 | 62 | read_csv.return_value = read_json.return_value = dummy_df_with_structures
|
55 |
| - out = load_train_test(parts, cache_dir=cache_dir, hydrate=hydrate) |
| 63 | + out = load_train_test(data_names, cache_dir=cache_dir, hydrate=hydrate) |
56 | 64 |
|
57 | 65 | stdout, stderr = capsys.readouterr()
|
58 | 66 |
|
59 |
| - assert ( |
60 |
| - "\n".join( |
61 |
| - f"Downloading {part} from {RAW_REPO_URL}/1.0.0/data/{DATA_FILENAMES[part]}" |
62 |
| - for part in parts |
63 |
| - ) |
64 |
| - in stdout |
| 67 | + expected_out = "\n".join( |
| 68 | + f"Downloading '{name}' from {RAW_REPO_URL}/1.0.0/data/{DATA_FILENAMES[name]}" |
| 69 | + for name in data_names |
65 | 70 | )
|
| 71 | + assert expected_out in stdout |
66 | 72 | assert "" == stderr
|
67 | 73 |
|
68 |
| - assert read_json.call_count + read_csv.call_count == len(parts) |
| 74 | + assert read_json.call_count + read_csv.call_count == len(data_names) |
69 | 75 |
|
70 |
| - if len(parts) > 1: |
| 76 | + if len(data_names) > 1: |
71 | 77 | assert isinstance(out, dict)
|
72 |
| - assert list(out) == parts |
| 78 | + assert list(out) == data_names |
73 | 79 | for df in out.values():
|
74 | 80 | assert isinstance(df, pd.DataFrame)
|
75 | 81 | else:
|
76 | 82 | assert isinstance(out, pd.DataFrame)
|
77 | 83 |
|
78 | 84 |
|
79 |
| -def test_load_train_test_raises() -> None: |
80 |
| - with pytest.raises( |
81 |
| - ValueError, |
82 |
| - match=f"must be subset of {set(DATA_FILENAMES)}", |
83 |
| - ): |
84 |
| - load_train_test(["invalid-part"]) |
| 85 | +def test_load_train_test_raises(tmp_path: Path) -> None: |
| 86 | + # bad data name |
| 87 | + with pytest.raises(ValueError, match=f"must be subset of {set(DATA_FILENAMES)}"): |
| 88 | + load_train_test(["bad-data-name"]) |
| 89 | + |
| 90 | + # bad_version |
| 91 | + version = "not-a-real-branch" |
| 92 | + with pytest.raises(ValueError) as exc_info: |
| 93 | + load_train_test("wbm-summary", version=version, cache_dir=tmp_path) |
85 | 94 |
|
86 |
| - with pytest.raises( |
87 |
| - ValueError, match="Only version 1 currently available, got version=2" |
88 |
| - ): |
89 |
| - load_train_test(version=2) |
| 95 | + assert ( |
| 96 | + str(exc_info.value) |
| 97 | + == "Bad url='https://raw.githubusercontent.com/janosh/matbench-discovery" |
| 98 | + f"/{version}/data/wbm/2022-10-19-wbm-summary.csv'" |
| 99 | + ) |
90 | 100 |
|
91 | 101 |
|
92 | 102 | def test_load_train_test_doc_str() -> None:
|
93 | 103 | doc_str = load_train_test.__doc__
|
94 | 104 | assert isinstance(doc_str, str) # mypy type narrowing
|
95 | 105 |
|
96 |
| - assert all(key in doc_str for key in DATA_FILENAMES) |
| 106 | + for name in DATA_FILENAMES: |
| 107 | + assert name in doc_str, f"Missing data {name=} in load_train_test() docstring" |
97 | 108 |
|
98 | 109 | # TODO refactor to load site URL from site/package.json for SSoT
|
99 | 110 | assert "https://matbench-discovery.janosh.dev" in doc_str
|
100 | 111 |
|
101 | 112 |
|
| 113 | +@pytest.mark.skipif(website_down, reason=f"{RAW_REPO_URL} unreachable") |
| 114 | +@pytest.mark.parametrize("version", ["main"]) # , "d00d475" |
| 115 | +def test_load_train_test_no_mock( |
| 116 | + version: str, capsys: CaptureFixture[str], tmp_path: Path |
| 117 | +) -> None: |
| 118 | + # this function runs the download from GitHub raw user content for real |
| 119 | + # hence takes some time and requires being online |
| 120 | + df_wbm = load_train_test("wbm-summary", version=version, cache_dir=tmp_path) |
| 121 | + assert df_wbm.shape == (256963, 17) |
| 122 | + assert set(df_wbm) > { |
| 123 | + "bandgap_pbe", |
| 124 | + "e_form_per_atom_mp2020_corrected", |
| 125 | + "e_form_per_atom_uncorrected", |
| 126 | + "e_form_per_atom_wbm", |
| 127 | + "e_hull_wbm", |
| 128 | + "formula", |
| 129 | + "n_sites", |
| 130 | + "uncorrected_energy", |
| 131 | + "uncorrected_energy_from_cse", |
| 132 | + "volume", |
| 133 | + "wyckoff_spglib", |
| 134 | + }, "Loaded df missing columns" |
| 135 | + |
| 136 | + stdout, stderr = capsys.readouterr() |
| 137 | + assert stderr == "" |
| 138 | + assert ( |
| 139 | + stdout |
| 140 | + == "Downloading 'wbm-summary' from https://raw.githubusercontent.com/janosh" |
| 141 | + f"/matbench-discovery/{version}/data/wbm/2022-10-19-wbm-summary.csv\n" |
| 142 | + ) |
| 143 | + |
| 144 | + df_wbm = load_train_test("wbm-summary", version=version, cache_dir=tmp_path) |
| 145 | + |
| 146 | + stdout, stderr = capsys.readouterr() |
| 147 | + assert stderr == "" |
| 148 | + assert ( |
| 149 | + stdout |
| 150 | + == f"Loading 'wbm-summary' from cached file at '{tmp_path}/main/wbm/2022-10-19-" |
| 151 | + "wbm-summary.csv'\n" |
| 152 | + ) |
| 153 | + |
| 154 | + |
102 | 155 | def test_chunks() -> None:
|
103 | 156 | assert list(chunks([], 1)) == []
|
104 | 157 | assert list(chunks([1], 1)) == [[1]]
|
|
0 commit comments