Skip to content

CI: 04/09/25 upstream sync #351

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 588 commits into
base: rocm-main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
588 commits
Select commit Hold shift + click to select a range
12526ea
[jaxlib] Pack/unpack subbyte types to/from numpy arrays to support in…
danielsuo Mar 31, 2025
b3d851d
Add Jax tracing micro benchmarks.
danielsuo Mar 31, 2025
95497ca
Remove legacy GPU kernel for LU decomposition.
dfm Mar 31, 2025
6b71949
[pallas:mosaic_gpu] Fixed lane-level lowering of `lax.optimization_ba…
superbobry Mar 31, 2025
200f826
[array api] return all devices in devices()
jakevdp Mar 31, 2025
e2df374
Merge pull request #27616 from jakevdp:array-api-devices
Google-ML-Automation Mar 31, 2025
aaa3ebf
Add optimization barrier.
wenscarl Mar 31, 2025
05039fe
Bump tsickert/discord-webhook from 5.3.0 to 7.0.0
dependabot[bot] Mar 31, 2025
5d69e6b
Bump actions/setup-python from 5.4.0 to 5.5.0
dependabot[bot] Mar 31, 2025
1355e7c
AutoPGLE: force-disable graphs less
olupton Mar 5, 2025
d6b4fed
Propagate sharding and vma rule for axis_index_p. There's no need for…
yashk2810 Mar 31, 2025
e2ee262
Merge pull request #27621 from jax-ml:dependabot/github_actions/actio…
Google-ML-Automation Mar 31, 2025
8cda2a2
[Mosaic-GPU] [2/3] Add NVSHMEM support to Mosaic-GPU custom call
nvcastet Mar 20, 2025
ca36047
__jax_array__: add support in jnp.reshape, jnp.transpose, jnp.matrix_…
jakevdp Mar 31, 2025
f59f615
Minor docstring updates for AOT wrappers in error checking
ayaka14732 Mar 31, 2025
5c35454
Merge pull request #27627 from jakevdp:transpose-jax-array
Google-ML-Automation Mar 31, 2025
4003e2d
jnp.power: support __jax_array__ on inputs
jakevdp Mar 27, 2025
16b2b91
Merge pull request #27540 from jakevdp:pow-jax-array
Google-ML-Automation Apr 1, 2025
994af3e
[Pallas TPU] Remove forward compatibility code for float -> signed co…
apaszke Apr 1, 2025
006a6a6
[Easy] Make pallas mesh grid handling more resilient to tuple names.
Google-ML-Automation Apr 1, 2025
6adb728
Clarify documentation of jnp.heaviside
LouisJustinTALLOT Apr 1, 2025
5d1bc00
Update XLA dependency to use revision
Google-ML-Automation Apr 1, 2025
40a3d0c
Create the test targets for the wheel size verification.
Google-ML-Automation Apr 1, 2025
76271d6
Add scan_p and cond_p vma rule.
yashk2810 Apr 1, 2025
a34c462
jnp.select: support __jax_array__ for inputs
jakevdp Apr 1, 2025
a80f627
make random_gamma_grad not a primitive anymore
mattjj Apr 1, 2025
126ef3d
Merge pull request #27620 from jax-ml:dependabot/github_actions/tsick…
Google-ML-Automation Apr 1, 2025
2d2be0b
Update permisisons community_release_actions.yml
ZacCranko Apr 1, 2025
28a1c9a
Merge pull request #27647 from ZacCranko:main
Google-ML-Automation Apr 1, 2025
f4c727a
Merge pull request #26964 from olupton:auto-pgle-with-graphs
Google-ML-Automation Apr 1, 2025
5370ac2
Remove the try/except for Shardy imports.
belitskiy Apr 1, 2025
0b199f4
[jaxlib] Roll back subbyte types due to failing asan tests.
danielsuo Apr 1, 2025
efd621a
Merge pull request #27643 from jakevdp:select-jax-array
Google-ML-Automation Apr 1, 2025
7b04a79
jnp.einsum: add support for __jax_array__
jakevdp Apr 1, 2025
2eb5c8a
Merge pull request #27648 from jakevdp:einsum-jax-array
Google-ML-Automation Apr 1, 2025
4908b2f
cumulative reductions: support __jax_array__ on inputs
jakevdp Apr 1, 2025
05269a8
[mutable-arrays] add vmap rule for mutable_array_p, very basic test
mattjj Apr 1, 2025
ff5a2e8
Enable test_scan_offload in memories_test.
Google-ML-Automation Apr 1, 2025
e75c052
Merge pull request #26674 from nvcastet:split_distributed_gpu_pallas_2
Google-ML-Automation Apr 1, 2025
747c580
Merge pull request #27632 from LouisJustinTALLOT:patch-1
Google-ML-Automation Apr 1, 2025
9bb899d
Merge pull request #27651 from mattjj:mutable-array-vmap
Google-ML-Automation Apr 1, 2025
880884d
Merge pull request #27086 from Amir-19:tma_reduction
Google-ML-Automation Apr 1, 2025
f139192
Add OOB checks to jax.numpy array indexing
ayaka14732 Apr 2, 2025
1875c76
let XLA metadata be unset in nested dynamic scopes
froystig Apr 2, 2025
6fe6d80
upgrade docs from `jax.core` to `jax.extend.core` where needed to fix…
froystig Apr 2, 2025
8e2c1a1
Updates for 3.14
vfdev-5 Jan 16, 2025
076d021
[better_errors] Fix the handling of kwargs for debug_info.
gnecula Mar 26, 2025
398e8b0
Merge pull request #27644 from jakevdp:cumulative-jax-array
Google-ML-Automation Apr 2, 2025
2e8ea62
Merge pull request #27463 from gnecula:debug_info_fix_kwargs
Google-ML-Automation Apr 2, 2025
306059a
Merge pull request #27662 from froystig:docfixes
Google-ML-Automation Apr 2, 2025
5556981
Merge pull request #27628 from mattjj:random-gamma-grad-no-more-primi…
Google-ML-Automation Apr 2, 2025
82ec573
Remove nanobind pin now that nanobind fix landed.
danielsuo Apr 2, 2025
735cec1
[jaxlib] Fix asan tests for subbyte types in CPU/GPU callbacks.
danielsuo Apr 2, 2025
45d577d
Prepare for disallowing `jnp.array(None)`
superbobry Apr 2, 2025
0bee42b
Update XLA dependency to use revision
Google-ML-Automation Apr 2, 2025
10b2cda
Relax the aval check in `select_hlo_lowering_opaque` to only check fo…
yashk2810 Apr 2, 2025
6242ffb
Remove unused Attrs from `lu_pivots_to_permutation` FFI kernel.
dfm Apr 2, 2025
297a4f4
docs: compilation_cache_expect_pgle option
olupton Apr 1, 2025
c18139b
Remove legacy GPU kernels for QR decomposition.
dfm Apr 2, 2025
5768432
Disabled default env var LOBPCG_EMIT_DEBUG_PLOTS=1
vfdev-5 Apr 2, 2025
2e16367
Remove the extra stack frame that was introduce in uniform due to dro…
yashk2810 Apr 2, 2025
3aeabae
jnp.isinf & friends: support __jax_array__
jakevdp Apr 1, 2025
c281907
Merge pull request #27389 from vfdev-5:add-tsan-ft-ci-job-314
Google-ML-Automation Apr 2, 2025
aeaa9e2
Merge pull request #27671 from vfdev-5:disable-env-var-LOBPCG_EMIT_DE…
Google-ML-Automation Apr 2, 2025
2a24b40
Bump actions/cache from 4.2.0 to 4.2.3
dependabot[bot] Apr 2, 2025
3ffc604
Merge pull request #27650 from jakevdp:isinf-jax-array
Google-ML-Automation Apr 2, 2025
056c976
Merge pull request #27660 from froystig:xla-meta-ctx
Google-ML-Automation Apr 2, 2025
3d70fc8
Add pbroadcast insertion for `psum_p` in the traceable. This effectiv…
yashk2810 Apr 2, 2025
e2f08d3
Merge pull request #27379 from jax-ml:dependabot/github_actions/actio…
Google-ML-Automation Apr 2, 2025
92f7aea
Add simple vmap support for lax.ragged_all_to_all.
ghpvnist Apr 2, 2025
a442fec
Fix custom_transpose when composed with custom_jvp and use_direct_lin…
dfm Apr 2, 2025
7f4e8c5
jnp.concat and friends: support __jax_array__
jakevdp Apr 2, 2025
a2d62e2
[array_api] update array_api_version to 2024.12
jakevdp Apr 2, 2025
83fc5d9
Merge pull request #27678 from jakevdp:concat-jax-array
Google-ML-Automation Apr 2, 2025
bff0fa1
Support `conv` `unfused_flops` in roofline.
zacmustin Apr 2, 2025
e75b664
Merge pull request #27680 from jakevdp:array-api-version
Google-ML-Automation Apr 2, 2025
96780f1
jax.numpy: support __jax_array__ in several more functions
jakevdp Apr 2, 2025
9c58a11
`jnp.array` no longer accepts None
superbobry Apr 2, 2025
c8273d7
Merge pull request #24197 from yhtang:add-k8s-ci
Google-ML-Automation Apr 2, 2025
9fa5de7
[pallas] Removed `pl.device_id`. Use `lax.axis_index` instead.
superbobry Apr 2, 2025
5ddec65
Remove asserts
kaixih Apr 3, 2025
2540fcd
add an `out_sharding` option to `jax.random.bits`
froystig Apr 2, 2025
ffbd5ef
Merge pull request #27677 from dfm:dir-lin-custom-transpose
Google-ML-Automation Apr 3, 2025
aa06e16
Merge pull request #27687 from froystig:out-shard-bits
Google-ML-Automation Apr 3, 2025
2f61763
use common `maybe_auto_axes` helper in `random.uniform`
froystig Apr 3, 2025
ab816ed
add an `out_sharding` option to `jax.random.randint`
froystig Apr 3, 2025
f1adec3
[Mosaic GPU] Define the `mosaic_gpu.custom_primitive` dialect op.
bchetioui Apr 3, 2025
45e6808
Merge pull request #27084 from danielsuo:switch-fwd
Google-ML-Automation Apr 3, 2025
862342d
Merge pull request #27688 from froystig:out-shard-randint
Google-ML-Automation Apr 3, 2025
6243ac8
[CI] Enable nightly TPU CI tests for v6e.
MichaelHudgins Apr 3, 2025
ea196da
[pallas:mosaic_gpu] Slightly reworded the docstrings for a few recent…
superbobry Apr 3, 2025
552eea8
[pallas:mosaic_gpu] `emit_pipeline*` now passes the loop indices into…
superbobry Apr 3, 2025
0ec1251
[Mosaic GPU] Get rid of `LayoutAttr` and related comments.
bchetioui Apr 3, 2025
bd92208
Update XLA dependency to use revision
Google-ML-Automation Apr 3, 2025
8d59902
Fix problem finding clang++ when building JAX via build.py on windows.
hawkinsp Apr 3, 2025
68fce71
Merge pull request #27635 from olupton:pgle-docs
Google-ML-Automation Apr 3, 2025
91b0884
Restrict the regex for copying the wheels.
Google-ML-Automation Apr 3, 2025
1941714
[export] Add support for override_lowering_rules to jax.export.
gnecula Apr 3, 2025
f2f9152
Moved the `jax.Array` baseclass to C++
superbobry Apr 3, 2025
e7a5147
Bump up tolerance in ShardMapSystematicTest.test_vmap_closure for GPUs.
belitskiy Apr 3, 2025
d1009a3
Only trigger K8s CI on changes to cluster config and distributed init…
yhtang Apr 3, 2025
42735d0
Not to use dynamic grid in the ragged paged attention Pallas kernel.
Google-ML-Automation Apr 3, 2025
24fef3d
Merge pull request #27304 from wenscarl:nvfp4_grad_ste
Google-ML-Automation Apr 3, 2025
0039d13
Merge pull request #27698 from hawkinsp:win
Google-ML-Automation Apr 3, 2025
51716f6
Merge pull request #27708 from yhtang:add-k8s-ci-scope
Google-ML-Automation Apr 3, 2025
780c882
[Mosaic GPU] Fix index_invariant slot in warp-specialized pipeline.
justinjfu Apr 3, 2025
d7fc04b
Merge pull request #27681 from jakevdp:jax-array
Google-ML-Automation Apr 3, 2025
c2eb9c1
Eliminate DeprecationWarning in python3.12+ in jax pallas for ~.
Google-ML-Automation Apr 3, 2025
1bd0c58
Merge pull request #27691 from gnecula:export_override_lowering
Google-ML-Automation Apr 3, 2025
41868ef
format
kaixih Apr 3, 2025
cb67d56
[Mosaic GPU] Re-enable WS pipelined copy test.
justinjfu Apr 3, 2025
3901014
[pallas:mgpu] General ref transform handling at lowering time.
cperivol Apr 3, 2025
bbdea54
add an `out_sharding` option to `jax.random.permutation`
froystig Apr 3, 2025
7583814
[mgpu:pallas] Changes to allow the use of WGMMA_TRANSPOSED_LAYOUT.
cperivol Apr 3, 2025
a04b5ec
Merge pull request #27717 from froystig:out-shard-perm
Google-ML-Automation Apr 3, 2025
26fc1cd
[pallas:mgpu] Initial version of inline_mgpu op
cperivol Apr 4, 2025
d645172
Delete `PjRtClient.Defragment`.
zacmustin Apr 4, 2025
f8bbe98
require `out_shardings` as a keyword-only argument on public functions
froystig Apr 4, 2025
5b3e419
Add `auto_axes`, `explicit_axes` and `manual_axes` properties to Mesh…
yashk2810 Apr 4, 2025
c1bdd1a
[Mosaic TPU] Allow specify priority in enqueueDMA.
bythew3i Apr 4, 2025
a9bd1e3
[Pallas TPU] Support DMA priority in async copy start
bythew3i Apr 4, 2025
4f00249
[pallas:mosaic_gpu] Do not specify the default `index_map` in tests
superbobry Apr 4, 2025
97cecdf
add an `out_sharding` option to `jax.random.truncated_normal`
froystig Apr 4, 2025
5eb4e7b
[Mosaic GPU] Return the combined softmax residuals.
Rifur13 Apr 4, 2025
12b1a99
fix(docs): corrected the name of the function call in the document
Qazalbash Apr 4, 2025
e619fc0
Avoid double buffering when no windowing info is present.
Google-ML-Automation Apr 4, 2025
1b63d5e
Fixed deadlock in NamedSharding ctor
vfdev-5 Apr 3, 2025
206dec8
[pallas:mosaic_gpu] Added pretty printing to primitives consuming refs
superbobry Apr 4, 2025
b0a920d
[Mosaic GPU] Don't force TiledLayout.lane_dims to partition data
apaszke Apr 4, 2025
bb63850
Merge pull request #27707 from vfdev-5:fix-named-sharding-deadlock
Google-ML-Automation Apr 4, 2025
635805e
[Mosaic GPU] Allow replicating data over warps
apaszke Apr 4, 2025
e4a381c
[pallas:mgpu] Check that swizzle dim is not transposed in copy_smem_t…
cperivol Apr 4, 2025
cbae253
[mgpu:pallas] Typo in `UnswizzleRef.untransform_reshape()` check.
cperivol Apr 4, 2025
5a29311
Update XLA dependency to use revision
Google-ML-Automation Apr 4, 2025
da7b157
[mgpu:pallas] Swizzle elements computed using bitwidth rather than by…
cperivol Apr 4, 2025
53abbd5
[mgpu] Foreach to handle scalar registers in fragmented arrays.
cperivol Apr 4, 2025
b900714
[mgpu:pallas] Fix swizzling check bug where it was comparing w/ #byte…
cperivol Apr 4, 2025
35d7518
`_attempt_rewriting_take_via_slice()`: canonicalize the slice index b…
Google-ML-Automation Apr 4, 2025
5a7dc42
Merge pull request #27730 from froystig:out-shard-trunc-normal
Google-ML-Automation Apr 4, 2025
e2f67e0
Always force synchronous pipelining when we have vmem storage and tri…
Google-ML-Automation Apr 4, 2025
e6b01bd
Parameterize the random tests taking out_sharding argument in pjit_te…
yashk2810 Apr 4, 2025
027f733
Merge pull request #27731 from Qazalbash:fix-doc-grad-checkpoint
Google-ML-Automation Apr 4, 2025
be1a554
Fix a possible race in pjit.cc.
hawkinsp Apr 4, 2025
5d4ac77
PR #26906: [jax.distributed] Allow explicitly setting slice_index
gspschmid Apr 4, 2025
549f1cd
Don't set `memory_kind` to `None` if the mesh is AbstractMesh and the
yashk2810 Apr 4, 2025
d81c0ff
[Mosaic GPU] Limit the maximum number of registers per thread to 255.
Rifur13 Apr 4, 2025
aab6613
[pallas:mosaic_gpu] Fixed a typo in `_barrier_arrive_pp_eqn`
superbobry Apr 4, 2025
fc5d9a4
Check that memory_kind of an aval is always None
yashk2810 Apr 5, 2025
2e62693
Update XLA dependency to use revision
Google-ML-Automation Apr 5, 2025
b1b54f9
Update XLA dependency to use revision
Google-ML-Automation Apr 6, 2025
ad36f7f
Automated Code Change
Google-ML-Automation Apr 6, 2025
477b108
Automated Code Change
Google-ML-Automation Apr 6, 2025
3f083ca
Automated Code Change
Google-ML-Automation Apr 6, 2025
d1c7ba4
Automated Code Change
Google-ML-Automation Apr 6, 2025
7874d79
Automated Code Change
Google-ML-Automation Apr 6, 2025
8a6efa3
Fix deadlock when computing cached Sharding::type() values.
hawkinsp Apr 6, 2025
cccc34d
Raise an error if the type passed to `axis_types` argument of `Mesh` …
yashk2810 Apr 7, 2025
90cfa99
[Mosaic GPU] Support Slice and Transpose in the Pallas WGMMA lowering
dimitar-asenov Apr 7, 2025
245194f
Use `contextlib.nullcontext` instead of `trivial_ctx`
superbobry Apr 7, 2025
695ee8f
Fix a race in pjit under free threading.
hawkinsp Apr 7, 2025
ce7dc85
[export] Add support for serializing functions with PRNG keys as inpu…
gnecula Apr 7, 2025
075d88f
Fix some test timeouts
apaszke Apr 7, 2025
6e93fa3
Removed unused deprecations
superbobry Apr 7, 2025
9b850a9
[Mosaic GPU] Delete mentions of `WGMMARowFragLayout` in `layouts.py`.
bchetioui Apr 7, 2025
4596ee3
Add a missing jaxlib version check in Pallas TPU lowering
apaszke Apr 7, 2025
153fa22
Add more TSAN skips to avoid timeouts
apaszke Apr 7, 2025
c2aa811
`jex.core.Var` is no longer ordered
superbobry Apr 7, 2025
724b13a
Merge pull request #27773 from hawkinsp:race
Google-ML-Automation Apr 7, 2025
5c0f885
Update XLA dependency to use revision
Google-ML-Automation Apr 7, 2025
83572e1
[Mosaic GPU] Add missing to/from tiled layout attributes with replica…
dimitar-asenov Apr 7, 2025
70485e3
Remove accidental exports jax.interpreters.mlir.{hlo,func_dialect}.
hawkinsp Apr 7, 2025
412f88e
Temporarily skip JaxNumpyErrorTests in multi-thread environments
ayaka14732 Apr 7, 2025
a099b28
Reverts 735cec18cb2f8dff2aea5e503fd886a37aee094e
danielsuo Apr 7, 2025
5a3fc60
Deprecate public export of mlir.custom_call.
dfm Apr 7, 2025
dbc3bcd
Apply forwarding in pjit linearization rule to avoid intermediate cop…
dfm Apr 4, 2025
ff00fa9
Removed unused `jax_remat_opt_barrier` config option
superbobry Apr 7, 2025
51c224c
Removed deprecated `jax.core.{full_lower,jaxpr_as_fun,lattice_join}`
superbobry Apr 7, 2025
5581e7d
Merge pull request #27735 from dfm:lin-fwd
Google-ML-Automation Apr 7, 2025
855829e
Add int4, uint4 to test_util.suppported_types
jburnim Apr 7, 2025
7239487
Bump medyagh/setup-minikube from 0.0.18 to 0.0.19
dependabot[bot] Apr 7, 2025
fcf5115
[Pallas Fuser] Add output_fusion_mask support
sharadmv Apr 7, 2025
e1e37f8
[Mosaic TPU] FWD compatibility needs to keep previous version at leas…
bythew3i Apr 7, 2025
f23dd64
Merge pull request #26853 from jeffcarp:scalar-event
Google-ML-Automation Apr 7, 2025
05ca023
[shard-map] in eager shmap, handle all rep rule output cases
mattjj Apr 7, 2025
dc00f9b
Apply output forwarding in lin rule for pjit.
dfm Apr 7, 2025
23b63cd
Update XLA dependency to use revision
Google-ML-Automation Apr 7, 2025
522add2
[CI] Temporarily disable TPU v6 due to runner issues
MichaelHudgins Apr 7, 2025
0d06731
Merge pull request #27797 from mattjj:shmap-fix
Google-ML-Automation Apr 7, 2025
e1b0572
[mgpu] Allow bf16 printing
cperivol Apr 7, 2025
b6e4b93
Add jaxlib_extension_version guard against explicit copying
pschuh Apr 7, 2025
9a3e94d
[shard-map] add while_map rep rule
mattjj Apr 7, 2025
74917ce
Merge pull request #27794 from jax-ml:dependabot/github_actions/medya…
Google-ML-Automation Apr 7, 2025
d3cfff0
jax.numpy: support __jax_array__ in remaining APIs
jakevdp Apr 7, 2025
db11efa
Migrate jaxlib to use a single common .so file for all C++ dependencies.
hawkinsp Apr 7, 2025
96e63ea
jnp.linalg: add symmetrize_input argument & docs
jakevdp Apr 7, 2025
b18dc1d
[Mosaic GPU] Add scaffolding for a new lowering "axis" (UserThreadSem…
justinjfu Apr 7, 2025
64e4bf2
Relax jax dependency constraints to be able to install RC wheels
nitins17 Apr 7, 2025
3a3c145
[shard-map] canonicalize rep=None to be rep={all possible axes}
mattjj Apr 7, 2025
48a9ad0
Reverts 006a6a63feb64bf9984526030ba008186d69d2b4
Google-ML-Automation Apr 7, 2025
3420546
Merge pull request #27716 from jakevdp:jax-array
Google-ML-Automation Apr 7, 2025
2944e3b
Removed `data_dependent_tracing_fallback` config option
superbobry Apr 7, 2025
0a72e85
Add **experimental** `with_dll_constraint` API. This is for cases whe…
yashk2810 Apr 7, 2025
84e04fe
Add custom pretty print rule for the unary ops with accuracy s.t. acc…
hanrach9 Apr 7, 2025
ca6e470
harden cache against jaxlib ver
ZacCranko Apr 7, 2025
9e03686
Merge pull request #27793 from dfm:lin-out-fwd
Google-ML-Automation Apr 8, 2025
4bae9cd
Merge pull request #27814 from ZacCranko:harden-cache
Google-ML-Automation Apr 8, 2025
3158996
Migrate custom_call filecheck to use internal custom_call since the e…
dfm Apr 8, 2025
86de478
Remove unused function jax._src.interpreters.mlir.xla_computation_to_…
hawkinsp Apr 8, 2025
bb515aa
Address previous FP8-related TODOs in jaxlib/XLA.
apivovarov Apr 8, 2025
51dbcd4
[export] Add backwards compatibility test for annotate_device_placement.
gnecula Apr 8, 2025
19fcae1
[Mosaic GPU] Add support for replicated warp_dim parsing and a dedica…
dimitar-asenov Apr 8, 2025
bc11a63
Clarify jax.make_jaxpr docstring
j-towns Apr 8, 2025
c2eaedf
Merge pull request #27776 from gnecula:export_keys
Google-ML-Automation Apr 8, 2025
8ed59d8
Removed `jax._src.raise_to_shaped`
superbobry Apr 8, 2025
af072fe
Removed redundant `pass`es
superbobry Apr 8, 2025
d12cbff
[Mosaic GPU] Refactor and generalize code in `optimization_barrier`.
dimitar-asenov Apr 8, 2025
c4cc94a
[Mosaic GPU] Add warpgroup lowering for `RunState` in Pallas.
dimitar-asenov Apr 8, 2025
12811f0
Removed `eager_pmap` config option
superbobry Apr 8, 2025
5f33280
[pallas:mosaic_gpu] `emit_pipeline*` now allows the grid to be dynamic
superbobry Apr 8, 2025
73ecf0b
Remove unused `return wrapper` in annotate_function that creates a se…
Google-ML-Automation Apr 8, 2025
511f782
Add a skeleton for Pallas:Mosaic GPU documentation
apaszke Apr 8, 2025
aa6e701
Merge pull request #27827 from apaszke:mgpu-docs
Google-ML-Automation Apr 8, 2025
d6524dc
Update XLA dependency to use revision
Google-ML-Automation Apr 8, 2025
b926fac
[Mosaic GPU] Simplify load/store methods now that we have fewer layouts
apaszke Apr 8, 2025
f5d73b8
[pallas:mosaic_gpu] Added test for custom pretty-printing rules
superbobry Apr 8, 2025
b8353d1
[Mosaic TPU] Add support for non-32bit types in vector.extract
apaszke Apr 8, 2025
e02faab
Replace references to jax.readthedocs.io with docs.jax.dev.
hawkinsp Apr 8, 2025
ae95797
change tack...
mattjj Apr 7, 2025
4d2808c
[mutable-arrays] limit implicit ref_swap dtype promotion
mattjj Apr 8, 2025
b7d430f
jnp.repeat: don't cast repeats to array, as they must be static.
jakevdp Apr 8, 2025
03c1bf9
Merge pull request #27803 from mattjj:27644
Google-ML-Automation Apr 8, 2025
29cb6cd
[Mosaic TPU] Add MemRead and MemStore effects to load and store ops.
bythew3i Apr 8, 2025
ef68063
Merge pull request #27809 from mattjj:26621
Google-ML-Automation Apr 8, 2025
b073e8d
Merge pull request #27836 from jakevdp:fix-repeat
Google-ML-Automation Apr 8, 2025
76825a2
Merge pull request #27807 from jakevdp:eigvalsh-symmetrize
Google-ML-Automation Apr 8, 2025
f1bcf3b
Merge pull request #27821 from j-towns:clarify-make-jaxpr-docstr
Google-ML-Automation Apr 8, 2025
a43136b
Simplify handling of type stubs in nanobind extension rules.
hawkinsp Apr 8, 2025
62df2e8
Added `jax.no_tracing` to the API docs
superbobry Apr 8, 2025
09fed2f
Remove reexports of ml_dtypes types from xla_client.py.
hawkinsp Apr 8, 2025
2d44f98
Finalize deprecation of `ffi_call` with inline arguments.
dfm Apr 8, 2025
b4629c2
Split weakref_lru_cache into its own extension.
hawkinsp Apr 8, 2025
8301c30
Make changes to shard_map to prepare for setting `varying_axes_in_typ…
yashk2810 Apr 8, 2025
5a340a9
Disable second order vjp tests in RunStateHypothesisTest.test_vjp if …
vfdev-5 Mar 11, 2025
7b45552
[ragged-paged-attn] Unify kv strided load to one.
bythew3i Apr 8, 2025
b8d9e7f
Merge pull request #27503 from kaixih:enable_doc_scaled_dot
Google-ML-Automation Apr 8, 2025
a516988
[JAX] Remove calls to xla_computation_to_mlir_module.
hawkinsp Apr 8, 2025
373ac2e
Merge pull request #27804 from vfdev-5:ft-adapt-state-test-2
Google-ML-Automation Apr 8, 2025
84016bc
Rename `pbroadcast` to `pvary` and expose it as `jax.lax.pvary`.
yashk2810 Apr 8, 2025
f95f6a8
Rename `psum2` to `psum_invariant` and put it in `lax_parallel`. We s…
yashk2810 Apr 9, 2025
4275135
Fix typo in the error message
yashk2810 Apr 9, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
18 changes: 10 additions & 8 deletions .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ build -c opt
build --output_filter=DONT_MATCH_ANYTHING

