Skip to content

Commit a7e7857

Browse files
committed
benchmark sample run
1 parent e873cde commit a7e7857

File tree

3 files changed

+502
-8
lines changed

3 files changed

+502
-8
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Sample job to run Aotc reproducibility benchmarks."""
16+
import sys
17+
import os
18+
19+
script_dir = os.path.dirname(os.path.abspath(__file__))
20+
project_root = os.path.abspath(os.path.join(script_dir, "..", "..", ".."))
21+
22+
print(f"Script directory: {script_dir}")
23+
print(f"Project root: {project_root}")
24+
25+
if project_root not in sys.path:
26+
sys.path.insert(0, project_root)
27+
28+
import datetime
29+
from dags.map_reproducibility.utils.constants import Image
30+
from dags.map_reproducibility.internal_runs.dag_configs import DAG_CONFIGS_ULTRA
31+
from dags.map_reproducibility.utils.sample_workload_utils import run_internal_sample_aotc_workload
32+
33+
utc_date = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d")
34+
NIGHTLY_IMAGE = f"{Image.MAXTEXT_JAX_STABLE_NIGHTLY}:{utc_date}"
35+
RELEASE_IMAGE = f"{Image.MAXTEXT_JAX_STABLE_RELEASE}:{utc_date}"
36+
RELEASE_IMAGE = f"{Image.MAXTEXT_JAX_STABLE_RELEASE}:2025-04-17"
37+
SAMPLE_RUN_BUCKET_NAME = "yujunzou-dev-supercomputer-testing"
38+
39+
40+
# Setup configuratio
41+
relative_config_yaml_path = (
42+
"recipes/a3ultra/a3ultra_llama3.1-8b_8gpus_bf16_maxtext.yaml"
43+
)
44+
config_name = relative_config_yaml_path.replace(".yaml", "")
45+
timeout = DAG_CONFIGS_ULTRA[relative_config_yaml_path]["timeout_minutes"]
46+
base_recipe_repo_root = f"{project_root}/../internal-gpu-recipes"
47+
48+
run_internal_sample_aotc_workload(
49+
relative_config_yaml_path=relative_config_yaml_path,
50+
base_recipe_repo_root=base_recipe_repo_root,
51+
timeout=timeout,
52+
image_version=RELEASE_IMAGE,
53+
sample_run_bucket_name=SAMPLE_RUN_BUCKET_NAME,
54+
)

dags/map_reproducibility/utils/common_utils.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
import string
2323
import time
2424
import subprocess
25+
import getpass
2526

26-
from google.cloud import storage
2727
from airflow.decorators import task
2828
from airflow.hooks.subprocess import SubprocessHook
2929
from xlml.utils import metric
@@ -115,11 +115,14 @@ def get_internal_pre_workload_cmds(job_name):
115115
return prepare_workload_cmds
116116

117117

118-
def get_internal_pre_workload_job_name(model_id, framework):
118+
def get_internal_pre_workload_job_name(model_id, framework, is_sample_run=True):
119119
helm_model_id = model_id.replace(".", "-")
120120
random_id = "".join(random.choices(string.ascii_lowercase, k=4))
121121
now = int(time.time())
122122
job_name = f"coreml-{helm_model_id}-{now}-{framework}-{random_id}"
123+
if is_sample_run:
124+
job_name = f"{getpass.getuser()}-{job_name}"
125+
print(f"NAME: {job_name}")
123126
return job_name
124127

125128

@@ -207,16 +210,17 @@ def helm_apply_cmds_internal_run(
207210
kueue_name: str = "a3-ultra",
208211
additional_cmds: str = "",
209212
test_run=False,
213+
bucket_name=BUCKET_NAME,
210214
):
211215
gcs_cmd = ""
212216
if framework == "maxtext":
213-
gcs_cmd += f" --set volumes.gcsMounts[0].bucketName={BUCKET_NAME} "
217+
gcs_cmd += f" --set volumes.gcsMounts[0].bucketName={bucket_name} "
214218

215219
if hypercomputer == "a3ultra":
216220
if framework != "maxtext":
217221
gcs_cmd += f" --set queue={kueue_name} "
218222
else:
219-
gcs_cmd += f" --set workload.gcsBucketForDataCataPath={BUCKET_NAME} "
223+
gcs_cmd += f" --set workload.gcsBucketForDataCataPath={bucket_name} "
220224

221225
cluster_cmd = ""
222226
if framework == "nemo" and hypercomputer == "a3ultra":
@@ -230,8 +234,9 @@ def helm_apply_cmds_internal_run(
230234
if aotc:
231235
set_aotc = " --set-string workload.aotc=true "
232236

233-
if test_run:
234-
helm_template_path = f"/home/airflow/gcs/dags/dags/map_reproducibility/helm-charts/{hypercomputer}/{framework}-training"
237+
local_helm_template_path = f"/home/airflow/gcs/dags/dags/map_reproducibility/helm-charts/{hypercomputer}/{framework}-training"
238+
if test_run and os.path.exists(local_helm_template_path):
239+
helm_template_path = local_helm_template_path
235240
else:
236241
helm_template_path = f"{recipe_repo_root}/src/helm-charts/{hypercomputer}/{framework}-training"
237242

@@ -321,8 +326,8 @@ def copy_bucket_cmds_nemo(recipe_repo_root, hypercomputer: str = "a3mega"):
321326
return copy_bucket_contents
322327

323328

324-
def copy_bucket_cmds_maxtext(tmpdir, recipe_repo_root):
325-
gcs_location = f"gs://{BUCKET_NAME}/maxtext/"
329+
def copy_bucket_cmds_maxtext(tmpdir, bucket_name=BUCKET_NAME):
330+
gcs_location = f"gs://{bucket_name}/maxtext/"
326331

327332
cmds = (
328333
f"METRICS_FILE={tmpdir}/tflog/metrics",
@@ -704,6 +709,7 @@ def parse_internal_config_filename(filename, config=None):
704709
return config
705710

706711

712+
@staticmethod
707713
def parse_internal_config_content(yaml_path, config=None):
708714
"""
709715
Parse the internal content of a config YAML file and update the existing config.

0 commit comments

Comments
 (0)