Skip to content

CI: 04/09/25 upstream sync #351

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 588 commits into
base: rocm-main
Choose a base branch
from

Conversation

rocm-repo-management-api-2[bot]
Copy link

Daily sync with upstream

danielsuo and others added 30 commits March 31, 2025 07:09
…t2, uint2, int4, uint4, float4_e2m1fn subbyte types in CPU/GPU callbacks.

PiperOrigin-RevId: 742253272
Add a first benchmark for tracing/lowering pallas splash attention.

Sample results below taken on a GCP n2d-standard-128 instance with 512GB Ram and 128 vCPU AMD EPYC Milan.

---------------------------------------------------------------------------------
Benchmark                                       Time             CPU   Iterations
---------------------------------------------------------------------------------
test_pallas_mqa_splash_attention_trace       39.8 ms         39.8 ms           19
test_pallas_mqa_splash_attention_lower       42.1 ms         41.9 ms           18

PiperOrigin-RevId: 742259409

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
Following the compatibility timeline described here: https://docs.jax.dev/en/latest/export/export.html#ensuring-forward-and-backward-compatibility

It has been 6 months since the release of 0.4.33 which is the relevant release for this kernel.

PiperOrigin-RevId: 742261532
Bumps [tsickert/discord-webhook](https://github.com/tsickert/discord-webhook) from 5.3.0 to 7.0.0.
- [Release notes](https://github.com/tsickert/discord-webhook/releases)
- [Commits](tsickert/discord-webhook@c840d45...b217a69)

---
updated-dependencies:
- dependency-name: tsickert/discord-webhook
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <[email protected]>
Bumps [actions/setup-python](https://github.com/actions/setup-python) from 5.4.0 to 5.5.0.
- [Release notes](https://github.com/actions/setup-python/releases)
- [Commits](actions/setup-python@4237552...8d9ed9a)

---
updated-dependencies:
- dependency-name: actions/setup-python
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <[email protected]>
Previously, XLA's command buffers (CUDA graphs) would be disabled both
for PGLE profile collection and when re-compiling using the profile
data. With this change, they are only disabled when collecting the
profile data.
… pbroadcast insertion for axis_index_p in the traceable

PiperOrigin-RevId: 742334213
…/actions/setup-python-5.5.0

PiperOrigin-RevId: 742345857
…nversions

This will be submitted automatically once the compatibility window has passed

PiperOrigin-RevId: 742464046
Add the tests to the Bazel presubmit RBE jobs (except `arm64`/`aarch64` jobs that use RBE cross-compilation).

PiperOrigin-RevId: 742724458
PiperOrigin-RevId: 742737384
…/tsickert/discord-webhook-7.0.0

PiperOrigin-RevId: 742756295
hawkinsp and others added 24 commits April 8, 2025 08:33
See jax-ml#18711

check_rep uses rep=None to indicate when an argument is a constant, and that's useful specifically when checking the backward pass for integer_pow, which has a multiplication by a constant that didn't get a pbroadcast applied to it. That is, we use rep=None as a special carve-out for constants.

The standard rules were compatible with rep=None, but the rules for higher-order primitives like scan and cond were not. So we had to upgrade them.
fixes jax-ml#27683

In b7715e2, specifically this line:

jax-ml@b7715e2#diff-8a1ad6e3b750565d66d30dbf4c9df0825bf5e87c4721e3352f44efbfb8b4a29cR193

we started ignoring the value dtype completely when it was weakly typed. But that could lead to surprising implicit bitcasts like in jax-ml#27683. A repro looks like:

```python
import jax.numpy as jnp
from jax._src import core

v = core.mutable_array(jnp.array([0, 0, 0]))
v[...] += 1.0
print(v)  # MutableArray([1065353216, 1065353216, 1065353216], dtype=int32)
```

We can't easily just drop this behavior because it seems many GPU x64 tests depend on it.

So in this change we're trying to
1. do the casting outside the bind, so that in jaxpr typechecking we can assert the value to assign has to match the ref dtype;
2. make that casting more restrictive, supporting only casts on weak-typed values between different precisions of floats or ints; and
3. do an ordinary cast rather than a bitcast.

I left a TODO to change this behavior, since it seems a bit ad-hoc. But we may not want to remove all implicit casting; for example, it's probably reasonable to support implicit casting of Python builtin numeric types when we don't lose any precision, e.g.

```python
v = core.mutable_array(jnp.array(0, dtype='bfloat16'))
v[...] += 1.0  # don't error!
```

But we can do that with special-purpose carve-outs for Python builtin numerictypes. I left one way to do it in a comment.

PiperOrigin-RevId: 745198669
PiperOrigin-RevId: 745201720
So duplicated load/store ops can be removed.

PiperOrigin-RevId: 745209849
PiperOrigin-RevId: 745212009
Pass pytype_srcs as data to the pybind_extension rule.

PiperOrigin-RevId: 745238783
PiperOrigin-RevId: 745247778
These should be used directly from ml_dtypes.

PiperOrigin-RevId: 745256523
Now that jax-ml@db11efa has landed, we're free to split up xla_extension without creating binary size problems or having to be quite so careful about cross-module dependencies. Here weakref_lru_cache has absolutely nothing to do with XLA.

There's no reason weakref_lru_cache is in the same Python extension as everything else.

PiperOrigin-RevId: 745271825
…es` to True.

The main changes here are:

* Don't take the `_efficient_transpose_rewrite` transformation path anymore. In other words, `RewriteTrace` and all the rewriting machinery is dead.

* Wherever internally we were setting `check_rep=False` explicitly like `_prim_applier`, `_match`, `_unmatch`, `_shard_map_partial_eval`, `_shard_map_partial_eval_custom` (for remat), don't do that anymore. Instead set `check_rep` to the `check_rep` value so that it can be True if the user hasn't passed `check_rep=False`.

* Introduce an internal `_check_rep` context manager and set it wherever `extend_axis_env_nd` is used so that if `check_rep=False` on `shard_map`, JAX will set `vma` in `ShapedArray` to empty `frozenset`.

* Because of point (2), if `check_rep=True`, we can't set `in_specs` and `out_specs` of shmap internally to all manual axes of the mesh on the 0th dim. It needs to be whatever the argument was varying on.

Co-authored-by: Matthew Johnson <[email protected]>
PiperOrigin-RevId: 745276474
…JAX_SKIP_SLOW_TESTS=true

Description:
- Disable second order vjp tests in RunStateHypothesisTest.test_vjp if JAX_SKIP_SLOW_TESTS=true to reduce the test execution time
  - especially for TSAN CI job where this test takes ~700 seconds to pass with the recent 3.13 cpython
- Removed optional deps for 3.14
I expected Mosaic can canonicalize 2 same strided loads to one but it did not. (We will fix this in Mosaic). For now, manually converting to one strided load boosts 20~35% speedup in both v6e and v5e single chip for Meta-Llama-3-8B.

PiperOrigin-RevId: 745294058
This (private) API will shortly be deleted, and hlo_to_stablehlo is its replacement.

PiperOrigin-RevId: 745333506
…houldn't expose this to public API and have users use `psum` instead which will dispatch to `psum_invariant` when `check_rep=True`.

PiperOrigin-RevId: 745352875
PiperOrigin-RevId: 745375892
@rocm-repo-management-api-2 rocm-repo-management-api-2 bot requested a review from a team as a code owner April 9, 2025 06:02
@rocm-repo-management-api-2 rocm-repo-management-api-2 bot enabled auto-merge (rebase) April 9, 2025 06:02
@charleshofer charleshofer disabled auto-merge April 9, 2025 17:58
@charleshofer charleshofer enabled auto-merge (rebase) April 9, 2025 17:59
@charleshofer charleshofer disabled auto-merge April 15, 2025 16:51
@charleshofer charleshofer enabled auto-merge (rebase) April 15, 2025 16:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.