Skip to content

CI - jaxlib head with NumPy nightly #48

CI - jaxlib head with NumPy nightly

CI - jaxlib head with NumPy nightly #48

Workflow file for this run

name: CI - jaxlib head with NumPy nightly
# This workflow is used to build and test against NumPy nightly releases. We build ml_dtypes from
# HEAD using the NumPy nightly ABI, then build jaxlib at head, and then finally run tests against
# NumPy nightly.
on:
workflow_dispatch:
inputs:
halt-for-connection:
description: 'Should this workflow run wait for a remote connection?'
type: choice
required: true
default: 'no'
options:
- 'yes'
- 'no'
schedule:
- cron: "0 */3 * * *" # Run once every 3 hours
permissions: {}
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
# Don't cancel in-progress jobs for main/release branches.
cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }}
jobs:
test-nightly-numpy:
defaults:
run:
shell: bash
runs-on: "linux-x86-n2-64"
strategy:
matrix:
python: ["3.13",]
container: "us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest"
name: "CI - jaxlib head with NumPy nightly"
env:
JAXCI_HERMETIC_PYTHON_VERSION: "${{ matrix.python }}"
JAXCI_PYTHON: "python${{ matrix.python }}"
JAXCI_BUILD_ARTIFACT_WITH_RBE: 1
JAXCI_CLONE_MAIN_XLA: 1
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
persist-credentials: false
- name: Checkout ml_dtypes
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683
with:
repository: jax-ml/ml_dtypes
ref: main
path: ml_dtypes
persist-credentials: false
# Halt for testing
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c
with:
halt-dispatch-input: ${{ inputs.halt-for-connection }}
- name: Install numpy & scipy development versions
run: |
"$JAXCI_PYTHON" -m uv pip install \
--system \
-i https://pypi.anaconda.org/scientific-python-nightly-wheels/simple \
--no-deps \
--pre \
--upgrade \
numpy \
scipy
"$JAXCI_PYTHON" -c "import numpy; print(f'{numpy.__version__=}')"
- name: Build ml_dtypes with NumPy nightly
run: |
pushd ml_dtypes
git submodule init
git submodule update
"$JAXCI_PYTHON" -m uv pip install . --no-build-isolation
popd
- name: Build jax at HEAD
run: ./ci/build_artifacts.sh jax
- name: Build jaxlib at HEAD
run: ./ci/build_artifacts.sh jaxlib
- name: Install test dependencies
run: $JAXCI_PYTHON -m uv pip install -r build/test-requirements.txt
- name: Run Pytest CPU tests
timeout-minutes: 30
run: ./ci/run_pytest_cpu.sh