Skip to content

Commit 580237c

Browse files
committed
[nnx] add flaxlib
1 parent 9eb0a61 commit 580237c

File tree

15 files changed

+575
-34
lines changed

15 files changed

+575
-34
lines changed

.github/workflows/pythonpublish.yml renamed to .github/workflows/flax_publish.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# This workflows will upload a Python Package using Twine when a release is created
22
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
33

4-
name: Upload Python Package
4+
name: Flax - Build and upload to PyPI
55

66
on:
77
release:
8-
types: [created]
8+
types: [published]
99

1010
jobs:
1111
deploy:

.github/workflows/build.yml renamed to .github/workflows/flax_test.yml

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
22
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
33

4-
name: Build
4+
name: Flax - Test
55

66
on:
77
push:
@@ -70,7 +70,7 @@ jobs:
7070
uses: actions/setup-python@v4
7171
with:
7272
python-version: ${{ matrix.python-version }}
73-
- uses: yezz123/setup-uv@v4
73+
- uses: astral-sh/setup-uv@v2
7474
with:
7575
uv-version: "0.3.0"
7676
- name: Install standalone dependencies only
@@ -104,23 +104,17 @@ jobs:
104104
uses: actions/setup-python@v4
105105
with:
106106
python-version: ${{ matrix.python-version }}
107-
- uses: yezz123/setup-uv@v4
107+
- name: Setup uv
108+
uses: astral-sh/setup-uv@v2
108109
with:
109-
uv-version: "0.3.0"
110-
- name: Cached virtual environment
111-
id: venv_cache
112-
uses: actions/cache@v3
113-
with:
114-
path: .venv
115-
key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('uv.lock') }}
116-
- name: Install Dependencies for cache
117-
if: steps.venv_cache.outputs.cache-hit != 'true'
118-
run: |
119-
if [ -d ".venv" ]; then rm -rf .venv; fi
120-
uv sync --locked --all-extras
121-
- name: Check lockfile
110+
version: "0.3.0"
111+
- name: Setup Rust (flaxlib)
112+
uses: actions-rust-lang/setup-rust-toolchain@v1
113+
114+
- name: Install dependencies
122115
run: |
123-
uv sync --locked --all-extras
116+
uv sync --locked --extra all --extra testing --extra docs
117+
uv pip install ./flaxlib
124118
- name: Install JAX
125119
run: |
126120
if [[ "${{ matrix.jax-version }}" == "newest" ]]; then

