Skip to content

Commit a7b6449

Browse files
Add support for MTC phase-2 drivers.
1 parent 7a53baa commit a7b6449

File tree

4 files changed

+125
-4
lines changed

4 files changed

+125
-4
lines changed

dags/common/vm_resource.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,13 @@ class XpkClusters:
270270
project=Project.TPU_PROD_ENV_ONE_VM.value,
271271
zone=Zone.SOUTHAMERICA_WEST1_A.value,
272272
)
273-
273+
TPU_V6E_16_IN_MEM_CLUSTER = XpkClusterConfig(
274+
name="in-mem-airflow-v6e-16",
275+
device_version=TpuVersion.TRILLIUM,
276+
core_count=16,
277+
project=Project.TPU_PROD_ENV_ONE_VM.value,
278+
zone=Zone.US_EAST5_C.value,
279+
)
274280
GPU_A3_CLUSTER = XpkClusterConfig(
275281
name="ninacai-maxtext-a3",
276282
device_version=GpuVersion.XPK_H100,
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
A DAG to run MaxText multi-tier checkpointing tests.
17+
"""
18+
import datetime
19+
from airflow import models
20+
from dags import composer_env, gcs_bucket
21+
from dags.common import test_owner
22+
from dags.common.vm_resource import TpuVersion, Zone, DockerImage, XpkClusters
23+
from dags.multipod.configs import gke_config
24+
from dags.multipod.configs.common import SetupMode # Run once a day at 10 am UTC (2 am PST)
25+
26+
SCHEDULED_TIME = "0 10 * * *" if composer_env.is_prod_env() else None
27+
28+
with models.DAG(
29+
dag_id="maxtext_muti_tier_p2_checkpointing",
30+
schedule=SCHEDULED_TIME,
31+
tags=[
32+
"multipod_team",
33+
"maxtext",
34+
"multi_tier_checkpointing_p2",
35+
"nightly",
36+
],
37+
start_date=datetime.datetime(2025, 4, 17),
38+
catchup=False,
39+
concurrency=2,
40+
) as dag:
41+
base_output_directory = (
42+
f"{gcs_bucket.BASE_OUTPUT_DIR}/maxtext_multi_tier_p2_checkpointing"
43+
)
44+
dataset_path = gcs_bucket.MAXTEXT_DIR
45+
docker_images = [
46+
(SetupMode.NIGHTLY, DockerImage.MAXTEXT_TPU_JAX_NIGHTLY),
47+
]
48+
test_configs = {
49+
# accelerator: list of slices to test
50+
"v6e-16": [3],
51+
}
52+
clusters = {
53+
# accelerator: cluster name
54+
"v6e-16": XpkClusters.TPU_V6E_16_IN_MEM_CLUSTER,
55+
}
56+
57+
for mode, image in docker_images:
58+
for accelerator, slices in test_configs.items():
59+
for slice_num in slices:
60+
command = (
61+
"bash end_to_end/test_mtc_phase_2_save_path.sh"
62+
f" multi_tier_checkpointing-{slice_num}x-{accelerator}"
63+
f" {base_output_directory} {dataset_path}",
64+
)
65+
maxtext_v6e_chkpt_save_test = gke_config.get_gke_config(
66+
num_slices=slice_num,
67+
cluster=clusters[accelerator],
68+
time_out_in_min=60,
69+
test_name="maxtext-multi-tier-checkpointing-p2-save",
70+
run_model_cmds=command,
71+
docker_image=image.value,
72+
test_owner=test_owner.ABHINAV_S,
73+
).run(ramdisk_directory="local", mtc_enabled=True)
74+
75+
command = "rm -rf /local/*"
76+
ramdisk_single_slice_cleanup = gke_config.get_gke_config(
77+
num_slices=1,
78+
cluster=clusters[accelerator],
79+
time_out_in_min=60,
80+
test_name="maxtext-multi-tier-checkpointing-p2-emulate-disruption",
81+
run_model_cmds=command,
82+
docker_image=image.value,
83+
test_owner=test_owner.ABHINAV_S,
84+
).run(ramdisk_directory="local", mtc_enabled=True)
85+
command = (
86+
"bash end_to_end/test_mtc_phase_2_save_path.sh"
87+
f" multi_tier_checkpointing-{slice_num}x-{accelerator}"
88+
f" {base_output_directory} {dataset_path}",
89+
)
90+
91+
maxtext_v6e_chkpt_restore_test = gke_config.get_gke_config(
92+
num_slices=slice_num,
93+
cluster=clusters[accelerator],
94+
time_out_in_min=60,
95+
test_name="maxtext-multi-tier-checkpointing-p2-restore",
96+
run_model_cmds=command,
97+
docker_image=image.value,
98+
test_owner=test_owner.ABHINAV_S,
99+
).run(ramdisk_directory="local", mtc_enabled=True)
100+
101+
(
102+
maxtext_v6e_chkpt_save_test
103+
>> ramdisk_single_slice_cleanup
104+
>> maxtext_v6e_chkpt_restore_test
105+
)

xlml/apis/task.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ def run(
173173
use_pathways: bool = False,
174174
skip_post_process: bool = False,
175175
ramdisk_directory: str = "",
176+
mtc_enabled: bool = False,
176177
xpk_branch: str = xpk.MAIN_BRANCH,
177178
) -> DAGNode:
178179
"""Run a test job within a docker image.
@@ -192,6 +193,7 @@ def run(
192193
use_vertex_tensorboard,
193194
use_pathways,
194195
ramdisk_directory,
196+
mtc_enabled,
195197
xpk_branch,
196198
)
197199
if not skip_post_process:
@@ -258,6 +260,7 @@ def run_model(
258260
use_vertex_tensorboard: bool = False,
259261
use_pathways: bool = False,
260262
ramdisk_directory: str = "",
263+
mtc_enabled: bool = False,
261264
xpk_branch: str = xpk.MAIN_BRANCH,
262265
) -> DAGNode:
263266
"""Run the TPU/GPU test in `task_test_config` using xpk.
@@ -285,6 +288,7 @@ def run_model(
285288
use_vertex_tensorboard,
286289
use_pathways,
287290
ramdisk_directory,
291+
mtc_enabled,
288292
xpk_branch,
289293
)
290294
wait_for_workload_completion = xpk.wait_for_workload_completion.override(
@@ -318,6 +322,7 @@ def launch_workload(
318322
use_vertex_tensorboard: bool,
319323
use_pathways: bool = False,
320324
ramdisk_directory: str = "",
325+
mtc_enabled: bool = False,
321326
xpk_branch: str = xpk.MAIN_BRANCH,
322327
) -> DAGNode:
323328
"""Create the workload and wait for it to provision."""
@@ -339,6 +344,7 @@ def launch_workload(
339344
use_vertex_tensorboard=use_vertex_tensorboard,
340345
use_pathways=use_pathways,
341346
ramdisk_directory=ramdisk_directory,
347+
mtc_enabled=mtc_enabled,
342348
xpk_branch=xpk_branch,
343349
)
344350
wait_for_workload_start = xpk.wait_for_workload_start.override(
@@ -702,9 +708,10 @@ def _get_job_manifest(self):
702708
},
703709
},
704710
"spec": {
705-
"activeDeadlineSeconds": (
706-
int(self.task_test_config.timeout.total_seconds()) or 3600
707-
),
711+
"activeDeadlineSeconds": int(
712+
self.task_test_config.timeout.total_seconds()
713+
)
714+
or 3600,
708715
"backoffLimit": 0,
709716
"completionMode": "Indexed",
710717
"completions": self.task_test_config.num_hosts,

xlml/utils/xpk.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def run_workload(
8989
use_vertex_tensorboard: bool = False,
9090
use_pathways: bool = False,
9191
ramdisk_directory: str = "", # Directory for enabling emergency checkpointing
92+
mtc_enabled: bool = False, # It enables MTC phase-2 drivers
9293
xpk_branch: str = MAIN_BRANCH,
9394
):
9495
"""Run workload through xpk tool."""
@@ -116,6 +117,8 @@ def run_workload(
116117
)
117118
if ramdisk_directory:
118119
workload_create_cmd += f" --ramdisk-directory={ramdisk_directory}"
120+
if mtc_enabled:
121+
workload_create_cmd += " --mtc-enabled"
119122

120123
# If using a valid GPU and the XPK branch is set to "main", then branch is switch to "v0.4.1".
121124
if is_valid_gpu_version(accelerator_type) and xpk_branch == MAIN_BRANCH:

0 commit comments

Comments
 (0)