@@ -85,9 +85,7 @@ def run_queued_resource_test(
85
85
post_process and clean_up.
86
86
"""
87
87
88
- with TaskGroup (
89
- group_id = task_test_config .benchmark_id , prefix_group_id = True
90
- ) as test :
88
+ with TaskGroup (group_id = task_test_config .benchmark_id , prefix_group_id = True ) as test :
91
89
with TaskGroup (group_id = "provision" ) as provision :
92
90
with TaskGroup (group_id = "initialize" ):
93
91
tpu_name = tpu .generate_tpu_name (
@@ -161,9 +159,7 @@ class XpkTask(BaseTask):
161
159
]
162
160
task_gcp_config : gcp_config .GCPConfig
163
161
task_metric_config : Optional [metric_config .MetricConfig ] = None
164
- workload_provision_timeout : datetime .timedelta = datetime .timedelta (
165
- minutes = 300
166
- )
162
+ workload_provision_timeout : datetime .timedelta = datetime .timedelta (minutes = 300 )
167
163
168
164
def run (
169
165
self ,
@@ -215,9 +211,7 @@ def run_with_run_name_generation(self, use_pathways: bool = False) -> DAGNode:
215
211
with TaskGroup (
216
212
group_id = self .task_test_config .benchmark_id , prefix_group_id = True
217
213
) as group :
218
- run_name = name_format .generate_run_name (
219
- self .task_test_config .benchmark_id
220
- )
214
+ run_name = name_format .generate_run_name (self .task_test_config .benchmark_id )
221
215
tb_file_location = name_format .generate_tb_file_location (
222
216
run_name , self .task_metric_config .tensorboard_summary .file_location
223
217
)
@@ -229,9 +223,7 @@ def run_with_run_name_generation(self, use_pathways: bool = False) -> DAGNode:
229
223
self .task_test_config .run_model_cmds = new_run_model_cmds
230
224
231
225
# Update tensorboard file location
232
- self .task_metric_config .tensorboard_summary .file_location = (
233
- tb_file_location
234
- )
226
+ self .task_metric_config .tensorboard_summary .file_location = tb_file_location
235
227
236
228
(
237
229
run_name
@@ -248,6 +240,7 @@ def run_model(
248
240
use_vertex_tensorboard : bool = False ,
249
241
use_pathways : bool = False ,
250
242
ramdisk_directory : str = "" ,
243
+ mtc_enabled : bool = False ,
251
244
) -> DAGNode :
252
245
"""Run the TPU/GPU test in `task_test_config` using xpk.
253
246
@@ -274,6 +267,7 @@ def run_model(
274
267
use_vertex_tensorboard ,
275
268
use_pathways ,
276
269
ramdisk_directory ,
270
+ mtc_enabled ,
277
271
)
278
272
wait_for_workload_completion = xpk .wait_for_workload_completion .override (
279
273
timeout = int (self .task_test_config .timeout .total_seconds ()),
@@ -306,12 +300,11 @@ def launch_workload(
306
300
use_vertex_tensorboard : bool ,
307
301
use_pathways : bool = False ,
308
302
ramdisk_directory : str = "" ,
303
+ mtc_enabled : bool = False ,
309
304
) -> DAGNode :
310
305
"""Create the workload and wait for it to provision."""
311
306
with TaskGroup (group_id = "launch_workload" ) as group :
312
- run_workload = xpk .run_workload .override (
313
- owner = self .task_test_config .task_owner
314
- )(
307
+ run_workload = xpk .run_workload .override (owner = self .task_test_config .task_owner )(
315
308
task_id = "run_workload" ,
316
309
cluster_project = self .task_gcp_config .project_name ,
317
310
zone = self .task_gcp_config .zone ,
@@ -326,6 +319,7 @@ def launch_workload(
326
319
use_vertex_tensorboard = use_vertex_tensorboard ,
327
320
use_pathways = use_pathways ,
328
321
ramdisk_directory = ramdisk_directory ,
322
+ mtc_enabled = mtc_enabled ,
329
323
)
330
324
wait_for_workload_start = xpk .wait_for_workload_start .override (
331
325
timeout = self .workload_provision_timeout .total_seconds ()
@@ -411,9 +405,7 @@ def run(self) -> DAGNode:
411
405
self .task_metric_config
412
406
and self .task_metric_config .use_runtime_generated_gcs_folder
413
407
):
414
- env_variable = {
415
- f"{ metric_config .SshEnvVars .GCS_OUTPUT .name } " : gcs_location
416
- }
408
+ env_variable = {f"{ metric_config .SshEnvVars .GCS_OUTPUT .name } " : gcs_location }
417
409
else :
418
410
env_variable = None
419
411
run_model = self .run_model (ip_address , ssh_keys , env_variable )
@@ -445,9 +437,7 @@ def run_with_existing_instance(self) -> DAGNode:
445
437
self .task_metric_config
446
438
and self .task_metric_config .use_runtime_generated_gcs_folder
447
439
):
448
- env_variable = {
449
- f"{ metric_config .SshEnvVars .GCS_OUTPUT .name } " : gcs_location
450
- }
440
+ env_variable = {f"{ metric_config .SshEnvVars .GCS_OUTPUT .name } " : gcs_location }
451
441
else :
452
442
env_variable = None
453
443
post_process = self .post_process (gcs_location )
@@ -458,7 +448,12 @@ def run_with_existing_instance(self) -> DAGNode:
458
448
459
449
def provision_via_existing_instance (
460
450
self ,
461
- ) -> Tuple [DAGNode , airflow .XComArg , airflow .XComArg , airflow .XComArg ,]:
451
+ ) -> Tuple [
452
+ DAGNode ,
453
+ airflow .XComArg ,
454
+ airflow .XComArg ,
455
+ airflow .XComArg ,
456
+ ]:
462
457
"""Provision an existing GPU accelerator.
463
458
464
459
Returns:
@@ -575,9 +570,7 @@ def post_process(
575
570
)
576
571
return group
577
572
578
- def clean_up (
579
- self , resource : airflow .XComArg , project_id : str , zone : str
580
- ) -> DAGNode :
573
+ def clean_up (self , resource : airflow .XComArg , project_id : str , zone : str ) -> DAGNode :
581
574
"""Clean up GPU resources created by `provision`.
582
575
583
576
Args:
@@ -590,9 +583,7 @@ def clean_up(
590
583
Raises:
591
584
AirflowTaskTimeout: An error occurs when execution_timeout is breached.
592
585
"""
593
- return gpu .delete_resource .override (group_id = "clean_up" )(
594
- resource , project_id , zone
595
- )
586
+ return gpu .delete_resource .override (group_id = "clean_up" )(resource , project_id , zone )
596
587
597
588
def clean_up_existing_instance (self , ssh_keys : airflow .XComArg ) -> DAGNode :
598
589
"""Clean up existing GPU resources - remove the one-time use generated ssh_keys.
@@ -655,9 +646,7 @@ def run(self) -> DAGNode:
655
646
gcs_location >> gke_run >> post_process
656
647
return group
657
648
658
- def post_process (
659
- self , result_location : Optional [airflow .XComArg ] = None
660
- ) -> DAGNode :
649
+ def post_process (self , result_location : Optional [airflow .XComArg ] = None ) -> DAGNode :
661
650
"""Process metrics and metadata, and insert them into BigQuery tables.
662
651
663
652
Returns:
@@ -688,9 +677,7 @@ def _get_job_manifest(self):
688
677
},
689
678
},
690
679
"spec" : {
691
- "activeDeadlineSeconds" : int (
692
- self .task_test_config .timeout .total_seconds ()
693
- )
680
+ "activeDeadlineSeconds" : int (self .task_test_config .timeout .total_seconds ())
694
681
or 3600 ,
695
682
"backoffLimit" : 0 ,
696
683
"completionMode" : "Indexed" ,
@@ -713,12 +700,8 @@ def _get_job_manifest(self):
713
700
"name" : "main" ,
714
701
"image" : self .task_test_config .docker_image ,
715
702
"imagePullPolicy" : "Always" ,
716
- "command" : shlex .split (
717
- self .task_test_config .setup_script
718
- ),
719
- "args" : shlex .split (
720
- self .task_test_config .test_script
721
- ),
703
+ "command" : shlex .split (self .task_test_config .setup_script ),
704
+ "args" : shlex .split (self .task_test_config .test_script ),
722
705
"resources" : {
723
706
"limits" : {
724
707
"nvidia.com/gpu" : accelerator .count ,
@@ -728,17 +711,13 @@ def _get_job_manifest(self):
728
711
{
729
712
"name" : "POD_NAME" ,
730
713
"valueFrom" : {
731
- "fieldRef" : {
732
- "fieldPath" : "metadata.name"
733
- }
714
+ "fieldRef" : {"fieldPath" : "metadata.name" }
734
715
},
735
716
},
736
717
{
737
718
"name" : "POD_NAMESPACE" ,
738
719
"valueFrom" : {
739
- "fieldRef" : {
740
- "fieldPath" : "metadata.namespace"
741
- }
720
+ "fieldRef" : {"fieldPath" : "metadata.namespace" }
742
721
},
743
722
},
744
723
{
0 commit comments