build --copt=-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir.
build --copt=-DNB_DOMAIN=jax

# #############################################################################
# Platform Specific configs below. These are automatically picked up by Bazel
Expand Down Expand Up @@ -97,6 +98,7 @@ build:windows --incompatible_strict_action_env=true
# #############################################################################
build:nonccl --define=no_nccl_support=true

build --repo_env USE_PYWRAP_RULES=1
build:posix --copt=-fvisibility=hidden
build:posix --copt=-Wno-sign-compare
build:posix --cxxopt=-std=c++17
Expand Down Expand Up @@ -136,13 +138,13 @@ build:cuda --repo_env TF_NEED_CUDA=1
build:cuda --repo_env TF_NCCL_USE_STUB=1
# "sm" means we emit only cubin, which is forward compatible within a GPU generation.
# "compute" means we emit both cubin and PTX, which is larger but also forward compatible to future GPU generations.
build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,compute_90"
build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,sm_90,sm_100,compute_120"
build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain
build:cuda --@local_config_cuda//:enable_cuda

# Default hermetic CUDA and CUDNN versions.
build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2"
build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1"
build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.8.0"
build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.8.0"
build:cuda --@local_config_cuda//cuda:include_cuda_libs=true

# This config is used for building targets with CUDA libraries from stubs.
Expand Down Expand Up @@ -260,8 +262,8 @@ build:ci_darwin_arm64 --color=yes
# Windows x86 CI configs
build:ci_windows_amd64 --config=avx_windows
build:ci_windows_amd64 --compiler=clang-cl --config=clang --verbose_failures=true
build:ci_windows_amd64 --crosstool_top="@xla//tools/toolchains/win/20240424:toolchain"
build:ci_windows_amd64 --extra_toolchains="@xla//tools/toolchains/win/20240424:cc-toolchain-x64_windows-clang-cl"
build:ci_windows_amd64 --crosstool_top="@xla//tools/toolchains/win2022/20241118:toolchain"
build:ci_windows_amd64 --extra_toolchains="@xla//tools/toolchains/win2022/20241118:cc-toolchain-x64_windows-clang-cl"
build:ci_windows_amd64 --host_linkopt=/FORCE:MULTIPLE --linkopt=/FORCE:MULTIPLE
build:ci_windows_amd64 --color=yes