.github/workflows/flaxlib_publish.yml

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
name: Flaxlib - Build and upload to PyPI
2+
3+
# for testing only:
4+
on:
5+
push:
6+
branches: [main]
7+
paths: ['flaxlib/**']
8+
release:
9+
types: [published]
10+
11+
jobs:
12+
build_wheels:
13+
if: github.event_name == 'push' && contains(github.event.head_commit.modified, 'flaxlib/')
14+
name: Build wheels on ${{ matrix.os }}
15+
runs-on: ${{ matrix.os }}
16+
strategy:
17+
matrix:
18+
# macos-13 is an intel runner, macos-14 is apple silicon
19+
os: [ubuntu-latest, windows-latest, macos-13, macos-14]
20+
21+
steps:
22+
- uses: actions/checkout@v4
23+
24+
- uses: actions/setup-python@v5
25+
26+
- name: Setup Rust
27+
uses: actions-rust-lang/setup-rust-toolchain@v1
28+
29+
- name: Install cibuildwheel
30+
run: python -m pip install cibuildwheel==2.21.0
31+
32+
- name: Build wheels
33+
run: python -m cibuildwheel --output-dir ./flaxlib/wheelhouse ./flaxlib
34+
env:
35+
# rust doesn't seem to be available for musl linux on i686
36+
CIBW_SKIP: "*-musllinux_i686"
37+
CIBW_ENVIRONMENT: 'PATH="$HOME/.cargo/bin:$PATH" CARGO_TERM_COLOR="always"'
38+
CIBW_ENVIRONMENT_WINDOWS: 'PATH="$UserProfile\.cargo\bin;$PATH"'
39+
CIBW_BEFORE_BUILD: rustup show
40+
CIBW_BEFORE_BUILD_LINUX: |
41+
curl https://sh.rustup.rs -sSf | sh -s -- --default-toolchain=stable --profile=minimal -y &&
42+
rustup show
43+
44+
- uses: actions/upload-artifact@v4
45+
with:
46+
name: cibw-wheels-${{ matrix.os }}-${{ strategy.job-index }}
47+
path: ./flaxlib/wheelhouse/*.whl
48+
49+
build_sdist:
50+
if: github.event_name == 'push' && contains(github.event.head_commit.modified, 'flaxlib/')
51+
name: Build source distribution
52+
runs-on: ubuntu-latest
53+
steps:
54+
- uses: actions/checkout@v4
55+
56+
- name: Setup Rust
57+
uses: actions-rust-lang/setup-rust-toolchain@v1
58+
59+
- name: Build sdist
60+
run: pipx run build --sdist flaxlib
61+
62+
- uses: actions/upload-artifact@v4
63+
with:
64+
name: cibw-sdist
65+
path: ./flaxlib/dist/*.tar.gz
66+
67+
upload_pypi:
68+
if: github.event_name == 'push' && contains(github.event.head_commit.modified, 'flaxlib/')
69+
name: Upload to PyPI
70+
needs: [build_wheels, build_sdist]
71+
runs-on: ubuntu-latest
72+
permissions:
73+
id-token: write
74+
steps:
75+
- uses: actions/setup-python@v1
76+
with:
77+
python-version: '3.x'
78+
- name: Install dependencies
79+
run: |
80+
python -m pip install --upgrade pip
81+
pip install setuptools build wheel twine
82+
- uses: actions/download-artifact@v4
83+
with:
84+
# unpacks all CIBW artifacts into dist/
85+
pattern: cibw-*
86+
path: ./flaxlib/dist
87+
merge-multiple: true
88+
89+
- name: Build and publish
90+
env:
91+
TWINE_USERNAME: __token__
92+
TWINE_PASSWORD: ${{ secrets.FLAXLIB_PYPI_TOKEN }}
93+
run: |
94+
twine upload flaxlib/dist/*

flax/nnx/graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
AuxData = tp.TypeVar('AuxData')
5252

5353
StateLeaf = VariableState[tp.Any]
54-
NodeLeaf = VariableState[tp.Any]
54+
NodeLeaf = Variable[tp.Any]
5555
GraphState = State[Key, StateLeaf]
5656
GraphFlatState = FlatState[StateLeaf]
5757

flaxlib/.gitignore

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
/target
2+
3+
# Byte-compiled / optimized / DLL files
4+
__pycache__/
5+
.pytest_cache/
6+
*.py[cod]
7+
8+
# C extensions
9+
*.so
10+
11+
# Distribution / packaging
12+
.Python
13+
.venv/
14+
env/
15+
bin/
16+
build/
17+
develop-eggs/
18+
dist/
19+
eggs/
20+
lib/
21+
lib64/
22+
parts/
23+
sdist/
24+
var/
25+
include/
26+
man/
27+
venv/
28+
*.egg-info/
29+
.installed.cfg
30+
*.egg
31+
32+
# Installer logs
33+
pip-log.txt
34+
pip-delete-this-directory.txt
35+
pip-selfcheck.json
36+
37+
# Unit test / coverage reports
38+
htmlcov/
39+
.tox/
40+
.coverage
41+
.cache
42+
nosetests.xml
43+
coverage.xml
44+
45+
# Translations
46+
*.mo
47+
48+
# Mr Developer
49+
.mr.developer.cfg
50+
.project
51+
.pydevproject
52+
53+
# Rope
54+
.ropeproject
55+
56+
# Django stuff:
57+
*.log
58+
*.pot
59+
60+
.DS_Store
61+
62+
# Sphinx documentation
63+
docs/_build/
64+
65+
# PyCharm
66+
.idea/
67+
68+
# VSCode
69+
.vscode/
70+
71+
# Pyenv
72+
.python-version
73+
74+
# cibuildwheel
75+
/wheelhouse

0 commit comments

Comments
 (0)