Skip to content

Add MTC Phase-2 XLML tests and option for enabling MTC drivers in xpk #651

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion dags/common/vm_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,13 @@ class XpkClusters:
project=Project.TPU_PROD_ENV_ONE_VM.value,
zone=Zone.SOUTHAMERICA_WEST1_A.value,
)

TPU_V6E_16_IN_MEM_CLUSTER = XpkClusterConfig(
name="in-mem-airflow-v6e-16",
device_version=TpuVersion.TRILLIUM,
core_count=16,
project=Project.TPU_PROD_ENV_ONE_VM.value,
zone=Zone.US_EAST5_C.value,
)
GPU_A3_CLUSTER = XpkClusterConfig(
name="ninacai-maxtext-a3",
device_version=GpuVersion.XPK_H100,
Expand Down
105 changes: 105 additions & 0 deletions dags/multipod/maxtext_multi_tier_p2_checkpointing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
A DAG to run MaxText multi-tier checkpointing tests.
"""
import datetime
from airflow import models
from dags import composer_env, gcs_bucket
from dags.common import test_owner
from dags.common.vm_resource import TpuVersion, Zone, DockerImage, XpkClusters
from dags.multipod.configs import gke_config
from dags.multipod.configs.common import SetupMode # Run once a day at 10 am UTC (2 am PST)

SCHEDULED_TIME = "0 10 * * *" if composer_env.is_prod_env() else None

with models.DAG(
dag_id="maxtext_muti_tier_p2_checkpointing",
schedule=SCHEDULED_TIME,
tags=[
"multipod_team",
"maxtext",
"multi_tier_checkpointing_p2",
"nightly",
],
start_date=datetime.datetime(2025, 4, 17),
catchup=False,
concurrency=2,
) as dag:
base_output_directory = (
f"{gcs_bucket.BASE_OUTPUT_DIR}/maxtext_multi_tier_p2_checkpointing"
)
dataset_path = gcs_bucket.MAXTEXT_DIR
docker_images = [
(SetupMode.NIGHTLY, DockerImage.MAXTEXT_TPU_JAX_NIGHTLY),
]
test_configs = {
# accelerator: list of slices to test
"v6e-16": [3],
}
clusters = {
# accelerator: cluster name
"v6e-16": XpkClusters.TPU_V6E_16_IN_MEM_CLUSTER,
}

for mode, image in docker_images:
for accelerator, slices in test_configs.items():
for slice_num in slices:
command = (
"bash end_to_end/test_mtc_phase_2_save_path.sh"
f" multi_tier_checkpointing-{slice_num}x-{accelerator}"
f" {base_output_directory} {dataset_path}",
)
maxtext_v6e_chkpt_save_test = gke_config.get_gke_config(
num_slices=slice_num,
cluster=clusters[accelerator],
time_out_in_min=60,
test_name="maxtext-multi-tier-checkpointing-p2-save",
run_model_cmds=command,
docker_image=image.value,
test_owner=test_owner.ABHINAV_S,
).run(ramdisk_directory="local", mtc_enabled=True)

command = "rm -rf /local/*"
ramdisk_single_slice_cleanup = gke_config.get_gke_config(
num_slices=1,
cluster=clusters[accelerator],
time_out_in_min=60,
test_name="maxtext-multi-tier-checkpointing-p2-emulate-disruption",
run_model_cmds=command,
docker_image=image.value,
test_owner=test_owner.ABHINAV_S,
).run(ramdisk_directory="local", mtc_enabled=True)
command = (
"bash end_to_end/test_mtc_phase_2_save_path.sh"
f" multi_tier_checkpointing-{slice_num}x-{accelerator}"
f" {base_output_directory} {dataset_path}",
)

maxtext_v6e_chkpt_restore_test = gke_config.get_gke_config(
num_slices=slice_num,
cluster=clusters[accelerator],
time_out_in_min=60,
test_name="maxtext-multi-tier-checkpointing-p2-restore",
run_model_cmds=command,
docker_image=image.value,
test_owner=test_owner.ABHINAV_S,
).run(ramdisk_directory="local", mtc_enabled=True)

(
maxtext_v6e_chkpt_save_test
>> ramdisk_single_slice_cleanup
>> maxtext_v6e_chkpt_restore_test
)
13 changes: 10 additions & 3 deletions xlml/apis/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def run(
use_pathways: bool = False,
skip_post_process: bool = False,
ramdisk_directory: str = "",
mtc_enabled: bool = False,
xpk_branch: str = xpk.MAIN_BRANCH,
) -> DAGNode:
"""Run a test job within a docker image.
Expand All @@ -192,6 +193,7 @@ def run(
use_vertex_tensorboard,
use_pathways,
ramdisk_directory,
mtc_enabled,
xpk_branch,
)
if not skip_post_process:
Expand Down Expand Up @@ -258,6 +260,7 @@ def run_model(
use_vertex_tensorboard: bool = False,
use_pathways: bool = False,
ramdisk_directory: str = "",
mtc_enabled: bool = False,
xpk_branch: str = xpk.MAIN_BRANCH,
) -> DAGNode:
"""Run the TPU/GPU test in `task_test_config` using xpk.
Expand Down Expand Up @@ -285,6 +288,7 @@ def run_model(
use_vertex_tensorboard,
use_pathways,
ramdisk_directory,
mtc_enabled,
xpk_branch,
)
wait_for_workload_completion = xpk.wait_for_workload_completion.override(
Expand Down Expand Up @@ -318,6 +322,7 @@ def launch_workload(
use_vertex_tensorboard: bool,
use_pathways: bool = False,
ramdisk_directory: str = "",
mtc_enabled: bool = False,
xpk_branch: str = xpk.MAIN_BRANCH,
) -> DAGNode:
"""Create the workload and wait for it to provision."""
Expand All @@ -339,6 +344,7 @@ def launch_workload(
use_vertex_tensorboard=use_vertex_tensorboard,
use_pathways=use_pathways,
ramdisk_directory=ramdisk_directory,
mtc_enabled=mtc_enabled,
xpk_branch=xpk_branch,
)
wait_for_workload_start = xpk.wait_for_workload_start.override(
Expand Down Expand Up @@ -702,9 +708,10 @@ def _get_job_manifest(self):
},
},
"spec": {
"activeDeadlineSeconds": (
int(self.task_test_config.timeout.total_seconds()) or 3600
),
"activeDeadlineSeconds": int(
self.task_test_config.timeout.total_seconds()
)
or 3600,
"backoffLimit": 0,
"completionMode": "Indexed",
"completions": self.task_test_config.num_hosts,
Expand Down
3 changes: 3 additions & 0 deletions xlml/utils/xpk.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def run_workload(
use_vertex_tensorboard: bool = False,
use_pathways: bool = False,
ramdisk_directory: str = "", # Directory for enabling emergency checkpointing
mtc_enabled: bool = False, # It enables MTC phase-2 drivers
xpk_branch: str = MAIN_BRANCH,
):
"""Run workload through xpk tool."""
Expand Down Expand Up @@ -116,6 +117,8 @@ def run_workload(
)
if ramdisk_directory:
workload_create_cmd += f" --ramdisk-directory={ramdisk_directory}"
if mtc_enabled:
workload_create_cmd += " --mtc-enabled"

# If using a valid GPU and the XPK branch is set to "main", then branch is switch to "v0.4.1".
if is_valid_gpu_version(accelerator_type) and xpk_branch == MAIN_BRANCH:
Expand Down