diff --git a/dags/sparsity_diffusion_devx/configs/common.py b/dags/sparsity_diffusion_devx/configs/common.py index 17a82a6a..833c9864 100644 --- a/dags/sparsity_diffusion_devx/configs/common.py +++ b/dags/sparsity_diffusion_devx/configs/common.py @@ -34,3 +34,17 @@ def set_up_nightly_jax() -> Tuple[str]: ), "pip install git+https://github.com/google/jax", ) + + +def set_up_jax_version(version) -> Tuple[str]: + return ( + ( + f"pip install jax[tpu]=={version} -f " + "https://storage.googleapis.com/jax-releases/libtpu_releases.html" + ), + ( + f"pip install --pre jaxlib=={version} -f" + " https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html" + ), + f"pip install jax=={version}", + ) diff --git a/dags/sparsity_diffusion_devx/configs/project_bite_config.py b/dags/sparsity_diffusion_devx/configs/project_bite_config.py index 0002a548..ebbbf71a 100644 --- a/dags/sparsity_diffusion_devx/configs/project_bite_config.py +++ b/dags/sparsity_diffusion_devx/configs/project_bite_config.py @@ -98,18 +98,21 @@ def get_bite_tpu_config( ) -def get_bite_tpu_unittests_config( - tpu_version: TpuVersion, - tpu_cores: int, - tpu_zone: str, - runtime_version: str, - time_out_in_min: int, - task_owner: str, - is_tpu_reserved: bool = False, - pinned_version: Optional[str] = None, -): - unittest_setupcmds = ( - # create configuration files needed +def dockerfile_build_cmd(jax_version): + # Generate pip commands to install certain version of JAX/libTPU e.g. + # pip install --pre jaxlib==0.5.1 -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html + # pip install jax[tpu]==0.5.1 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + # pip install jax==0.5.1 + if jax_version: + pip_tpu_jax_install = "\n".join( + ["RUN " + x for x in common.set_up_jax_version(jax_version)] + ) + else: + pip_tpu_jax_install = "\n".join( + ["RUN " + x for x in common.set_up_nightly_jax()] + ) + + return ( """cat > Dockerfile_CI < run_tpu_tests.sh <