From d2f3b4f246f387752dce6d1b87a9f256fcb0733d Mon Sep 17 00:00:00 2001 From: Abhinav Singh Date: Fri, 4 Apr 2025 17:57:09 +0000 Subject: [PATCH] Add support for MTC phase-2 drivers. --- xlml/apis/task.py | 71 +++++++++++++++++------------------------------ xlml/utils/xpk.py | 19 +++++-------- 2 files changed, 32 insertions(+), 58 deletions(-) diff --git a/xlml/apis/task.py b/xlml/apis/task.py index 6b07d471..6437c8be 100644 --- a/xlml/apis/task.py +++ b/xlml/apis/task.py @@ -85,9 +85,7 @@ def run_queued_resource_test( post_process and clean_up. """ - with TaskGroup( - group_id=task_test_config.benchmark_id, prefix_group_id=True - ) as test: + with TaskGroup(group_id=task_test_config.benchmark_id, prefix_group_id=True) as test: with TaskGroup(group_id="provision") as provision: with TaskGroup(group_id="initialize"): tpu_name = tpu.generate_tpu_name( @@ -161,9 +159,7 @@ class XpkTask(BaseTask): ] task_gcp_config: gcp_config.GCPConfig task_metric_config: Optional[metric_config.MetricConfig] = None - workload_provision_timeout: datetime.timedelta = datetime.timedelta( - minutes=300 - ) + workload_provision_timeout: datetime.timedelta = datetime.timedelta(minutes=300) def run( self, @@ -215,9 +211,7 @@ def run_with_run_name_generation(self, use_pathways: bool = False) -> DAGNode: with TaskGroup( group_id=self.task_test_config.benchmark_id, prefix_group_id=True ) as group: - run_name = name_format.generate_run_name( - self.task_test_config.benchmark_id - ) + run_name = name_format.generate_run_name(self.task_test_config.benchmark_id) tb_file_location = name_format.generate_tb_file_location( run_name, self.task_metric_config.tensorboard_summary.file_location ) @@ -229,9 +223,7 @@ def run_with_run_name_generation(self, use_pathways: bool = False) -> DAGNode: self.task_test_config.run_model_cmds = new_run_model_cmds # Update tensorboard file location - self.task_metric_config.tensorboard_summary.file_location = ( - tb_file_location - ) + self.task_metric_config.tensorboard_summary.file_location = tb_file_location ( run_name @@ -248,6 +240,7 @@ def run_model( use_vertex_tensorboard: bool = False, use_pathways: bool = False, ramdisk_directory: str = "", + mtc_enabled: bool = False, ) -> DAGNode: """Run the TPU/GPU test in `task_test_config` using xpk. @@ -274,6 +267,7 @@ def run_model( use_vertex_tensorboard, use_pathways, ramdisk_directory, + mtc_enabled, ) wait_for_workload_completion = xpk.wait_for_workload_completion.override( timeout=int(self.task_test_config.timeout.total_seconds()), @@ -306,12 +300,11 @@ def launch_workload( use_vertex_tensorboard: bool, use_pathways: bool = False, ramdisk_directory: str = "", + mtc_enabled: bool = False, ) -> DAGNode: """Create the workload and wait for it to provision.""" with TaskGroup(group_id="launch_workload") as group: - run_workload = xpk.run_workload.override( - owner=self.task_test_config.task_owner - )( + run_workload = xpk.run_workload.override(owner=self.task_test_config.task_owner)( task_id="run_workload", cluster_project=self.task_gcp_config.project_name, zone=self.task_gcp_config.zone, @@ -326,6 +319,7 @@ def launch_workload( use_vertex_tensorboard=use_vertex_tensorboard, use_pathways=use_pathways, ramdisk_directory=ramdisk_directory, + mtc_enabled=mtc_enabled, ) wait_for_workload_start = xpk.wait_for_workload_start.override( timeout=self.workload_provision_timeout.total_seconds() @@ -411,9 +405,7 @@ def run(self) -> DAGNode: self.task_metric_config and self.task_metric_config.use_runtime_generated_gcs_folder ): - env_variable = { - f"{metric_config.SshEnvVars.GCS_OUTPUT.name}": gcs_location - } + env_variable = {f"{metric_config.SshEnvVars.GCS_OUTPUT.name}": gcs_location} else: env_variable = None run_model = self.run_model(ip_address, ssh_keys, env_variable) @@ -445,9 +437,7 @@ def run_with_existing_instance(self) -> DAGNode: self.task_metric_config and self.task_metric_config.use_runtime_generated_gcs_folder ): - env_variable = { - f"{metric_config.SshEnvVars.GCS_OUTPUT.name}": gcs_location - } + env_variable = {f"{metric_config.SshEnvVars.GCS_OUTPUT.name}": gcs_location} else: env_variable = None post_process = self.post_process(gcs_location) @@ -458,7 +448,12 @@ def run_with_existing_instance(self) -> DAGNode: def provision_via_existing_instance( self, - ) -> Tuple[DAGNode, airflow.XComArg, airflow.XComArg, airflow.XComArg,]: + ) -> Tuple[ + DAGNode, + airflow.XComArg, + airflow.XComArg, + airflow.XComArg, + ]: """Provision an existing GPU accelerator. Returns: @@ -575,9 +570,7 @@ def post_process( ) return group - def clean_up( - self, resource: airflow.XComArg, project_id: str, zone: str - ) -> DAGNode: + def clean_up(self, resource: airflow.XComArg, project_id: str, zone: str) -> DAGNode: """Clean up GPU resources created by `provision`. Args: @@ -590,9 +583,7 @@ def clean_up( Raises: AirflowTaskTimeout: An error occurs when execution_timeout is breached. """ - return gpu.delete_resource.override(group_id="clean_up")( - resource, project_id, zone - ) + return gpu.delete_resource.override(group_id="clean_up")(resource, project_id, zone) def clean_up_existing_instance(self, ssh_keys: airflow.XComArg) -> DAGNode: """Clean up existing GPU resources - remove the one-time use generated ssh_keys. @@ -655,9 +646,7 @@ def run(self) -> DAGNode: gcs_location >> gke_run >> post_process return group - def post_process( - self, result_location: Optional[airflow.XComArg] = None - ) -> DAGNode: + def post_process(self, result_location: Optional[airflow.XComArg] = None) -> DAGNode: """Process metrics and metadata, and insert them into BigQuery tables. Returns: @@ -688,9 +677,7 @@ def _get_job_manifest(self): }, }, "spec": { - "activeDeadlineSeconds": int( - self.task_test_config.timeout.total_seconds() - ) + "activeDeadlineSeconds": int(self.task_test_config.timeout.total_seconds()) or 3600, "backoffLimit": 0, "completionMode": "Indexed", @@ -713,12 +700,8 @@ def _get_job_manifest(self): "name": "main", "image": self.task_test_config.docker_image, "imagePullPolicy": "Always", - "command": shlex.split( - self.task_test_config.setup_script - ), - "args": shlex.split( - self.task_test_config.test_script - ), + "command": shlex.split(self.task_test_config.setup_script), + "args": shlex.split(self.task_test_config.test_script), "resources": { "limits": { "nvidia.com/gpu": accelerator.count, @@ -728,17 +711,13 @@ def _get_job_manifest(self): { "name": "POD_NAME", "valueFrom": { - "fieldRef": { - "fieldPath": "metadata.name" - } + "fieldRef": {"fieldPath": "metadata.name"} }, }, { "name": "POD_NAMESPACE", "valueFrom": { - "fieldRef": { - "fieldPath": "metadata.namespace" - } + "fieldRef": {"fieldPath": "metadata.namespace"} }, }, { diff --git a/xlml/utils/xpk.py b/xlml/utils/xpk.py index d41f940c..f17cf6f0 100644 --- a/xlml/utils/xpk.py +++ b/xlml/utils/xpk.py @@ -77,6 +77,7 @@ def run_workload( use_vertex_tensorboard: bool = False, use_pathways: bool = False, ramdisk_directory: str = "", # Directory for enabling emergency checkpointing + mtc_enabled: bool = False, # It enables MTC phase-2 drivers ): """Run workload through xpk tool.""" @@ -103,6 +104,8 @@ def run_workload( ) if ramdisk_directory: workload_create_cmd += f" --ramdisk-directory={ramdisk_directory}" + if mtc_enabled: + workload_create_cmd += " --mtc-enabled" cmds = get_xpk_setup_cmd(tmpdir) if accelerator_type == GpuVersion.XPK_H100_MEGA.value: workload_create_cmd += " --scheduler=gke.io/topology-aware-auto" @@ -118,9 +121,7 @@ def run_workload( ["bash", "-c", ";".join(cmds)], env={**os.environ, "KUBECONFIG": os.path.join(tmpdir, "xpk.conf")}, ) - assert ( - result.exit_code == 0 - ), f"XPK command failed with code {result.exit_code}" + assert result.exit_code == 0, f"XPK command failed with code {result.exit_code}" def _get_core_api_client( @@ -155,9 +156,7 @@ def _get_batch_api_client( # Initilize the client batch_api = k8s_client.BatchV1Api(client) - logging.info( - "Successful initilize k8s batch api client from cluster response." - ) + logging.info("Successful initilize k8s batch api client from cluster response.") return batch_api @@ -210,9 +209,7 @@ def wait_for_workload_completion( batch_api = _get_batch_api_client(project_id, region, cluster_name) job = _get_workload_job(batch_api, workload_id) if job is None: - logging.info( - f"No pods or jobs were found for workload selector: {workload_id}" - ) + logging.info(f"No pods or jobs were found for workload selector: {workload_id}") return False if any(condition.type == "Failed" for condition in job.status.conditions): @@ -280,6 +277,4 @@ def clean_up_workload( ["bash", "-c", ";".join(cmds)], env={**os.environ, "KUBECONFIG": os.path.join(tmpdir, "xpk.conf")}, ) - assert ( - result.exit_code == 0 - ), f"XPK clean-up failed with code {result.exit_code}" + assert result.exit_code == 0, f"XPK clean-up failed with code {result.exit_code}"