22
22
import string
23
23
import time
24
24
import subprocess
25
+ import getpass
25
26
26
- from google .cloud import storage
27
27
from airflow .decorators import task
28
28
from airflow .hooks .subprocess import SubprocessHook
29
29
from xlml .utils import metric
@@ -115,11 +115,14 @@ def get_internal_pre_workload_cmds(job_name):
115
115
return prepare_workload_cmds
116
116
117
117
118
- def get_internal_pre_workload_job_name (model_id , framework ):
118
+ def get_internal_pre_workload_job_name (model_id , framework , is_sample_run = False ):
119
119
helm_model_id = model_id .replace ("." , "-" )
120
120
random_id = "" .join (random .choices (string .ascii_lowercase , k = 4 ))
121
121
now = int (time .time ())
122
122
job_name = f"coreml-{ helm_model_id } -{ now } -{ framework } -{ random_id } "
123
+ if is_sample_run :
124
+ job_name = f"{ getpass .getuser ()} -{ job_name } "
125
+ print (f"NAME: { job_name } " )
123
126
return job_name
124
127
125
128
@@ -207,16 +210,17 @@ def helm_apply_cmds_internal_run(
207
210
kueue_name : str = "a3-ultra" ,
208
211
additional_cmds : str = "" ,
209
212
test_run = False ,
213
+ bucket_name = BUCKET_NAME ,
210
214
):
211
215
gcs_cmd = ""
212
216
if framework == "maxtext" :
213
- gcs_cmd += f" --set volumes.gcsMounts[0].bucketName={ BUCKET_NAME } "
217
+ gcs_cmd += f" --set volumes.gcsMounts[0].bucketName={ bucket_name } "
214
218
215
219
if hypercomputer == "a3ultra" :
216
220
if framework != "maxtext" :
217
221
gcs_cmd += f" --set queue={ kueue_name } "
218
222
else :
219
- gcs_cmd += f" --set workload.gcsBucketForDataCataPath={ BUCKET_NAME } "
223
+ gcs_cmd += f" --set workload.gcsBucketForDataCataPath={ bucket_name } "
220
224
221
225
cluster_cmd = ""
222
226
if framework == "nemo" and hypercomputer == "a3ultra" :
@@ -230,8 +234,9 @@ def helm_apply_cmds_internal_run(
230
234
if aotc :
231
235
set_aotc = " --set-string workload.aotc=true "
232
236
233
- if test_run :
234
- helm_template_path = f"/home/airflow/gcs/dags/dags/map_reproducibility/helm-charts/{ hypercomputer } /{ framework } -training"
237
+ local_helm_template_path = f"/home/airflow/gcs/dags/dags/map_reproducibility/helm-charts/{ hypercomputer } /{ framework } -training"
238
+ if test_run and os .path .exists (local_helm_template_path ):
239
+ helm_template_path = local_helm_template_path
235
240
else :
236
241
helm_template_path = f"{ recipe_repo_root } /src/helm-charts/{ hypercomputer } /{ framework } -training"
237
242
@@ -321,8 +326,8 @@ def copy_bucket_cmds_nemo(recipe_repo_root, hypercomputer: str = "a3mega"):
321
326
return copy_bucket_contents
322
327
323
328
324
- def copy_bucket_cmds_maxtext (tmpdir , recipe_repo_root ):
325
- gcs_location = f"gs://{ BUCKET_NAME } /maxtext/"
329
+ def copy_bucket_cmds_maxtext (tmpdir , bucket_name = BUCKET_NAME ):
330
+ gcs_location = f"gs://{ bucket_name } /maxtext/"
326
331
327
332
cmds = (
328
333
f"METRICS_FILE={ tmpdir } /tflog/metrics" ,
@@ -704,6 +709,7 @@ def parse_internal_config_filename(filename, config=None):
704
709
return config
705
710
706
711
712
+ @staticmethod
707
713
def parse_internal_config_content (yaml_path , config = None ):
708
714
"""
709
715
Parse the internal content of a config YAML file and update the existing config.
0 commit comments