Skip to content

Commit ad2e9eb

Browse files
authored
stable release runs (#655)
1 parent af480b8 commit ad2e9eb

File tree

4 files changed

+127
-42
lines changed

4 files changed

+127
-42
lines changed

dags/map_reproducibility/internal_runs/a3ultra_maxtext_benchmarking_dags.py

+87-35
Original file line numberDiff line numberDiff line change
@@ -15,66 +15,118 @@
1515
"""DAGs to run Aotc reproducibility benchmarks."""
1616

1717
import datetime
18+
import os
1819

1920
from airflow import models
20-
from dags.map_reproducibility.utils.constants import Schedule
21+
from dags.map_reproducibility.utils.constants import Schedule, Image
2122
from dags.map_reproducibility.utils.internal_aotc_workload import run_internal_aotc_workload
2223

2324

25+
# Configuration parameters
2426
TEST_RUN = False
2527
TURN_ON_SCHEDULE = True
2628
BACKFILL = False
2729

28-
# List of configuration setups as a dictionary with schedule times
29-
config_yamls = {
30-
# a3ultra_llama3.1-8b
31-
"recipes/a3ultra/a3ultra_llama3.1-8b_8gpus_bf16_maxtext.yaml": Schedule.DAILY_PST_6PM_EXCEPT_THURSDAY, # < 10mins
32-
"recipes/a3ultra/a3ultra_llama3.1-8b_8gpus_fp8_maxtext.yaml": Schedule.DAILY_PST_6PM_EXCEPT_THURSDAY,
33-
"recipes/a3ultra/a3ultra_llama3.1-8b_16gpus_bf16_maxtext.yaml": Schedule.DAILY_PST_6PM_EXCEPT_THURSDAY,
34-
"recipes/a3ultra/a3ultra_llama3.1-8b_16gpus_fp8_maxtext.yaml": Schedule.DAILY_PST_6PM_EXCEPT_THURSDAY,
35-
# a3ultra_mixtral-8x7
36-
"recipes/a3ultra/a3ultra_mixtral-8x7b_8gpus_bf16_maxtext.yaml": Schedule.DAILY_PST_6PM_EXCEPT_THURSDAY,
37-
"recipes/a3ultra/a3ultra_mixtral-8x7b_16gpus_bf16_maxtext.yaml": Schedule.DAILY_PST_6PM_EXCEPT_THURSDAY,
38-
# a3ultra_llama3.1-70b
39-
"recipes/a3ultra/a3ultra_llama3.1-70b_256gpus_bf16_maxtext.yaml": Schedule.DAILY_PST_6_30PM_EXCEPT_THURSDAY, # ~10min
40-
"recipes/a3ultra/a3ultra_llama3.1-70b_256gpus_fp8_maxtext.yaml": Schedule.DAILY_PST_6_30PM_EXCEPT_THURSDAY,
41-
# a3ultra_llama3.1-405b
42-
"recipes/a3ultra/a3ultra_llama3.1-405b_256gpus_fp8_maxtext.yaml": Schedule.DAILY_PST_7PM_EXCEPT_THURSDAY, # ~30mins
43-
"recipes/a3ultra/a3ultra_llama3.1-405b_256gpus_bf16_maxtext.yaml": Schedule.DAILY_PST_7_30PM_EXCEPT_THURSDAY, # ~30mins
44-
# Add more config paths as needed
45-
}
30+
# Get current date for image tags
31+
utc_date = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d")
32+
NIGHTLY_IMAGE = f"{Image.MAXTEXT_JAX_STABLE_NIGHTLY}:{utc_date}"
33+
RELEASE_IMAGE = f"{Image.MAXTEXT_JAX_STABLE_RELEASE}:{utc_date}"
4634

47-
# Define common tags
48-
common_tags = [
35+
# Common DAG tags
36+
DAG_TAGS = [
4937
"reproducibility",
5038
"experimental",
5139
"xlml",
52-
"v1.15",
40+
"v1.16",
5341
"internal",
5442
"regressiontests",
5543
"a3ultra",
5644
]
5745

58-
# Create a DAG for each config
59-
for relative_config_yaml_path, schedule_time in config_yamls.items():
46+
# Model configurations with schedule and timeout settings
47+
MODEL_CONFIGS = {
48+
# a3ultra_llama3.1-8b
49+
"recipes/a3ultra/a3ultra_llama3.1-8b_8gpus_bf16_maxtext.yaml": {
50+
"schedule": Schedule.DAILY_PDT_6PM_EXCEPT_THURSDAY,
51+
"timeout_minutes": 15,
52+
},
53+
"recipes/a3ultra/a3ultra_llama3.1-8b_8gpus_fp8_maxtext.yaml": {
54+
"schedule": Schedule.DAILY_PDT_6PM_EXCEPT_THURSDAY,
55+
"timeout_minutes": 15,
56+
},
57+
"recipes/a3ultra/a3ultra_llama3.1-8b_16gpus_bf16_maxtext.yaml": {
58+
"schedule": Schedule.DAILY_PDT_6PM_EXCEPT_THURSDAY,
59+
"timeout_minutes": 15,
60+
},
61+
"recipes/a3ultra/a3ultra_llama3.1-8b_16gpus_fp8_maxtext.yaml": {
62+
"schedule": Schedule.DAILY_PDT_6PM_EXCEPT_THURSDAY,
63+
"timeout_minutes": 15,
64+
},
65+
# a3ultra_mixtral-8x7
66+
"recipes/a3ultra/a3ultra_mixtral-8x7b_8gpus_bf16_maxtext.yaml": {
67+
"schedule": Schedule.DAILY_PDT_6PM_EXCEPT_THURSDAY,
68+
"timeout_minutes": 15,
69+
},
70+
"recipes/a3ultra/a3ultra_mixtral-8x7b_16gpus_bf16_maxtext.yaml": {
71+
"schedule": Schedule.DAILY_PDT_6PM_EXCEPT_THURSDAY,
72+
"timeout_minutes": 15,
73+
},
74+
# a3ultra_llama3.1-70b
75+
"recipes/a3ultra/a3ultra_llama3.1-70b_256gpus_bf16_maxtext.yaml": {
76+
"schedule": Schedule.DAILY_PDT_6_30PM_EXCEPT_THURSDAY,
77+
"timeout_minutes": 15,
78+
},
79+
"recipes/a3ultra/a3ultra_llama3.1-70b_256gpus_fp8_maxtext.yaml": {
80+
"schedule": Schedule.DAILY_PDT_7PM_EXCEPT_THURSDAY,
81+
"timeout_minutes": 15,
82+
},
83+
# a3ultra_llama3.1-405b
84+
"recipes/a3ultra/a3ultra_llama3.1-405b_256gpus_fp8_maxtext.yaml": {
85+
"schedule": Schedule.DAILY_PDT_7_30PM_EXCEPT_THURSDAY,
86+
"timeout_minutes": 30,
87+
},
88+
"recipes/a3ultra/a3ultra_llama3.1-405b_256gpus_bf16_maxtext.yaml": {
89+
"schedule": Schedule.DAILY_PDT_8PM_EXCEPT_THURSDAY,
90+
"timeout_minutes": 40,
91+
},
92+
}
93+
94+
95+
# Create DAGs for each configuration
96+
for config_path, config_info in MODEL_CONFIGS.items():
6097
# Extract config name for the DAG ID
61-
config_yaml_name = relative_config_yaml_path.rsplit("/", maxsplit=1)[
62-
-1
63-
].replace(".yaml", "")
64-
actual_schedule = schedule_time if TURN_ON_SCHEDULE else None
65-
dag_id = f"new_internal_{config_yaml_name}"
98+
config_name = os.path.basename(config_path).replace(".yaml", "")
99+
schedule = config_info["schedule"] if TURN_ON_SCHEDULE else None
100+
timeout = config_info["timeout_minutes"]
101+
102+
# Create DAG for nightly build
103+
with models.DAG(
104+
dag_id=f"new_internal_{config_name}",
105+
schedule=schedule,
106+
tags=DAG_TAGS,
107+
start_date=datetime.datetime(2025, 4, 3),
108+
catchup=False,
109+
) as dag:
110+
run_internal_aotc_workload(
111+
relative_config_yaml_path=config_path,
112+
test_run=TEST_RUN,
113+
backfill=BACKFILL,
114+
timeout=timeout,
115+
image_version=NIGHTLY_IMAGE,
116+
)
66117

67-
# Define the DAG
118+
# Create DAG for stable release
68119
with models.DAG(
69-
dag_id=dag_id,
70-
schedule=actual_schedule, # Use the specific schedule time
71-
tags=common_tags,
120+
dag_id=f"new_internal_stable_release_{config_name}_",
121+
schedule=schedule,
122+
tags=DAG_TAGS,
72123
start_date=datetime.datetime(2025, 4, 3),
73124
catchup=False,
74125
) as dag:
75-
# Create the workload for this specific config
76126
run_internal_aotc_workload(
77-
relative_config_yaml_path=relative_config_yaml_path,
127+
relative_config_yaml_path=config_path,
78128
test_run=TEST_RUN,
79129
backfill=BACKFILL,
130+
timeout=timeout,
131+
image_version=RELEASE_IMAGE,
80132
)

dags/map_reproducibility/utils/common_utils.py

+13
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,19 @@ def wait_for_jobs_cmds():
282282
return wait_for_job
283283

284284

285+
def internal_wait_for_jobs_cmds(timeout="100m"):
286+
timeout = str(timeout)
287+
if not timeout.endswith("m"):
288+
timeout += "m"
289+
wait_for_job = (
290+
"kubectl get pods --selector=job-name=$JOB_NAME --namespace=default",
291+
"echo 'will wait for jobs to finish'",
292+
"kubectl wait --for=condition=complete "
293+
f"job/$JOB_NAME --namespace=default --timeout={timeout}",
294+
)
295+
return wait_for_job
296+
297+
285298
def copy_bucket_cmds_nemo(recipe_repo_root, hypercomputer: str = "a3mega"):
286299
gcs_location = ""
287300
if hypercomputer in ("a3ultra", "a4"):

dags/map_reproducibility/utils/constants.py

+19
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,22 @@ class Schedule:
1616
DAILY_PST_6_30PM_EXCEPT_THURSDAY = "30 2 * * 1,2,3,4,6,0"
1717
DAILY_PST_7PM_EXCEPT_THURSDAY = "0 3 * * 1,2,3,4,6,0"
1818
DAILY_PST_7_30PM_EXCEPT_THURSDAY = "30 3 * * 1,2,3,4,6,0"
19+
20+
DAILY_PDT_6PM_EXCEPT_THURSDAY = "0 3 * * 1,2,3,4,6,0"
21+
DAILY_PDT_6_30PM_EXCEPT_THURSDAY = "30 3 * * 1,2,3,4,6,0"
22+
DAILY_PDT_7PM_EXCEPT_THURSDAY = "0 4 * * 1,2,3,4,6,0"
23+
DAILY_PDT_7_30PM_EXCEPT_THURSDAY = "30 4 * * 1,2,3,4,6,0"
24+
DAILY_PDT_8PM_EXCEPT_THURSDAY = "0 5 * * 1,2,3,4,6,0"
25+
DAILY_PDT_8_30PM_EXCEPT_THURSDAY = "30 5 * * 1,2,3,4,6,0"
26+
DAILY_PDT_9PM_EXCEPT_THURSDAY = "0 6 * * 1,2,3,4,6,0"
27+
28+
29+
class Image:
30+
MAXTEXT_JAX_STABLE_NIGHTLY = (
31+
"gcr.io/tpu-prod-env-multipod/maxtext_gpu_stable_stack_nightly_jax"
32+
)
33+
MAXTEXT_JAX_STABLE_RELEASE = (
34+
"gcr.io/tpu-prod-env-multipod/maxtext_gpu_jax_stable_stack"
35+
)
36+
MAXTEXT_JAX_STABLE_NIGHTLY_OLD = "gcr.io/supercomputer-testing/jax3p_nightly"
37+
MAXTEXT_JAX_STABLE_RELEASE_OLD = "gcr.io/supercomputer-testing/jax3p_stable"

dags/map_reproducibility/utils/internal_aotc_workload.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from dags.map_reproducibility.utils.common_utils import BUCKET_NAME, configure_project_and_cluster
2323
from dags.map_reproducibility.utils.common_utils import install_helm_cmds
2424
from dags.map_reproducibility.utils.common_utils import namespace_cmds
25-
from dags.map_reproducibility.utils.common_utils import wait_for_jobs_cmds
25+
from dags.map_reproducibility.utils.common_utils import internal_wait_for_jobs_cmds
2626
from dags.map_reproducibility.utils.common_utils import cleanup_cmds
2727
from dags.map_reproducibility.utils.common_utils import git_cookie_authdaemon
2828
from dags.map_reproducibility.utils.common_utils import clone_recipes_gob, clone_internal_recipes_gob
@@ -34,7 +34,6 @@
3434
from dags.map_reproducibility.utils.common_utils import get_bq_writer_path
3535
from dags.map_reproducibility.utils.common_utils import get_recipe_repo_path, get_internal_recipe_repo_path
3636
from dags.map_reproducibility.utils.common_utils import get_cluster
37-
from dags.map_reproducibility.utils.common_utils import get_internal_docker_image
3837
from dags.map_reproducibility.utils.common_utils import calculate_maxtext_metrics
3938
from dags.map_reproducibility.utils.common_utils import copy_bucket_cmds_maxtext
4039
from dags.map_reproducibility.utils.common_utils import parse_internal_config_filename
@@ -44,7 +43,11 @@
4443

4544
@task
4645
def run_internal_aotc_workload(
47-
relative_config_yaml_path, test_run=False, backfill=False
46+
relative_config_yaml_path,
47+
test_run=False,
48+
backfill=False,
49+
timeout=None,
50+
image_version=None,
4851
):
4952
"""Runs the AOTC workload benchmark.
5053
@@ -59,9 +62,7 @@ def run_internal_aotc_workload(
5962

6063
# Get derived configuration
6164
cluster, cluster_region = get_cluster(config.HYPERCOMPUTER)
62-
docker_image = get_internal_docker_image(
63-
config.HYPERCOMPUTER, config.FRAMEWORK
64-
)
65+
docker_image = image_version
6566
values_name = f"{config.HYPERCOMPUTER}_{config.FRAMEWORK}_values"
6667

6768
with tempfile.TemporaryDirectory() as tmpdir:
@@ -149,7 +150,7 @@ def run_internal_aotc_workload(
149150
additional_cmds=f" --set workload.gpus={config.NUM_GPUS} ",
150151
test_run=test_run,
151152
)
152-
+ wait_for_jobs_cmds()
153+
+ internal_wait_for_jobs_cmds(timeout=timeout)
153154
+ copy_bucket_cmds_maxtext(
154155
tmpdir, recipe_repo_root=recipe_repo_root
155156
)

0 commit comments

Comments
 (0)