Expand Down Expand Up @@ -329,9 +331,9 @@ common:rbe_windows_amd64 --remote_instance_name=projects/tensorflow-testing/inst
build:rbe_windows_amd64 --config=rbe

# Set the host, execution, and target platform
build:rbe_windows_amd64 --host_platform="@xla//tools/toolchains/win:x64_windows-clang-cl"
build:rbe_windows_amd64 --extra_execution_platforms="@xla//tools/toolchains/win:x64_windows-clang-cl"
build:rbe_windows_amd64 --platforms="@xla//tools/toolchains/win:x64_windows-clang-cl"
build:rbe_windows_amd64 --host_platform="@xla//tools/toolchains/win2022:windows_ltsc2022_clang"
build:rbe_windows_amd64 --extra_execution_platforms="@xla//tools/toolchains/win2022:windows_ltsc2022_clang"
build:rbe_windows_amd64 --platforms="@xla//tools/toolchains/win2022:windows_ltsc2022_clang"

build:rbe_windows_amd64 --shell_executable=C:\\tools\\msys64\\usr\\bin\\bash.exe
build:rbe_windows_amd64 --enable_runfiles
Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/bazel_cuda_non_rbe.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ jobs:
# Explicitly set the shell to bash
shell: bash
runs-on: ${{ inputs.runner }}
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.3-cudnn9.1:latest"
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.8-cudnn9.8:latest"

