@@ -173,6 +173,7 @@ def run(
173
173
use_pathways : bool = False ,
174
174
skip_post_process : bool = False ,
175
175
ramdisk_directory : str = "" ,
176
+
176
177
) -> DAGNode :
177
178
"""Run a test job within a docker image.
178
179
@@ -248,6 +249,7 @@ def run_model(
248
249
use_vertex_tensorboard : bool = False ,
249
250
use_pathways : bool = False ,
250
251
ramdisk_directory : str = "" ,
252
+ mtc_enabled : bool = False ,
251
253
) -> DAGNode :
252
254
"""Run the TPU/GPU test in `task_test_config` using xpk.
253
255
@@ -274,6 +276,7 @@ def run_model(
274
276
use_vertex_tensorboard ,
275
277
use_pathways ,
276
278
ramdisk_directory ,
279
+ mtc_enabled ,
277
280
)
278
281
wait_for_workload_completion = xpk .wait_for_workload_completion .override (
279
282
timeout = int (self .task_test_config .timeout .total_seconds ()),
@@ -306,6 +309,7 @@ def launch_workload(
306
309
use_vertex_tensorboard : bool ,
307
310
use_pathways : bool = False ,
308
311
ramdisk_directory : str = "" ,
312
+ mtc_enabled : bool = False ,
309
313
) -> DAGNode :
310
314
"""Create the workload and wait for it to provision."""
311
315
with TaskGroup (group_id = "launch_workload" ) as group :
@@ -326,6 +330,7 @@ def launch_workload(
326
330
use_vertex_tensorboard = use_vertex_tensorboard ,
327
331
use_pathways = use_pathways ,
328
332
ramdisk_directory = ramdisk_directory ,
333
+ mtc_enabled = mtc_enabled ,
329
334
)
330
335
wait_for_workload_start = xpk .wait_for_workload_start .override (
331
336
timeout = self .workload_provision_timeout .total_seconds ()
0 commit comments