forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 4
CI: 04/09/25 upstream sync #351
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
rocm-repo-management-api-2
wants to merge
588
commits into
rocm-main
Choose a base branch
from
ci-upstream-sync-168_1
base: rocm-main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
…t2, uint2, int4, uint4, float4_e2m1fn subbyte types in CPU/GPU callbacks. PiperOrigin-RevId: 742253272
Add a first benchmark for tracing/lowering pallas splash attention. Sample results below taken on a GCP n2d-standard-128 instance with 512GB Ram and 128 vCPU AMD EPYC Milan. --------------------------------------------------------------------------------- Benchmark Time CPU Iterations --------------------------------------------------------------------------------- test_pallas_mqa_splash_attention_trace 39.8 ms 39.8 ms 19 test_pallas_mqa_splash_attention_lower 42.1 ms 41.9 ms 18 PiperOrigin-RevId: 742259409
Following the compatibility timeline described here: https://docs.jax.dev/en/latest/export/export.html#ensuring-forward-and-backward-compatibility It has been 6 months since the release of 0.4.33 which is the relevant release for this kernel. PiperOrigin-RevId: 742261532
…rrier` PiperOrigin-RevId: 742265860
PiperOrigin-RevId: 742302906
Bumps [tsickert/discord-webhook](https://github.com/tsickert/discord-webhook) from 5.3.0 to 7.0.0. - [Release notes](https://github.com/tsickert/discord-webhook/releases) - [Commits](tsickert/discord-webhook@c840d45...b217a69) --- updated-dependencies: - dependency-name: tsickert/discord-webhook dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] <[email protected]>
Bumps [actions/setup-python](https://github.com/actions/setup-python) from 5.4.0 to 5.5.0. - [Release notes](https://github.com/actions/setup-python/releases) - [Commits](actions/setup-python@4237552...8d9ed9a) --- updated-dependencies: - dependency-name: actions/setup-python dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] <[email protected]>
Previously, XLA's command buffers (CUDA graphs) would be disabled both for PGLE profile collection and when re-compiling using the profile data. With this change, they are only disabled when collecting the profile data.
… pbroadcast insertion for axis_index_p in the traceable PiperOrigin-RevId: 742334213
…/actions/setup-python-5.5.0 PiperOrigin-RevId: 742345857
PiperOrigin-RevId: 742431349
PiperOrigin-RevId: 742447929
PiperOrigin-RevId: 742464022
…nversions This will be submitted automatically once the compatibility window has passed PiperOrigin-RevId: 742464046
PiperOrigin-RevId: 742531956
http://github.com/openxla/xla/commit/b1971cc2b3407e87fada2674a057d72897b79acc. PiperOrigin-RevId: 742646393
Add the tests to the Bazel presubmit RBE jobs (except `arm64`/`aarch64` jobs that use RBE cross-compilation). PiperOrigin-RevId: 742724458
PiperOrigin-RevId: 742737384
Fixes jax-ml#16076 Co-authored-by: Roy Frostig <[email protected]>
…/tsickert/discord-webhook-7.0.0 PiperOrigin-RevId: 742756295
PiperOrigin-RevId: 742771594
PiperOrigin-RevId: 742774171
PiperOrigin-RevId: 745156931
See jax-ml#18711 check_rep uses rep=None to indicate when an argument is a constant, and that's useful specifically when checking the backward pass for integer_pow, which has a multiplication by a constant that didn't get a pbroadcast applied to it. That is, we use rep=None as a special carve-out for constants. The standard rules were compatible with rep=None, but the rules for higher-order primitives like scan and cond were not. So we had to upgrade them.
fixes jax-ml#27683 In b7715e2, specifically this line: jax-ml@b7715e2#diff-8a1ad6e3b750565d66d30dbf4c9df0825bf5e87c4721e3352f44efbfb8b4a29cR193 we started ignoring the value dtype completely when it was weakly typed. But that could lead to surprising implicit bitcasts like in jax-ml#27683. A repro looks like: ```python import jax.numpy as jnp from jax._src import core v = core.mutable_array(jnp.array([0, 0, 0])) v[...] += 1.0 print(v) # MutableArray([1065353216, 1065353216, 1065353216], dtype=int32) ``` We can't easily just drop this behavior because it seems many GPU x64 tests depend on it. So in this change we're trying to 1. do the casting outside the bind, so that in jaxpr typechecking we can assert the value to assign has to match the ref dtype; 2. make that casting more restrictive, supporting only casts on weak-typed values between different precisions of floats or ints; and 3. do an ordinary cast rather than a bitcast. I left a TODO to change this behavior, since it seems a bit ad-hoc. But we may not want to remove all implicit casting; for example, it's probably reasonable to support implicit casting of Python builtin numeric types when we don't lose any precision, e.g. ```python v = core.mutable_array(jnp.array(0, dtype='bfloat16')) v[...] += 1.0 # don't error! ``` But we can do that with special-purpose carve-outs for Python builtin numerictypes. I left one way to do it in a comment. PiperOrigin-RevId: 745198669
PiperOrigin-RevId: 745201720
So duplicated load/store ops can be removed. PiperOrigin-RevId: 745209849
PiperOrigin-RevId: 745212009
PiperOrigin-RevId: 745215941
PiperOrigin-RevId: 745216021
PiperOrigin-RevId: 745216259
Pass pytype_srcs as data to the pybind_extension rule. PiperOrigin-RevId: 745238783
PiperOrigin-RevId: 745247778
These should be used directly from ml_dtypes. PiperOrigin-RevId: 745256523
PiperOrigin-RevId: 745261995
Now that jax-ml@db11efa has landed, we're free to split up xla_extension without creating binary size problems or having to be quite so careful about cross-module dependencies. Here weakref_lru_cache has absolutely nothing to do with XLA. There's no reason weakref_lru_cache is in the same Python extension as everything else. PiperOrigin-RevId: 745271825
…es` to True. The main changes here are: * Don't take the `_efficient_transpose_rewrite` transformation path anymore. In other words, `RewriteTrace` and all the rewriting machinery is dead. * Wherever internally we were setting `check_rep=False` explicitly like `_prim_applier`, `_match`, `_unmatch`, `_shard_map_partial_eval`, `_shard_map_partial_eval_custom` (for remat), don't do that anymore. Instead set `check_rep` to the `check_rep` value so that it can be True if the user hasn't passed `check_rep=False`. * Introduce an internal `_check_rep` context manager and set it wherever `extend_axis_env_nd` is used so that if `check_rep=False` on `shard_map`, JAX will set `vma` in `ShapedArray` to empty `frozenset`. * Because of point (2), if `check_rep=True`, we can't set `in_specs` and `out_specs` of shmap internally to all manual axes of the mesh on the 0th dim. It needs to be whatever the argument was varying on. Co-authored-by: Matthew Johnson <[email protected]> PiperOrigin-RevId: 745276474
…JAX_SKIP_SLOW_TESTS=true Description: - Disable second order vjp tests in RunStateHypothesisTest.test_vjp if JAX_SKIP_SLOW_TESTS=true to reduce the test execution time - especially for TSAN CI job where this test takes ~700 seconds to pass with the recent 3.13 cpython - Removed optional deps for 3.14
I expected Mosaic can canonicalize 2 same strided loads to one but it did not. (We will fix this in Mosaic). For now, manually converting to one strided load boosts 20~35% speedup in both v6e and v5e single chip for Meta-Llama-3-8B. PiperOrigin-RevId: 745294058
PiperOrigin-RevId: 745322012
This (private) API will shortly be deleted, and hlo_to_stablehlo is its replacement. PiperOrigin-RevId: 745333506
PiperOrigin-RevId: 745341315
PiperOrigin-RevId: 745342103
…houldn't expose this to public API and have users use `psum` instead which will dispatch to `psum_invariant` when `check_rep=True`. PiperOrigin-RevId: 745352875
PiperOrigin-RevId: 745375892
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Daily sync with upstream