Skip to content

Commit a943140

Browse files
authored
upgrade light-the-torch requirements (#5)
* upgrade light-the-torch requirements * extract force_cpu help from light-the-torch
1 parent 993aa58 commit a943140

File tree

3 files changed

+77
-32
lines changed

3 files changed

+77
-32
lines changed

setup.cfg

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ packages = find:
2525
include_package_data = True
2626
python_requires = >=3.6
2727
install_requires =
28-
light-the-torch>=0.1.1
28+
light-the-torch>=0.2
2929
tox
3030

3131
[options.packages.find]

tests/test_plugin.py

+56-25
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,35 @@
33
from light_the_torch.computation_backend import CPUBackend
44

55

6-
@pytest.mark.slow
7-
def test_help_ini(cmd):
8-
result = cmd("--help-ini")
9-
result.assert_success(is_run_test_env=False)
10-
assert "disable_light_the_torch" in result.out
11-
assert "force_cpu" in result.out
6+
@pytest.fixture
7+
def patch_extract_dists(mocker):
8+
def patch_extract_dists_(return_value=None):
9+
if return_value is None:
10+
return_value = []
11+
return mocker.patch(
12+
"tox_ltt.plugin.ltt.extract_dists", return_value=return_value
13+
)
14+
return mocker.patch()
15+
16+
return patch_extract_dists_
17+
18+
19+
@pytest.fixture
20+
def patch_find_links(mocker):
21+
def patch_find_links_(return_value=None):
22+
if return_value is None:
23+
return_value = []
24+
return mocker.patch(
25+
"tox_ltt.plugin.ltt.find_links", return_value=return_value
26+
)
27+
return mocker.patch()
28+
29+
return patch_find_links_
30+
31+
32+
@pytest.fixture
33+
def install_mock(mocker):
34+
return mocker.patch("tox.venv.VirtualEnv.run_install_command")
1235

1336

1437
def get_pyproject_toml():
@@ -100,14 +123,16 @@ def tox_ltt_initproj_(
100123
return tox_ltt_initproj_
101124

102125

103-
@pytest.fixture
104-
def install_mock(mocker):
105-
return mocker.patch("tox.venv.VirtualEnv.run_install_command")
126+
def test_help_ini(cmd):
127+
result = cmd("--help-ini")
128+
result.assert_success(is_run_test_env=False)
129+
assert "disable_light_the_torch" in result.out
130+
assert "force_cpu" in result.out
106131

107132

108133
@pytest.mark.slow
109-
def test_tox_ltt_disabled(mocker, tox_ltt_initproj, cmd):
110-
mock = mocker.patch("tox_ltt.plugin.ltt.resolve_dists")
134+
def test_tox_ltt_disabled(patch_extract_dists, tox_ltt_initproj, cmd):
135+
mock = patch_extract_dists()
111136
tox_ltt_initproj(disable_light_the_torch=True)
112137

113138
result = cmd()
@@ -117,9 +142,8 @@ def test_tox_ltt_disabled(mocker, tox_ltt_initproj, cmd):
117142

118143

119144
@pytest.mark.slow
120-
def test_tox_ltt_force_cpu(mocker, tox_ltt_initproj, cmd, install_mock):
121-
mock = mocker.patch("tox_ltt.plugin.ltt.find_links", return_value=[])
122-
145+
def test_tox_ltt_force_cpu(patch_find_links, tox_ltt_initproj, cmd, install_mock):
146+
mock = patch_find_links()
123147
tox_ltt_initproj(deps=("torch",), force_cpu=True)
124148

125149
result = cmd()
@@ -130,9 +154,10 @@ def test_tox_ltt_force_cpu(mocker, tox_ltt_initproj, cmd, install_mock):
130154
assert kwargs["computation_backend"] == CPUBackend()
131155

132156

133-
def test_tox_ltt_no_requirements(mocker, tox_ltt_initproj, cmd, install_mock):
134-
mock = mocker.patch("tox_ltt.plugin.ltt.resolve_dists")
135-
157+
def test_tox_ltt_no_requirements(
158+
patch_extract_dists, tox_ltt_initproj, cmd, install_mock
159+
):
160+
mock = patch_extract_dists()
136161
tox_ltt_initproj(skip_install=True)
137162

138163
result = cmd()
@@ -142,8 +167,10 @@ def test_tox_ltt_no_requirements(mocker, tox_ltt_initproj, cmd, install_mock):
142167

143168

144169
@pytest.mark.slow
145-
def test_tox_ltt_no_pytorch_dists(mocker, tox_ltt_initproj, cmd, install_mock):
146-
mock = mocker.patch("tox_ltt.plugin.ltt.find_links")
170+
def test_tox_ltt_no_pytorch_dists(
171+
patch_find_links, tox_ltt_initproj, cmd, install_mock
172+
):
173+
mock = patch_find_links()
147174

148175
deps = ("light-the-torch",)
149176
tox_ltt_initproj(deps=deps)
@@ -155,8 +182,10 @@ def test_tox_ltt_no_pytorch_dists(mocker, tox_ltt_initproj, cmd, install_mock):
155182

156183

157184
@pytest.mark.slow
158-
def test_tox_ltt_direct_pytorch_dists(mocker, tox_ltt_initproj, cmd, install_mock):
159-
mock = mocker.patch("tox_ltt.plugin.ltt.find_links", return_value=[])
185+
def test_tox_ltt_direct_pytorch_dists(
186+
patch_find_links, tox_ltt_initproj, cmd, install_mock
187+
):
188+
mock = patch_find_links()
160189

161190
deps = ("torch", "torchaudio", "torchtext", "torchvision")
162191
dists = set(deps)
@@ -171,8 +200,10 @@ def test_tox_ltt_direct_pytorch_dists(mocker, tox_ltt_initproj, cmd, install_moc
171200

172201

173202
@pytest.mark.slow
174-
def test_tox_ltt_indirect_pytorch_dists(mocker, tox_ltt_initproj, cmd, install_mock):
175-
mock = mocker.patch("tox_ltt.plugin.ltt.find_links", return_value=[])
203+
def test_tox_ltt_indirect_pytorch_dists(
204+
patch_find_links, tox_ltt_initproj, cmd, install_mock
205+
):
206+
mock = patch_find_links()
176207

177208
deps = ("git+https://github.com/pmeier/[email protected]",)
178209
dists = {"torch>=1.5.0", "torchvision>=0.6.0"}
@@ -187,9 +218,9 @@ def test_tox_ltt_indirect_pytorch_dists(mocker, tox_ltt_initproj, cmd, install_m
187218

188219

189220
def test_tox_ltt_project_pytorch_dists(
190-
subtests, mocker, tox_ltt_initproj, cmd, install_mock
221+
subtests, patch_find_links, tox_ltt_initproj, cmd, install_mock
191222
):
192-
mock = mocker.patch("tox_ltt.plugin.ltt.find_links", return_value=[])
223+
mock = patch_find_links()
193224

194225
install_requires = ("torch>=1.5.0", "torchvision>=0.6.0")
195226
dists = set(install_requires)

tox_ltt/plugin.py

+20-6
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional
1+
from typing import Any, Optional, Sequence, cast
22

33
import tox
44
from tox import reporter
@@ -7,9 +7,25 @@
77
from tox.venv import VirtualEnv
88

99
import light_the_torch as ltt
10+
from light_the_torch.cli import make_ltt_parser
1011
from light_the_torch.computation_backend import CPUBackend
1112

1213

14+
def extract_force_cpu_help() -> str:
15+
def extract(seq: Sequence, attr: str, eq_cond: Any) -> Any:
16+
reduced_seq = [item for item in seq if getattr(item, attr) == eq_cond]
17+
assert len(reduced_seq) == 1
18+
return reduced_seq[0]
19+
20+
ltt_parser = make_ltt_parser()
21+
22+
argument_group = extract(ltt_parser._action_groups, "title", "subcommands")
23+
sub_parsers = extract(argument_group._actions, "dest", "subcommand")
24+
install_parser = sub_parsers.choices["install"]
25+
force_cpu = extract(install_parser._actions, "dest", "force_cpu")
26+
return cast(str, force_cpu.help)
27+
28+
1329
@tox.hookimpl
1430
def tox_addoption(parser: Parser) -> None:
1531
parser.add_testenv_attribute(
@@ -18,11 +34,9 @@ def tox_addoption(parser: Parser) -> None:
1834
help="disable installing PyTorch distributions with light-the-torch",
1935
default=False,
2036
)
37+
2138
parser.add_testenv_attribute(
22-
name="force_cpu",
23-
type="bool",
24-
help="force CPU as computation backend",
25-
default=False,
39+
name="force_cpu", type="bool", help=extract_force_cpu_help(), default=False,
2640
)
2741

2842

@@ -54,7 +68,7 @@ def tox_testenv_install_deps(venv: VirtualEnv, action: Action) -> None:
5468

5569
action.setactivity("finddeps-light-the-torch", "")
5670

57-
dists = ltt.resolve_dists(requirements)
71+
dists = ltt.extract_dists(requirements)
5872

5973
if not dists:
6074
reporter.verbosity1(

0 commit comments

Comments
 (0)