env:
JAXCI_HERMETIC_PYTHON_VERSION: ${{ inputs.python }}
Expand Down Expand Up @@ -79,6 +79,7 @@ jobs:
continue-on-error: true
run: >-
mkdir -p $(pwd)/dist &&
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ &&
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ &&
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ &&
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/
Expand Down
65 changes: 65 additions & 0 deletions .github/workflows/bazel_optional_cuda.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
name: CI - Bazel Optional CUDA tests
on:
workflow_dispatch:
inputs:
halt-for-connection:
description: 'Should this workflow run wait for a remote connection?'
type: choice
required: true
default: 'no'
options:
- 'yes'
- 'no'
pull_request:
branches:
- main
types: [ labeled, synchronize ]
schedule:
- cron: "0 */2 * * *" # Run once every 2 hours
permissions:
contents: read
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
# Don't cancel in-progress jobs for main/release branches.
cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }}
jobs:
run_tests:
if: ${{ github.event.repository.fork == false && (github.event_name == 'schedule' || contains(github.event.pull_request.labels.*.name, 'CI Optional GPU Presubmit')) }}
runs-on: ${{ matrix.runner }}
container: 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.8-cudnn9.8:latest'
strategy:
matrix:
# Optional gpus to run against
runner: ["linux-x86-a4-224-b200-1gpu"]
name: "Bazel single accelerator CUDA tests (${{ matrix.runner }})"
# End Presubmit Naming Check github-cuda-presubmits
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@main
with:
halt-dispatch-input: ${{ inputs.halt-for-connection }}
- name: Run Bazel CUDA Tests
run: |
nvidia-smi
bazel test --config=ci_linux_x86_64_cuda \
--config=resultstore \
--config=rbe_cache \
--repo_env=HERMETIC_CUDA_VERSION="12.8.0" \
--repo_env=HERMETIC_CUDNN_VERSION="9.8.0" \
--repo_env=HERMETIC_PYTHON_VERSION="3.13" \
--test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \
--run_under "$(pwd)/build/parallel_accelerator_execute.sh" \
--test_output=errors \
--test_env=JAX_ACCELERATOR_COUNT=1 \
--test_env=JAX_TESTS_PER_ACCELERATOR=32 \
--local_test_jobs=32 \
--test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow \
--test_tag_filters=-multiaccelerator \
--test_env=TF_CPP_MIN_LOG_LEVEL=0 \
--test_env=JAX_SKIP_SLOW_TESTS=true \
--action_env=JAX_ENABLE_X64="1" \
--action_env=NCCL_DEBUG=WARN \
--color=yes \
//tests:gpu_tests //tests:backend_independent_tests \
//tests/pallas:gpu_tests //tests/pallas:backend_independent_tests
12 changes: 6 additions & 6 deletions .github/workflows/build_artifacts.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ on:
default: "linux-x86-n2-16"
options:
- "linux-x86-n2-16"
- "linux-arm64-c4a-64"
- "windows-x86-n2-64"
- "linux-arm64-t2a-48"
- "windows-x86-n2-16"
artifact:
description: "Which JAX artifact to build?"
type: choice
Expand Down Expand Up @@ -119,11 +119,11 @@ jobs:

steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Enable RBE if building on Linux x86
if: contains(inputs.runner, 'linux-x86')
- name: Enable RBE if building on Linux x86 or Windows x86
if: contains(inputs.runner, 'linux-x86') || contains(inputs.runner, 'windows-x86')
run: echo "JAXCI_BUILD_ARTIFACT_WITH_RBE=1" >> $GITHUB_ENV
- name: Enable Bazel remote cache (with writes enabled) if building on Linux Aarch64 or Windows x86
if: contains(inputs.runner, 'linux-arm64') || contains(inputs.runner, 'windows-x86')
- name: Enable Bazel remote cache (with writes enabled) if building on Linux Aarch64
if: contains(inputs.runner, 'linux-arm64')
run: echo "JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE=1" >> $GITHUB_ENV
# Halt for testing
- name: Wait For Connection
Expand Down
14 changes: 7 additions & 7 deletions .github/workflows/ci-build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ jobs:
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python 3.11
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0
with:
python-version: 3.11
- run: python -m pip install pre-commit
- uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
- uses: actions/cache@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3
with:
path: ~/.cache/pre-commit
key: pre-commit-${{ env.pythonLocation }}-${{ hashFiles('.pre-commit-config.yaml', 'setup.py') }}
Expand Down Expand Up @@ -70,7 +70,7 @@ jobs:
apt update
apt install -y libssl-dev
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
Expand Down Expand Up @@ -108,7 +108,7 @@ jobs:
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
Expand Down Expand Up @@ -142,7 +142,7 @@ jobs:
apt update
apt install -y libssl-dev libsqlite3-dev
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
Expand All @@ -168,7 +168,7 @@ jobs:
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
Expand Down Expand Up @@ -201,7 +201,7 @@ jobs:
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0
with:
python-version: 3.12
- name: Install JAX
Expand Down
1 change: 0 additions & 1 deletion .github/workflows/cloud-tpu-ci-nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ jobs:
matrix:
jaxlib-version: ["head", "pypi_latest", "nightly", "nightly+oldest_supported_libtpu"]
tpu: [
# {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available
{type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}
]
Expand Down
34 changes: 34 additions & 0 deletions .github/workflows/community_release_actions.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
name: Release Actions

