From 8c182294ac04ef12c34f53d404e6ae92c4e59ae0 Mon Sep 17 00:00:00 2001 From: Steven Ingram Date: Tue, 28 Jan 2025 22:19:54 +0000 Subject: [PATCH 1/9] Updated helm installed command to set clusterName. --- dags/map_reproducibility/a3ultra_mixtral_8_7b_nemo.py | 6 ++++-- dags/map_reproducibility/utils/common_utils.py | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/dags/map_reproducibility/a3ultra_mixtral_8_7b_nemo.py b/dags/map_reproducibility/a3ultra_mixtral_8_7b_nemo.py index d218733a4..01fb09135 100644 --- a/dags/map_reproducibility/a3ultra_mixtral_8_7b_nemo.py +++ b/dags/map_reproducibility/a3ultra_mixtral_8_7b_nemo.py @@ -55,7 +55,7 @@ VALUE_YAML_PATH = ( f"training/{HYPERCOMPUTER}/{MODEL_ID}/nemo-pretraining-gke/values.yaml" ) -CLUSTER = "gke-a3ultra-map" +CLUSTER = "gke-a3u-map-01-31" CLUSTER_REGION = "europe-west1" SOFTWARE_ID = "pytorch_nemo" IMAGE_VERSION = "nemo_workload:24.07" @@ -130,7 +130,9 @@ def run_aotc_workload(): accelerator_type, tmpdir, ) - + cleanup_cmds() + # DEBUG: to clean-up, get manifest by doing: helm list | grep regression | awk '{print $1}' + # + cleanup_cmds() + ), ], cwd=tmpdir, diff --git a/dags/map_reproducibility/utils/common_utils.py b/dags/map_reproducibility/utils/common_utils.py index 53242cfaf..4928a7353 100644 --- a/dags/map_reproducibility/utils/common_utils.py +++ b/dags/map_reproducibility/utils/common_utils.py @@ -120,10 +120,12 @@ def helm_apply_cmds( gcs_cmd = "" if hypercomputer == "a3ultra": gcs_cmd = f" --set volumes.gcsMounts[0].bucketName={BUCKET_NAME}" + network_prefix = "gke-a3u-map-01-31" + gcs_cmd += f" --set clusterName={network_prefix}" else: gcs_cmd = f" --set workload.gcsBucketForDataCataPath={BUCKET_NAME}" set_aotc = "" - if aotc is True: + if aotc: set_aotc = " --set-string workload.aotc=true " helm_cmds = ( " helm install -f values.yaml " @@ -178,10 +180,10 @@ def get_nemo_metrics_cmds( def cleanup_cmds(): cleanup = ( + "helm uninstall $JOB_NAME", "kubectl get pods " "--no-headers=true | awk '{print $1}' " "| grep $JOB_NAME | xargs kubectl delete pods", - "helm uninstall $JOB_NAME", ) return cleanup From 54ba2597932adc515f8c83bd3ac179932956d769 Mon Sep 17 00:00:00 2001 From: Yifei Teng Date: Tue, 28 Jan 2025 21:10:46 -0800 Subject: [PATCH 2/9] Add a multi-slice llama training test (#588) Specifically, two slices of v5p-8. This is to ensure that both r2.6 and ptxla nightly continue to work in multi-slice (DCN networking involved) environments. I noticed that we are not testing llama3 in nightly so I added that too, copying from our release tests. --- .../pytorch/nightly/llama2-model.libsonnet | 100 ++++++++++++++++++ .../tests/pytorch/r2.6/llama2-model.libsonnet | 31 +++++- dags/pytorch_xla/nightly.py | 27 +++++ dags/pytorch_xla/r2_6.py | 10 ++ xlml/apis/test_config.py | 4 + 5 files changed, 171 insertions(+), 1 deletion(-) diff --git a/dags/legacy_test/tests/pytorch/nightly/llama2-model.libsonnet b/dags/legacy_test/tests/pytorch/nightly/llama2-model.libsonnet index 14c096ad7..eca45071d 100644 --- a/dags/legacy_test/tests/pytorch/nightly/llama2-model.libsonnet +++ b/dags/legacy_test/tests/pytorch/nightly/llama2-model.libsonnet @@ -130,14 +130,114 @@ local utils = import 'templates/utils.libsonnet'; ||| % common.HuggingfacePipVersionConstraints, }, }, + local llama3_train = self.llama3_train, + llama3_train:: common.PyTorchTest + common.Functional + common.PyTorchTpuVmMixin { + modelName: 'llama3-train', + command: [ + 'python', + 'transformers/examples/pytorch/language-modeling/run_clm.py', + '--dataset_name=wikitext', + '--dataset_config_name=wikitext-2-raw-v1', + '--per_device_train_batch_size=4', + '--do_train', + '--output_dir=./tmp/test-clm', + '--overwrite_output_dir', + '--config_name=./llama_3/config.json', + '--cache_dir=./cache', + '--tokenizer_name=./llama_3/tokenizer/', + '--block_size=8192', + '--optim=adafactor', + '--save_strategy=no', + '--logging_strategy=no', + '--fsdp=full_shard', + '--fsdp_config=./llama_3/fsdp_config.json', + '--torch_dtype=bfloat16', + '--dataloader_drop_last=yes', + '--flash_attention', + '--max_steps=10', + ], + tpuSettings+: { + tpuVmExports+: ||| + export PJRT_DEVICE=TPU + export XLA_USE_SPMD=1 + |||, + tpuVmExtraSetup: ||| + cat > ~/hf-constraints.txt << 'HF_CONSTRAINTS_EOF' + %s + HF_CONSTRAINTS_EOF + + git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git + + # install tokenizer model + curl -O https://dl.google.com/dl/cloudsdk/channels/rapid/downloads/google-cloud-cli-linux-x86_64.tar.gz + tar -xf google-cloud-cli-linux-x86_64.tar.gz + yes | ./google-cloud-sdk/install.sh + google-cloud-sdk/bin/gsutil cp -r gs://pytorch-airflow/llama_3/ . + + cd transformers + sudo pip3 install -e . -c ~/hf-constraints.txt + pip3 install 'torch_xla[pallas]' -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html + pip3 install datasets evaluate scikit-learn accelerate -c ~/hf-constraints.txt + ||| % common.HuggingfacePipVersionConstraints, + }, + }, + + local llama3_train_2_slice = self.llama3_train_2_slice, + llama3_train_2_slice:: llama3_train { + modelName: 'llama3-train-2-slice', + command: [ + 'python', + 'transformers/examples/pytorch/language-modeling/run_clm.py', + '--dataset_name=wikitext', + '--dataset_config_name=wikitext-2-raw-v1', + '--per_device_train_batch_size=8', + '--do_train', + '--output_dir=./tmp/test-clm', + '--overwrite_output_dir', + '--config_name=./llama_3/config.json', + '--cache_dir=./cache', + '--tokenizer_name=./llama_3/tokenizer/', + '--block_size=8192', + '--optim=adafactor', + '--save_strategy=no', + '--logging_strategy=no', + '--fsdp=full_shard', + '--fsdp_config=./llama_3/fsdp_config.json', + '--torch_dtype=bfloat16', + '--dataloader_drop_last=yes', + '--flash_attention', + '--max_steps=10', + ] + }, local v4_8 = self.v4_8, v4_8:: { accelerator: tpus.v4_8, }, + local v5p_8 = self.v5p_8, + v5p_8:: { + tpuSettings+: { + softwareVersion: 'v2-alpha-tpuv5', + }, + accelerator: tpus.v5p_8, + }, + + local trillium_4 = self.trillium_4, + trillium_4:: { + tpuSettings+: { + softwareVersion: 'v2-alpha-tpuv6e', + }, + accelerator: tpus.trillium_4, + }, + configs: [ llama2 + infer + v4_8 + timeouts.Hours(3), llama2 + spmd + v4_8 + timeouts.Hours(3), + llama2 + infer + v5p_8 + timeouts.Hours(3), + llama2 + spmd + v5p_8 + timeouts.Hours(3), + llama3_train + v5p_8 + timeouts.Hours(3), + llama3_train + trillium_4 + timeouts.Hours(3), + llama3_train_2_slice + v5p_8 + timeouts.Hours(3), ], } diff --git a/dags/legacy_test/tests/pytorch/r2.6/llama2-model.libsonnet b/dags/legacy_test/tests/pytorch/r2.6/llama2-model.libsonnet index aa53442aa..eca45071d 100644 --- a/dags/legacy_test/tests/pytorch/r2.6/llama2-model.libsonnet +++ b/dags/legacy_test/tests/pytorch/r2.6/llama2-model.libsonnet @@ -138,7 +138,7 @@ local utils = import 'templates/utils.libsonnet'; 'transformers/examples/pytorch/language-modeling/run_clm.py', '--dataset_name=wikitext', '--dataset_config_name=wikitext-2-raw-v1', - '--per_device_train_batch_size=2', + '--per_device_train_batch_size=4', '--do_train', '--output_dir=./tmp/test-clm', '--overwrite_output_dir', @@ -182,6 +182,34 @@ local utils = import 'templates/utils.libsonnet'; }, }, + local llama3_train_2_slice = self.llama3_train_2_slice, + llama3_train_2_slice:: llama3_train { + modelName: 'llama3-train-2-slice', + command: [ + 'python', + 'transformers/examples/pytorch/language-modeling/run_clm.py', + '--dataset_name=wikitext', + '--dataset_config_name=wikitext-2-raw-v1', + '--per_device_train_batch_size=8', + '--do_train', + '--output_dir=./tmp/test-clm', + '--overwrite_output_dir', + '--config_name=./llama_3/config.json', + '--cache_dir=./cache', + '--tokenizer_name=./llama_3/tokenizer/', + '--block_size=8192', + '--optim=adafactor', + '--save_strategy=no', + '--logging_strategy=no', + '--fsdp=full_shard', + '--fsdp_config=./llama_3/fsdp_config.json', + '--torch_dtype=bfloat16', + '--dataloader_drop_last=yes', + '--flash_attention', + '--max_steps=10', + ] + }, + local v4_8 = self.v4_8, v4_8:: { accelerator: tpus.v4_8, @@ -210,5 +238,6 @@ local utils = import 'templates/utils.libsonnet'; llama2 + spmd + v5p_8 + timeouts.Hours(3), llama3_train + v5p_8 + timeouts.Hours(3), llama3_train + trillium_4 + timeouts.Hours(3), + llama3_train_2_slice + v5p_8 + timeouts.Hours(3), ], } diff --git a/dags/pytorch_xla/nightly.py b/dags/pytorch_xla/nightly.py index ddfe1cc3d..d21a82ee8 100644 --- a/dags/pytorch_xla/nightly.py +++ b/dags/pytorch_xla/nightly.py @@ -204,6 +204,33 @@ def llama(): ), US_CENTRAL2_B, ) + llama_3_train_trillium = task.run_queued_resource_test( + test_config.JSonnetTpuVmTest.from_pytorch( + "pt-nightly-llama3-train-func-v6e-4-1vm", + network=V5_NETWORKS, + subnetwork=V6E_SUBNETWORKS, + ), + US_CENTRAL2_B_TPU_PROD_ENV, + ) + llama_3_train_v5p_2_slices = task.run_queued_resource_test( + test_config.JSonnetTpuVmTest.from_pytorch( + "pt-nightly-llama3-train-2-slice-func-v5p-8-1vm", + reserved=True, + network=V5_NETWORKS, + subnetwork=V5P_SUBNETWORKS, + num_slices=2, + ), + US_EAST5_A_TPU_PROD_ENV_AUTOMATED, + ) + llama_3_train_v5p_8 = task.run_queued_resource_test( + test_config.JSonnetTpuVmTest.from_pytorch( + "pt-nightly-llama3-train-func-v5p-8-1vm", + reserved=True, + network=V5_NETWORKS, + subnetwork=V5P_SUBNETWORKS, + ), + US_EAST5_A_TPU_PROD_ENV_AUTOMATED, + ) with models.DAG( diff --git a/dags/pytorch_xla/r2_6.py b/dags/pytorch_xla/r2_6.py index 6f774adbc..5e4f1c0cd 100644 --- a/dags/pytorch_xla/r2_6.py +++ b/dags/pytorch_xla/r2_6.py @@ -230,6 +230,16 @@ def llama(): ), US_CENTRAL2_B_TPU_PROD_ENV, ) + llama_3_train_v5p_2_slices = task.run_queued_resource_test( + test_config.JSonnetTpuVmTest.from_pytorch( + "pt-2-6-llama3-train-2-slice-func-v5p-8-1vm", + reserved=True, + network=V5_NETWORKS, + subnetwork=V5P_SUBNETWORKS, + num_slices=2, + ), + US_EAST5_A_TPU_PROD_ENV_AUTOMATED, + ) llama_3_train_v5p_8 = task.run_queued_resource_test( test_config.JSonnetTpuVmTest.from_pytorch( "pt-2-6-llama3-train-func-v5p-8-1vm", diff --git a/xlml/apis/test_config.py b/xlml/apis/test_config.py index 1704b95a9..abc2299f7 100644 --- a/xlml/apis/test_config.py +++ b/xlml/apis/test_config.py @@ -396,6 +396,7 @@ def _from_json_helper( reserved: bool, network: str, subnetwork: str, + num_slices: int = 1, ): return JSonnetTpuVmTest( test_name=test['testName'], @@ -414,6 +415,7 @@ def _from_json_helper( exports=exports, test_command=test_command, timeout=datetime.timedelta(seconds=test['timeout']), + num_slices=num_slices, ) @staticmethod @@ -442,6 +444,7 @@ def from_pytorch( reserved: bool = False, network='default', subnetwork='default', + num_slices: int = 1, ): """Parses a compiled legacy JSonnet test config from `tests/pytorch`.""" test = _load_compiled_jsonnet(test_name) @@ -455,6 +458,7 @@ def from_pytorch( reserved=reserved, network=network, subnetwork=subnetwork, + num_slices=num_slices, ) @property From 9ee7608dd9917969b3757f0a2af9e66d413ab8bd Mon Sep 17 00:00:00 2001 From: Pei Zhang Date: Wed, 29 Jan 2025 10:16:02 -0800 Subject: [PATCH 3/9] fix lock issue (#579) --- dags/legacy_test/tests/pytorch/nightly/common.libsonnet | 2 +- dags/legacy_test/tests/pytorch/r2.6/common.libsonnet | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dags/legacy_test/tests/pytorch/nightly/common.libsonnet b/dags/legacy_test/tests/pytorch/nightly/common.libsonnet index 05c3309f9..f70deacaa 100644 --- a/dags/legacy_test/tests/pytorch/nightly/common.libsonnet +++ b/dags/legacy_test/tests/pytorch/nightly/common.libsonnet @@ -95,8 +95,8 @@ local volumes = import 'templates/volumes.libsonnet'; sudo systemctl disable unattended-upgrades || true sudo killall --signal SIGKILL unattended-upgrades || true sudo dpkg --configure -a || true - sudo apt purge unattended-upgrades -y || true sudo rm /var/lib/dpkg/lock-frontend || true + sudo apt purge unattended-upgrades -y || true echo "unattended-upgrades stopped." sudo apt-get -y update diff --git a/dags/legacy_test/tests/pytorch/r2.6/common.libsonnet b/dags/legacy_test/tests/pytorch/r2.6/common.libsonnet index f41cca9a9..ac72870b6 100644 --- a/dags/legacy_test/tests/pytorch/r2.6/common.libsonnet +++ b/dags/legacy_test/tests/pytorch/r2.6/common.libsonnet @@ -97,8 +97,8 @@ local rcVersion = 'rc10'; sudo systemctl disable unattended-upgrades || true sudo killall --signal SIGKILL unattended-upgrades || true sudo dpkg --configure -a || true - sudo apt purge unattended-upgrades -y || true sudo rm /var/lib/dpkg/lock-frontend || true + sudo apt purge unattended-upgrades -y || true echo "unattended-upgrades stopped." sudo apt-get -y update From bd3982ab45428e888ad0abc52003c04d6662532c Mon Sep 17 00:00:00 2001 From: Yifei Teng Date: Wed, 29 Jan 2025 13:39:19 -0800 Subject: [PATCH 4/9] Test final 2.6.0 wheel (#589) --- .../tests/pytorch/r2.6/common.libsonnet | 16 +++++++--------- .../configs/pytorchxla_torchbench_config.py | 6 +++--- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/dags/legacy_test/tests/pytorch/r2.6/common.libsonnet b/dags/legacy_test/tests/pytorch/r2.6/common.libsonnet index ac72870b6..7e4ceaf68 100644 --- a/dags/legacy_test/tests/pytorch/r2.6/common.libsonnet +++ b/dags/legacy_test/tests/pytorch/r2.6/common.libsonnet @@ -18,15 +18,13 @@ local mixins = import 'templates/mixins.libsonnet'; local utils = import 'templates/utils.libsonnet'; local volumes = import 'templates/volumes.libsonnet'; -local rcVersion = 'rc10'; - { local r2_6 = { frameworkPrefix: 'pt-2-6', tpuSettings+: { softwareVersion: 'tpu-ubuntu2204-base', }, - imageTag: 'r2.6.0-%(rc)s_3.10' % {rc: rcVersion}, + imageTag: 'r2.6.0_3.10', }, PyTorchTest:: common.PyTorchTest + r2_6 { local config = self, @@ -109,13 +107,13 @@ local rcVersion = 'rc10'; pip install torch==2.6 --index-url https://download.pytorch.org/whl/test/cpu # torchvision commit reference: https://github.com/pytorch/pytorch/blob/release/2.6/.github/ci_commit_pins/vision.txt pip install --user --no-use-pep517 "git+https://github.com/pytorch/vision.git@d23a6e1664d20707c11781299611436e1f0c104f" - pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0%(rc)s-cp310-cp310-manylinux_2_28_x86_64.whl + pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0-cp310-cp310-manylinux_2_28_x86_64.whl pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html pip install pillow git clone --depth=1 https://github.com/pytorch/pytorch.git cd pytorch - git clone -b v2.6.0-%(rc)s https://github.com/pytorch/xla.git - ||| % {rc: rcVersion}, + git clone -b v2.6.0 https://github.com/pytorch/xla.git + |||, }, podTemplate+:: { spec+: { @@ -152,16 +150,16 @@ local rcVersion = 'rc10'; pip uninstall -y torch torchvision pip install torch==2.6 --index-url https://download.pytorch.org/whl/test/cpu pip install --user --no-use-pep517 "git+https://github.com/pytorch/vision.git@d23a6e1664d20707c11781299611436e1f0c104f" - pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0%(rc)s-cp310-cp310-manylinux_2_28_x86_64.whl + pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0-cp310-cp310-manylinux_2_28_x86_64.whl mkdir -p pytorch/xla - git clone -b v2.6.0-%(rc)s https://github.com/pytorch/xla.git pytorch/xla + git clone -b v2.6.0 https://github.com/pytorch/xla.git pytorch/xla %(cmd)s # Run whatever is in `command` here "${@:0}" - ||| % {cmd: config.tpuSettings.tpuVmExports, rc: rcVersion}, + ||| % {cmd: config.tpuSettings.tpuVmExports}, ], command: [ 'torchrun', diff --git a/dags/pytorch_xla/configs/pytorchxla_torchbench_config.py b/dags/pytorch_xla/configs/pytorchxla_torchbench_config.py index 4fe46f1e5..a9d63a5fe 100644 --- a/dags/pytorch_xla/configs/pytorchxla_torchbench_config.py +++ b/dags/pytorch_xla/configs/pytorchxla_torchbench_config.py @@ -111,12 +111,12 @@ class R2_5_1(enum.Enum): TORCH_XLA_REPO_BRANCH = "-b v2.5.1" class R2_6(enum.Enum): - TORCH_XLA_TPU_WHEEL = "https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0rc10+cxx11-cp311-cp311-manylinux_2_28_x86_64.whl" - TORCH_XLA_CUDA_WHEEL = "https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.6.0rc10+cxx11-cp310-cp310-linux_x86_64.whl" + TORCH_XLA_TPU_WHEEL = "https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0+cxx11-cp311-cp311-manylinux_2_28_x86_64.whl" + TORCH_XLA_CUDA_WHEEL = "https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.6.0+cxx11-cp310-cp310-linux_x86_64.whl" TORCH = "torch==2.6.0" TORCHVISION = "torchvision" TORCHAUDIO = "torchaudio" - TORCH_XLA_GPU_DOCKER = "us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0rc10_3.10_cuda_12.1" + TORCH_XLA_GPU_DOCKER = "us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_cuda_12.1" TORCH_INDEX_CPU_URL = "https://download.pytorch.org/whl/test/cpu" TORCH_INDEX_CUDA_URL = "https://download.pytorch.org/whl/test/cu121" TORCH_REPO_BRANCH = "-b release/2.6" From 1faf9d5d4200479075413060eb6631c0cedbc142 Mon Sep 17 00:00:00 2001 From: Daniel Li Date: Wed, 29 Jan 2025 17:29:24 -0800 Subject: [PATCH 5/9] Migrate vLLM:TPU XLML Test from Python Setup to Docker with Latest Dependencies (#583) * Run the vLLM TPU test in Docker * Use Half-width comma * Use "sudo docker exec $CONTAINER_NAME /bin/bash -c" * Add HF_TOKEN * Use a static VLLM_TPU_CONTAINER name * Remove unused argument * Try jq * use escape * Try \\" * escape the single quote * Try 5: escape of single quote + double escape of double quote * Try 6: double escape of double quote * Try 7: double quote for command inside container and escape of double quotes * Try 8 * Try 9: double escape of double quotes * Escape \n * Try \\\n * Try \\$ * Try \\\" * Get the GCS destination path *before* constructing the command. * Get the GCS destination path *before* constructing the command. OUTSIDE the list. * Add \" and debug information * passing GCS as an environment variable * pkill vllm * Address comments * Reformat using command "pre-commit run --files dags/solutions_team/configs/vllm/vllm_benchmark_config.py" --- .../configs/vllm/vllm_benchmark_config.py | 135 ++++++++++++++---- 1 file changed, 104 insertions(+), 31 deletions(-) diff --git a/dags/solutions_team/configs/vllm/vllm_benchmark_config.py b/dags/solutions_team/configs/vllm/vllm_benchmark_config.py index c1fda9194..a8554fb93 100644 --- a/dags/solutions_team/configs/vllm/vllm_benchmark_config.py +++ b/dags/solutions_team/configs/vllm/vllm_benchmark_config.py @@ -30,6 +30,8 @@ RUNTIME_IMAGE = RuntimeVersion.TPU_UBUNTU2204_BASE.value GCS_SUBFOLDER_PREFIX = test_owner.Team.SOLUTIONS_TEAM.value HF_TOKEN = Variable.get("HF_TOKEN", None) +VLLM_TPU_DOCKER_IMAGE = "gcr.io/cloud-tpu-v2-images/vllm-tpu-nightly:latest" +VLLM_TPU_CONTAINER = "vllm-tpu-container" def get_vllm_gpu_setup_cmds(): @@ -53,33 +55,25 @@ def get_vllm_gpu_setup_cmds(): def get_vllm_tpu_setup_cmds(): setup_cmds = ( - # Update environment and installs basic deps - "pip install --upgrade pip", - "sudo apt-get -y update", - "sudo apt install -y libopenblas-base libopenblas-dev", - "sudo apt-get -y install python3.10-venv", - "sudo apt-get -y install jq", - "python -m venv .env", - "source .env/bin/activate", - # Install vllm at head - "rm -rf vllm && git clone https://github.com/vllm-project/vllm", - "cd vllm", - # From https://docs.vllm.ai/en/latest/getting_started/tpu-installation.html - "pip uninstall torch torch-xla -y", - "pip install -r requirements-tpu.txt", - # Build vLLM - 'VLLM_TARGET_DEVICE="tpu" python setup.py develop', - # Download dataset - "cd .. && wget --no-verbose https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json", - # Download benchmark - "pip install --upgrade google-cloud-storage", - "rm -rf inference-benchmark && git clone https://github.com/AI-Hypercomputer/inference-benchmark", + # Download and start the vLLM TPU Docker container + f"export CONTAINER_NAME={VLLM_TPU_CONTAINER}", + f"sudo docker run --name $CONTAINER_NAME -d --privileged --network host -v /dev/shm:/dev/shm {VLLM_TPU_DOCKER_IMAGE} tail -f /dev/null", + # Download dataset inside the container + "sudo docker exec $CONTAINER_NAME /bin/bash -c 'wget --no-verbose https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json'", + # Download benchmark inside the container + "sudo docker exec $CONTAINER_NAME /bin/bash -c 'pip install --upgrade google-cloud-storage'", + "sudo docker exec $CONTAINER_NAME /bin/bash -c 'rm -rf inference-benchmark && git clone https://github.com/AI-Hypercomputer/inference-benchmark'", + # Download Google Cloud SDK inside the container, which is needed for the gsutil command. + "sudo docker exec $CONTAINER_NAME /bin/bash -c 'echo \"deb [signed-by=/usr/share/keyrings/cloud.google.gpg] http://packages.cloud.google.com/apt cloud-sdk main\" > /etc/apt/sources.list.d/google-cloud-sdk.list'", + "sudo docker exec $CONTAINER_NAME /bin/bash -c 'curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key --keyring /usr/share/keyrings/cloud.google.gpg add -'", + "sudo docker exec $CONTAINER_NAME /bin/bash -c 'apt-get update && apt-get install -y google-cloud-sdk'", + "sudo docker exec $CONTAINER_NAME /bin/bash -c 'apt-get -y install jq'", ) return setup_cmds -def get_vllm_benchmark_cmds( +def _get_vllm_benchmark_parameters( model_id: str, num_chips: int, test_run_id: str, model_configs: Dict = {} ): base_model_id = model_id.split("/")[-1] @@ -87,6 +81,39 @@ def get_vllm_benchmark_cmds( instance_type = model_configs["instance_type"] num_prompts = 1000 + # Group metrics together using test_run_id. + metadata = { + "test_run_id": test_run_id, + "instance_type": instance_type, + "num_accelerators": num_chips, + } + + # Get the GCS destination path *before* constructing the command, OUTSIDE the list. + gcs_destination = metric_config.SshEnvVars.GCS_OUTPUT.value + if not gcs_destination: + raise ValueError("GCS_OUTPUT environment variable is not set or is empty.") + # Debug Print + print(f"DEBUG: GCS Destination: {gcs_destination}") + + return base_model_id, request_rates, num_prompts, metadata, gcs_destination + + +def get_gpu_vllm_benchmark_cmds( + model_id: str, num_chips: int, test_run_id: str, model_configs: Dict = {} +): + ( + base_model_id, + request_rates, + num_prompts, + metadata, + gcs_destination, + ) = _get_vllm_benchmark_parameters( + model_id=model_id, + num_chips=num_chips, + test_run_id=test_run_id, + model_configs=model_configs, + ) + run_cmds = [ "export PATH=$PATH:/home/cloud-ml-auto-solutions/vllm:/home/cloud-ml-auto-solutions/.local/bin", # HF_TOKEN is set in Composer environment variables @@ -99,12 +126,6 @@ def get_vllm_benchmark_cmds( "sleep 600", ] - # Group metrics together using test_run_id. - metadata = { - "test_run_id": test_run_id, - "instance_type": instance_type, - "num_accelerators": num_chips, - } for request_rate in request_rates: benchmark_cmd_fmt = "python inference-benchmark/benchmark_serving.py --host localhost --port 8000 --num-prompts {num_prompts} --max-input-length 1024 --max-output-length 1024 --dataset ShareGPT_V3_unfiltered_cleaned_split.json --save-json-results --model '{model_id}' --tokenizer '{model_id}' --request-rate {request_rate} --additional-metadata-metrics-to-save '{additional_metadata}'" @@ -132,7 +153,59 @@ def get_vllm_benchmark_cmds( # Kill background process "pkill -P $$", # Copy metrics as the last step - f"gsutil cp metric_report.jsonl {metric_config.SshEnvVars.GCS_OUTPUT.value}", + f"gsutil cp metric_report.jsonl {gcs_destination}", + ]) + + return tuple(run_cmds) + + +def get_tpu_vllm_benchmark_cmds( + model_id: str, num_chips: int, test_run_id: str, model_configs: Dict = {} +): + ( + base_model_id, + request_rates, + num_prompts, + metadata, + gcs_destination, + ) = _get_vllm_benchmark_parameters( + model_id=model_id, + num_chips=num_chips, + test_run_id=test_run_id, + model_configs=model_configs, + ) + + run_cmds = [ + f"export CONTAINER_NAME={VLLM_TPU_CONTAINER}", + # Start vllm in the background and wait for server to come up + f"sudo docker exec $CONTAINER_NAME /bin/bash -c 'export HF_TOKEN={HF_TOKEN} && vllm serve {model_id} --swap-space 16 --disable-log-requests --tensor_parallel_size={num_chips} --max-model-len=2048 --num-scheduler-steps=4 & sleep 600'", + ] + + for request_rate in request_rates: + benchmark_cmd_fmt = "sudo docker exec $CONTAINER_NAME /bin/bash -c \"export HF_TOKEN={HF_TOKEN} && python inference-benchmark/benchmark_serving.py --host localhost --port 8000 --num-prompts {num_prompts} --max-input-length 1024 --max-output-length 1024 --dataset ShareGPT_V3_unfiltered_cleaned_split.json --save-json-results --model {model_id} --tokenizer {model_id} --request-rate {request_rate} --additional-metadata-metrics-to-save '{additional_metadata}'\"" + + benchmark_cmds = [ + # Run benchmark inside the container + benchmark_cmd_fmt.format( + HF_TOKEN=HF_TOKEN, + num_prompts=num_prompts, + model_id=model_id, + request_rate=request_rate, + additional_metadata=json.dumps(metadata).replace('"', '\\"'), + ), + # Process result json files inside the container + f"sudo docker exec $CONTAINER_NAME /bin/bash -c \"export OUTPUT_FORMAT='*vllm*{base_model_id}*' && export BENCHMARK_OUTPUT=\\$(find . -name \\$OUTPUT_FORMAT -type f -printf \\\"%T@ %Tc %p\n\\\" | sort -n | head -1 | awk 'NF>1{{print \\$NF}}') && cat \\$BENCHMARK_OUTPUT >> metric_report.jsonl && rm \\$BENCHMARK_OUTPUT\"", + "sudo docker exec $CONTAINER_NAME /bin/bash -c \"echo '' >> metric_report.jsonl\"", + ] + run_cmds.extend(benchmark_cmds) + + run_cmds.extend([ + # Kill background process + "sudo docker exec $CONTAINER_NAME /bin/bash -c 'pkill vllm'", + # Copy metrics + f"sudo docker exec -e GCS=\"{gcs_destination}\" $CONTAINER_NAME /bin/bash -c 'gsutil cp metric_report.jsonl $GCS'", + # Stop the container + "sudo docker stop $CONTAINER_NAME", ]) return tuple(run_cmds) @@ -163,7 +236,7 @@ def get_gpu_vllm_gce_config( set_up_cmds = get_vllm_gpu_setup_cmds() model_configs["instance_type"] = machine_version.value - run_model_cmds = get_vllm_benchmark_cmds( + run_model_cmds = get_gpu_vllm_benchmark_cmds( model_id=model_configs["model_id"], num_chips=count, test_run_id=test_run_id, @@ -235,7 +308,7 @@ def get_tpu_vllm_gce_config( set_up_cmds = get_vllm_tpu_setup_cmds() model_configs["instance_type"] = tpu_version.value - run_model_cmds = get_vllm_benchmark_cmds( + run_model_cmds = get_tpu_vllm_benchmark_cmds( model_id=model_configs["model_id"], num_chips=tpu_cores, test_run_id=test_run_id, From 23a25de414421ee9746e0f68a05914f8223a47a3 Mon Sep 17 00:00:00 2001 From: Ran Ran Date: Fri, 31 Jan 2025 14:41:09 -0800 Subject: [PATCH 6/9] Move MoE TPU tests out of Quarantine (#591) --- dags/common/quarantined_tests.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/dags/common/quarantined_tests.py b/dags/common/quarantined_tests.py index fa658b7a8..36865fcb0 100644 --- a/dags/common/quarantined_tests.py +++ b/dags/common/quarantined_tests.py @@ -80,18 +80,6 @@ class QuarantineTests: # DAG: maxtext_end_to_end "chained_tests_gemma-7b_stable": TestInfo(team.LLM_DEVX, "2024-11-12"), "chained_tests_gemma-7b_nightly": TestInfo(team.LLM_DEVX, "2024-11-12"), - "chained_tests_mixtral-8x7b_stable": TestInfo( - team.SPARSITY_DIFFUSION_DEVX, "2024-11-12" - ), - "chained_tests_mixtral-8x7b_nightly": TestInfo( - team.SPARSITY_DIFFUSION_DEVX, "2024-11-12" - ), - "maxtext_stable_mixtral-8x22b-v4-128": TestInfo( - team.SPARSITY_DIFFUSION_DEVX, "2024-11-12" - ), - "maxtext_nightly_mixtral-8x22b-v4-128": TestInfo( - team.SPARSITY_DIFFUSION_DEVX, "2024-11-12" - ), "chained_tests_llama2-70b_stable": TestInfo(team.LLM_DEVX, "2024-11-12"), "chained_tests_llama2-70b_nightly": TestInfo(team.LLM_DEVX, "2024-11-12"), # DAG: jax_stable_stack_gpu_e2e From 2d10e30cf52f320e6dc812d71aaa83091c9cc6ba Mon Sep 17 00:00:00 2001 From: Luke Baumann Date: Mon, 3 Feb 2025 08:23:59 -0800 Subject: [PATCH 7/9] Adding a check to see if the job is complete while waiting for workload completion. (#575) --- xlml/utils/xpk.py | 57 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/xlml/utils/xpk.py b/xlml/utils/xpk.py index 5786c2701..b49130425 100644 --- a/xlml/utils/xpk.py +++ b/xlml/utils/xpk.py @@ -144,6 +144,42 @@ def _list_workload_pods( return pods +def _get_batch_api_client( + project_id: str, region: str, cluster_name: str +) -> k8s_client.BatchV1Api: + """Create a batch API client for the given cluster.""" + client = gke.get_authenticated_client(project_id, region, cluster_name) + + # Initilize the client + batch_api = k8s_client.BatchV1Api(client) + logging.info( + "Successful initilize k8s batch api client from cluster response." + ) + return batch_api + + +def _get_workload_job( + batch_api: k8s_client.BatchV1Api, workload_id: str +) -> k8s_client.V1Job: + """Get the job for a given workload.""" + logging.info(f"Getting job for workload_id: {workload_id}") + jobs = batch_api.list_namespaced_job( + label_selector=f"jobset.sigs.k8s.io/jobset-name={workload_id}", + namespace="default", + ) + if len(jobs.items) == 0: + logging.info(f"Getting job for workload_id: {workload_id}") + return None + + if len(jobs.items) > 1: + logging.info(f"Got more than one job for workload_id: {workload_id}") + for i, job in enumerate(jobs.items): + logging.info(f"Job {i=}") + logging.info(f"{job}") + + return jobs.items[0] + + @task.sensor(poke_interval=60, timeout=600, mode="reschedule") def wait_for_workload_start( workload_id: str, project_id: str, region: str, cluster_name: str @@ -165,6 +201,27 @@ def wait_for_workload_completion( if not pods.items: logging.info(f"No pods found for workload selector: {workload_id}.") + + # Pathways jobs delete all pods on failure so we must also check if the job + # is complete + batch_api = _get_batch_api_client(project_id, region, cluster_name) + job = _get_workload_job(batch_api, workload_id) + if job is None: + logging.info( + f"No pods or jobs were found for workload selector: {workload_id}" + ) + return False + + if any(condition.type == "Failed" for condition in job.status.conditions): + # Don't keep retrying if the job has failed + raise AirflowFailException('Job has condition type: "Failed"') + + if any(condition.type == "Complete" for condition in job.status.conditions): + logging.info( + f"No pods found but job is complete for workload selector: {workload_id}" + ) + return True + return False if any(pod.status.phase in ["Pending", "Running"] for pod in pods.items): From 46132b10d258564c3f338bbc64d6672b7862a156 Mon Sep 17 00:00:00 2001 From: raymondzouu <31597464+raymondzouu@users.noreply.github.com> Date: Mon, 3 Feb 2025 08:40:38 -0800 Subject: [PATCH 8/9] Update mxla test to use llama3 8B and remove v4 tests (#585) --- dags/common/quarantined_tests.py | 13 ------- dags/common/vm_resource.py | 6 +-- dags/multipod/configs/gke_config.py | 2 +- dags/multipod/mxla_maxtext_nightly_gke.py | 46 ++--------------------- 4 files changed, 8 insertions(+), 59 deletions(-) diff --git a/dags/common/quarantined_tests.py b/dags/common/quarantined_tests.py index 36865fcb0..73dd9c02e 100644 --- a/dags/common/quarantined_tests.py +++ b/dags/common/quarantined_tests.py @@ -165,19 +165,6 @@ class QuarantineTests: "mxla-gpt3-6b-nightly-gke-8xv5p-8": TestInfo( team.PERFORMANCE, "2024-11-12" ), - # DAG: mxla_maxtext_nightly_gke - "mxla-maxtext-nightly-gke-v5p-8": TestInfo( - team.PERFORMANCE, "2024-11-12" - ), - "mxla-maxtext-nightly-gke-2xv5p-8": TestInfo( - team.PERFORMANCE, "2024-11-12" - ), - "mxla-maxtext-nightly-gke-4xv5p-8": TestInfo( - team.PERFORMANCE, "2024-11-12" - ), - "mxla-maxtext-nightly-gke-8xv5p-8": TestInfo( - team.PERFORMANCE, "2024-11-12" - ), # DAG: maxtext_trillium_configs_perf "maxtext-llama2_70b_4096-stable-3-2xv6e-256": TestInfo( team.PERFORMANCE, "2024-11-12" diff --git a/dags/common/vm_resource.py b/dags/common/vm_resource.py index 64b894bfc..5b9e6a171 100644 --- a/dags/common/vm_resource.py +++ b/dags/common/vm_resource.py @@ -231,11 +231,11 @@ class XpkClusters: zone=Zone.US_CENTRAL2_B.value, ) TPU_V5P_8_CLUSTER = XpkClusterConfig( - name="v5p-8-bodaborg-us-east5-a", + name="v5p-8-bodaborg-europe-west4-b", device_version=TpuVersion.V5P, core_count=8, - project=Project.TPU_PROD_ENV_LARGE_CONT.value, - zone=Zone.US_EAST5_A.value, + project=Project.CLOUD_TPU_MULTIPOD_DEV.value, + zone=Zone.EUROPE_WEST4_B.value, ) TPU_V5E_256_CLUSTER = XpkClusterConfig( name="v5e-256-bodaborg-europe-west4", diff --git a/dags/multipod/configs/gke_config.py b/dags/multipod/configs/gke_config.py index cecbd5f55..fa01deb78 100644 --- a/dags/multipod/configs/gke_config.py +++ b/dags/multipod/configs/gke_config.py @@ -116,7 +116,7 @@ def get_gke_maxtext_nightly_config( f" python3 MaxText/train.py MaxText/configs/base.yml run_name={run_name}" f" base_output_directory={base_output_directory}" " dataset_path=gs://max-datasets-rogue dataset_type=synthetic" - " per_device_batch_size=12 reuse_example_batch=1 global_parameter_scale=1 metrics_file='metrics.txt'" + " model_name=llama3-8b per_device_batch_size=12 reuse_example_batch=1 metrics_file='metrics.txt'" " steps=50 enable_checkpointing=false profiler=xplane upload_all_profiler_results=true skip_first_n_steps_for_profiler=10 profiler_steps=10 gcs_metrics=true" ), ) diff --git a/dags/multipod/mxla_maxtext_nightly_gke.py b/dags/multipod/mxla_maxtext_nightly_gke.py index 9acc3a97d..e02e5ef73 100644 --- a/dags/multipod/mxla_maxtext_nightly_gke.py +++ b/dags/multipod/mxla_maxtext_nightly_gke.py @@ -40,43 +40,12 @@ group_id="Quarantine", dag=dag, prefix_group_id=False ) - maxtext_nightly_1slice_v4_8 = gke_config.get_gke_maxtext_nightly_config( - time_out_in_min=60, - test_name=default_test_name, - docker_image=jax_nightly_image.value, - test_owner=test_owner.TONY_C, - ).run_with_quarantine(quarantine_task_group) - - maxtext_nightly_2slice_v4_8 = gke_config.get_gke_maxtext_nightly_config( - num_slices=2, - time_out_in_min=60, - test_name=default_test_name, - docker_image=jax_nightly_image.value, - test_owner=test_owner.TONY_C, - ).run_with_quarantine(quarantine_task_group) - - maxtext_nightly_4slice_v4_8 = gke_config.get_gke_maxtext_nightly_config( - num_slices=4, - time_out_in_min=60, - test_name=default_test_name, - docker_image=jax_nightly_image.value, - test_owner=test_owner.TONY_C, - ).run_with_quarantine(quarantine_task_group) - - maxtext_nightly_8slice_v4_8 = gke_config.get_gke_maxtext_nightly_config( - num_slices=8, - time_out_in_min=60, - test_name=default_test_name, - docker_image=jax_nightly_image.value, - test_owner=test_owner.TONY_C, - ).run_with_quarantine(quarantine_task_group) - maxtext_nightly_1slice_v5p_8 = gke_config.get_gke_maxtext_nightly_config( cluster=XpkClusters.TPU_V5P_8_CLUSTER, time_out_in_min=60, test_name=default_test_name, docker_image=jax_nightly_image.value, - test_owner=test_owner.TONY_C, + test_owner=test_owner.RAYMOND_Z, ).run_with_quarantine(quarantine_task_group) maxtext_nightly_2slice_v5p_8 = gke_config.get_gke_maxtext_nightly_config( @@ -85,7 +54,7 @@ time_out_in_min=60, test_name=default_test_name, docker_image=jax_nightly_image.value, - test_owner=test_owner.TONY_C, + test_owner=test_owner.RAYMOND_Z, ).run_with_quarantine(quarantine_task_group) maxtext_nightly_4slice_v5p_8 = gke_config.get_gke_maxtext_nightly_config( @@ -94,7 +63,7 @@ time_out_in_min=60, test_name=default_test_name, docker_image=jax_nightly_image.value, - test_owner=test_owner.TONY_C, + test_owner=test_owner.RAYMOND_Z, ).run_with_quarantine(quarantine_task_group) maxtext_nightly_8slice_v5p_8 = gke_config.get_gke_maxtext_nightly_config( @@ -103,16 +72,9 @@ time_out_in_min=60, test_name=default_test_name, docker_image=jax_nightly_image.value, - test_owner=test_owner.TONY_C, + test_owner=test_owner.RAYMOND_Z, ).run_with_quarantine(quarantine_task_group) - ( - maxtext_nightly_1slice_v4_8 - >> maxtext_nightly_2slice_v4_8 - >> maxtext_nightly_4slice_v4_8 - >> maxtext_nightly_8slice_v4_8 - ) - ( maxtext_nightly_1slice_v5p_8 >> maxtext_nightly_2slice_v5p_8 From 5cd937bf14fc9fa46f9336d0db9427b279f8a4d1 Mon Sep 17 00:00:00 2001 From: raymondzouu <31597464+raymondzouu@users.noreply.github.com> Date: Mon, 3 Feb 2025 09:13:48 -0800 Subject: [PATCH 9/9] Update trillium and v5e perf model configs and update test script to use benchmark_runner.py (#590) --- dags/common/model_configs.py | 42 +++++++++++++++++++ .../multipod/maxtext_trillium_configs_perf.py | 18 ++++---- dags/multipod/maxtext_v5e_configs_perf.py | 24 ++++------- 3 files changed, 61 insertions(+), 23 deletions(-) create mode 100644 dags/common/model_configs.py diff --git a/dags/common/model_configs.py b/dags/common/model_configs.py new file mode 100644 index 000000000..5501a541e --- /dev/null +++ b/dags/common/model_configs.py @@ -0,0 +1,42 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common model perf configs""" + +import enum + + +class MaxTextV5eModelConfigs(enum.Enum): + # Refers to model configs in https://github.com/AI-Hypercomputer/maxtext/blob/main/benchmarks/maxtext_v5e_model_configs.py + DEFAULT_16B = "default_16b_v5e_256" + DEFAULT_32B = "default_32b_v5e_256" + DEFAULT_64B = "default_64b_v5e_256" + DEFAULT_128B = "default_128b_v5e_256" + GPT3_175B = "gpt_3_175b_v5e_256" + LLAMA2_7B = "llama2_7b_v5e_256" + LLAMA2_13B = "llama2_13b_v5e_256" + LLAMA2_70B = "llama2_70b_v5e_256" + + +class MaxTextTrilliumModelConfigs(enum.Enum): + # Refers to model configs in https://github.com/AI-Hypercomputer/maxtext/blob/main/benchmarks/maxtext_trillium_model_configs.py + GPT3_175B = "gpt_3_175b" + LLAMA2_70B_4096 = "llama2_70b_4096_synthetic" + LLAMA3_1_8B_8192 = "llama3_1_8b_8192" + LLAMA3_1_70B_8192 = "llama3_1_70b_8192" + LLAMA3_1_70B_129024 = "llama3_1_70b_129024" + LLAMA3_1_405B_8192 = "llama3_1_405b_8192_fsdp_dcn" + MIXTRAL_8X7B_DROPLESS = "mixtral_8x7b_dropless" + MIXTRAL_8X7B_DROPPED = "mixtral_8x7b_dropped" + MIXTRAL_8X7B_DROPPED_INT8 = "mixtral_8x7b_dropped_int8" diff --git a/dags/multipod/maxtext_trillium_configs_perf.py b/dags/multipod/maxtext_trillium_configs_perf.py index c6948f878..c05586c70 100644 --- a/dags/multipod/maxtext_trillium_configs_perf.py +++ b/dags/multipod/maxtext_trillium_configs_perf.py @@ -21,18 +21,17 @@ from dags import composer_env from dags.common import test_owner from dags.common.vm_resource import TpuVersion, Zone, Project, XpkClusters, DockerImage +from dags.common.model_configs import MaxTextTrilliumModelConfigs from dags.multipod.configs import maxtext_sweep_gke_config from dags.multipod.configs.common import SetupMode from xlml.apis import metric_config # Run once a day at 4 am UTC (8 pm PST / 9 pm PDT) SCHEDULED_TIME = "0 4 * * *" if composer_env.is_prod_env() else None -MODEL_CONFIGS = ["gpt3_175b", "llama2_7b_4096", "mixtral_8x7b"] DOCKER_IMAGES = [ (SetupMode.STABLE, DockerImage.MAXTEXT_TPU_JAX_STABLE_STACK), (SetupMode.NIGHTLY, DockerImage.MAXTEXT_TPU_JAX_NIGHTLY), ] -QUANTIZATION_SWEEP = {"M_QUANTIZATION": ["", "int8"]} BASE_OUTPUT_DIRECTORY = "gs://runner-maxtext-logs" with models.DAG( @@ -46,10 +45,15 @@ group_id="Quarantine", dag=dag, prefix_group_id=False ) for mode, image in DOCKER_IMAGES: - for model in MODEL_CONFIGS: + for model in MaxTextTrilliumModelConfigs: base_run_model_cmds = [ - f"bash MaxText/configs/trillium/{model}.sh OUTPUT_PATH={BASE_OUTPUT_DIRECTORY} DATASET_PATH=gs://max-datasets-rogue", + f"python3 benchmarks/benchmark_runner.py on-device --base_output_directory={BASE_OUTPUT_DIRECTORY} --model_name={model.value} --libtpu_type=maxtext-docker --num_steps=15", ] + num_slices = ( + [2] + if model == MaxTextTrilliumModelConfigs.LLAMA3_1_405B_8192 + else [1, 2] + ) maxtext_sweep_gke_test = ( maxtext_sweep_gke_config.get_maxtext_sweep_gke_config( test_owner=test_owner.RAYMOND_Z, @@ -59,11 +63,11 @@ cluster=XpkClusters.TPU_V6E_256_MLPERF_CLUSTER, time_out_in_min=360, base_output_directory=BASE_OUTPUT_DIRECTORY, - num_slices=[1, 2], + num_slices=num_slices, docker_image=image.value, - run_name_prefix=f"maxtext-{model}-{mode.value}", + run_name_prefix=f"maxtext-{model.name.lower()}-{mode.value}", base_run_model_cmds=base_run_model_cmds, - sweep_params=QUANTIZATION_SWEEP, + sweep_params={}, ) ) diff --git a/dags/multipod/maxtext_v5e_configs_perf.py b/dags/multipod/maxtext_v5e_configs_perf.py index 722889423..ab6e665f1 100644 --- a/dags/multipod/maxtext_v5e_configs_perf.py +++ b/dags/multipod/maxtext_v5e_configs_perf.py @@ -21,22 +21,13 @@ from dags import composer_env from dags.common import test_owner from dags.common.vm_resource import TpuVersion, Zone, Project, XpkClusters, DockerImage +from dags.common.model_configs import MaxTextV5eModelConfigs from dags.multipod.configs import maxtext_sweep_gke_config from dags.multipod.configs.common import SetupMode from xlml.apis import metric_config # Run once a day at 4 am UTC (8 pm PST / 9 pm PDT) SCHEDULED_TIME = "0 4 * * *" if composer_env.is_prod_env() else None -MODEL_CONFIGS = [ - "16b", - "32b", - "64b", - "128b", - "gpt3_175b", - "llama2_7b", - "llama2_13b", - "llama2_70b", -] DOCKER_IMAGES = [ (SetupMode.STABLE, DockerImage.MAXTEXT_TPU_JAX_STABLE_STACK), (SetupMode.NIGHTLY, DockerImage.MAXTEXT_TPU_JAX_NIGHTLY), @@ -55,9 +46,10 @@ group_id="Quarantine", dag=dag, prefix_group_id=False ) for mode, image in DOCKER_IMAGES: - for model in MODEL_CONFIGS: + for model in MaxTextV5eModelConfigs: base_run_model_cmds = [ - f"bash MaxText/configs/v5e/{model}.sh OUTPUT_PATH={BASE_OUTPUT_DIRECTORY} DATASET_PATH=gs://max-datasets-rogue", + "bash preflight.sh", + f"python3 benchmarks/benchmark_runner.py on-device --base_output_directory={BASE_OUTPUT_DIRECTORY} --model_name={model.value} --libtpu_type=maxtext-docker --num_steps=15", ] maxtext_sweep_gke_test = ( maxtext_sweep_gke_config.get_maxtext_sweep_gke_config( @@ -70,7 +62,7 @@ base_output_directory=BASE_OUTPUT_DIRECTORY, num_slices=[1, 2], docker_image=image.value, - run_name_prefix=f"maxtext-{model}-{mode.value}", + run_name_prefix=f"maxtext-{model.name.lower()}-{mode.value}", base_run_model_cmds=base_run_model_cmds, sweep_params=QUANTIZATION_SWEEP, ) @@ -103,9 +95,9 @@ group_id="Quarantine", dag=dag, prefix_group_id=False ) for mode, image in DOCKER_IMAGES: - for model in MODEL_CONFIGS: + for model in MaxTextV5eModelConfigs: base_run_model_cmds = [ - f"bash MaxText/configs/v5e/{model}.sh OUTPUT_PATH={BASE_OUTPUT_DIRECTORY} DATASET_PATH=gs://max-datasets-rogue RUN_PREFLIGHT=false", + f"python3 benchmarks/benchmark_runner.py on-device --base_output_directory={BASE_OUTPUT_DIRECTORY} --model_name={model.value} --libtpu_type=maxtext-docker --num_steps=15", ] maxtext_sweep_gke_test = ( maxtext_sweep_gke_config.get_maxtext_sweep_gke_config( @@ -118,7 +110,7 @@ base_output_directory=BASE_OUTPUT_DIRECTORY, num_slices=[1, 2], docker_image=image.value, - run_name_prefix=f"p-maxtext-{model}-{mode.value}", + run_name_prefix=f"p-maxtext-{model.name.lower()}-{mode.value}", base_run_model_cmds=base_run_model_cmds, sweep_params=QUANTIZATION_SWEEP, )