CI - Wheel Tests (Continuous) #1571
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# CI - Wheel Tests (Continuous) | |
# | |
# This workflow builds JAX artifacts and runs CPU/CUDA tests. | |
# | |
# It orchestrates the following: | |
# 1. build-jaxlib-artifact: Calls the `build_artifacts.yml` workflow to build jaxlib and | |
# uploads it to a GCS bucket. | |
# 2. run-pytest-cpu: Calls the `pytest_cpu.yml` workflow which downloads the jaxlib wheel | |
# that was built in the previous step and runs CPU tests. | |
# 3. build-cuda-artifacts: Calls the `build_artifacts.yml` workflow to build CUDA artifacts and | |
# uploads them to a GCS bucket. | |
# 4. run-bazel-test-cpu-py-import: Calls the `bazel_cpu_py_import_rbe.yml` workflow which | |
# runs Bazel CPU tests with py_import on RBE. | |
# 5. run-pytest-cuda: Calls the `pytest_cuda.yml` workflow which downloads the jaxlib and CUDA | |
# artifacts that were built in the previous steps and runs the CUDA tests. | |
# 6. run-bazel-test-cuda: Calls the `bazel_cuda_non_rbe.yml` workflow which downloads the jaxlib | |
# and CUDA artifacts that were built in the previous steps and runs the | |
# CUDA tests using Bazel. | |
name: CI - Wheel Tests (Continuous) | |
permissions: | |
contents: read | |
on: | |
schedule: | |
- cron: "0 */3 * * *" # Run once every 3 hours | |
workflow_dispatch: # allows triggering the workflow run manually | |
concurrency: | |
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} | |
cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }} | |
jobs: | |
build-jax-artifact: | |
uses: ./.github/workflows/build_artifacts.yml | |
name: "Build jax artifact" | |
with: | |
# Note that since jax is a pure python package, the runner OS and Python values do not | |
# matter. In addition, cloning main XLA also has no effect. | |
runner: "linux-x86-n2-16" | |
artifact: "jax" | |
upload_artifacts_to_gcs: true | |
gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' | |
build-jaxlib-artifact: | |
uses: ./.github/workflows/build_artifacts.yml | |
strategy: | |
fail-fast: false # don't cancel all jobs on failure | |
matrix: | |
# Runner OS and Python values need to match the matrix stategy in the CPU tests job | |
runner: ["linux-x86-n2-16", "linux-arm64-t2a-48", "windows-x86-n2-16"] | |
artifact: ["jaxlib"] | |
python: ["3.11"] | |
# Note: For reasons unknown, Github actions groups jobs with the same top-level name in the | |
# dashboard only if we use an expression in the "name" field. Otherwise, it appends the matrix | |
# values to the name and creates a separate entry for each matrix combination. | |
name: "Build ${{ format('{0}', 'jaxlib') }} artifacts" | |
with: | |
runner: ${{ matrix.runner }} | |
artifact: ${{ matrix.artifact }} | |
python: ${{ matrix.python }} | |
clone_main_xla: 1 | |
upload_artifacts_to_gcs: true | |
gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' | |
build-cuda-artifacts: | |
uses: ./.github/workflows/build_artifacts.yml | |
strategy: | |
fail-fast: false # don't cancel all jobs on failure | |
matrix: | |
# Python values need to match the matrix stategy in the CUDA tests job below | |
runner: ["linux-x86-n2-16"] | |
artifact: ["jax-cuda-plugin", "jax-cuda-pjrt"] | |
python: ["3.11",] | |
name: "Build ${{ format('{0}', 'CUDA') }} artifacts" | |
with: | |
runner: ${{ matrix.runner }} | |
artifact: ${{ matrix.artifact }} | |
python: ${{ matrix.python }} | |
clone_main_xla: 1 | |
upload_artifacts_to_gcs: true | |
gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' | |
run-pytest-cpu: | |
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated | |
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we | |
# still want to run the tests for other platforms. | |
if: ${{ !cancelled() }} | |
needs: [build-jax-artifact, build-jaxlib-artifact] | |
uses: ./.github/workflows/pytest_cpu.yml | |
strategy: | |
fail-fast: false # don't cancel all jobs on failure | |
matrix: | |
# Runner OS and Python values need to match the matrix stategy in the | |
# build_jaxlib_artifact job above | |
runner: ["linux-x86-n2-64", "linux-arm64-t2a-48", "windows-x86-n2-64"] | |
python: ["3.11",] | |
enable-x64: [1, 0] | |
name: "Pytest CPU (JAX artifacts version = ${{ format('{0}', 'head') }})" | |
with: | |
runner: ${{ matrix.runner }} | |
python: ${{ matrix.python }} | |
enable-x64: ${{ matrix.enable-x64 }} | |
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }} | |
run-pytest-cuda: | |
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated | |
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we | |
# still want to run the tests for other platforms. | |
if: ${{ !cancelled() }} | |
needs: [build-jax-artifact, build-jaxlib-artifact, build-cuda-artifacts] | |
uses: ./.github/workflows/pytest_cuda.yml | |
strategy: | |
fail-fast: false # don't cancel all jobs on failure | |
matrix: | |
# Python values need to match the matrix stategy in the artifact build jobs above | |
# See exlusions for what is fully tested | |
runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu", "linux-x86-a4-224-b200-1gpu"] | |
python: ["3.11",] | |
cuda: [ | |
{version: "12.1", use-nvidia-pip-wheels: false}, | |
{version: "12.8", use-nvidia-pip-wheels: true}, | |
] | |
enable-x64: [1, 0] | |
exclude: | |
# H100 runs only a single config, CUDA 12.8 Enable x64 1 | |
- runner: "linux-x86-a3-8g-h100-8gpu" | |
cuda: | |
version: "12.1" | |
- runner: "linux-x86-a3-8g-h100-8gpu" | |
enable-x64: "0" | |
# B200 runs only a single config, CUDA 12.8 Enable x64 1 | |
- runner: "linux-x86-a4-224-b200-1gpu" | |
cuda: | |
version: "12.1" | |
- runner: "linux-x86-a4-224-b200-1gpu" | |
enable-x64: "0" | |
name: "Pytest CUDA (JAX artifacts version = ${{ format('{0}', 'head') }}, CUDA Pip packages = ${{ matrix.cuda.use-nvidia-pip-wheels }})" | |
with: | |
runner: ${{ matrix.runner }} | |
python: ${{ matrix.python }} | |
cuda-version: ${{ matrix.cuda.version }} | |
use-nvidia-pip-wheels: ${{ matrix.cuda.use-nvidia-pip-wheels }} | |
enable-x64: ${{ matrix.enable-x64 }} | |
# GCS upload URI is the same for both artifact build jobs | |
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }} | |
run-bazel-test-cpu-py-import: | |
uses: ./.github/workflows/bazel_cpu_py_import_rbe.yml | |
strategy: | |
fail-fast: false # don't cancel all jobs on failure | |
matrix: | |
runner: ["linux-x86-n2-16", "linux-arm64-t2a-48"] | |
python: ["3.11",] | |
enable-x64: [1, 0] | |
name: "Bazel CPU tests with ${{ format('{0}', 'py_import') }}" | |
with: | |
runner: ${{ matrix.runner }} | |
python: ${{ matrix.python }} | |
enable-x64: ${{ matrix.enable-x64 }} | |
run-bazel-test-cuda: | |
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated | |
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we | |
# still want to run the tests for other platforms. | |
if: ${{ !cancelled() }} | |
needs: [build-jax-artifact, build-jaxlib-artifact, build-cuda-artifacts] | |
uses: ./.github/workflows/bazel_cuda_non_rbe.yml | |
strategy: | |
fail-fast: false # don't cancel all jobs on failure | |
matrix: | |
# Python values need to match the matrix stategy in the build artifacts job above | |
runner: ["linux-x86-g2-48-l4-4gpu",] | |
python: ["3.11",] | |
jaxlib-version: ["head", "pypi_latest"] | |
enable-x64: [1, 0] | |
name: "Bazel CUDA Non-RBE (jax version = ${{ format('{0}', 'head') }})" | |
with: | |
runner: ${{ matrix.runner }} | |
python: ${{ matrix.python }} | |
enable-x64: ${{ matrix.enable-x64 }} | |
jaxlib-version: ${{ matrix.jaxlib-version }} | |
# GCS upload URI is the same for both artifact build jobs | |
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }} | |
run-pytest-tpu: | |
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated | |
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we | |
# still want to run the tests for other platforms. | |
if: ${{ !cancelled() }} | |
needs: [build-jax-artifact, build-jaxlib-artifact] | |
uses: ./.github/workflows/pytest_tpu.yml | |
strategy: | |
fail-fast: false # don't cancel all jobs on failure | |
matrix: | |
python: ["3.11",] | |
tpu-specs: [ | |
# {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available | |
{type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"}, | |
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}, | |
{type: "v6e-8", cores: "8", runner: "linux-x86-ct6e-180-8tpu"} | |
] | |
libtpu-version-type: ["nightly", "oldest_supported_libtpu"] | |
exclude: | |
# Run a single config for oldest_supported_libtpu | |
- libtpu-version-type: "oldest_supported_libtpu" | |
tpu-specs: | |
type: "v4-8" | |
- libtpu-version-type: "oldest_supported_libtpu" | |
tpu-specs: | |
type: "v6e-8" | |
name: "Pytest TPU (JAX artifacts version = ${{ format('{0}', 'head') }})" | |
with: | |
runner: ${{ matrix.tpu-specs.runner }} | |
cores: ${{ matrix.tpu-specs.cores }} | |
tpu-type: ${{ matrix.tpu-specs.type }} | |
python: ${{ matrix.python }} | |
run-full-tpu-test-suite: "1" | |
libtpu-version-type: ${{ matrix.libtpu-version-type }} | |
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }} |