on:
release:
types: [published]

permissions:
contents: read

jobs:
discord_release:
if: github.repository_owner == 'jax-ml'
runs-on: ubuntu-latest
steps:
- name: Get release URL
id: get-release-url
run: |
URL="https://docs.jax.dev/en/latest/changelog.html"
echo "::set-output name=URL::$URL"
- name: Get content
uses: 2428392/gh-truncate-string-action@b3ff790d21cf42af3ca7579146eedb93c8fb0757 # v1.4.1
id: get-content
with:
stringToTruncate: |
JAX [${{ github.event.release.tag_name }}](<${{ steps.get-release-url.outputs.URL }}>) was just released!

${{ github.event.release.body }}
maxLength: 2000
truncationSymbol: "..."
- name: Discord Webhook Action
uses: tsickert/discord-webhook@b217a69502f52803de774ded2b1ab7c282e99645 # v7.0.0
with:
webhook-url: ${{ secrets.DISCORD_WEBHOOK_URL }}
content: ${{ steps.get-content.outputs.string }}
4 changes: 2 additions & 2 deletions .github/workflows/jax-array-api.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ jobs:
with:
repository: data-apis/array-api-tests
# TODO(jakevdp) update this to a stable release/tag when available.
ref: '0b89c5268e4e4a352223a487b8f63dbd1023872d' # Latest commit as of 2025-03-04
ref: 'c48410f96fc58e02eea844e6b7f6cc01680f77ce' # Latest commit as of 2025-04-02
submodules: 'true'
path: 'array-api-tests'
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
Expand Down
109 changes: 109 additions & 0 deletions .github/workflows/k8s.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
name: Distributed run using K8s Jobset

