From 4a5fd5dc48d541fc68c3d6efe657f85ed1e411ac Mon Sep 17 00:00:00 2001 From: Andrew S Date: Tue, 18 Mar 2025 11:54:55 -0700 Subject: [PATCH] Updated TPU unittests to cover more TPU types / JAX versions - Expanded test to v5e, v5p and JAX 0.5.1 and nightly - Allow for different projects/networks for test VMs - Merges tests into existing project_bite_tpu_e2e DAG to minimize number of DAG entries --- .../sparsity_diffusion_devx/configs/common.py | 14 +++ .../configs/project_bite_config.py | 67 ++++++++--- .../project_bite_tpu_e2e.py | 109 ++++++++++++------ 3 files changed, 139 insertions(+), 51 deletions(-) diff --git a/dags/sparsity_diffusion_devx/configs/common.py b/dags/sparsity_diffusion_devx/configs/common.py index 0f273c101..c9bc36523 100644 --- a/dags/sparsity_diffusion_devx/configs/common.py +++ b/dags/sparsity_diffusion_devx/configs/common.py @@ -32,3 +32,17 @@ def set_up_nightly_jax() -> Tuple[str]: ), "pip install git+https://github.com/google/jax", ) + + +def set_up_jax_version(version) -> Tuple[str]: + return ( + ( + f"pip install jax[tpu]=={version} -f " + "https://storage.googleapis.com/jax-releases/libtpu_releases.html" + ), + ( + f"pip install --pre jaxlib=={version} -f" + " https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html" + ), + f"pip install jax=={version}", + ) diff --git a/dags/sparsity_diffusion_devx/configs/project_bite_config.py b/dags/sparsity_diffusion_devx/configs/project_bite_config.py index 0573b8ddb..9c2072f2d 100644 --- a/dags/sparsity_diffusion_devx/configs/project_bite_config.py +++ b/dags/sparsity_diffusion_devx/configs/project_bite_config.py @@ -94,18 +94,21 @@ def get_bite_tpu_config( ) -def get_bite_tpu_unittests_config( - tpu_version: TpuVersion, - tpu_cores: int, - tpu_zone: str, - runtime_version: str, - time_out_in_min: int, - task_owner: str, - is_tpu_reserved: bool = False, - pinned_version: Optional[str] = None, -): - unittest_setupcmds = ( - # create configuration files needed +def dockerfile_build_cmd(jax_version): + # Generate pip commands to install certain version of JAX/libTPU e.g. + # pip install --pre jaxlib==0.5.1 -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html + # pip install jax[tpu]==0.5.1 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + # pip install jax==0.5.1 + if jax_version: + pip_tpu_jax_install = "\n".join( + ["RUN " + x for x in common.set_up_jax_version(jax_version)] + ) + else: + pip_tpu_jax_install = "\n".join( + ["RUN " + x for x in common.set_up_nightly_jax()] + ) + + return ( """cat > Dockerfile_CI < run_tpu_tests.sh <