@@ -65,6 +65,7 @@ def get_setup_cfg(name, version, install_requires=None, extra_requires=None):
65
65
def get_tox_ini (
66
66
basepython = None ,
67
67
disable_light_the_torch = None ,
68
+ pytorch_channel = None ,
68
69
pytorch_force_cpu = None ,
69
70
force_cpu = None ,
70
71
deps = None ,
@@ -90,6 +91,8 @@ def get_tox_ini(
90
91
lines .append ("extras = extra" )
91
92
if disable_light_the_torch is not None :
92
93
lines .append (f"disable_light_the_torch = { disable_light_the_torch } " )
94
+ if pytorch_channel is not None :
95
+ lines .append (f"pytorch_channel = { pytorch_channel } " )
93
96
if pytorch_force_cpu is not None :
94
97
lines .append (f"pytorch_force_cpu = { pytorch_force_cpu } " )
95
98
if force_cpu is not None :
@@ -110,6 +113,7 @@ def tox_ltt_initproj_(
110
113
install_requires = None ,
111
114
extra_requires = None ,
112
115
disable_light_the_torch = None ,
116
+ pytorch_channel = None ,
113
117
pytorch_force_cpu = None ,
114
118
force_cpu = None ,
115
119
deps = None ,
@@ -130,6 +134,7 @@ def tox_ltt_initproj_(
130
134
usedevelop = usedevelop ,
131
135
extra = extra_requires is not None ,
132
136
disable_light_the_torch = disable_light_the_torch ,
137
+ pytorch_channel = pytorch_channel ,
133
138
pytorch_force_cpu = pytorch_force_cpu ,
134
139
force_cpu = force_cpu ,
135
140
deps = deps ,
@@ -151,6 +156,7 @@ def test_help_ini(cmd):
151
156
result = cmd ("--help-ini" )
152
157
result .assert_success (is_run_test_env = False )
153
158
assert "disable_light_the_torch" in result .out
159
+ assert "pytorch_channel" in result .out
154
160
assert "pytorch_force_cpu" in result .out
155
161
156
162
@@ -165,6 +171,21 @@ def test_tox_ltt_disabled(patch_extract_dists, tox_ltt_initproj, cmd):
165
171
mock .assert_not_called ()
166
172
167
173
174
+ @pytest .mark .slow
175
+ def test_tox_ltt_pytorch_channel (patch_find_links , tox_ltt_initproj , cmd , install_mock ):
176
+ channel = "channel"
177
+
178
+ mock = patch_find_links ()
179
+ tox_ltt_initproj (deps = ("torch" ,), pytorch_channel = channel )
180
+
181
+ result = cmd ()
182
+
183
+ result .assert_success (is_run_test_env = False )
184
+
185
+ _ , kwargs = mock .call_args
186
+ assert kwargs ["channel" ] == channel
187
+
188
+
168
189
@pytest .mark .slow
169
190
def test_tox_ltt_pytorch_force_cpu (
170
191
patch_find_links , tox_ltt_initproj , cmd , install_mock
0 commit comments