on:
push:
branches:
- main
paths:
- 'jax/distributed.py'
- 'jax/_src/distributed.py'
- 'jax/_src/clusters/**'
pull_request:
branches:
- main
paths:
- 'jax/distributed.py'
- 'jax/_src/distributed.py'
- 'jax/_src/clusters/**'

permissions:
contents: read

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: true

defaults:
run:
shell: bash -ex -o pipefail {0}

jobs:

distributed-initialize:
runs-on: ubuntu-22.04
steps:
- name: Checkout
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # ratchet:actions/checkout@v4
with:
path: jax

- name: Start Minikube cluster
uses: medyagh/setup-minikube@cea33675329b799adccc9526aa5daccc26cd5052 # ratchet:medyagh/[email protected]

- name: Install K8s Jobset
run: |
kubectl apply --server-side -f https://github.com/kubernetes-sigs/jobset/releases/download/v0.6.0/manifests.yaml

- name: Build image
run: |
cat > Dockerfile <<EOF
FROM ubuntu:22.04
ADD jax /opt/jax
RUN apt-get update && apt-get install -y python-is-python3 python3-pip
RUN pip install -e /opt/jax[k8s]
EOF

minikube image build -t local/jax:latest .

- name: Create service account for K8s job introspection
run: |
kubectl apply -f jax/examples/k8s/svc-acct.yaml

