Skip to content

Commit 3d1034d

Browse files
Add support for MTC phase-2 drivers.
1 parent 8e2bbb9 commit 3d1034d

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

xlml/apis/task.py

+5
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ def run(
173173
use_pathways: bool = False,
174174
skip_post_process: bool = False,
175175
ramdisk_directory: str = "",
176+
176177
) -> DAGNode:
177178
"""Run a test job within a docker image.
178179
@@ -248,6 +249,7 @@ def run_model(
248249
use_vertex_tensorboard: bool = False,
249250
use_pathways: bool = False,
250251
ramdisk_directory: str = "",
252+
mtc_enabled: bool = False,
251253
) -> DAGNode:
252254
"""Run the TPU/GPU test in `task_test_config` using xpk.
253255
@@ -274,6 +276,7 @@ def run_model(
274276
use_vertex_tensorboard,
275277
use_pathways,
276278
ramdisk_directory,
279+
mtc_enabled,
277280
)
278281
wait_for_workload_completion = xpk.wait_for_workload_completion.override(
279282
timeout=int(self.task_test_config.timeout.total_seconds()),
@@ -306,6 +309,7 @@ def launch_workload(
306309
use_vertex_tensorboard: bool,
307310
use_pathways: bool = False,
308311
ramdisk_directory: str = "",
312+
mtc_enabled: bool = False,
309313
) -> DAGNode:
310314
"""Create the workload and wait for it to provision."""
311315
with TaskGroup(group_id="launch_workload") as group:
@@ -326,6 +330,7 @@ def launch_workload(
326330
use_vertex_tensorboard=use_vertex_tensorboard,
327331
use_pathways=use_pathways,
328332
ramdisk_directory=ramdisk_directory,
333+
mtc_enabled=mtc_enabled,
329334
)
330335
wait_for_workload_start = xpk.wait_for_workload_start.override(
331336
timeout=self.workload_provision_timeout.total_seconds()

xlml/utils/xpk.py

+3
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def run_workload(
7777
use_vertex_tensorboard: bool = False,
7878
use_pathways: bool = False,
7979
ramdisk_directory: str = "", # Directory for enabling emergency checkpointing
80+
mtc_enabled: bool = False, # It enables MTC phase-2 drivers
8081
):
8182
"""Run workload through xpk tool."""
8283

@@ -103,6 +104,8 @@ def run_workload(
103104
)
104105
if ramdisk_directory:
105106
workload_create_cmd += f" --ramdisk-directory={ramdisk_directory}"
107+
if mtc_enabled:
108+
workload_create_cmd += " --mtc-enabled"
106109
cmds = get_xpk_setup_cmd(tmpdir)
107110
if accelerator_type == GpuVersion.XPK_H100_MEGA.value:
108111
workload_create_cmd += " --scheduler=gke.io/topology-aware-auto"

0 commit comments

Comments
 (0)