Skip to content

Commit d2f3b4f

Browse files
Add support for MTC phase-2 drivers.
1 parent 8e2bbb9 commit d2f3b4f

File tree

2 files changed

+32
-58
lines changed

2 files changed

+32
-58
lines changed

xlml/apis/task.py

+25-46
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,7 @@ def run_queued_resource_test(
8585
post_process and clean_up.
8686
"""
8787

88-
with TaskGroup(
89-
group_id=task_test_config.benchmark_id, prefix_group_id=True
90-
) as test:
88+
with TaskGroup(group_id=task_test_config.benchmark_id, prefix_group_id=True) as test:
9189
with TaskGroup(group_id="provision") as provision:
9290
with TaskGroup(group_id="initialize"):
9391
tpu_name = tpu.generate_tpu_name(
@@ -161,9 +159,7 @@ class XpkTask(BaseTask):
161159
]
162160
task_gcp_config: gcp_config.GCPConfig
163161
task_metric_config: Optional[metric_config.MetricConfig] = None
164-
workload_provision_timeout: datetime.timedelta = datetime.timedelta(
165-
minutes=300
166-
)
162+
workload_provision_timeout: datetime.timedelta = datetime.timedelta(minutes=300)
167163

168164
def run(
169165
self,
@@ -215,9 +211,7 @@ def run_with_run_name_generation(self, use_pathways: bool = False) -> DAGNode:
215211
with TaskGroup(
216212
group_id=self.task_test_config.benchmark_id, prefix_group_id=True
217213
) as group:
218-
run_name = name_format.generate_run_name(
219-
self.task_test_config.benchmark_id
220-
)
214+
run_name = name_format.generate_run_name(self.task_test_config.benchmark_id)
221215
tb_file_location = name_format.generate_tb_file_location(
222216
run_name, self.task_metric_config.tensorboard_summary.file_location
223217
)
@@ -229,9 +223,7 @@ def run_with_run_name_generation(self, use_pathways: bool = False) -> DAGNode:
229223
self.task_test_config.run_model_cmds = new_run_model_cmds
230224

231225
# Update tensorboard file location
232-
self.task_metric_config.tensorboard_summary.file_location = (
233-
tb_file_location
234-
)
226+
self.task_metric_config.tensorboard_summary.file_location = tb_file_location
235227

236228
(
237229
run_name
@@ -248,6 +240,7 @@ def run_model(
248240
use_vertex_tensorboard: bool = False,
249241
use_pathways: bool = False,
250242
ramdisk_directory: str = "",
243+
mtc_enabled: bool = False,
251244
) -> DAGNode:
252245
"""Run the TPU/GPU test in `task_test_config` using xpk.
253246
@@ -274,6 +267,7 @@ def run_model(
274267
use_vertex_tensorboard,
275268
use_pathways,
276269
ramdisk_directory,
270+
mtc_enabled,
277271
)
278272
wait_for_workload_completion = xpk.wait_for_workload_completion.override(
279273
timeout=int(self.task_test_config.timeout.total_seconds()),
@@ -306,12 +300,11 @@ def launch_workload(
306300
use_vertex_tensorboard: bool,
307301
use_pathways: bool = False,
308302
ramdisk_directory: str = "",
303+
mtc_enabled: bool = False,
309304
) -> DAGNode:
310305
"""Create the workload and wait for it to provision."""
311306
with TaskGroup(group_id="launch_workload") as group:
312-
run_workload = xpk.run_workload.override(
313-
owner=self.task_test_config.task_owner
314-
)(
307+
run_workload = xpk.run_workload.override(owner=self.task_test_config.task_owner)(
315308
task_id="run_workload",
316309
cluster_project=self.task_gcp_config.project_name,
317310
zone=self.task_gcp_config.zone,
@@ -326,6 +319,7 @@ def launch_workload(
326319
use_vertex_tensorboard=use_vertex_tensorboard,
327320
use_pathways=use_pathways,
328321
ramdisk_directory=ramdisk_directory,
322+
mtc_enabled=mtc_enabled,
329323
)
330324
wait_for_workload_start = xpk.wait_for_workload_start.override(
331325
timeout=self.workload_provision_timeout.total_seconds()
@@ -411,9 +405,7 @@ def run(self) -> DAGNode:
411405
self.task_metric_config
412406
and self.task_metric_config.use_runtime_generated_gcs_folder
413407
):
414-
env_variable = {
415-
f"{metric_config.SshEnvVars.GCS_OUTPUT.name}": gcs_location
416-
}
408+
env_variable = {f"{metric_config.SshEnvVars.GCS_OUTPUT.name}": gcs_location}
417409
else:
418410
env_variable = None
419411
run_model = self.run_model(ip_address, ssh_keys, env_variable)
@@ -445,9 +437,7 @@ def run_with_existing_instance(self) -> DAGNode:
445437
self.task_metric_config
446438
and self.task_metric_config.use_runtime_generated_gcs_folder
447439
):
448-
env_variable = {
449-
f"{metric_config.SshEnvVars.GCS_OUTPUT.name}": gcs_location
450-
}
440+
env_variable = {f"{metric_config.SshEnvVars.GCS_OUTPUT.name}": gcs_location}
451441
else:
452442
env_variable = None
453443
post_process = self.post_process(gcs_location)
@@ -458,7 +448,12 @@ def run_with_existing_instance(self) -> DAGNode:
458448

459449
def provision_via_existing_instance(
460450
self,
461-
) -> Tuple[DAGNode, airflow.XComArg, airflow.XComArg, airflow.XComArg,]:
451+
) -> Tuple[
452+
DAGNode,
453+
airflow.XComArg,
454+
airflow.XComArg,
455+
airflow.XComArg,
456+
]:
462457
"""Provision an existing GPU accelerator.
463458
464459
Returns:
@@ -575,9 +570,7 @@ def post_process(
575570
)
576571
return group
577572

578-
def clean_up(
579-
self, resource: airflow.XComArg, project_id: str, zone: str
580-
) -> DAGNode:
573+
def clean_up(self, resource: airflow.XComArg, project_id: str, zone: str) -> DAGNode:
581574
"""Clean up GPU resources created by `provision`.
582575
583576
Args:
@@ -590,9 +583,7 @@ def clean_up(
590583
Raises:
591584
AirflowTaskTimeout: An error occurs when execution_timeout is breached.
592585
"""
593-
return gpu.delete_resource.override(group_id="clean_up")(
594-
resource, project_id, zone
595-
)
586+
return gpu.delete_resource.override(group_id="clean_up")(resource, project_id, zone)
596587

597588
def clean_up_existing_instance(self, ssh_keys: airflow.XComArg) -> DAGNode:
598589
"""Clean up existing GPU resources - remove the one-time use generated ssh_keys.
@@ -655,9 +646,7 @@ def run(self) -> DAGNode:
655646
gcs_location >> gke_run >> post_process
656647
return group
657648

658-
def post_process(
659-
self, result_location: Optional[airflow.XComArg] = None
660-
) -> DAGNode:
649+
def post_process(self, result_location: Optional[airflow.XComArg] = None) -> DAGNode:
661650
"""Process metrics and metadata, and insert them into BigQuery tables.
662651
663652
Returns:
@@ -688,9 +677,7 @@ def _get_job_manifest(self):
688677
},
689678
},
690679
"spec": {
691-
"activeDeadlineSeconds": int(
692-
self.task_test_config.timeout.total_seconds()
693-
)
680+
"activeDeadlineSeconds": int(self.task_test_config.timeout.total_seconds())
694681
or 3600,
695682
"backoffLimit": 0,
696683
"completionMode": "Indexed",
@@ -713,12 +700,8 @@ def _get_job_manifest(self):
713700
"name": "main",
714701
"image": self.task_test_config.docker_image,
715702
"imagePullPolicy": "Always",
716-
"command": shlex.split(
717-
self.task_test_config.setup_script
718-
),
719-
"args": shlex.split(
720-
self.task_test_config.test_script
721-
),
703+
"command": shlex.split(self.task_test_config.setup_script),
704+
"args": shlex.split(self.task_test_config.test_script),
722705
"resources": {
723706
"limits": {
724707
"nvidia.com/gpu": accelerator.count,
@@ -728,17 +711,13 @@ def _get_job_manifest(self):
728711
{
729712
"name": "POD_NAME",
730713
"valueFrom": {
731-
"fieldRef": {
732-
"fieldPath": "metadata.name"
733-
}
714+
"fieldRef": {"fieldPath": "metadata.name"}
734715
},
735716
},
736717
{
737718
"name": "POD_NAMESPACE",
738719
"valueFrom": {
739-
"fieldRef": {
740-
"fieldPath": "metadata.namespace"
741-
}
720+
"fieldRef": {"fieldPath": "metadata.namespace"}
742721
},
743722
},
744723
{

xlml/utils/xpk.py

+7-12
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def run_workload(
7777
use_vertex_tensorboard: bool = False,
7878
use_pathways: bool = False,
7979
ramdisk_directory: str = "", # Directory for enabling emergency checkpointing
80+
mtc_enabled: bool = False, # It enables MTC phase-2 drivers
8081
):
8182
"""Run workload through xpk tool."""
8283

@@ -103,6 +104,8 @@ def run_workload(
103104
)
104105
if ramdisk_directory:
105106
workload_create_cmd += f" --ramdisk-directory={ramdisk_directory}"
107+
if mtc_enabled:
108+
workload_create_cmd += " --mtc-enabled"
106109
cmds = get_xpk_setup_cmd(tmpdir)
107110
if accelerator_type == GpuVersion.XPK_H100_MEGA.value:
108111
workload_create_cmd += " --scheduler=gke.io/topology-aware-auto"
@@ -118,9 +121,7 @@ def run_workload(
118121
["bash", "-c", ";".join(cmds)],
119122
env={**os.environ, "KUBECONFIG": os.path.join(tmpdir, "xpk.conf")},
120123
)
121-
assert (
122-
result.exit_code == 0
123-
), f"XPK command failed with code {result.exit_code}"
124+
assert result.exit_code == 0, f"XPK command failed with code {result.exit_code}"
124125

125126

126127
def _get_core_api_client(
@@ -155,9 +156,7 @@ def _get_batch_api_client(
155156

156157
# Initilize the client
157158
batch_api = k8s_client.BatchV1Api(client)
158-
logging.info(
159-
"Successful initilize k8s batch api client from cluster response."
160-
)
159+
logging.info("Successful initilize k8s batch api client from cluster response.")
161160
return batch_api
162161

163162

@@ -210,9 +209,7 @@ def wait_for_workload_completion(
210209
batch_api = _get_batch_api_client(project_id, region, cluster_name)
211210
job = _get_workload_job(batch_api, workload_id)
212211
if job is None:
213-
logging.info(
214-
f"No pods or jobs were found for workload selector: {workload_id}"
215-
)
212+
logging.info(f"No pods or jobs were found for workload selector: {workload_id}")
216213
return False
217214

218215
if any(condition.type == "Failed" for condition in job.status.conditions):
@@ -280,6 +277,4 @@ def clean_up_workload(
280277
["bash", "-c", ";".join(cmds)],
281278
env={**os.environ, "KUBECONFIG": os.path.join(tmpdir, "xpk.conf")},
282279
)
283-
assert (
284-
result.exit_code == 0
285-
), f"XPK clean-up failed with code {result.exit_code}"
280+
assert result.exit_code == 0, f"XPK clean-up failed with code {result.exit_code}"

0 commit comments

Comments
 (0)