Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for MTC phase-2 drivers in XLML tests #651

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 25 additions & 46 deletions xlml/apis/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,7 @@ def run_queued_resource_test(
post_process and clean_up.
"""

with TaskGroup(
group_id=task_test_config.benchmark_id, prefix_group_id=True
) as test:
with TaskGroup(group_id=task_test_config.benchmark_id, prefix_group_id=True) as test:
with TaskGroup(group_id="provision") as provision:
with TaskGroup(group_id="initialize"):
tpu_name = tpu.generate_tpu_name(
Expand Down Expand Up @@ -161,9 +159,7 @@ class XpkTask(BaseTask):
]
task_gcp_config: gcp_config.GCPConfig
task_metric_config: Optional[metric_config.MetricConfig] = None
workload_provision_timeout: datetime.timedelta = datetime.timedelta(
minutes=300
)
workload_provision_timeout: datetime.timedelta = datetime.timedelta(minutes=300)

def run(
self,
Expand Down Expand Up @@ -215,9 +211,7 @@ def run_with_run_name_generation(self, use_pathways: bool = False) -> DAGNode:
with TaskGroup(
group_id=self.task_test_config.benchmark_id, prefix_group_id=True
) as group:
run_name = name_format.generate_run_name(
self.task_test_config.benchmark_id
)
run_name = name_format.generate_run_name(self.task_test_config.benchmark_id)
tb_file_location = name_format.generate_tb_file_location(
run_name, self.task_metric_config.tensorboard_summary.file_location
)
Expand All @@ -229,9 +223,7 @@ def run_with_run_name_generation(self, use_pathways: bool = False) -> DAGNode:
self.task_test_config.run_model_cmds = new_run_model_cmds

# Update tensorboard file location
self.task_metric_config.tensorboard_summary.file_location = (
tb_file_location
)
self.task_metric_config.tensorboard_summary.file_location = tb_file_location

(
run_name
Expand All @@ -248,6 +240,7 @@ def run_model(
use_vertex_tensorboard: bool = False,
use_pathways: bool = False,
ramdisk_directory: str = "",
mtc_enabled: bool = False,
) -> DAGNode:
"""Run the TPU/GPU test in `task_test_config` using xpk.

Expand All @@ -274,6 +267,7 @@ def run_model(
use_vertex_tensorboard,
use_pathways,
ramdisk_directory,
mtc_enabled,
)
wait_for_workload_completion = xpk.wait_for_workload_completion.override(
timeout=int(self.task_test_config.timeout.total_seconds()),
Expand Down Expand Up @@ -306,12 +300,11 @@ def launch_workload(
use_vertex_tensorboard: bool,
use_pathways: bool = False,
ramdisk_directory: str = "",
mtc_enabled: bool = False,
) -> DAGNode:
"""Create the workload and wait for it to provision."""
with TaskGroup(group_id="launch_workload") as group:
run_workload = xpk.run_workload.override(
owner=self.task_test_config.task_owner
)(
run_workload = xpk.run_workload.override(owner=self.task_test_config.task_owner)(
task_id="run_workload",
cluster_project=self.task_gcp_config.project_name,
zone=self.task_gcp_config.zone,
Expand All @@ -326,6 +319,7 @@ def launch_workload(
use_vertex_tensorboard=use_vertex_tensorboard,
use_pathways=use_pathways,
ramdisk_directory=ramdisk_directory,
mtc_enabled=mtc_enabled,
)
wait_for_workload_start = xpk.wait_for_workload_start.override(
timeout=self.workload_provision_timeout.total_seconds()
Expand Down Expand Up @@ -411,9 +405,7 @@ def run(self) -> DAGNode:
self.task_metric_config
and self.task_metric_config.use_runtime_generated_gcs_folder
):
env_variable = {
f"{metric_config.SshEnvVars.GCS_OUTPUT.name}": gcs_location
}
env_variable = {f"{metric_config.SshEnvVars.GCS_OUTPUT.name}": gcs_location}
else:
env_variable = None
run_model = self.run_model(ip_address, ssh_keys, env_variable)
Expand Down Expand Up @@ -445,9 +437,7 @@ def run_with_existing_instance(self) -> DAGNode:
self.task_metric_config
and self.task_metric_config.use_runtime_generated_gcs_folder
):
env_variable = {
f"{metric_config.SshEnvVars.GCS_OUTPUT.name}": gcs_location
}
env_variable = {f"{metric_config.SshEnvVars.GCS_OUTPUT.name}": gcs_location}
else:
env_variable = None
post_process = self.post_process(gcs_location)
Expand All @@ -458,7 +448,12 @@ def run_with_existing_instance(self) -> DAGNode:

def provision_via_existing_instance(
self,
) -> Tuple[DAGNode, airflow.XComArg, airflow.XComArg, airflow.XComArg,]:
) -> Tuple[
DAGNode,
airflow.XComArg,
airflow.XComArg,
airflow.XComArg,
]:
"""Provision an existing GPU accelerator.

Returns:
Expand Down Expand Up @@ -575,9 +570,7 @@ def post_process(
)
return group

def clean_up(
self, resource: airflow.XComArg, project_id: str, zone: str
) -> DAGNode:
def clean_up(self, resource: airflow.XComArg, project_id: str, zone: str) -> DAGNode:
"""Clean up GPU resources created by `provision`.

Args:
Expand All @@ -590,9 +583,7 @@ def clean_up(
Raises:
AirflowTaskTimeout: An error occurs when execution_timeout is breached.
"""
return gpu.delete_resource.override(group_id="clean_up")(
resource, project_id, zone
)
return gpu.delete_resource.override(group_id="clean_up")(resource, project_id, zone)

def clean_up_existing_instance(self, ssh_keys: airflow.XComArg) -> DAGNode:
"""Clean up existing GPU resources - remove the one-time use generated ssh_keys.
Expand Down Expand Up @@ -655,9 +646,7 @@ def run(self) -> DAGNode:
gcs_location >> gke_run >> post_process
return group

def post_process(
self, result_location: Optional[airflow.XComArg] = None
) -> DAGNode:
def post_process(self, result_location: Optional[airflow.XComArg] = None) -> DAGNode:
"""Process metrics and metadata, and insert them into BigQuery tables.

Returns:
Expand Down Expand Up @@ -688,9 +677,7 @@ def _get_job_manifest(self):
},
},
"spec": {
"activeDeadlineSeconds": int(
self.task_test_config.timeout.total_seconds()
)
"activeDeadlineSeconds": int(self.task_test_config.timeout.total_seconds())
or 3600,
"backoffLimit": 0,
"completionMode": "Indexed",
Expand All @@ -713,12 +700,8 @@ def _get_job_manifest(self):
"name": "main",
"image": self.task_test_config.docker_image,
"imagePullPolicy": "Always",
"command": shlex.split(
self.task_test_config.setup_script
),
"args": shlex.split(
self.task_test_config.test_script
),
"command": shlex.split(self.task_test_config.setup_script),
"args": shlex.split(self.task_test_config.test_script),
"resources": {
"limits": {
"nvidia.com/gpu": accelerator.count,
Expand All @@ -728,17 +711,13 @@ def _get_job_manifest(self):
{
"name": "POD_NAME",
"valueFrom": {
"fieldRef": {
"fieldPath": "metadata.name"
}
"fieldRef": {"fieldPath": "metadata.name"}
},
},
{
"name": "POD_NAMESPACE",
"valueFrom": {
"fieldRef": {
"fieldPath": "metadata.namespace"
}
"fieldRef": {"fieldPath": "metadata.namespace"}
},
},
{
Expand Down
19 changes: 7 additions & 12 deletions xlml/utils/xpk.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def run_workload(
use_vertex_tensorboard: bool = False,
use_pathways: bool = False,
ramdisk_directory: str = "", # Directory for enabling emergency checkpointing
mtc_enabled: bool = False, # It enables MTC phase-2 drivers
):
"""Run workload through xpk tool."""

Expand All @@ -103,6 +104,8 @@ def run_workload(
)
if ramdisk_directory:
workload_create_cmd += f" --ramdisk-directory={ramdisk_directory}"
if mtc_enabled:
workload_create_cmd += " --mtc-enabled"
cmds = get_xpk_setup_cmd(tmpdir)
if accelerator_type == GpuVersion.XPK_H100_MEGA.value:
workload_create_cmd += " --scheduler=gke.io/topology-aware-auto"
Expand All @@ -118,9 +121,7 @@ def run_workload(
["bash", "-c", ";".join(cmds)],
env={**os.environ, "KUBECONFIG": os.path.join(tmpdir, "xpk.conf")},
)
assert (
result.exit_code == 0
), f"XPK command failed with code {result.exit_code}"
assert result.exit_code == 0, f"XPK command failed with code {result.exit_code}"


def _get_core_api_client(
Expand Down Expand Up @@ -155,9 +156,7 @@ def _get_batch_api_client(

# Initilize the client
batch_api = k8s_client.BatchV1Api(client)
logging.info(
"Successful initilize k8s batch api client from cluster response."
)
logging.info("Successful initilize k8s batch api client from cluster response.")
return batch_api


Expand Down Expand Up @@ -210,9 +209,7 @@ def wait_for_workload_completion(
batch_api = _get_batch_api_client(project_id, region, cluster_name)
job = _get_workload_job(batch_api, workload_id)
if job is None:
logging.info(
f"No pods or jobs were found for workload selector: {workload_id}"
)
logging.info(f"No pods or jobs were found for workload selector: {workload_id}")
return False

if any(condition.type == "Failed" for condition in job.status.conditions):
Expand Down Expand Up @@ -280,6 +277,4 @@ def clean_up_workload(
["bash", "-c", ";".join(cmds)],
env={**os.environ, "KUBECONFIG": os.path.join(tmpdir, "xpk.conf")},
)
assert (
result.exit_code == 0
), f"XPK clean-up failed with code {result.exit_code}"
assert result.exit_code == 0, f"XPK clean-up failed with code {result.exit_code}"
Loading