Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[regression test] Update run_mantaray_jobs.py for splited test order for PyTorch regression test on TPU #572

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
30 changes: 26 additions & 4 deletions dags/mantaray/run_mantaray_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,21 @@
)
xlml_jobs = yaml.safe_load(xlml_jobs_yaml)

# Create a DAG for PyTorch/XLA tests
pattern = r"^(ptxla|pytorchxla).*"
# Create two DAG for PyTorch/XLA tests
pattern = r"^(ptxla|pytorchxla_part1).*"
pattern2 = r"^(pytorchxla_part2).*"
workload_file_name_list = []
workload_file_name_list_2 = []
for job in xlml_jobs:
if re.match(pattern, job["task_name"]):
workload_file_name_list.append(job["file_name"])
elif re.match(pattern2, job["task_name"]):
workload_file_name_list_2.append(job["file_name"])

# merge all PyTorch/XLA tests ino one Dag
with models.DAG(
dag_id="pytorch_xla_model_regression_test_on_trillium",
schedule="0 0 * * *", # everyday at midnight # job["schedule"],
schedule="0 0 * * *", # everyday at midnight
tags=["mantaray", "pytorchxla", "xlml"],
start_date=datetime.datetime(2024, 4, 22),
catchup=False,
Expand All @@ -54,9 +58,27 @@
)
run_workload

# split out sd2 model test
with models.DAG(
dag_id="pytorch_xla_model_regression_test_on_trillium_share_zone_2",
schedule="0 0 * * *", # everyday at midnight # job["schedule"],
tags=["mantaray", "pytorchxla", "xlml"],
start_date=datetime.datetime(2024, 4, 22),
catchup=False,
) as dag:
for workload_file_name in workload_file_name_list_2:
run_workload = mantaray.run_workload.override(
task_id=workload_file_name.split(".")[0]
)(
workload_file_name=workload_file_name,
)
run_workload

# Create a DAG for each job from maxtext
for job in xlml_jobs:
if not re.match(pattern, job["task_name"]):
if (not re.match(pattern, job["task_name"])) and (
not re.match(pattern2, job["task_name"])
):
with models.DAG(
dag_id=job["task_name"],
schedule=job["schedule"],
Expand Down
Loading