Skip to content

Commit 9df80db

Browse files
authored
enable pytorch_channel option (#13)
1 parent 6407fbe commit 9df80db

File tree

3 files changed

+33
-6
lines changed

3 files changed

+33
-6
lines changed

setup.cfg

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ packages = find:
2525
include_package_data = True
2626
python_requires = >=3.6
2727
install_requires =
28-
light-the-torch>=0.2
28+
light-the-torch>=0.3
2929
tox
3030

3131
[options.packages.find]

tests/test_plugin.py

+21
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def get_setup_cfg(name, version, install_requires=None, extra_requires=None):
6565
def get_tox_ini(
6666
basepython=None,
6767
disable_light_the_torch=None,
68+
pytorch_channel=None,
6869
pytorch_force_cpu=None,
6970
force_cpu=None,
7071
deps=None,
@@ -90,6 +91,8 @@ def get_tox_ini(
9091
lines.append("extras = extra")
9192
if disable_light_the_torch is not None:
9293
lines.append(f"disable_light_the_torch = {disable_light_the_torch}")
94+
if pytorch_channel is not None:
95+
lines.append(f"pytorch_channel = {pytorch_channel}")
9396
if pytorch_force_cpu is not None:
9497
lines.append(f"pytorch_force_cpu = {pytorch_force_cpu}")
9598
if force_cpu is not None:
@@ -110,6 +113,7 @@ def tox_ltt_initproj_(
110113
install_requires=None,
111114
extra_requires=None,
112115
disable_light_the_torch=None,
116+
pytorch_channel=None,
113117
pytorch_force_cpu=None,
114118
force_cpu=None,
115119
deps=None,
@@ -130,6 +134,7 @@ def tox_ltt_initproj_(
130134
usedevelop=usedevelop,
131135
extra=extra_requires is not None,
132136
disable_light_the_torch=disable_light_the_torch,
137+
pytorch_channel=pytorch_channel,
133138
pytorch_force_cpu=pytorch_force_cpu,
134139
force_cpu=force_cpu,
135140
deps=deps,
@@ -151,6 +156,7 @@ def test_help_ini(cmd):
151156
result = cmd("--help-ini")
152157
result.assert_success(is_run_test_env=False)
153158
assert "disable_light_the_torch" in result.out
159+
assert "pytorch_channel" in result.out
154160
assert "pytorch_force_cpu" in result.out
155161

156162

@@ -165,6 +171,21 @@ def test_tox_ltt_disabled(patch_extract_dists, tox_ltt_initproj, cmd):
165171
mock.assert_not_called()
166172

167173

174+
@pytest.mark.slow
175+
def test_tox_ltt_pytorch_channel(patch_find_links, tox_ltt_initproj, cmd, install_mock):
176+
channel = "channel"
177+
178+
mock = patch_find_links()
179+
tox_ltt_initproj(deps=("torch",), pytorch_channel=channel)
180+
181+
result = cmd()
182+
183+
result.assert_success(is_run_test_env=False)
184+
185+
_, kwargs = mock.call_args
186+
assert kwargs["channel"] == channel
187+
188+
168189
@pytest.mark.slow
169190
def test_tox_ltt_pytorch_force_cpu(
170191
patch_find_links, tox_ltt_initproj, cmd, install_mock

tox_ltt/plugin.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from light_the_torch.computation_backend import CPUBackend
1313

1414

15-
def extract_force_cpu_help() -> str:
15+
def extract_ltt_option_help(subcommand: str, option: str) -> str:
1616
def extract(seq: Sequence, attr: str, eq_cond: Any) -> Any:
1717
reduced_seq = [item for item in seq if getattr(item, attr) == eq_cond]
1818
assert len(reduced_seq) == 1
@@ -22,9 +22,8 @@ def extract(seq: Sequence, attr: str, eq_cond: Any) -> Any:
2222

2323
argument_group = extract(ltt_parser._action_groups, "title", "subcommands")
2424
sub_parsers = extract(argument_group._actions, "dest", "subcommand")
25-
install_parser = sub_parsers.choices["install"]
26-
force_cpu = extract(install_parser._actions, "dest", "force_cpu")
27-
return cast(str, force_cpu.help)
25+
subcommand_parser = sub_parsers.choices[subcommand]
26+
return cast(str, extract(subcommand_parser._actions, "dest", option).help)
2827

2928

3029
@tox.hookimpl
@@ -35,10 +34,16 @@ def tox_addoption(parser: Parser) -> None:
3534
help="disable installing PyTorch distributions with light-the-torch",
3635
default=False,
3736
)
37+
parser.add_testenv_attribute(
38+
name="pytorch_channel",
39+
type="string",
40+
help=extract_ltt_option_help("install", "channel"),
41+
default="stable",
42+
)
3843
parser.add_testenv_attribute(
3944
name="pytorch_force_cpu",
4045
type="bool",
41-
help=extract_force_cpu_help(),
46+
help=extract_ltt_option_help("install", "force_cpu"),
4247
default=False,
4348
)
4449
parser.add_testenv_attribute(
@@ -98,6 +103,7 @@ def tox_testenv_install_deps(venv: VirtualEnv, action: Action) -> None:
98103
links = ltt.find_links(
99104
dists,
100105
computation_backend=get_computation_backend(envconfig),
106+
channel=envconfig.pytorch_channel,
101107
python_version=get_python_version(envconfig),
102108
)
103109

0 commit comments

Comments
 (0)