- name: Prepare test job
run: |
export VERSION=v4.44.3
export BINARY=yq_linux_amd64
wget https://github.com/mikefarah/yq/releases/download/${VERSION}/${BINARY} -O /usr/bin/yq && chmod +x /usr/bin/yq

cat jax/examples/k8s/example.yaml |\
yq '.spec.replicatedJobs[0].template.spec.template.spec.containers[0].image = "local/jax:latest"' |\
yq '.spec.replicatedJobs[0].template.spec.template.spec.containers[0].imagePullPolicy = "Never"' |\
tee example.yaml

- name: Submit test job
run: |
kubectl apply -f example.yaml

- name: Check job status
shell: bash -e -o pipefail {0}
run: |
while true; do
status=$(kubectl get jobset example -o yaml | yq .status.conditions[0].type)
timestamp=$(date +"%Y-%m-%d %H:%M:%S")
echo "[$timestamp] Checking job status..."

if [ "$status" == "Completed" ]; then
echo "[$timestamp] Job has completed successfully!"
exit 0
elif [ "$status" == "Failed" ]; then
echo "[$timestamp] Job has failed!"
exit 1
else
echo "[$timestamp] Job is still running. Current pod status:"
kubectl get pods --no-headers
echo "[$timestamp] Waiting for 3 seconds before checking again..."
sleep 3
fi
done

- name: Examine individual pod outputs
if: "!cancelled()"
run: |
set +x
kubectl get pods --no-headers | awk '{print $1}' | while read -s pod; do
echo "========================================"
echo "Pod $pod output:"
echo "----------------------------------------"
kubectl logs $pod
echo "========================================"
done
Loading
Loading