From 132f88e8d57f768ff690f197db30b20de837feb6 Mon Sep 17 00:00:00 2001 From: Charles Hofer Date: Wed, 26 Feb 2025 17:07:44 +0000 Subject: [PATCH 001/483] Fix ROCm builds not finding numa library --- build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm b/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm index 08b6bd3ff8d6..8afe8b17252c 100644 --- a/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm +++ b/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm @@ -9,7 +9,7 @@ ARG ROCM_BUILD_NUM # manylinux base image. However, adding this does fix an issue where Bazel isn't able # to find them. RUN --mount=type=cache,target=/var/cache/dnf \ - dnf install -y gcc-c++-8.5.0-22.el8_10.x86_64 + dnf install -y gcc-c++-8.5.0-22.el8_10.x86_64 numactl-devel RUN --mount=type=cache,target=/var/cache/dnf \ --mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \ From 2bb7dbaa32f2eb42b785edbe377e15b3f5e73f28 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Sun, 9 Mar 2025 01:26:40 +0000 Subject: [PATCH 002/483] add jax.input_saved_vjp to let user pass primal inputs to bwd pass Co-authored-by: Dougal Maclaurin --- jax/_src/api.py | 80 +++++++++++++++++++++++++++++++++++- jax/experimental/__init__.py | 4 ++ tests/api_test.py | 58 ++++++++++++++++++++++++++ 3 files changed, 141 insertions(+), 1 deletion(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 4b14d809621d..9bbaf01d2c50 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -25,6 +25,7 @@ import atexit import collections from collections.abc import Callable, Hashable, Iterable, Sequence +import dataclasses from functools import partial, lru_cache import inspect import math @@ -41,7 +42,8 @@ from jax._src.tree_util import ( tree_map, tree_flatten, tree_unflatten, tree_structure, tree_transpose, tree_leaves, Partial, PyTreeDef, all_leaves, keystr, broadcast_prefix, - prefix_errors, generate_key_paths, tree_flatten_with_path) + prefix_errors, generate_key_paths, tree_flatten_with_path, + equality_errors_pytreedef) from jax._src import config from jax._src import core from jax._src import dispatch @@ -2031,6 +2033,82 @@ def _vjp(fun: lu.WrappedFun, *primals, has_aux=False): return out_primal_py, vjp_py, tree_unflatten(aux_tree, aux) +def saved_input_vjp(f: Callable, which: Sequence[bool], *primals, + allow_unused: bool = True, allow_opaque: bool = True): + if len(which) != len(primals): + raise ValueError( + "length of 'which' argument must equal the number of primal input values, " + f"but got {len(which)=} and {len(primals)=}") + + dbg = debug_info("saved_input_vjp", f, primals, {}) + fun = lu.wrap_init(f, debug_info=dbg) + primals_flat, in_tree = tree_flatten(primals) + fun, out_tree = flatten_fun_nokwargs(fun, in_tree) + out_primals_flat, _, jaxpr, residuals = ad.linearize(fun, *primals_flat) + primals_filt, filt_tree = tree_flatten(tuple(p for w, p in zip(which, primals) if w)) + id_map = {id(x): i for i, x in enumerate(primals_filt)} + opaque_residuals = [] + res_spec = [RSpec(id_map[id(r)], True) if id(r) in id_map else + RSpec(opaque_residuals.append(r) or (len(opaque_residuals) - 1), False) # type: ignore + for r in residuals] + f_vjp = Partial(_saved_input_vjpfun, res_spec, filt_tree, in_tree, out_tree(), + jaxpr, opaque_residuals) + + if not allow_unused and not set(id_map).issubset(res_ids := {id(r) for r in residuals}): + unused = [(i, core.get_aval(x)) for i, (x, w) in enumerate(zip(primals, which)) + if w and id(x) not in res_ids] + assert unused + if len(unused) == 1: + (i, a), = unused + start, was = "an input value", "was" + msg = f" {dbg.arg_names[i]} of type {a.str_short()}" + else: + start, was = "multiple input values", "were" + msg = "\n" + "\n".join(f" * {dbg.arg_names[i]} of type {a.str_short()}" + for i, a in unused) + raise Exception(f"with {allow_unused=}, {start} marked to be saved {was} " + f"not used by the backward pass:{msg}") + + if not allow_opaque and opaque_residuals: + msg = ", ".join(core.get_aval(x).str_short() for x in opaque_residuals) + raise Exception(f"with {allow_opaque=}, the backward pass requires opaque " + f"(non-input) residuals: {msg}") + + out_primals = tree_unflatten(out_tree(), out_primals_flat) + return out_primals, f_vjp + +def _saved_input_vjpfun(res_spec, filtered_tree, in_tree, out_tree, jaxpr, + opaque_residuals, ct, *saved_primals): + primals_filtered, filtered_tree_ = tree_flatten(saved_primals) + if filtered_tree != filtered_tree_: + raise ValueError( + "inputs passed to f_vjp must be a tuple of (pytrees of) " + "arrays with the same structure as\n" + " tuple(x for x, w in zip(inputs, which) if w)\n" + "given the original call\n" + " _, f_vjp = saved_input_vjp(f, which, *inputs, ...)\n" + "but the structures differ:\n" + + "\n".join(f" * inputs{keystr(path)} was a {thing1} in the original " + f"call, but a {thing2} here, so {explanation}" + for path, thing1, thing2, explanation + in equality_errors_pytreedef(filtered_tree, filtered_tree_))) + + residuals = [primals_filtered[i.idx] if i.primal else opaque_residuals[i.idx] + for i in res_spec] + dummy_args = [ad.UndefinedPrimal(v.aval) for v in jaxpr.invars] + cts_flat, out_tree_ = tree_flatten(ct) + assert out_tree_ == out_tree + arg_cts = ad.backward_pass(jaxpr, True, residuals, dummy_args, cts_flat) + return tree_unflatten(in_tree, arg_cts) + +@dataclasses.dataclass(frozen=True) +class RSpec: + idx: int + primal: bool + +si_vjp = saved_input_vjp + + def linear_transpose(fun: Callable, *primals, reduce_axes=()) -> Callable: """Transpose a function that is promised to be linear. diff --git a/jax/experimental/__init__.py b/jax/experimental/__init__.py index 375d058d0edc..6c37635df1b0 100644 --- a/jax/experimental/__init__.py +++ b/jax/experimental/__init__.py @@ -19,6 +19,10 @@ enable_x64 as enable_x64, disable_x64 as disable_x64, ) +from jax._src.api import ( + saved_input_vjp as saved_input_vjp, + si_vjp as si_vjp +) from jax._src.callback import ( io_callback as io_callback ) diff --git a/tests/api_test.py b/tests/api_test.py index c9cf28e0af28..39590f744dcc 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -11496,5 +11496,63 @@ def wsc_as_noop(ctx, operand, *args, **kwargs): self.assertNotIn("stablehlo.custom_call @Sharding", lowered_ir) +class InputSavedVJPTest(jtu.JaxTestCase): + + def test_basic(self): + def f(x, y): + return x * y + + primals = 2., 3. + y, f_vjp = api.si_vjp(f, [True, True], *primals) + arg_cts = f_vjp(1., *primals) + self.assertAllClose(y, 6.) + self.assertAllClose(arg_cts, (3., 2.)) + + def test_basic_unused(self): + f = jnp.sin + primals = 3., + y, f_vjp = api.si_vjp(f, [True], *primals) + x_ct, = f_vjp(1., *primals) + self.assertAllClose(y, jnp.sin(3.)) + self.assertAllClose(x_ct, jnp.cos(3.)) + + with self.assertRaisesRegex(Exception, "not used by the backward pass: x"): + _ = api.si_vjp(f, [True], *primals, allow_unused=False) + + def test_basic_opaque(self): + f = jnp.sin + primals = 3., + with self.assertRaisesRegex(Exception, "the backward pass requires opaque"): + _ = api.si_vjp(f, [True], *primals, allow_opaque=False) + + def test_basic_pytree_error(self): + def f(x): + return [x['hi'] * x['bye']] + + y, f_vjp = api.si_vjp(f, [True], {'hi': 2., 'bye': 3.}) + arg_ct, = f_vjp([1.], {'hi': 2., 'bye': 3.}) + self.assertAllClose(y, [6.]) + self.assertAllClose(arg_ct, {'hi': 3., 'bye': 2.}) + + with self.assertRaisesRegex(ValueError, "but the structures differ"): + f_vjp(1., {'hi': 2.}) + + def test_fsdp(self): + # see https://github.com/jax-ml/jax/pull/27017 for why this is called "fsdp" + def f2(x, w): + x = 1. * x + x = x @ w + x = 2. * x + return x + + x = jnp.ones((3, 4)) + w = jnp.ones((4, 4)) + y, f2_sivjp = api.si_vjp(f2, [False, True], x, w) + y_grad = jnp.ones_like(y) + x_grad, w_grad = f2_sivjp(y_grad, w) + self.assertAllClose(x_grad, 2. * y_grad @ w.T) + self.assertAllClose(w_grad, 2. * x.T @ y_grad) + + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From eb8c908d2daff31559ac556be8564ed09e001509 Mon Sep 17 00:00:00 2001 From: Yu-Hang Tang Date: Mon, 23 Sep 2024 18:33:46 +0000 Subject: [PATCH 003/483] Add CI workflow for JAX distibuted initialize in K8s jobsets --- .github/workflows/k8s.yaml | 105 +++++++++++++++++++++++++++++++ .pre-commit-config.yaml | 1 + examples/k8s/example.yaml | 40 ++++++++++++ examples/k8s/svc-acct.yaml | 31 +++++++++ jax/_src/clusters/k8s_cluster.py | 20 +++--- 5 files changed, 188 insertions(+), 9 deletions(-) create mode 100644 .github/workflows/k8s.yaml create mode 100644 examples/k8s/example.yaml create mode 100644 examples/k8s/svc-acct.yaml diff --git a/.github/workflows/k8s.yaml b/.github/workflows/k8s.yaml new file mode 100644 index 000000000000..5149e79f14b4 --- /dev/null +++ b/.github/workflows/k8s.yaml @@ -0,0 +1,105 @@ +name: Distributed run using K8s Jobset + +on: + push: + branches: + - main + pull_request: + branches: + - main + +permissions: + contents: read + pull-requests: read + actions: write # to cancel previous workflows + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true + +defaults: + run: + shell: bash -ex -o pipefail {0} + +jobs: + + distributed-initialize: + runs-on: ubuntu-22.04 + outputs: + TAG: ${{ steps.metadata.outputs.tags }} + steps: + - name: Checkout + uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # ratchet:actions/checkout@v4 + with: + path: jax + + - name: Start Minikube cluster + uses: medyagh/setup-minikube@d8c0eb871f6f455542491d86a574477bd3894533 # ratchet:medyagh/setup-minikube@v0.0.18 + + - name: Install K8s Jobset + run: | + kubectl apply --server-side -f https://github.com/kubernetes-sigs/jobset/releases/download/v0.6.0/manifests.yaml + + - name: Build image + run: | + cat > Dockerfile < 1 + assert len(jax.devices()) > len(jax.local_devices()) diff --git a/examples/k8s/svc-acct.yaml b/examples/k8s/svc-acct.yaml new file mode 100644 index 000000000000..d05fb9b0cd2a --- /dev/null +++ b/examples/k8s/svc-acct.yaml @@ -0,0 +1,31 @@ +apiVersion: v1 +kind: ServiceAccount +metadata: + name: training-job-sa + namespace: default +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: Role +metadata: + name: pod-reader +rules: + - apiGroups: [""] + resources: ["pods"] + verbs: ["get", "list", "watch"] + - apiGroups: ["batch"] + resources: ["jobs"] + verbs: ["get", "list", "watch"] +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: RoleBinding +metadata: + name: pod-reader-binding + namespace: default +subjects: + - kind: ServiceAccount + name: training-job-sa + namespace: default +roleRef: + kind: Role + name: pod-reader + apiGroup: rbac.authorization.k8s.io diff --git a/jax/_src/clusters/k8s_cluster.py b/jax/_src/clusters/k8s_cluster.py index 1274724b8ebd..a3b415df580a 100644 --- a/jax/_src/clusters/k8s_cluster.py +++ b/jax/_src/clusters/k8s_cluster.py @@ -35,15 +35,17 @@ def is_env_present(cls) -> bool: try: import kubernetes as k8s # pytype: disable=import-error except ImportError as e: - warnings.warn(textwrap.fill( - "Kubernetes environment detected, but the `kubernetes` package is " - "not installed to enable automatic bootstrapping in this " - "environment. To enable automatic boostrapping, please install " - "jax with the [k8s] extra. For example:" - " pip install jax[k8s]" - " OR" - " pip install jax[k8s,]" - )) + warnings.warn( + '\n'.join([ + textwrap.fill( + "Kubernetes environment detected, but the `kubernetes` package " + "is not installed to enable automatic bootstrapping in this " + "environment. To enable automatic boostrapping, please install " + "jax with the [k8s] extra. For example:"), + " pip install jax[k8s]", + " pip install jax[k8s,]", + ]) + ) return False k8s.config.load_incluster_config() From c6ef01d1618510deafde9ee6b91ea658bd105e0d Mon Sep 17 00:00:00 2001 From: Yu-Hang Tang Date: Mon, 17 Mar 2025 18:21:01 +0000 Subject: [PATCH 004/483] address review comments --- .github/workflows/k8s.yaml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/.github/workflows/k8s.yaml b/.github/workflows/k8s.yaml index 5149e79f14b4..4da6a69775c2 100644 --- a/.github/workflows/k8s.yaml +++ b/.github/workflows/k8s.yaml @@ -10,8 +10,6 @@ on: permissions: contents: read - pull-requests: read - actions: write # to cancel previous workflows concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} @@ -25,8 +23,6 @@ jobs: distributed-initialize: runs-on: ubuntu-22.04 - outputs: - TAG: ${{ steps.metadata.outputs.tags }} steps: - name: Checkout uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # ratchet:actions/checkout@v4 From ec994986ca5e0bacf15c324e24e514fd1f4005b8 Mon Sep 17 00:00:00 2001 From: Yu-Hang Tang Date: Mon, 17 Mar 2025 18:28:18 +0000 Subject: [PATCH 005/483] update ratchet action pin --- .github/workflows/k8s.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/k8s.yaml b/.github/workflows/k8s.yaml index 4da6a69775c2..31ee05a03482 100644 --- a/.github/workflows/k8s.yaml +++ b/.github/workflows/k8s.yaml @@ -25,7 +25,7 @@ jobs: runs-on: ubuntu-22.04 steps: - name: Checkout - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # ratchet:actions/checkout@v4 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # ratchet:actions/checkout@v4 with: path: jax From ed43119a86069a777a4e0c045c90bbbbe7accccd Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 18 Mar 2025 21:38:14 -0400 Subject: [PATCH 006/483] JAX release v0.5.3 --- CHANGELOG.md | 2 +- jax/version.py | 2 +- setup.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c30877ecae14..9faff67cf305 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,7 +14,7 @@ Remember to align the itemized text with the first line of an item within a list When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.md. --> -## Unreleased +## jax 0.5.3 * New Features diff --git a/jax/version.py b/jax/version.py index be20aca06358..13df5f00a11b 100644 --- a/jax/version.py +++ b/jax/version.py @@ -146,7 +146,7 @@ def make_release_tree(self, base_dir, files): __version__ = _get_version_string() -_minimum_jaxlib_version = "0.5.1" +_minimum_jaxlib_version = "0.5.3" def _version_as_tuple(version_str): return tuple(int(i) for i in version_str.split(".") if i.isdigit()) diff --git a/setup.py b/setup.py index 80f45285ba61..a5c8500dc1cf 100644 --- a/setup.py +++ b/setup.py @@ -19,11 +19,11 @@ project_name = 'jax' -_current_jaxlib_version = '0.5.1' +_current_jaxlib_version = '0.5.3' # The following should be updated after each new jaxlib release. _latest_jaxlib_version_on_pypi = '0.5.1' -_libtpu_version = '0.0.10.*' +_libtpu_version = '0.0.11.*' def load_version_module(pkg_path): spec = importlib.util.spec_from_file_location( From 8a493129e7dcf4c2c3a3187b4a6ea0ca780ceb04 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Wed, 19 Mar 2025 02:40:10 -0700 Subject: [PATCH 007/483] [mosaic_gpu] Fix usage of `absl::Cleanup` in CUDA events timer. PiperOrigin-RevId: 738315605 --- jaxlib/mosaic/gpu/mosaic_gpu_ext.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc index 4f804c9e2116..8f52dce3b021 100644 --- a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc +++ b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc @@ -100,10 +100,10 @@ static const auto* kEventElapsed = gpuStreamSynchronize(stream); auto start_event = std::make_unique(); auto end_event = std::make_unique(); - absl::MakeCleanup([&]() { + absl::Cleanup cleanup = [&]() { gpuEventDestroy(*start_event); gpuEventDestroy(*end_event); - }); + }; gpuMemcpy(start_event.get(), start.untyped_data(), sizeof(gpuEvent_t), gpuMemcpyDeviceToHost); gpuMemcpy(end_event.get(), end.untyped_data(), sizeof(gpuEvent_t), From 00ce0bee56361cf88de49e11eaf61484895b047c Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Wed, 19 Mar 2025 03:06:24 -0700 Subject: [PATCH 008/483] [mosaic_gpu] Remove unnecessary allocations in CUDA events timer. PiperOrigin-RevId: 738321801 --- jaxlib/mosaic/gpu/mosaic_gpu_ext.cc | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc index 8f52dce3b021..a726acd4d662 100644 --- a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc +++ b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc @@ -98,19 +98,21 @@ static const auto* kEventElapsed = .Ret>() // elapsed_ms .To([](gpuStream_t stream, auto start, auto end, auto out) { gpuStreamSynchronize(stream); - auto start_event = std::make_unique(); - auto end_event = std::make_unique(); + gpuEvent_t start_event = nullptr; + gpuEvent_t end_event = nullptr; + absl::Cleanup cleanup = [&]() { - gpuEventDestroy(*start_event); - gpuEventDestroy(*end_event); + gpuEventDestroy(start_event); + gpuEventDestroy(end_event); }; - gpuMemcpy(start_event.get(), start.untyped_data(), sizeof(gpuEvent_t), + + gpuMemcpy(&start_event, start.untyped_data(), sizeof(gpuEvent_t), gpuMemcpyDeviceToHost); - gpuMemcpy(end_event.get(), end.untyped_data(), sizeof(gpuEvent_t), + gpuMemcpy(&end_event, end.untyped_data(), sizeof(gpuEvent_t), gpuMemcpyDeviceToHost); + float elapsed; - if (auto res = - gpuEventElapsedTime(&elapsed, *start_event, *end_event); + if (auto res = gpuEventElapsedTime(&elapsed, start_event, end_event); res) { return ffi::Error::Internal(absl::StrCat( "Failed to get elapsed time between events: ", ToString(res))); From b0865508a63089d9ce4e0ca4d372e3fd7f2d5cfd Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 19 Mar 2025 04:32:29 -0700 Subject: [PATCH 009/483] [pallas:mosaic_gpu] Dialect lowering can now handle `lax.cond` PiperOrigin-RevId: 738342517 --- jax/_src/pallas/mosaic_gpu/lowering.py | 12 +- .../mosaic/gpu/dialect_lowering.py | 192 +++++++++++++----- .../mosaic/gpu/fragmented_array.py | 62 ++++-- .../mosaic/gpu/layout_inference.py | 42 +++- tests/pallas/mosaic_gpu_test.py | 10 +- 5 files changed, 223 insertions(+), 95 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 6b06e6b7dfc2..2ae51a8b22e8 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -2036,16 +2036,16 @@ def _yielded_values(outs, avals): ret.append(_ensure_ir_value(out, aval.dtype)) return ret - # We need the branch return mlir types in order to construct the - # switch operation. To avoid leaking information about what kind of - # mlir types are internal to FragmentedArrays and other mgpu types, - # we run one of the branches in a dummy module that we throw away to - # extract the return types + # We need to know the result types ahead of time to construct the switch + # operation. Below we lower the first branch in a throw-away module to + # extract them. with ir.InsertionPoint(ir.Module.create().body): outs = lower_jaxpr_to_mosaic_gpu( ctx.module_ctx, ctx.launch_ctx, branches[0].jaxpr, args ) - yielded_types = [v.type for v in jax.tree.leaves(_yielded_values(outs, ctx.avals_out))] + yielded_types = [ + v.type for v in jax.tree.leaves(_yielded_values(outs, ctx.avals_out)) + ] del outs switch_op = scf_dialect.IndexSwitchOp( diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index fedde5a00887..6e7f4a981f9d 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -14,10 +14,11 @@ """Lowering rules and pass for the MLIR Mosaic GPU dialect.""" -from collections.abc import Callable +from collections.abc import Callable, Iterable import dataclasses import functools import itertools +import math import operator from typing import Any, Sequence, Type, cast @@ -34,6 +35,7 @@ from jax._src.lib.mlir.dialects import nvvm from jax._src.lib.mlir.dialects import scf from jax._src.lib.mlir.dialects import vector +from jax.experimental.mosaic.gpu import layouts as layouts_lib import numpy as np from . import fragmented_array as fa @@ -872,6 +874,66 @@ def _slice_smem(result: ir.Type, offset: ir.Value): return memref.view(result, smem_base, offset, []) +# The metadata needed to recostruct a vector from its flattened representation. +_VectorTemplate = tuple[Sequence[int], fa.FragmentedLayout, ir.VectorType] + + +def _flatten_ir_values( + values: Sequence[ir.Value], fa_layouts: Iterable[ir.Attribute] +) -> tuple[Sequence[ir.Value], Sequence[_VectorTemplate | None]]: + """Flattens a sequence of values. + + Non-vector values are preserved as is. Vectors are mapped to fragmented + arrays and then flattened into per-register values. + + Args: + values: The sequence of values to flatten. + fa_layouts: The layouts of vectors in ``values``. + + Returns: + A tuple of (flattened values, templates). The templates are used to + reconstruct the vectors from the per-register values. + """ + fa_layouts_it = iter(fa_layouts) + result = [] + templates = [] + for v in values: + if ir.VectorType.isinstance(v.type): + fa = _fragmented_array_from_ir(v, next(fa_layouts_it)) + result.extend(fa.registers.flat) + templates.append((fa.registers.shape, fa.layout, ir.VectorType(v.type))) + else: + result.append(v) + templates.append(None) + return result, templates + + +def _unflatten_ir_values( + flat_values: Sequence[ir.Value], templates: Sequence[_VectorTemplate | None] +) -> Sequence[ir.Value]: + """The inverse of ``_flatten_ir_values``.""" + result = [] + flat_values_it = iter(flat_values) + for template in templates: + if template is None: + result.append(next(flat_values_it)) + continue + registers_shape, layout, vec_type = template + value_registers = np.asarray( + [next(flat_values_it) for _ in range(math.prod(registers_shape))], + dtype=object, + ) + value = fa.FragmentedArray( + _registers=value_registers.reshape(registers_shape), + _layout=layout, + _is_signed=False + if ir.IntegerType.isinstance(vec_type.element_type) + else None, + ) + result.append(_fragmented_array_to_ir(value, vec_type)) + return result + + @_register_lowering(scf.ForOp) def _for_op_lowering_rule( ctx: LoweringContext, for_op: scf.ForOp @@ -884,60 +946,22 @@ def _for_op_lowering_rule( yield_layouts = inference_utils.in_layouts(yield_op) if in_layouts != out_layouts or in_layouts != yield_layouts: raise ValueError("Layout mismatch") - fa_layouts = in_layouts - - fa_layouts_it = iter(fa_layouts) - arg_template = [ - (_fragmented_array_from_ir(arg, next(fa_layouts_it)), arg.type) - if ir.VectorType.isinstance(arg.type) - else (arg, arg.type) - for arg in for_op.initArgs - ] - def lower_carry(carry): - fa_layouts_it = iter(fa_layouts) - carry_with_fas = [ - _fragmented_array_from_ir(arg, next(fa_layouts_it)) - if ir.VectorType.isinstance(arg.type) - else arg - for arg in carry - ] - lowered_carry = [] - for c in carry_with_fas: - if isinstance(c, fa.FragmentedArray): - lowered_carry.extend(c.registers.flat) - else: - lowered_carry.append(c) - return lowered_carry - - def recreate_carry(lowered_carry): - recreated_carry = [] - arg_it = iter(lowered_carry) - for arg_value, arg_type in arg_template: - if isinstance(arg_value, fa.FragmentedArray): - carry_registers = np.asarray( - [next(arg_it) for _ in arg_value.registers.flat], dtype=object - ) - carry_registers = carry_registers.reshape(arg_value.registers.shape) - carry = fa.FragmentedArray( - _registers=carry_registers, - _layout=arg_value.layout, - _is_signed=arg_value.is_signed, - ) - recreated_carry.append(_fragmented_array_to_ir(carry, arg_type)) - else: - recreated_carry.append(next(arg_it)) - return recreated_carry + flat_init_args, args_template = _flatten_ir_values( + for_op.initArgs, in_layouts + ) new_for_op = scf.ForOp( for_op.lowerBound, for_op.upperBound, for_op.step, - lower_carry(for_op.initArgs), + flat_init_args, ) with ir.InsertionPoint(new_for_op.body): - recreated_carry = recreate_carry(new_for_op.body.arguments[1:]) + recreated_carry = _unflatten_ir_values( + new_for_op.body.arguments[1:], args_template + ) ops_to_lower = [] - for op in for_op.body: + for op in [*for_op.body]: if op == yield_op: continue mgpu.private_operation_remove_from_parent(op) @@ -952,16 +976,80 @@ def recreate_carry(lowered_carry): ctx.lower_op(op) with ir.InsertionPoint(new_for_op.body): - new_yield_operands = lower_carry(yield_op.operands) + flat_operands, _ = _flatten_ir_values(yield_op.operands, in_layouts) yield_op.erase() - scf.yield_(new_yield_operands) - return recreate_carry(new_for_op.results) + scf.yield_(flat_operands) + + return _unflatten_ir_values(new_for_op.results, args_template) + + +def _infer_flat_result_types( + op: ir.OpView, out_layouts: Sequence[ir.Attribute] +) -> Sequence[ir.Type]: + result_types: list[ir.Type] = [] + out_layouts_it = iter(out_layouts) + for r in op.results: + if not ir.VectorType.isinstance(r.type): + result_types.append(r.type) + continue + vec_type = ir.VectorType(r.type) + layout = layouts_lib.from_layout_attr(next(out_layouts_it)) + result_types.extend( + [layout.registers_element_type(vec_type.element_type)] + * math.prod(layout.registers_shape(tuple(vec_type.shape))) + ) + return result_types + + +@_register_lowering(scf.IfOp) +def _if_op_lowering_rule( + ctx: LoweringContext, if_op: scf.IfOp +) -> MlirLoweringRuleResult: + if not inference_utils.should_have_layout(if_op): + return _traverse_op_lowering_rule(ctx, if_op) + + raise NotImplementedError + + +@_register_lowering(scf.IndexSwitchOp) +def _index_switch_op_lowering_rule( + ctx: LoweringContext, switch_op: scf.IndexSwitchOp +) -> MlirLoweringRuleResult: + if not inference_utils.should_have_layout(switch_op): + return _traverse_op_lowering_rule(ctx, switch_op) + + out_layouts = inference_utils.out_layouts(switch_op) + new_switch_op = scf.IndexSwitchOp( + _infer_flat_result_types(switch_op, out_layouts), + switch_op.arg, + switch_op.cases, + len(switch_op.regions) - 1, + ) + + results_template: Sequence[_VectorTemplate | None] = [] + for region, new_region in zip( + switch_op.regions, new_switch_op.regions, strict=True + ): + [block] = region.blocks + new_block = new_region.blocks.append() + with ir.InsertionPoint(new_block): + for op in [*block]: + if not isinstance(op, scf.YieldOp): + mgpu.private_operation_remove_from_parent(op) + mgpu.private_block_append_owned_operation(new_block, op) + ctx.lower_op(op) + continue + if inference_utils.in_layouts(op) != out_layouts: + raise ValueError("Layout mismatch") + flat_results, results_template = _flatten_ir_values( + op.operands, out_layouts + ) + scf.yield_(flat_results) + return _unflatten_ir_values(new_switch_op.results, results_template) @_register_lowering(func.FuncOp) @_register_lowering(gpu.LaunchOp) -@_register_lowering(scf.IfOp) # TODO(apaszke,bchetioui): Add a proper rule. -@_register_lowering(scf.IndexSwitchOp) # TODO(apaszke,bchetioui): Add a proper rule. def _traverse_op_lowering_rule( ctx: LoweringContext, op: ir.OpView ) -> MlirLoweringRuleResult: diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 5daed8416589..4bbfd0dd8afe 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -319,6 +319,9 @@ def tiled_tiling_rank(self) -> int: def vector_length(self) -> int: return self.tiled_tiling_shape[self.vector_dim] + def registers_element_type(self, t: ir.Type) -> ir.Type: + return ir.VectorType.get((self.vector_length,), t) + def registers_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: """Returns the shape of the register array needed to represent an array of the given logical shape.""" tiled_shape = list(self.tiling.tile_shape(shape)) @@ -386,6 +389,19 @@ def _tiled_wgmma_layout(shape: tuple[int, ...]): class WGMMARowFragLayout: """[m] matrix, where m % 64 == 0.""" + def registers_element_type(self, t: ir.Type) -> ir.Type: + return t + + def registers_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: + """Returns the shape of the register array needed to represent an array of the given logical shape.""" + if len(shape) != 1: + raise ValueError("WGMMARowFragLayout requires a 1D shape") + if shape[0] % 64: + raise ValueError( + "WGMMARowFragLayout requires shape[0] to be a multiple of 64" + ) + return (shape[0] // 64, 2) + def thread_idxs(self, shape): index = ir.IndexType.get() assert len(shape) == 1 @@ -435,6 +451,14 @@ def can_broadcast_to(self, shape) -> bool: """ return all(dim1 == dim2 or dim1 == 1 for dim1, dim2 in zip(self.shape[::-1], shape[::-1])) + def registers_element_type(self, t: ir.Type) -> ir.Type: + return t + + def registers_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: + """Returns the shape of the register array needed to represent an array of the given logical shape.""" + del shape # Unused. + return () + def thread_idxs(self, shape): assert shape == self.shape raise NotImplementedError @@ -469,6 +493,15 @@ def from_shaped_type(cls, shaped_ty: ir.Type): shape=tuple(shaped_ty.shape), vec_size=min(8 // bw, max_vec_size) ) + def registers_element_type(self, t: ir.Type) -> ir.Type: + return ir.VectorType.get((self.vec_size,), t) + + def registers_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: + """Returns the shape of the register array needed to represent an array of the given logical shape.""" + if shape != self.shape: + raise ValueError(f"Shape {shape} is not compatible with {self}") + return (math.prod(self.shape) // (WARPGROUP_SIZE * self.vec_size),) + def thread_idxs(self, shape): assert shape == self.shape index = ir.IndexType.get() @@ -626,8 +659,8 @@ def __init__( != math.prod(_registers.shape) * WARPGROUP_SIZE * reg_size ): raise ValueError( - "Invalid register array shape: math.prod({_registers.shape}) *" - " {WARPGROUP_SIZE} * {reg_size}, want: math.prod({shape})" + f"Invalid register array shape: math.prod({_registers.shape}) *" + f" {WARPGROUP_SIZE} * {reg_size}, want: math.prod({shape})" ) # Just a single register @@ -703,30 +736,15 @@ def load_wgmma_row( def splat(cls, value, shape, layout=None, *, is_signed: bool | None = None): layout = layout or WGSplatFragLayout(shape) match layout: - case WGMMARowFragLayout(): - if len(shape) != 1: - raise ValueError("WGMMARowFragLayout requires a 1D shape") - if shape[0] % 64: - raise ValueError( - "WGMMARowFragLayout requires shape[0] to be a multiple of 64" - ) - reg_shape = (shape[0] // 64, 2) - case WGStridedFragLayout(vec_size=vec_size): - assert shape == layout.shape - elems = np.prod(shape) - reg_shape = (elems // (WARPGROUP_SIZE * vec_size),) - value = vector.splat(ir.VectorType.get((vec_size,), value.type), value) - case WGSplatFragLayout(): - assert shape == layout.shape - reg_shape = () - case TiledLayout(): - value = vector.splat(ir.VectorType.get((layout.vector_length,), value.type), value) - reg_shape = layout.registers_shape(shape) + case WGMMARowFragLayout() | WGSplatFragLayout(): + pass + case WGStridedFragLayout() | TiledLayout(): + value = vector.splat(layout.registers_element_type(value.type), value) case _: raise NotImplementedError(layout) return cls( - _registers=np.full(reg_shape, value, dtype=object), + _registers=np.full(layout.registers_shape(shape), value, dtype=object), _layout=layout, _is_signed=is_signed, ) diff --git a/jax/experimental/mosaic/gpu/layout_inference.py b/jax/experimental/mosaic/gpu/layout_inference.py index 0d2811bb5610..470b0d328d8e 100644 --- a/jax/experimental/mosaic/gpu/layout_inference.py +++ b/jax/experimental/mosaic/gpu/layout_inference.py @@ -306,23 +306,46 @@ def _infer_yield_op_layout(op: scf.YieldOp) -> OptionalLayouts: return (layouts, []) -@partial(_add_layout_inference_rule, scf.ForOp) -def _infer_for_op_layout(op: scf.ForOp) -> OptionalLayouts: - yield_op = op.body.operations[len(op.body.operations) - 1] - assert isinstance(yield_op, scf.YieldOp) - - if inference_utils.has_in_layouts_set(yield_op): - yield_layouts = list(inference_utils.in_layouts(yield_op)) +def _infer_from_yield_ops(op: ir.Operation) -> list[ir.Attribute] | None: + candidates = [] + for region in op.regions: + [block] = region.blocks + yield_op = block.operations[len(block.operations) - 1] + assert isinstance(yield_op, scf.YieldOp) + if not inference_utils.has_in_layouts_set(yield_op): + continue + yield_layouts = inference_utils.in_layouts(yield_op) if any( layouts_lib.is_splat_fragmented_layout(layout) for layout in yield_layouts ): - return None - return (yield_layouts, yield_layouts) + continue + candidates.append(yield_layouts) + if not candidates: + return None + return [_choose_representative_layout(set(c)) for c in zip(*candidates)] + +@partial(_add_layout_inference_rule, scf.ForOp) +def _infer_for_op_layout(op: scf.ForOp) -> OptionalLayouts: # TODO(bchetioui): we don't attempt to propagate from outside for the moment. # For the existing kernels, propagating from the YieldOp should be enough. + if layouts := _infer_from_yield_ops(op): + return layouts, layouts + return None + +@partial(_add_layout_inference_rule, scf.IfOp) +def _infer_if_op_layout(op: scf.IfOp) -> OptionalLayouts: + if layouts := _infer_from_yield_ops(op): + return [], layouts + return None + + +@partial(_add_layout_inference_rule, scf.IndexSwitchOp) +def _infer_index_switch_op_layout(op: scf.IndexSwitchOp) -> OptionalLayouts: + if layouts := _infer_from_yield_ops(op): + return [], layouts return None @@ -333,7 +356,6 @@ def _infer_splat_op_layout(splat_op: vector.SplatOp) -> OptionalLayouts: shape=cast(ir.ShapedType, splat_op.result.type).shape ) ) - return [], [layout] diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index b3c3ddb84e09..6792ddfaa9a8 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1121,13 +1121,13 @@ def test_cond_returning_array(self, thread_semantics): ), ) def kernel(x_ref, o_ref): - acc = _sum_same_dtype(x_ref[...]) + acc_sum = _sum_same_dtype(x_ref[...]) acc2, acc = jax.lax.cond( - acc % 2 == 0, - lambda: (acc * 2, acc), - lambda: (acc, acc * 2), + acc_sum % 2 == 0, + lambda: (acc_sum * 2, x_ref[...]), + lambda: (acc_sum, x_ref[...]), ) - o_ref[...] = jnp.broadcast_to(acc + acc2, o_ref.shape) + o_ref[...] = jnp.broadcast_to(_sum_same_dtype(acc) + acc2, o_ref.shape) x = jnp.arange(256, dtype=jnp.int32) np.testing.assert_array_equal(kernel(x), jnp.broadcast_to(jnp.sum(x) * 3, [256])) From 30f770970404785016d3503ab1543540c8c88df0 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 19 Mar 2025 05:14:52 -0700 Subject: [PATCH 010/483] Update XLA dependency to use revision http://github.com/openxla/xla/commit/0d20d73f2c8f21c21b9f343c4363a76e980f032e. PiperOrigin-RevId: 738352930 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 73bf2eb3850d..f81e3931b1dc 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "df971129bd82e381954da0185b534220e21798a4" -XLA_SHA256 = "11e9a568320cf7e7d61819620fd369927527ecefb68d5d1154b1521456bbdb72" +XLA_COMMIT = "0d20d73f2c8f21c21b9f343c4363a76e980f032e" +XLA_SHA256 = "9df61c200b0a54b7a5c55155fa7a454e33d660e6a49239b6980f5a10305fecc5" def repo(): tf_http_archive( From c8032a9904eeb2410995425f817929f507fe22d5 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 19 Mar 2025 08:56:02 -0400 Subject: [PATCH 011/483] Fix line continuation character in Windows wheel build. --- .github/workflows/wheel_win_x64.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/wheel_win_x64.yml b/.github/workflows/wheel_win_x64.yml index 444bc83f2889..912088428fd5 100644 --- a/.github/workflows/wheel_win_x64.yml +++ b/.github/workflows/wheel_win_x64.yml @@ -38,7 +38,7 @@ jobs: JAXLIB_RELEASE: true run: | python -m pip install uv~=0.5.30 - python -m uv pip install -r build/test-requirements.txt \ + python -m uv pip install -r build/test-requirements.txt ` --upgrade numpy==2.0.0 scipy==1.13.1 "C:\\msys64\\;C:\\msys64\\usr\\bin\\;" >> $env:GITHUB_PATH python.exe build\build.py build --wheels=jaxlib ` @@ -58,7 +58,7 @@ jobs: JAX_SKIP_SLOW_TESTS: true PY_COLORS: 1 run: | - python -m uv pip install --find-links ${{ github.workspace }}\dist jaxlib \ + python -m uv pip install --find-links ${{ github.workspace }}\dist jaxlib ` -e ${{ github.workspace }} echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS" pytest -n auto --tb=short tests examples From 133a885e3b7a8347c121dce99eb3a920b6333a9e Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 19 Mar 2025 06:44:35 -0700 Subject: [PATCH 012/483] `use_mesh` and `use_concrete_mesh` should error when used under jit PiperOrigin-RevId: 738376533 --- jax/_src/array.py | 4 ++-- jax/_src/pjit.py | 8 +++++--- jax/_src/sharding_impls.py | 17 +++++++++-------- 3 files changed, 16 insertions(+), 13 deletions(-) diff --git a/jax/_src/array.py b/jax/_src/array.py index b0793d2c3330..e49963ccda9c 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -43,7 +43,7 @@ from jax._src.sharding_impls import ( PmapSharding, SingleDeviceSharding, device_replica_id_map, hashed_index, num_addressable_indices, - local_to_global_shape, use_concrete_mesh) # pyformat: disable + local_to_global_shape, _internal_use_concrete_mesh) # pyformat: disable from jax._src.typing import ArrayLike, DLDeviceType, DTypeLike from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method, cache import numpy as np @@ -1149,7 +1149,7 @@ def shard_device_array(x, devices, indices, sharding): else: # TODO(yashkatariya): Maybe this should be set when we call the handler in # InputsHandler.__call__? - with use_concrete_mesh(None): + with _internal_use_concrete_mesh(None): shards = x._multi_slice(start_indices, limit_indices, removed_dims) aval = core.shaped_abstractify(x) return pxla.batched_device_put(aval, sharding, shards, devices) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index f7a4361ffee2..38ccb4513766 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -68,7 +68,7 @@ NamedSharding, GSPMDSharding, SingleDeviceSharding, PmapSharding, AUTO, UNSPECIFIED, UnspecifiedValue, prepare_axis_resources, parse_flatten_op_sharding, canonicalize_sharding, - flatten_spec) + flatten_spec, _internal_use_concrete_mesh) from jax._src.layout import Layout, DeviceLocalLayout, AutoLayout from jax._src.state import discharge as state_discharge, RefEffect, AbstractRef from jax._src.traceback_util import api_boundary @@ -689,8 +689,10 @@ def _infer_params_cached( def _infer_params( fun: Callable, ji: PjitInfo, args: tuple[Any, ...], kwargs: dict[str, Any] ) -> tuple[PjitParams, list[Any]]: - if ji.use_resource_env: - with sharding_impls.use_mesh(mesh_lib.thread_resources.env.physical_mesh): + if ji.use_resource_env: # pjit + phys_mesh = mesh_lib.thread_resources.env.physical_mesh + with (_internal_use_concrete_mesh(phys_mesh), + mesh_lib.use_abstract_mesh(phys_mesh.abstract_mesh)): return _infer_params_internal(fun, ji, args, kwargs) return _infer_params_internal(fun, ji, args, kwargs) diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 2bbf913783e3..f3295a75cf7a 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -1382,10 +1382,8 @@ def use_mesh(mesh: mesh_lib.Mesh): if not isinstance(mesh, mesh_lib.Mesh): raise ValueError( f"Expected mesh of type `jax.sharding.Mesh`. Got {type(mesh)}") - - # TODO(yashkatariya): Enable this. - # if not core.trace_state_clean(): - # raise ValueError('`use_mesh` can only be used outside of `jax.jit`') + if not core.trace_state_clean(): + raise ValueError('`use_mesh` can only be used outside of `jax.jit`') with mesh_lib.use_abstract_mesh(mesh.abstract_mesh), use_concrete_mesh(mesh): yield @@ -1410,13 +1408,16 @@ def set_mesh(mesh: mesh_lib.Mesh | None) -> mesh_lib.Mesh | None: @contextlib.contextmanager def use_concrete_mesh(mesh: mesh_lib.Mesh | None): + if not core.trace_state_clean(): + raise ValueError('`use_concrete_mesh` can only be used outside of `jax.jit`.') + with _internal_use_concrete_mesh(mesh): + yield + +@contextlib.contextmanager +def _internal_use_concrete_mesh(mesh: mesh_lib.Mesh | None): if mesh is not None and not isinstance(mesh, mesh_lib.Mesh): raise ValueError( f"Expected mesh of type `jax.sharding.Mesh`. Got {type(mesh)}") - # TODO(yashkatariya): Enable this. - # if not core.trace_state_clean(): - # raise ValueError('`use_concrete_mesh` can only be used outside of `jax.jit`.') - prev_val = config.device_context.swap_local(mesh) try: yield From 1e25c44d67a024daa2333b652c578ac3535bb803 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Wed, 19 Mar 2025 07:42:54 -0700 Subject: [PATCH 013/483] [mosaic_gpu] Only `jit` function to profile with cupti if it not already `jit`ted. PiperOrigin-RevId: 738393973 --- jax/experimental/mosaic/gpu/profiler.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/jax/experimental/mosaic/gpu/profiler.py b/jax/experimental/mosaic/gpu/profiler.py index 0c128f88d169..99fefc1adc9c 100644 --- a/jax/experimental/mosaic/gpu/profiler.py +++ b/jax/experimental/mosaic/gpu/profiler.py @@ -21,6 +21,7 @@ import warnings import jax +from jax._src import stages from jax._src.lib import xla_client import jax.numpy as jnp from jaxlib.mlir import ir @@ -98,10 +99,13 @@ def run(*args, **kwargs): def _measure_cupti(f, aggregate): + if not isinstance(f, (stages.Wrapped, stages.Compiled)): + f = jax.jit(f) + def run(*args, **kwargs): mosaic_gpu_lib._mosaic_gpu_ext._cupti_init() try: - results = jax.block_until_ready(jax.jit(f)(*args, **kwargs)) + results = jax.block_until_ready(f(*args, **kwargs)) finally: timings = mosaic_gpu_lib._mosaic_gpu_ext._cupti_get_timings() return results, timings From d7d0aa943e825b89d9d696066f3a7389b1e9bb9e Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Wed, 19 Mar 2025 07:56:34 -0700 Subject: [PATCH 014/483] Move PRNG GPU lowering from jaxlib into JAX. PiperOrigin-RevId: 738398099 --- jax/_src/prng.py | 36 ++++++++++----------- jaxlib/gpu_prng.py | 79 ++++++---------------------------------------- 2 files changed, 25 insertions(+), 90 deletions(-) diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 2fa9b2b37aa4..5fdd673b3454 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -31,6 +31,7 @@ from jax._src import core from jax._src import dispatch from jax._src import dtypes +from jax._src import ffi from jax._src import pretty_printer as pp from jax._src import source_info_util from jax._src import tree_util as tree_util_internal @@ -64,6 +65,13 @@ UINT_DTYPES = { 8: jnp.uint8, 16: jnp.uint16, 32: jnp.uint32, 64: jnp.uint64} +if hasattr(gpu_prng, "registrations"): + for platform, targets in gpu_prng.registrations().items(): + for name, value, api_version in targets: + ffi.register_ffi_target( + name, value, platform=platform, api_version=api_version + ) + # -- PRNG implementation interface class PRNGImpl(NamedTuple): @@ -902,7 +910,7 @@ def _threefry2x32_lowering(key1, key2, x1, x2, use_rolled_loops=True): multiple_results=True) -def _threefry2x32_gpu_lowering_rule(lowering_func, ctx, k1, k2, x1, x2): +def _threefry2x32_gpu_lowering_rule(ctx, k1, k2, x1, x2, *, target_name_prefix): if not config.threefry_gpu_kernel_lowering.value: # back to default lowering return _threefry2x32_lowering_rule(ctx, k1, k2, x1, x2) @@ -917,23 +925,11 @@ def _broadcast(x, aval): return mlir.broadcast_in_dim(ctx, x, aval_out, broadcast_dimensions=range(rank - len(aval.shape), rank)) - out_len = reduce(op.mul, aval_out.shape, 1) - if not core.is_constant_dim(out_len): - length = mlir.eval_dynamic_shape_as_tensor(ctx, [out_len]) - length = mlir.hlo.convert( - ir.RankedTensorType.get((1,), ir.IntegerType.get_signless(64)), - length) - output_shape = mlir.eval_dynamic_shape_as_tensor(ctx, aval_out.shape) - else: - length = int(out_len) # will be passed statically - output_shape = None - - return lowering_func( - (_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)), - (_broadcast(x1, x1_aval), _broadcast(x2, x2_aval)), length, - output_shape, - False, # forward_compatibility_mode - ) + sub_ctx = ctx.replace(avals_in=(aval_out,) * 4) + rule = ffi.ffi_lowering( + f"{target_name_prefix}_threefry2x32_ffi") + return rule(sub_ctx, _broadcast(k1, k1_aval), _broadcast(k2, k2_aval), + _broadcast(x1, x1_aval), _broadcast(x2, x2_aval)) threefry2x32_p = core.Primitive("threefry2x32") @@ -947,11 +943,11 @@ def _broadcast(x, aval): threefry2x32_p, _threefry2x32_cpu_lowering_rule, platform='cpu') mlir.register_lowering( threefry2x32_p, - partial(_threefry2x32_gpu_lowering_rule, gpu_prng.cuda_threefry2x32), + partial(_threefry2x32_gpu_lowering_rule, target_name_prefix='cu'), platform='cuda') mlir.register_lowering( threefry2x32_p, - partial(_threefry2x32_gpu_lowering_rule, gpu_prng.rocm_threefry2x32), + partial(_threefry2x32_gpu_lowering_rule, target_name_prefix='hip'), platform='rocm') diff --git a/jaxlib/gpu_prng.py b/jaxlib/gpu_prng.py index 6f74d5813ce4..b112534c0575 100644 --- a/jaxlib/gpu_prng.py +++ b/jaxlib/gpu_prng.py @@ -12,79 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations -from functools import partial -import itertools +from typing import Any -import jaxlib.mlir.ir as ir - -from jaxlib import xla_client - -from .hlo_helpers import custom_call from .plugin_support import import_from_plugin _cuda_prng = import_from_plugin("cuda", "_prng") _hip_prng = import_from_plugin("rocm", "_prng") -if _cuda_prng: - for _name, _value in _cuda_prng.registrations().items(): - # TODO(danfm): remove after JAX 0.5.1 release - api_version = 1 if "_ffi" in _name else 0 - xla_client.register_custom_call_target(_name, _value, platform="CUDA", - api_version=api_version) - -if _hip_prng: - for _name, _value in _hip_prng.registrations().items(): - # TODO(danfm): remove after JAX 0.5.1 release - api_version = 1 if "_ffi" in _name else 0 - xla_client.register_custom_call_target(_name, _value, platform="ROCM", - api_version=api_version) - - -def _threefry2x32_lowering(prng, platform: str, keys, data, - length: int | ir.Value | None = None, - output_shape: ir.Value | None = None, - forward_compatibility_mode: bool = False): - """ThreeFry2x32 kernel for GPU. - - In presence of dynamic shapes, `length` is an `ir.Value` and `output_shape` - is a 1D tensor describing the shape of the two outputs. - """ - del forward_compatibility_mode - assert len(keys) == 2, keys - assert len(data) == 2, data - assert (ir.RankedTensorType(keys[0].type).element_type == - ir.IntegerType.get_unsigned(32)), keys[0].type - - typ = keys[0].type - dims = ir.RankedTensorType(typ).shape - - for x in itertools.chain(keys, data): - assert x.type == typ, (x.type, typ) - ndims = len(dims) - layout = tuple(range(ndims - 1, -1, -1)) - operand_layouts = [layout] * 4 - operands = [keys[0], keys[1], data[0], data[1]] - - opaque = {} # Use if not forward_compatibility_mode to trigger the FFI (v4). - if isinstance(length, int): - result_shapes = None - else: - assert output_shape is not None - # We also need to pass separately the shapes of the outputs. - result_shapes = [output_shape, output_shape] - - custom_call_target = f"{platform}_threefry2x32_ffi" - return custom_call( - custom_call_target, - api_version=4, - result_types=[typ, typ], - operands=operands, - backend_config=opaque, - operand_layouts=operand_layouts, - result_layouts=[layout] * 2, - result_shapes=result_shapes).results - - -cuda_threefry2x32 = partial(_threefry2x32_lowering, _cuda_prng, "cu") -rocm_threefry2x32 = partial(_threefry2x32_lowering, _hip_prng, "hip") +def registrations() -> dict[str, list[tuple[str, Any, int]]]: + registrations = {"CUDA": [], "ROCM": []} + for platform, module in [("CUDA", _cuda_prng), ("ROCM", _hip_prng)]: + if module: + registrations[platform].extend( + (name, value, int(name.endswith("_ffi"))) + for name, value in module.registrations().items()) + return registrations From 1dcf872c64dbd6bf93dfeebdde869847d9ac5b53 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 19 Mar 2025 08:15:10 -0700 Subject: [PATCH 015/483] Move //jaxlib:pass_boilerplate to //jaxlib/mosaic:pass_boilerplate. This code is Mosaic specific, move it to the Mosaic directory. PiperOrigin-RevId: 738404429 --- jaxlib/BUILD | 11 ----------- jaxlib/mosaic/BUILD | 15 +++++++++++++-- jaxlib/mosaic/dialect/tpu/transforms/serde.h | 2 +- jaxlib/mosaic/gpu/BUILD | 2 +- jaxlib/mosaic/gpu/passes.cc | 2 +- jaxlib/mosaic/gpu/serde.h | 2 +- jaxlib/{ => mosaic}/pass_boilerplate.h | 6 +++--- 7 files changed, 20 insertions(+), 20 deletions(-) rename jaxlib/{ => mosaic}/pass_boilerplate.h (94%) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index a35eabc9a505..a5e8cee08cdc 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -171,17 +171,6 @@ cc_library( ], ) -cc_library( - name = "pass_boilerplate", - hdrs = ["pass_boilerplate.h"], - # compatible with libtpu - deps = [ - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Support", - ], -) - cc_library( name = "handle_pool", hdrs = ["handle_pool.h"], diff --git a/jaxlib/mosaic/BUILD b/jaxlib/mosaic/BUILD index 4cc2530dd7ca..775c34c8e7c7 100644 --- a/jaxlib/mosaic/BUILD +++ b/jaxlib/mosaic/BUILD @@ -60,9 +60,9 @@ cc_library( ]), # compatible with libtpu deps = [ + ":pass_boilerplate", + ":serde", ":tpu_inc_gen", - "//jaxlib:pass_boilerplate", - "//jaxlib/mosaic:serde", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/hash", @@ -279,6 +279,17 @@ filegroup( # compatible with libtpu ) +cc_library( + name = "pass_boilerplate", + hdrs = ["pass_boilerplate.h"], + # compatible with libtpu + deps = [ + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + ], +) + cc_library( name = "serde", srcs = ["serde.cc"], diff --git a/jaxlib/mosaic/dialect/tpu/transforms/serde.h b/jaxlib/mosaic/dialect/tpu/transforms/serde.h index 8685918d3b39..64753a22e7be 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/serde.h +++ b/jaxlib/mosaic/dialect/tpu/transforms/serde.h @@ -9,7 +9,7 @@ #include "mlir/include/mlir/Interfaces/DataLayoutInterfaces.h" #include "mlir/include/mlir/Pass/Pass.h" #include "mlir/include/mlir/Pass/PassRegistry.h" -#include "jaxlib/pass_boilerplate.h" +#include "jaxlib/mosaic/pass_boilerplate.h" namespace mlir::tpu { diff --git a/jaxlib/mosaic/gpu/BUILD b/jaxlib/mosaic/gpu/BUILD index 9249ae256901..abe326474808 100644 --- a/jaxlib/mosaic/gpu/BUILD +++ b/jaxlib/mosaic/gpu/BUILD @@ -52,7 +52,7 @@ cc_library( "serde.h", ], deps = [ - "//jaxlib:pass_boilerplate", + "//jaxlib/mosaic:pass_boilerplate", "//jaxlib/mosaic:serde", "@llvm-project//llvm:Support", "@llvm-project//mlir:DataLayoutInterfaces", diff --git a/jaxlib/mosaic/gpu/passes.cc b/jaxlib/mosaic/gpu/passes.cc index b8c3fbb74c81..1815e18ca927 100644 --- a/jaxlib/mosaic/gpu/passes.cc +++ b/jaxlib/mosaic/gpu/passes.cc @@ -31,7 +31,7 @@ limitations under the License. #include "mlir/include/mlir/Pass/PassRegistry.h" #include "mlir/include/mlir/Support/LLVM.h" #include "mlir/include/mlir/Transforms/DialectConversion.h" -#include "jaxlib/pass_boilerplate.h" +#include "jaxlib/mosaic/pass_boilerplate.h" namespace mosaic { namespace gpu { diff --git a/jaxlib/mosaic/gpu/serde.h b/jaxlib/mosaic/gpu/serde.h index 6187d72b4cd5..d1e25e3f0912 100644 --- a/jaxlib/mosaic/gpu/serde.h +++ b/jaxlib/mosaic/gpu/serde.h @@ -25,7 +25,7 @@ limitations under the License. #include "mlir/include/mlir/Interfaces/DataLayoutInterfaces.h" #include "mlir/include/mlir/Pass/Pass.h" #include "mlir/include/mlir/Pass/PassRegistry.h" -#include "jaxlib/pass_boilerplate.h" +#include "jaxlib/mosaic/pass_boilerplate.h" namespace mosaic::gpu { diff --git a/jaxlib/pass_boilerplate.h b/jaxlib/mosaic/pass_boilerplate.h similarity index 94% rename from jaxlib/pass_boilerplate.h rename to jaxlib/mosaic/pass_boilerplate.h index b9754a8738ee..546981feeef7 100644 --- a/jaxlib/pass_boilerplate.h +++ b/jaxlib/mosaic/pass_boilerplate.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef JAXLIB_PASS_BOILERPLATE_H_ -#define JAXLIB_PASS_BOILERPLATE_H_ +#ifndef JAXLIB_MOSAIC_PASS_BOILERPLATE_H_ +#define JAXLIB_MOSAIC_PASS_BOILERPLATE_H_ #include @@ -64,4 +64,4 @@ class Pass : public ::mlir::OperationPass { } // namespace mlir } // namespace jaxlib -#endif // JAXLIB_PASS_BOILERPLATE_H_ +#endif // JAXLIB_MOSAIC_PASS_BOILERPLATE_H_ From 4893c08441231bd15c20dd76c4acb5d36890cd79 Mon Sep 17 00:00:00 2001 From: Thai Nguyen Date: Wed, 19 Mar 2025 08:32:49 -0700 Subject: [PATCH 016/483] Support bfloat16 and other scalar values in broadcast PiperOrigin-RevId: 738410122 --- jax/_src/pallas/mosaic/lowering.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 10b9de7487eb..2bfd2f357510 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -2346,13 +2346,13 @@ def _bcast(x, y, x_aval, y_aval, out_aval): y_dtype = x_aval.dtype elif x_aval.weak_type: x_dtype = y_aval.dtype - if isinstance(x, (np.ndarray, np.number, int, float)): + if not isinstance(x, ir.Value): if getattr(y, "type", None) == ir.IndexType.get(): mlir_type = y.type else: mlir_type = _dtype_to_ir_type(x_dtype) x = ir_constant(x, mlir_type) - if isinstance(y, (np.ndarray, np.number, int, float)): + if not isinstance(y, ir.Value): if getattr(x, "type", None) == ir.IndexType.get(): mlir_type = x.type else: From fd23fa8cf0f67e3bc82940b827536a61465b08b7 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Wed, 19 Mar 2025 08:58:39 -0700 Subject: [PATCH 017/483] [Mosaic GPU] Remove `transpose_{a,b}` attributes from `mosaic_gpu.WGMMAOp`. Now that we have full control over strides in the lowering, these attributes are no longer necessary. PiperOrigin-RevId: 738418852 --- jax/experimental/mosaic/gpu/dialect_lowering.py | 3 --- jaxlib/mosaic/dialect/gpu/mosaic_gpu.td | 16 ++++------------ 2 files changed, 4 insertions(+), 15 deletions(-) diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index 6e7f4a981f9d..d605a2dea8f9 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -763,9 +763,6 @@ def _bitcast_op_lowering_rule( def _mgpu_wgmma_op_lowering_rule( _: LoweringContext, wgmma_op: mgpu.WGMMAOp ) -> Sequence[ir.Value]: - if wgmma_op.transpose_a or wgmma_op.transpose_b: - raise ValueError("Transpose arguments are to be deleted.") - fa_layouts = ( *inference_utils.in_layouts(wgmma_op), *inference_utils.out_layouts(wgmma_op), diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td index 0882986fcf5e..108ff952b571 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td @@ -394,19 +394,14 @@ def MosaicGPU_WGMMAOp : Op { This operation supports larger inputs than the PTX-level WGMMA operation and will schedule as many PTX-level WGMMA operations as needed to - accomplish the calculation. The `b` matrix, and optionally `a`, needs to be - provided as a 2-dimensional memref. All memrefs may have transforms that - define swizzling, tiling, and transposition. + accomplish the calculation. The `b` matrix, and optionally `a`, need to be + provided as a 2-dimensional memref. The inputs should have the following shapes: - a: [groups_m * 64, groups_k * s] - b: [groups_k * s, groups_n * s] - accumulator: [groups_m * 64, groups_n * s] - Where: - - `s == swizzle/element_bytediwth` (for `kNoSwizzle`, `swizzle` is 16.) - and the tilings are [64, s] for `a` and [s, s] for `b`. - - `a` and/or `b` may be transposed if the corresponding attribute is set - to `true`. + where `s == swizzle / element_bytewidth`. The output has an identical shape and type as the input accumulator. @@ -429,10 +424,7 @@ def MosaicGPU_WGMMAOp : Op { AnyTypeOf<[ MemRefOf<[MosaicGPU_WGMMASupportedType]>, VectorOfAnyRankOf<[MosaicGPU_WGMMASupportedType]>]>:$a, - MemRefOf<[MosaicGPU_WGMMASupportedType]>:$b, - - DefaultValuedOptionalAttr:$transpose_a, - DefaultValuedOptionalAttr:$transpose_b + MemRefOf<[MosaicGPU_WGMMASupportedType]>:$b ); let results = (outs VectorOfAnyRankOf<[MosaicGPU_WGMMASupportedType]>); From af5b2efd3e2fd7b071b95581f01d555451e95c32 Mon Sep 17 00:00:00 2001 From: Jacob Burnim Date: Tue, 18 Mar 2025 16:30:13 -0700 Subject: [PATCH 018/483] Fix input_output_aliases for non-HBM kernel args in TPU interpret mode. --- jax/_src/pallas/mosaic/interpret.py | 140 +++++++++++----------- tests/pallas/tpu_pallas_interpret_test.py | 42 +++++++ 2 files changed, 111 insertions(+), 71 deletions(-) diff --git a/jax/_src/pallas/mosaic/interpret.py b/jax/_src/pallas/mosaic/interpret.py index 1ad7be8154cd..3384026c1f5b 100644 --- a/jax/_src/pallas/mosaic/interpret.py +++ b/jax/_src/pallas/mosaic/interpret.py @@ -13,7 +13,6 @@ # limitations under the License. import collections -from collections.abc import Iterable, Sequence import dataclasses import enum import functools @@ -1283,23 +1282,6 @@ def f(*args, jaxpr): return jax.util.safe_map(read, jaxpr.outvars) -def _initialize_output_vals( - block_mappings_output: Iterable[BlockMapping], - input_args, input_output_aliases, - interpret_params: TPUInterpretParams, -) -> Sequence[jax.Array]: - oi_map = {v: k for k, v in input_output_aliases} - output_vals = [] - for i, bm in enumerate(block_mappings_output): - if i in oi_map: - output_vals.append(input_args[oi_map[i]]) - else: - output_vals.append(_uninitialized_value( - bm.array_shape_dtype.shape, - bm.array_shape_dtype.dtype, - interpret_params)) - return output_vals - def _compute_start_indices(block_mapping, loop_idx, *args): block_indices = ( jax_core.jaxpr_as_fun(block_mapping.index_map_jaxpr)(*loop_idx, *args)) @@ -1423,30 +1405,52 @@ def interpret_pallas_call( for a, bs in zip(input_args, block_shapes[:num_inputs]) ] - # Allocate buffers in HBM for outputs. - output_buffer_ids = [] - output_buffer_shapes = [] - output_vals = _initialize_output_vals( - grid_mapping.block_mappings_output, - scalars + input_args, - input_output_aliases, - interpret_params) - num_outputs = grid_mapping.num_outputs - output_block_shapes = block_shapes[num_inputs : num_inputs + num_outputs] - for out_val, bs in zip(output_vals, output_block_shapes): - padded_val = _pad_to_block_dimension(out_val, bs, interpret_params) - output_buffer_shapes.append(padded_val.shape) - output_buffer_ids.append(callback.io_callback( + # Allocate HBM buffers for pallas_call inputs. + # + # TODO(jburnim): As an optimization, skip allocating buffers for inputs that + # are neither aliased nor passed to the kernel in HBM? + input_buffer_ids = [] + for i, var in enumerate( + jaxpr.invars[grid_mapping.num_index_operands:][:grid_mapping.num_inputs]): + input_buffer_ids.append(callback.io_callback( _allocate_buffer, jax.ShapeDtypeStruct((), jnp.int16), device_id, TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], - padded_val, + input_args[i], ordered=True)) - # Allocate buffers for all kernel arguments (e.g., scalars, inputs, - # outputs, scratch). - io_alias_map = dict(input_output_aliases) + + # Allocate buffers in HBM for pallas_call outputs. oi_alias_map = {v: k for k, v in input_output_aliases} + output_buffer_ids = [] + output_buffer_shapes = [] + output_vals = [] + num_outputs = grid_mapping.num_outputs + output_block_shapes = block_shapes[num_inputs : num_inputs + num_outputs] + for i, bm in enumerate(grid_mapping.block_mappings_output): + if i in oi_alias_map: + # Re-use the HBM buffer for the aliased pallas_call input. + output_buffer_ids.append(input_buffer_ids[oi_alias_map[i]]) + output_buffer_shapes.append(input_args[oi_alias_map[i]].shape) + output_vals.append(input_args[oi_alias_map[i]]) + else: + out_val = _uninitialized_value(bm.array_shape_dtype.shape, + bm.array_shape_dtype.dtype, + interpret_params) + padded_val = _pad_to_block_dimension( + out_val, output_block_shapes[i], interpret_params) + output_buffer_ids.append(callback.io_callback( + _allocate_buffer, + jax.ShapeDtypeStruct((), jnp.int16), + device_id, + TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], + padded_val, + ordered=True)) + output_buffer_shapes.append(padded_val.shape) + output_vals.append(out_val) + + # Allocate buffers for non-HBM kernel arguments (e.g., scalars, inputs, + # outputs, scratch). kernel_buffer_ids = [] for _, val in zip(jaxpr.invars[grid_mapping.slice_index_ops], scalars): kernel_buffer_ids.append(callback.io_callback( @@ -1467,23 +1471,18 @@ def interpret_pallas_call( device_id, var.aval.shape, ordered=True)) - elif is_output and _is_any(var.aval.memory_space): - # Use the already-allocated HBM output buffer. + elif _is_any(var.aval.memory_space): + # Use the already-allocated HBM input or output buffer. # - # TODO(jburnim): For kernel args in HBM, check that block shape is the - # same as for the corresponding pallas_call input, and that the index_map + # TODO(jburnim): For kernel args in HBM, check that block shape eqals the + # shape of the corresponding pallas_call input, and that the index_map # is trivial. - kernel_buffer_ids.append(output_buffer_ids[output_idx]) - elif is_output and (output_idx in oi_alias_map): - # Use the already-allocated (non-HBM) input buffer. - kernel_buffer_ids.append(kernel_buffer_ids[oi_alias_map[output_idx]]) - elif is_input and (i in io_alias_map) and _is_any(var.aval.memory_space): - # Use the already-allocated HBM output buffer. - kernel_buffer_ids.append(output_buffer_ids[io_alias_map[i]]) + assert is_input ^ is_output + if is_input: + kernel_buffer_ids.append(input_buffer_ids[i]) + if is_output: + kernel_buffer_ids.append(output_buffer_ids[output_idx]) else: - # TODO(jburnim): For kernel args in HBM, check that block shape is the - # same as for the corresponding pallas_call input, and that the index_map - # is trivial. kernel_buffer_ids.append(callback.io_callback( _allocate_buffer, jax.ShapeDtypeStruct((), jnp.int16), @@ -1499,24 +1498,6 @@ def interpret_pallas_call( input_vars, output_vars = split_list( jaxpr.invars[grid_mapping.slice_block_ops], [num_inputs]) - # For kernel inputs that are in HBM, we populate the buffer once before - # any kernel invocations. - for buffer_id, var, val in zip(input_ids, input_vars, input_args): - if not _is_any(var.aval.memory_space): - continue - if (val.shape != var.aval.shape) or (val.dtype != var.aval.dtype): - # TODO(jburnim): Also check that the index_map is trivial. - raise ValueError() - callback.io_callback( - store, - (), - device_id, - TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], - buffer_id, - (), - val, - ordered=True) - if grid: num_iterations = functools.reduce(jnp.multiply, grid) # type: ignore[arg-type] else: @@ -1547,9 +1528,26 @@ def body(carry): for j, var in enumerate(input_vars): if _is_any(var.aval.memory_space): continue - sliced_val = _maybe_dynamic_slice(start_indices[j], block_shapes[j], - input_args[j], is_indexing_dim[j]) - assert(sliced_val.shape == var.aval.shape) + # Copy from the HBM buffer for the pallas_call input to the kernel + # input buffer. + # TODO(jburnim): Just use input_args[j] when the input is not aliased? + transform = indexing.NDIndexer( + indices=tuple(indexing.ds(st, sz) if not iid else st + for st, sz, iid in zip(start_indices[j], + block_shapes[j], + is_indexing_dim[j])), + shape=input_args[j].shape, + int_indexer_shape=()) + sliced_val = callback.io_callback( + # TODO(jburnim): Pass source_info from the pallas_call, in case this + # read is involved in a data race. + get, + jax.ShapeDtypeStruct(var.aval.shape, var.aval.dtype), + device_id, + TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], + input_buffer_ids[j], + (transform,), + ordered=True) callback.io_callback( # TODO(jburnim): Pass source_info from the pallas_call, in case this # store is involved in a data race. diff --git a/tests/pallas/tpu_pallas_interpret_test.py b/tests/pallas/tpu_pallas_interpret_test.py index bc589855b836..9b8a5b46865d 100644 --- a/tests/pallas/tpu_pallas_interpret_test.py +++ b/tests/pallas/tpu_pallas_interpret_test.py @@ -69,6 +69,7 @@ def matmul(x: jax.Array, y: jax.Array): np.testing.assert_allclose(z, x @ y, atol=1e-4) def test_dynamic_grid_and_aliasing(self): + self.skipTest('Broken pending fix to extra reads/writes of inputs/outputs') def kernel(s_ref, x_ref, o_ref): o_ref[...] = x_ref[...] + s_ref[0].astype(x_ref.dtype) @@ -91,8 +92,49 @@ def f(s, x): s = jnp.array([1], dtype=jnp.int32) x = jnp.arange(32 * 128.).reshape((32, 128)) y = f(s, x) + # NOTE: No matter how many times the kernel body is run, the kernel input + # buffer will only be written once by the pallas_call machinery, just + # before the first iteration. So the output will be x + 1 , despite the + # aliasing in HBM. np.testing.assert_allclose(y, x + 1.0) + def test_aliasing(self): + def kernel(x_ref, o_ref, s_ref): + @pl.when((pl.program_id(0) == 0) & (pl.program_id(1) == 0)) + def _(): + s_ref[0] = jnp.int32(0) + s = s_ref[0] + s_ref[0] = s + 1 + o_ref[:] = x_ref[:] + s.astype(x_ref.dtype) + + x = jnp.zeros((4 * 8, 4 * 128)) + y = pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + grid=(4, 4), + in_specs=[ + pl.BlockSpec(block_shape=(8, 128), index_map=lambda i, j: (i, j)), + ], + out_specs=pl.BlockSpec( + block_shape=(8, 128), index_map=lambda i, j: (j, i)), + scratch_shapes=(pltpu.SMEM((1,), jnp.int32),), + input_output_aliases={0: 0}, + interpret=mosaic_interpret.TPUInterpretParams(), + )(x) + + expected = np.zeros((4, 4)) + t = 0 + for i in range(4): + for j in range(4): + expected[j, i] = expected[i, j] + t + t += 1 + # NOTE: expected is + # [[0, 5, 10, 15], + # [1, 5, 15, 20], + # [2, 6, 10, 25], + # [3, 7, 11, 15]] + np.testing.assert_allclose(y[::8, ::128], expected) + @parameterized.parameters('eager', 'on_wait') def test_race_detection(self, dma_execution_mode): def kernel_without_race(x_ref, o_ref, t_ref, sem): From dde861af5fcf7d56863cce5afd671720df975cf4 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 19 Mar 2025 09:05:05 -0700 Subject: [PATCH 019/483] Remove the jax Array migration guide from the TOC tree but keep the doc around PiperOrigin-RevId: 738421256 --- docs/jax_array_migration.md | 3 +++ docs/notes.rst | 4 ---- jax/_src/pjit.py | 8 +++----- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/docs/jax_array_migration.md b/docs/jax_array_migration.md index 95d4a632a295..a557f4ae7efc 100644 --- a/docs/jax_array_migration.md +++ b/docs/jax_array_migration.md @@ -1,3 +1,6 @@ +--- +orphan: true +--- (jax-array-migration)= # jax.Array migration diff --git a/docs/notes.rst b/docs/notes.rst index 08265638000e..24a9dc8594cd 100644 --- a/docs/notes.rst +++ b/docs/notes.rst @@ -9,9 +9,6 @@ Dependencies and version compatibility: - :doc:`api_compatibility` outlines JAX's policies with regard to API compatibility across releases. - :doc:`deprecation` outlines JAX's policies with regard to compatibility with Python and NumPy. -Migrations and deprecations: - - :doc:`jax_array_migration` summarizes the changes to the default array type in jax v 0.4.1 - Memory and computation usage: - :doc:`async_dispatch` describes JAX's asynchronous dispatch model. - :doc:`concurrency` describes how JAX interacts with other Python concurrency. @@ -27,7 +24,6 @@ Programmer guardrails: api_compatibility deprecation - jax_array_migration async_dispatch concurrency gpu_memory_allocation diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 38ccb4513766..b6024dcdfedd 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1573,14 +1573,12 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding] 'Passing non-trivial shardings for numpy ' 'inputs is not allowed. To fix this error, either specify a ' 'replicated sharding explicitly or use ' - '`jax.experimental.multihost_utils.host_local_array_to_global_array(...)` ' + '`jax.make_array_from_process_local_data(...)` ' 'to convert your host local numpy inputs to a jax.Array which you ' - 'can pass to pjit. ' + 'can pass to jit. ' 'If the numpy input is the same on each process, then you can use ' '`jax.make_array_from_callback(...) to create a `jax.Array` which ' - 'you can pass to pjit. ' - 'Please see the jax.Array migration guide for more information ' - 'https://jax.readthedocs.io/en/latest/jax_array_migration.html#handling-of-host-local-inputs-to-pjit-like-batch-etc. ' + 'you can pass to jit. ' f'Got arg shape: {arg.shape}, arg value: {arg}') if not isinstance(arg_s, UnspecifiedValue) and arg_s._is_concrete: # jax.jit does not allow resharding across different memory kinds even From dd93eeae2e603352f2a57afb5c8af432bacbbdcf Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 19 Mar 2025 09:21:12 -0700 Subject: [PATCH 020/483] [JAX] Move py_client_gpu into JAX. This callback functionality is only used by JAX and shipped as part of its CUDA and ROCM GPU plugins. Move it into JAX, as part of a wider move of xla/python pieces that belong to JAX into JAX. PiperOrigin-RevId: 738426489 --- jaxlib/cuda/BUILD | 44 ++++ jaxlib/cuda/cuda_plugin_extension.cc | 13 ++ jaxlib/gpu/BUILD | 3 +- jaxlib/gpu/gpu_plugin_extension.cc | 9 - jaxlib/gpu/py_client_gpu.cc | 295 +++++++++++++++++++++++++++ jaxlib/gpu/py_client_gpu.h | 37 ++++ jaxlib/gpu/vendor.h | 2 + jaxlib/rocm/BUILD | 44 ++++ jaxlib/rocm/rocm_plugin_extension.cc | 12 ++ 9 files changed, 449 insertions(+), 10 deletions(-) create mode 100644 jaxlib/gpu/py_client_gpu.cc create mode 100644 jaxlib/gpu/py_client_gpu.h diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index a9bd35b7768d..23ab64aa2d01 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -657,11 +657,55 @@ py_library( ], ) +cc_library( + name = "py_client_gpu", + srcs = ["//jaxlib/gpu:py_client_gpu.cc"], + hdrs = ["//jaxlib/gpu:py_client_gpu.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":cuda_vendor", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@nanobind", + "@xla//xla:comparison_util", + "@xla//xla:shape_util", + "@xla//xla/ffi", + "@xla//xla/ffi:ffi_api", + "@xla//xla/pjrt:exceptions", + "@xla//xla/pjrt:host_callback", + "@xla//xla/pjrt:transpose", + "@xla//xla/python:callback", + "@xla//xla/python:nb_numpy", + "@xla//xla/python:py_host_callback", + "@xla//xla/python:types", + "@xla//xla/python/ifrt", + "@xla//xla/service:custom_call_status", + "@xla//xla/service:custom_call_target_registry", + "@xla//xla/service:platform_util", + "@xla//xla/tsl/concurrency:ref_count", + "@xla//xla/tsl/platform:errors", + "@xla//xla/tsl/platform:statusor", + ], +) + nanobind_extension( name = "cuda_plugin_extension", srcs = ["cuda_plugin_extension.cc"], module_name = "cuda_plugin_extension", deps = [ + ":py_client_gpu", + "//jaxlib:kernel_nanobind_helpers", "//jaxlib/gpu:gpu_plugin_extension", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", diff --git a/jaxlib/cuda/cuda_plugin_extension.cc b/jaxlib/cuda/cuda_plugin_extension.cc index 8d8514bd2740..789227e273b6 100644 --- a/jaxlib/cuda/cuda_plugin_extension.cc +++ b/jaxlib/cuda/cuda_plugin_extension.cc @@ -21,12 +21,15 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "third_party/gpus/cuda/include/cuda.h" #include "jaxlib/gpu/gpu_plugin_extension.h" +#include "jaxlib/gpu/py_client_gpu.h" +#include "jaxlib/kernel_nanobind_helpers.h" #include "xla/pjrt/status_casters.h" namespace nb = nanobind; namespace xla { namespace { + static std::string ToString(CUresult result) { const char* error_name; if (cuGetErrorName(result, &error_name)) { @@ -38,10 +41,20 @@ static std::string ToString(CUresult result) { } return absl::StrCat(error_name, ": ", error_string); } + +nb::dict Registrations() { + nb::dict dict; + dict["xla_python_gpu_callback"] = + jax::EncapsulateFunction(jax::cuda::XlaPythonGpuCallback); + return dict; +} + } // namespace NB_MODULE(cuda_plugin_extension, m) { BuildGpuPluginExtension(m); + m.def("registrations", &Registrations); + m.def( "get_device_ordinal", [](std::intptr_t data_value) { diff --git a/jaxlib/gpu/BUILD b/jaxlib/gpu/BUILD index b5292746dd10..de55989bf73f 100644 --- a/jaxlib/gpu/BUILD +++ b/jaxlib/gpu/BUILD @@ -52,6 +52,8 @@ exports_files(srcs = [ "prng_kernels.cc", "prng_kernels.cu.cc", "prng_kernels.h", + "py_client_gpu.cc", + "py_client_gpu.h", "rnn.cc", "rnn_kernels.cc", "rnn_kernels.h", @@ -115,7 +117,6 @@ cc_library( "@xla//xla/pjrt/c:pjrt_c_api_hdrs", "@xla//xla/pjrt/c:pjrt_c_api_helpers", "@xla//xla/pjrt/c:pjrt_c_api_triton_extension_hdrs", - "@xla//xla/python:py_client_gpu", "@xla//xla/tsl/python/lib/core:numpy", ], ) diff --git a/jaxlib/gpu/gpu_plugin_extension.cc b/jaxlib/gpu/gpu_plugin_extension.cc index b56cb8337f1b..5726e0929ee5 100644 --- a/jaxlib/gpu/gpu_plugin_extension.cc +++ b/jaxlib/gpu/gpu_plugin_extension.cc @@ -35,7 +35,6 @@ limitations under the License. #include "xla/pjrt/c/pjrt_c_api_helpers.h" #include "xla/pjrt/c/pjrt_c_api_triton_extension.h" #include "xla/pjrt/status_casters.h" -#include "xla/python/py_client_gpu.h" #include "xla/tsl/python/lib/core/numpy.h" #include "xla/util.h" @@ -202,13 +201,6 @@ absl::Status RegisterCustomTypeId(const PJRT_Api* c_api, return absl::OkStatus(); } -nb::dict Registrations() { - nb::dict dict; - dict["xla_python_gpu_callback"] = - jax::EncapsulateFunction(xla::XlaPythonGpuCallback); - return dict; -} - } // namespace void BuildGpuPluginExtension(nanobind::module_& m) { @@ -264,7 +256,6 @@ void BuildGpuPluginExtension(nanobind::module_& m) { type_name_size, std::move(type_id))); }, nb::arg("c_api"), nb::arg("type_name"), nb::arg("type_id")); - m.def("registrations", &Registrations); } } // namespace xla diff --git a/jaxlib/gpu/py_client_gpu.cc b/jaxlib/gpu/py_client_gpu.cc new file mode 100644 index 000000000000..d6faa1859eb8 --- /dev/null +++ b/jaxlib/gpu/py_client_gpu.cc @@ -0,0 +1,295 @@ +/* Copyright 2022 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include + +#include "nanobind/nanobind.h" +#include "absl/algorithm/container.h" +#include "absl/base/casts.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/ascii.h" +#include "absl/strings/numbers.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/include/llvm/Support/Casting.h" +#include "jaxlib/gpu/vendor.h" +#include "xla/ffi/ffi.h" +#include "xla/ffi/ffi_api.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/host_callback.h" +#include "xla/pjrt/transpose.h" +#include "xla/primitive_util.h" +#include "xla/python/callback.h" +#include "xla/python/ifrt/host_callback.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/py_host_callback.h" +#include "xla/python/types.h" +#include "xla/service/custom_call_status.h" +#include "xla/service/custom_call_target_registry.h" +#include "xla/shape_util.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/statusor.h" + +namespace nb = nanobind; + +namespace jax { +namespace JAX_GPU_NAMESPACE { + +void XlaPythonGpuCallback(gpuStream_t stream, void** buffers, + const char* opaque, size_t opaque_len, + XlaCustomCallStatus* status) { + // Ignore `descriptor` arg to callback + buffers += 1; + uint64_t descriptor; + if (!absl::SimpleAtoi(opaque, &descriptor)) { + throw xla::XlaRuntimeError("Invalid callback descriptor"); + return; + } + xla::CpuCallback* callback = + absl::bit_cast(static_cast(descriptor)); + size_t arity = callback->num_args(); + std::vector host_input_buffers(arity); + // Copy input GPU buffers to host + for (size_t i = 0; i < arity; ++i) { + const xla::CpuCallback::Arg& arg = callback->args()[i]; + if (arg.type == xla::TOKEN) { + host_input_buffers[i] = nullptr; + continue; + } + void* buf = new char[arg.size_in_bytes]; + host_input_buffers[i] = buf; + // TODO(b/238441608): Use pinned memory here to speed up the transfer. + auto gpu_res = gpuMemcpyAsync(buf, buffers[i], arg.size_in_bytes, + gpuMemcpyDeviceToHost, stream); + CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; + } + CHECK_EQ(gpuStreamSynchronize(stream), gpuSuccess) + << "Failed to gpuStreamSynchronize"; + nb::gil_scoped_acquire gil; + nb::tuple host_input_arrays = nb::steal(PyTuple_New(arity)); + for (size_t i = 0; i < arity; ++i) { + xla::CpuCallback::Arg arg = callback->args()[i]; + if (arg.type == xla::TOKEN) { + PyTuple_SET_ITEM(host_input_arrays.ptr(), i, nb::none().inc_ref().ptr()); + continue; + } + nb::capsule base(host_input_buffers[i], [](void* ptr) noexcept { + delete[] static_cast(ptr); + }); + auto array = xla::nb_numpy_ndarray(arg.dtype, arg.dims, arg.strides, + const_cast(host_input_buffers[i]), + /*base=*/base); + array.attr("flags").attr("writeable") = nb::bool_(false); + PyTuple_SET_ITEM(host_input_arrays.ptr(), i, array.inc_ref().ptr()); + } + xla::EnterHostCallback(); + absl::StatusOr maybe_result_tuple = + callback->Call(host_input_arrays); + xla::LeaveHostCallback(); + if (!maybe_result_tuple.ok()) { + absl::string_view msg = maybe_result_tuple.status().message(); + XlaCustomCallStatusSetFailure(status, msg.data(), msg.length()); + return; + } + nb::tuple result_tuple = maybe_result_tuple.value(); + std::vector temp_buffers; + for (size_t i = 0; i < callback->results().size(); ++i) { + xla::CpuCallback::Result result = callback->results()[i]; + if (result.type == xla::TOKEN) { + continue; + } + nb::object output = + nb::borrow(PyTuple_GetItem(result_tuple.ptr(), i)); + xla::nb_numpy_ndarray array = + xla::nb_numpy_ndarray::ensure(std::move(output)); + absl::Span dims( + reinterpret_cast(array.shape()), array.ndim()); + absl::Span strides( + reinterpret_cast(array.strides()), array.ndim()); + if (strides == result.expected_strides) { + auto gpu_res = + gpuMemcpyAsync(buffers[arity + i], array.data(), result.size_in_bytes, + gpuMemcpyHostToDevice, stream); + CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; + } else { + void* temp = new char[result.size_in_bytes]; + temp_buffers.push_back(temp); + xla::TransposePlan::Options options; + options.elem_size_in_bytes = xla::primitive_util::ByteWidth(result.type); + options.dims = dims; + options.permutation = result.reversed_layout; + options.input_layout = xla::TransposePlan::Striding{strides}; + absl::StatusOr> plan = + callback->transpose_cache().GetOrCreate(options); + if (!plan.ok()) { + throw xla::XlaRuntimeError(plan.status().ToString()); + } + plan.value()->Execute(array.data(), temp); + auto gpu_res = + gpuMemcpyAsync(buffers[arity + i], temp, result.size_in_bytes, + gpuMemcpyHostToDevice, stream); + CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; + } + } + nb::gil_scoped_release release; + CHECK_EQ(gpuStreamSynchronize(stream), gpuSuccess) + << "Failed to gpuStreamSynchronize"; + for (int i = 0; i < temp_buffers.size(); ++i) { + delete[] static_cast(temp_buffers[i]); + } +} + +// TODO(danfm): When compiled as part of a jaxlib plugin, this will register +// the custom call target in the plugin's registry. This won't affect +// registration via the Python API, but we should remove this once we have +// fully migrated to the plugin interface. +XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( + "xla_python_gpu_callback", &XlaPythonGpuCallback, + absl::AsciiStrToUpper(JAX_GPU_PLUGIN_NAME)); + +absl::Status XlaFfiPythonGpuCallback( + gpuStream_t stream, + std::vector>* callbacks, + uint64_t index, xla::ffi::RemainingArgs args, + xla::ffi::RemainingRets rets) { + auto loaded_callback = llvm::dyn_cast_or_null( + callbacks->at(index).get()); + if (loaded_callback == nullptr) { + return absl::InternalError( + "Expected a PyCpuLoadedHostCallback, got something else."); + } + xla::CpuCallback* callback = loaded_callback->cpu_callback(); + size_t arity = args.size(); + std::vector host_input_buffers(arity); + // Copy input GPU buffers to host + for (size_t i = 0; i < arity; ++i) { + auto arg = args.get(i); + if (arg->element_type() == xla::TOKEN) { + host_input_buffers[i] = nullptr; + continue; + } + void* buf = new char[arg->size_bytes()]; + host_input_buffers[i] = buf; + // TODO(b/238441608): Use pinned memory here to speed up the transfer. + auto gpu_res = + gpuMemcpyAsync(buf, arg.value().untyped_data(), arg->size_bytes(), + gpuMemcpyDeviceToHost, stream); + CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; + } + CHECK_EQ(gpuStreamSynchronize(stream), gpuSuccess) + << "Failed to gpuStreamSynchronize"; + nb::gil_scoped_acquire gil; + nb::tuple host_input_arrays = nb::steal(PyTuple_New(arity)); + for (size_t i = 0; i < arity; ++i) { + auto arg = args.get(i); + xla::PrimitiveType ptype = arg->element_type(); + if (ptype == xla::TOKEN) { + PyTuple_SET_ITEM(host_input_arrays.ptr(), i, nb::none().inc_ref().ptr()); + } else { + nb::capsule base(host_input_buffers[i], [](void* ptr) noexcept { + delete[] static_cast(ptr); + }); + TF_ASSIGN_OR_RETURN(auto dtype, xla::PrimitiveTypeToNbDtype(ptype)); + auto array = xla::nb_numpy_ndarray(dtype, arg->dimensions(), std::nullopt, + host_input_buffers[i], base); + array.attr("flags").attr("writeable") = nb::bool_(false); + PyTuple_SET_ITEM(host_input_arrays.ptr(), i, array.inc_ref().ptr()); + } + } + + xla::EnterHostCallback(); + // TODO(dsuo): Change this to use the Python vectorcall protocol, which allows + // you to avoid constructing a tuple for the arguments. + absl::StatusOr maybe_result_tuple = + callback->FfiCall(host_input_arrays); + xla::LeaveHostCallback(); + TF_ASSIGN_OR_RETURN(auto result_tuple, maybe_result_tuple); + + std::vector temp_buffers; + for (size_t i = 0; i < rets.size(); ++i) { + auto ret = rets.get(i).value(); + auto ptype = ret->element_type(); + if (ptype == xla::TOKEN) continue; + nb::object output = + nb::borrow(PyTuple_GetItem(result_tuple.ptr(), i)); + xla::nb_numpy_ndarray array = + xla::nb_numpy_ndarray::ensure(std::move(output)); + absl::Span strides( + reinterpret_cast(array.strides()), array.ndim()); + // We expect the output to be in default numpy layout. + TF_ASSIGN_OR_RETURN(auto expected_shape, xla::ShapeUtil::MakeValidatedShape( + ptype, ret->dimensions())); + auto expected_strides = ByteStridesForShape(expected_shape); + if (strides == expected_strides) { + auto gpu_res = + gpuMemcpyAsync(ret->untyped_data(), array.data(), ret->size_bytes(), + gpuMemcpyHostToDevice, stream); + CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; + } else { + void* temp = new char[ret->size_bytes()]; + temp_buffers.push_back(temp); + xla::TransposePlan::Options options; + options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); + absl::Span dims( + reinterpret_cast(array.shape()), array.ndim()); + options.dims = dims; + absl::InlinedVector reversed_layout; + reversed_layout.resize(expected_shape.rank()); + absl::c_reverse_copy(expected_shape.layout().minor_to_major(), + reversed_layout.begin()); + options.permutation = reversed_layout; + options.input_layout = xla::TransposePlan::Striding{strides}; + TF_ASSIGN_OR_RETURN(auto plan, + callback->transpose_cache().GetOrCreate(options)); + plan->Execute(array.data(), temp); + auto gpu_res = + gpuMemcpyAsync(ret->untyped_data(), temp, ret->size_bytes(), + gpuMemcpyHostToDevice, stream); + CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; + } + } + nb::gil_scoped_release release; + CHECK_EQ(gpuStreamSynchronize(stream), gpuSuccess) + << "Failed to gpuStreamSynchronize"; + for (int i = 0; i < temp_buffers.size(); ++i) { + delete[] static_cast(temp_buffers[i]); + } + return absl::OkStatus(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + kXlaFfiPythonGpuCallback, XlaFfiPythonGpuCallback, + xla::ffi::Ffi::Bind() + .Ctx>() + .Ctx>>>() + .Attr("index") + .RemainingArgs() + .RemainingRets()); +XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), + "xla_ffi_python_gpu_callback", + absl::AsciiStrToUpper(JAX_GPU_PLUGIN_NAME), + kXlaFfiPythonGpuCallback); + +} // namespace JAX_GPU_NAMESPACE +} // namespace jax diff --git a/jaxlib/gpu/py_client_gpu.h b/jaxlib/gpu/py_client_gpu.h new file mode 100644 index 000000000000..6be2d40823dc --- /dev/null +++ b/jaxlib/gpu/py_client_gpu.h @@ -0,0 +1,37 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAX_JAXLIB_GPU_PY_CLIENT_GPU_H_ +#define JAX_JAXLIB_GPU_PY_CLIENT_GPU_H_ + +#include + +#include "jaxlib/gpu/vendor.h" +#include "xla/ffi/ffi.h" +#include "xla/service/custom_call_status.h" + +namespace jax { +namespace JAX_GPU_NAMESPACE { + +void XlaPythonGpuCallback(gpuStream_t stream, void** buffers, + const char* opaque, size_t opaque_len, + XlaCustomCallStatus* status); + +XLA_FFI_DECLARE_HANDLER_SYMBOL(kXlaFfiPythonGpuCallback); + +} // namespace JAX_GPU_NAMESPACE +} // namespace jax + +#endif // JAX_JAXLIB_GPU_PY_CLIENT_GPU_H_ diff --git a/jaxlib/gpu/vendor.h b/jaxlib/gpu/vendor.h index 7334d4690b59..cadd5453107a 100644 --- a/jaxlib/gpu/vendor.h +++ b/jaxlib/gpu/vendor.h @@ -48,6 +48,7 @@ limitations under the License. #define JAX_GPU_NAMESPACE cuda #define JAX_GPU_PREFIX "cu" +#define JAX_GPU_PLUGIN_NAME "cuda" typedef cuComplex gpuComplex; typedef cuDoubleComplex gpuDoubleComplex; @@ -413,6 +414,7 @@ constexpr uint32_t kNumThreadsPerWarp = 32; #define JAX_GPU_NAMESPACE hip #define JAX_GPU_PREFIX "hip" +#define JAX_GPU_PLUGIN_NAME "rocm" #define JAX_GPU_HAVE_SPARSE 1 #define JAX_GPU_HAVE_64_BIT 0 diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index 9a25a795fd14..867048509afa 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -555,11 +555,55 @@ py_library( ], ) +cc_library( + name = "py_client_gpu", + srcs = ["//jaxlib/gpu:py_client_gpu.cc"], + hdrs = ["//jaxlib/gpu:py_client_gpu.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":hip_vendor", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@nanobind", + "@xla//xla:comparison_util", + "@xla//xla:shape_util", + "@xla//xla/ffi", + "@xla//xla/ffi:ffi_api", + "@xla//xla/pjrt:exceptions", + "@xla//xla/pjrt:host_callback", + "@xla//xla/pjrt:transpose", + "@xla//xla/python:callback", + "@xla//xla/python:nb_numpy", + "@xla//xla/python:py_host_callback", + "@xla//xla/python:types", + "@xla//xla/python/ifrt", + "@xla//xla/service:custom_call_status", + "@xla//xla/service:custom_call_target_registry", + "@xla//xla/service:platform_util", + "@xla//xla/tsl/concurrency:ref_count", + "@xla//xla/tsl/platform:errors", + "@xla//xla/tsl/platform:statusor", + ], +) + nanobind_extension( name = "rocm_plugin_extension", srcs = ["rocm_plugin_extension.cc"], module_name = "rocm_plugin_extension", deps = [ + ":py_client_gpu", + "//jaxlib:kernel_nanobind_helpers", "//jaxlib/gpu:gpu_plugin_extension", "@com_google_absl//absl/log", "@com_google_absl//absl/strings", diff --git a/jaxlib/rocm/rocm_plugin_extension.cc b/jaxlib/rocm/rocm_plugin_extension.cc index 1dd1f1943fc8..1e8013f2bc1b 100644 --- a/jaxlib/rocm/rocm_plugin_extension.cc +++ b/jaxlib/rocm/rocm_plugin_extension.cc @@ -21,11 +21,14 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "rocm/include/hip/hip_runtime.h" #include "jaxlib/gpu/gpu_plugin_extension.h" +#include "jaxlib/gpu/py_client_gpu.h" +#include "jaxlib/kernel_nanobind_helpers.h" namespace nb = nanobind; namespace xla { namespace { + std::string ToString(hipError_t result) { #define OSTREAM_ROCM_ERROR(__name) \ case hipError##__name: \ @@ -62,10 +65,19 @@ std::string ToString(hipError_t result) { return absl::StrCat("hipError_t(", static_cast(result), ")"); } } + +nb::dict Registrations() { + nb::dict dict; + dict["xla_python_gpu_callback"] = + jax::EncapsulateFunction(jax::hip::XlaPythonGpuCallback); + return dict; +} + } // namespace NB_MODULE(rocm_plugin_extension, m) { BuildGpuPluginExtension(m); + m.def("registrations", &Registrations); m.def( "get_device_ordinal", [](std::intptr_t data_value) { From ee74c289ac02b0f0f07ce8819bef4b2f97207d4b Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 19 Mar 2025 10:03:57 -0700 Subject: [PATCH 021/483] Move //jaxlib:handle_pool to //jaxlib/gpu:handle_pool. This is a GPU-specific target. PiperOrigin-RevId: 738441625 --- jaxlib/BUILD | 15 --------------- jaxlib/cuda/BUILD | 8 ++++---- jaxlib/gpu/BUILD | 15 +++++++++++++++ jaxlib/gpu/blas_handle_pool.cc | 2 +- jaxlib/gpu/blas_handle_pool.h | 2 +- jaxlib/{ => gpu}/handle_pool.h | 6 +++--- jaxlib/gpu/rnn_kernels.cc | 2 +- jaxlib/gpu/solver_handle_pool.cc | 2 +- jaxlib/gpu/solver_handle_pool.h | 2 +- jaxlib/gpu/sparse_kernels.cc | 2 +- jaxlib/gpu/sparse_kernels.h | 2 +- jaxlib/rocm/BUILD | 8 ++++---- 12 files changed, 33 insertions(+), 33 deletions(-) rename jaxlib/{ => gpu}/handle_pool.h (96%) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index a5e8cee08cdc..faf52a702386 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -171,21 +171,6 @@ cc_library( ], ) -cc_library( - name = "handle_pool", - hdrs = ["handle_pool.h"], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], - features = ["-use_header_modules"], - deps = [ - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/synchronization", - ], -) - # This isn't a CPU kernel. This exists to catch cases where jaxlib is built for the wrong # target architecture. nanobind_extension( diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index 23ab64aa2d01..4e74cc2dcf5b 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -89,7 +89,7 @@ cc_library( deps = [ ":cuda_gpu_kernel_helpers", ":cuda_vendor", - "//jaxlib:handle_pool", + "//jaxlib/gpu:handle_pool", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", "@local_config_cuda//cuda:cuda_headers", @@ -155,8 +155,8 @@ cc_library( ":cuda_gpu_kernel_helpers", ":cuda_vendor", ":ffi_wrapper", - "//jaxlib:handle_pool", "//jaxlib:kernel_helpers", + "//jaxlib/gpu:handle_pool", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", @@ -195,7 +195,7 @@ cc_library( deps = [ ":cuda_gpu_kernel_helpers", ":cuda_vendor", - "//jaxlib:handle_pool", + "//jaxlib/gpu:handle_pool", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", "@local_config_cuda//cuda:cuda_headers", @@ -308,8 +308,8 @@ cc_library( ":cuda_gpu_kernel_helpers", ":cuda_vendor", ":ffi_wrapper", - "//jaxlib:handle_pool", "//jaxlib:kernel_helpers", + "//jaxlib/gpu:handle_pool", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/jaxlib/gpu/BUILD b/jaxlib/gpu/BUILD index de55989bf73f..3613be567533 100644 --- a/jaxlib/gpu/BUILD +++ b/jaxlib/gpu/BUILD @@ -93,6 +93,21 @@ xla_py_proto_library( deps = [":triton_proto"], ) +cc_library( + name = "handle_pool", + hdrs = ["handle_pool.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/synchronization", + ], +) + cc_library( name = "gpu_plugin_extension", srcs = ["gpu_plugin_extension.cc"], diff --git a/jaxlib/gpu/blas_handle_pool.cc b/jaxlib/gpu/blas_handle_pool.cc index 2ce204453039..ff381b802ab2 100644 --- a/jaxlib/gpu/blas_handle_pool.cc +++ b/jaxlib/gpu/blas_handle_pool.cc @@ -19,7 +19,7 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/vendor.h" -#include "jaxlib/handle_pool.h" +#include "jaxlib/gpu/handle_pool.h" namespace jax { diff --git a/jaxlib/gpu/blas_handle_pool.h b/jaxlib/gpu/blas_handle_pool.h index b3cdbaa88867..43724baab45e 100644 --- a/jaxlib/gpu/blas_handle_pool.h +++ b/jaxlib/gpu/blas_handle_pool.h @@ -18,7 +18,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "jaxlib/gpu/vendor.h" -#include "jaxlib/handle_pool.h" +#include "jaxlib/gpu/handle_pool.h" namespace jax { diff --git a/jaxlib/handle_pool.h b/jaxlib/gpu/handle_pool.h similarity index 96% rename from jaxlib/handle_pool.h rename to jaxlib/gpu/handle_pool.h index 9201d8d579c5..9189bb174b06 100644 --- a/jaxlib/handle_pool.h +++ b/jaxlib/gpu/handle_pool.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef JAXLIB_HANDLE_POOL_H_ -#define JAXLIB_HANDLE_POOL_H_ +#ifndef JAXLIB_GPU_HANDLE_POOL_H_ +#define JAXLIB_GPU_HANDLE_POOL_H_ #include #include @@ -107,4 +107,4 @@ void HandlePool::Return(HandleType handle, } // namespace jax -#endif // JAXLIB_HANDLE_POOL_H_ +#endif // JAXLIB_GPU_HANDLE_POOL_H_ diff --git a/jaxlib/gpu/rnn_kernels.cc b/jaxlib/gpu/rnn_kernels.cc index e9820bc31f1e..45f8ba8187ba 100644 --- a/jaxlib/gpu/rnn_kernels.cc +++ b/jaxlib/gpu/rnn_kernels.cc @@ -23,7 +23,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "jaxlib/gpu/ffi_wrapper.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" -#include "jaxlib/handle_pool.h" +#include "jaxlib/gpu/handle_pool.h" #include "jaxlib/kernel_helpers.h" #include "xla/service/custom_call_status.h" diff --git a/jaxlib/gpu/solver_handle_pool.cc b/jaxlib/gpu/solver_handle_pool.cc index c55ea923b21b..416ccf9d1bbc 100644 --- a/jaxlib/gpu/solver_handle_pool.cc +++ b/jaxlib/gpu/solver_handle_pool.cc @@ -19,7 +19,7 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/vendor.h" -#include "jaxlib/handle_pool.h" +#include "jaxlib/gpu/handle_pool.h" #ifdef JAX_GPU_CUDA #include "third_party/gpus/cuda/include/cusolverSp.h" diff --git a/jaxlib/gpu/solver_handle_pool.h b/jaxlib/gpu/solver_handle_pool.h index c46c062b3054..4e369ea85520 100644 --- a/jaxlib/gpu/solver_handle_pool.h +++ b/jaxlib/gpu/solver_handle_pool.h @@ -18,7 +18,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "jaxlib/gpu/vendor.h" -#include "jaxlib/handle_pool.h" +#include "jaxlib/gpu/handle_pool.h" #ifdef JAX_GPU_CUDA #include "third_party/gpus/cuda/include/cusolverSp.h" diff --git a/jaxlib/gpu/sparse_kernels.cc b/jaxlib/gpu/sparse_kernels.cc index 5b620a05236d..c66e96b6b89e 100644 --- a/jaxlib/gpu/sparse_kernels.cc +++ b/jaxlib/gpu/sparse_kernels.cc @@ -28,7 +28,7 @@ limitations under the License. #include "jaxlib/gpu/ffi_wrapper.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/vendor.h" -#include "jaxlib/handle_pool.h" +#include "jaxlib/gpu/handle_pool.h" #include "jaxlib/kernel_helpers.h" #include "xla/service/custom_call_status.h" diff --git a/jaxlib/gpu/sparse_kernels.h b/jaxlib/gpu/sparse_kernels.h index 323431812758..0d74ebc7d8e4 100644 --- a/jaxlib/gpu/sparse_kernels.h +++ b/jaxlib/gpu/sparse_kernels.h @@ -24,7 +24,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "jaxlib/gpu/vendor.h" -#include "jaxlib/handle_pool.h" +#include "jaxlib/gpu/handle_pool.h" #include "xla/ffi/api/ffi.h" #include "xla/service/custom_call_status.h" diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index 867048509afa..1e54d82c4f71 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -79,7 +79,7 @@ cc_library( deps = [ ":hip_gpu_kernel_helpers", ":hip_vendor", - "//jaxlib:handle_pool", + "//jaxlib/gpu:handle_pool", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", "@local_config_rocm//rocm:hipblas", @@ -143,8 +143,8 @@ cc_library( ":ffi_wrapper", ":hip_gpu_kernel_helpers", ":hip_vendor", - "//jaxlib:handle_pool", "//jaxlib:kernel_helpers", + "//jaxlib/gpu:handle_pool", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", @@ -182,7 +182,7 @@ cc_library( deps = [ ":hip_gpu_kernel_helpers", ":hip_vendor", - "//jaxlib:handle_pool", + "//jaxlib/gpu:handle_pool", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", "@local_config_rocm//rocm:hipsolver", @@ -291,8 +291,8 @@ cc_library( ":ffi_wrapper", ":hip_gpu_kernel_helpers", ":hip_vendor", - "//jaxlib:handle_pool", "//jaxlib:kernel_helpers", + "//jaxlib/gpu:handle_pool", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", From 85e78840103a1ded7fa7feb1243ca3f0a9c7c63b Mon Sep 17 00:00:00 2001 From: Ayaka Date: Wed, 19 Mar 2025 10:07:08 -0700 Subject: [PATCH 022/483] Support error checking in auto mode PiperOrigin-RevId: 738443014 --- jax/_src/error_check.py | 108 +++++++++++++++++++++++++------------- tests/error_check_test.py | 47 ++++++++++++++++- 2 files changed, 117 insertions(+), 38 deletions(-) diff --git a/jax/_src/error_check.py b/jax/_src/error_check.py index 60dc2f76a5b2..88dcec7063d9 100644 --- a/jax/_src/error_check.py +++ b/jax/_src/error_check.py @@ -16,6 +16,7 @@ from functools import partial import threading +import warnings import jax from jax._src import core @@ -58,11 +59,11 @@ def __init__(self): def _initialize_error_code_ref() -> None: - """Initialize error_code_ref in the current thread. + """Initialize the error code ref in the current thread. - The size of the error code array is determined by the mesh in the context. In - single-device environment, the array is a scalar. In multi-device - environment, the array has the same shape as the mesh. + The shape and size of the error code array depend on the mesh in the context. + In single-device environments, the array is a scalar. In multi-device + environments, its shape and size match those of the mesh. """ with core.eval_context(): # Get mesh from the context. @@ -83,13 +84,18 @@ def _initialize_error_code_ref() -> None: class error_checking_context: - """Redefine the error checking state based on the mesh in the context. + """Redefine the internal error state based on the mesh in the context. - This context manager should be used when starting a multi-device - computation, and whenever the mesh is changed. + When using JAX in multi-device environments in explicit mode, error tracking + needs to be properly aligned with the device mesh. This context manager + ensures that the internal error state is correctly initialized based on the + current mesh configuration. - When exiting the context, the error checking state will be reset to the - original state. + This context manager should be used when starting a multi-device computation, + or when switching between different device meshes. + + On entering the context, it initializes a new error state based on the mesh in + the context. On exiting the context, it restores the previous error state. """ __slots__ = ("old_ref",) @@ -107,12 +113,28 @@ def __exit__(self, exc_type, exc_value, traceback): def set_error_if(pred: jax.Array, /, msg: str) -> None: - """Set error if any element of pred is true. - - If the error is already set, the new error will be ignored. It will not - override the existing error. - - In auto mode, this function does not work under jit. + """Set the internal error state if any element of `pred` is `True`. + + This function is used inside JAX computations to detect runtime errors without + immediately halting execution. When this function is traced (e.g., inside + :func:`jax.jit`), the corresponding error message and its traceback are + recorded. At execution time, if `pred` contains any `True` values, the error + state is set, but execution continues without interruption. The recorded error + can later be raised using :func:`raise_if_error`. + + If the error state has already been set, subsequent errors are ignored and + will not override the existing error. + + For multi-device environments, in explicit mode, users must call + :func:`error_checking_context()` to initialize a new error tracking state that + matches the device mesh. In auto mode, implicit cross-device communication may + occur inside this function, which could impact performance. A warning is + issued in such cases. + + Args: + pred: A JAX boolean array. If any element of `pred` is `True`, the internal + error state will be set. + msg: The corresponding error message to be raised later. """ if _error_storage.ref is None: _initialize_error_code_ref() @@ -127,28 +149,34 @@ def set_error_if(pred: jax.Array, /, msg: str) -> None: out_sharding = core.typeof(_error_storage.ref).sharding in_sharding: NamedSharding = core.typeof(pred).sharding - if out_sharding.mesh.shape_tuple == (): # single-device case. + # Reduce `pred`. + if all(dim is None for dim in out_sharding.spec): # single-device case. pred = pred.any() else: # multi-device case. has_auto_axes = mesh_lib.AxisType.Auto in in_sharding.mesh.axis_types - if has_auto_axes: - raise NotImplementedError( - "Error checking in auto mode is not supported yet. Please use" - " explicit mode." + if has_auto_axes: # auto mode. + warnings.warn( + "When at least one mesh axis of `pred` is in auto mode, calling" + " `set_error_if` will cause implicit communication between devices." + " To avoid this, consider converting the mesh axis in auto mode to" + " explicit mode.", + RuntimeWarning, ) - if out_sharding.mesh != in_sharding.mesh: - raise ValueError( - "The error code state and the predicate must be on the same mesh, " - f"but got {out_sharding.mesh} and {in_sharding.mesh} respectively. " - "Please use `with error_checking_context()` to redefine the error " - "code state based on the mesh." - ) - pred = shard_map( - partial(jnp.any, keepdims=True), - mesh=out_sharding.mesh, - in_specs=in_sharding.spec, - out_specs=out_sharding.spec, - )(pred) # perform per-device reduction + pred = pred.any() # reduce to a single scalar + else: # explicit mode. + if out_sharding.mesh != in_sharding.mesh: + raise ValueError( + "The error code state and the predicate must be on the same mesh, " + f"but got {out_sharding.mesh} and {in_sharding.mesh} respectively. " + "Please use `with error_checking_context()` to redefine the error " + "code state based on the mesh." + ) + pred = shard_map( + partial(jnp.any, keepdims=True), + mesh=out_sharding.mesh, + in_specs=in_sharding.spec, + out_specs=out_sharding.spec, + )(pred) # perform per-device reduction error_code = _error_storage.ref[...] should_update = jnp.logical_and(pred, error_code == jnp.uint32(_NO_ERROR)) @@ -158,10 +186,18 @@ def set_error_if(pred: jax.Array, /, msg: str) -> None: def raise_if_error() -> None: - """Raise error if an error is set. + """Raise an exception if the internal error state is set. + + This function should be called after a computation completes to check for any + errors that were marked during execution via `set_error_if()`. If an error + exists, it raises a `JaxValueError` with the corresponding error message. + + This function should not be called inside a traced function (e.g., inside + :func:`jax.jit`). Doing so will raise a `ValueError`. - This function should be called after the computation is finished. It should - not be called within a traced context, such as within a jitted function." + Raises: + JaxValueError: If the internal error state is set. + ValueError: If called within a traced JAX function. """ if _error_storage.ref is None: # if not initialized, do nothing return diff --git a/tests/error_check_test.py b/tests/error_check_test.py index b96c6281411f..ad67cadfb074 100644 --- a/tests/error_check_test.py +++ b/tests/error_check_test.py @@ -18,6 +18,7 @@ import jax from jax._src import config from jax._src import error_check +from jax._src import mesh as mesh_lib from jax._src import test_util as jtu import jax.numpy as jnp from jax.sharding import NamedSharding, PartitionSpec as P @@ -202,9 +203,51 @@ def f(x): if jit: f = jax.jit(f) - sharding = NamedSharding(mesh, P("x", "y")) - x = jnp.full((4, 4), -1, dtype=jnp.int32, device=sharding) with error_check.error_checking_context(): + x = jnp.full((4, 4), -1, dtype=jnp.int32) + f(x) + with self.assertRaisesRegex(JaxValueError, "x must be greater than 0"): + error_check.raise_if_error() + + sharding = NamedSharding(mesh, P("x", "y")) + with error_check.error_checking_context(): + y = jnp.full((4, 4), -1, dtype=jnp.int32, device=sharding) + f(y) + with self.assertRaisesRegex(JaxValueError, "x must be greater than 0"): + error_check.raise_if_error() + + # The unsharded version of `f` should still be able to check errors after + # exiting the error checking context. + f(x) + with self.assertRaisesRegex(JaxValueError, "x must be greater than 0"): + error_check.raise_if_error() + + @parameterized.product(jit=[True, False]) + @jtu.with_user_mesh( + (2, 2), + ("x", "y"), + axis_types=(mesh_lib.AxisType.Auto, mesh_lib.AxisType.Auto), + ) + @jtu.ignore_warning( + message=( + "When at least one mesh axis of `pred` is in auto mode, calling" + " `set_error_if` will cause implicit communication between devices." + " To avoid this, consider converting the mesh axis in auto mode to" + " explicit mode." + ), + category=RuntimeWarning, + ) + def test_error_check_auto_mode(self, jit, mesh): + def f(x): + error_check.set_error_if(x <= 0, "x must be greater than 0") + return x + 1 + + if jit: + f = jax.jit(f) + + with error_check.error_checking_context(): + sharding = NamedSharding(mesh, P("x", "y")) + x = jnp.full((4, 4), -1, dtype=jnp.int32, device=sharding) f(x) with self.assertRaisesRegex(JaxValueError, "x must be greater than 0"): error_check.raise_if_error() From b456855c40da6c7904638013d58488c3ff8304a8 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 19 Mar 2025 10:19:51 -0700 Subject: [PATCH 023/483] [pallas:mosaic_gpu] Added support for accessing cluster ID via `lax.axis_index` PiperOrigin-RevId: 738448436 --- jax/_src/pallas/mosaic_gpu/core.py | 18 ++-- jax/_src/pallas/mosaic_gpu/lowering.py | 138 +++++++++++++++++-------- tests/pallas/mosaic_gpu_test.py | 29 +++++- 3 files changed, 134 insertions(+), 51 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 630c1b8f4bed..5e4566ddfc9c 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -506,17 +506,21 @@ class GPUMesh: axis_names: tuple[str, ...] = () def __post_init__(self): - if len(self.axis_names) != len(self.grid) + (self.num_threads is not None): - raise ValueError("Need as many axis names as grid dimensions + warp groups") + if len(self.cluster) > 3: + raise ValueError(f"cluster= must be at most 3D, got {self}.") + num_axis_names = ( + len(self.grid) + len(self.cluster) + (self.num_threads is not None) + ) + if len(self.axis_names) != num_axis_names: + raise ValueError( + "Need an axis name for each grid and cluster dimension plus " + f" an additional axis name when num_threads= is given, got {self}." + ) if self.num_threads is not None and self.num_threads > 2048 // 128: raise ValueError( "Requested too many CUDA threads per block. Each Mosaic thread" " corresponds to 128 CUDA threads." ) - if self.cluster: - raise NotImplementedError( - "Pallas/MosaicGPU does not support clusters yet." - ) @property def backend(self) -> str: @@ -556,8 +560,6 @@ def _gpu_mesh_discharge_rule( ): if not isinstance(mesh, GPUMesh): raise TypeError(f"Mesh must be a GPUMesh, got {type(mesh)}") - if mesh.cluster: - raise NotImplementedError if compiler_params and not isinstance(compiler_params, GPUCompilerParams): raise TypeError( "Compiler params must be a GPUCompilerParams, got" diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 2ae51a8b22e8..1fae91773178 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -17,11 +17,13 @@ from __future__ import annotations import collections -from collections.abc import Callable, Hashable, MutableMapping, MutableSequence, Sequence +from collections.abc import Callable, Hashable, Iterable, MutableMapping, MutableSequence, Sequence import contextlib import dataclasses import functools +import itertools import math +import operator from typing import Any, Protocol, cast import jax @@ -233,10 +235,33 @@ def _reduce_sum_resource_estimator( return Resources(smem_scratch_bytes=4 * x_aval.dtype.itemsize) +@dataclasses.dataclass(frozen=True) +class _AxisNames: + grid: Sequence[Hashable] + cluster: Sequence[Hashable] = () + wg: Hashable | None = None + + def __iter__(self) -> Iterable[Hashable]: + return itertools.chain( + self.grid, self.cluster, [self.wg] if self.wg is not None else [] + ) + + @classmethod + def from_mesh( + cls, mesh: gpu_core.GPUMesh, axis_names: Sequence[str] + ) -> "_AxisNames": + wg_name = None + if mesh.num_threads is not None: + wg_name = axis_names[-1] + axis_names = axis_names[:-1] + grid_names, cluster_names = util.split_list(axis_names, [len(mesh.grid)]) + return cls(grid_names, cluster_names, wg_name) + + @dataclasses.dataclass class ModuleContext: name: str - grid_names: Sequence[Hashable] | None + axis_names: _AxisNames | None program_ids: Sequence[ir.Value] | None approx_math: bool single_wg_lane_predicate: ir.Value @@ -565,10 +590,15 @@ def body_fn(*refs): ) assert not new_consts + axis_names = ( + _AxisNames.from_mesh(mesh, grid_mapping.grid_names) + if mesh is not None + else _AxisNames(grid_mapping.grid_names) + ) with grid_mapping.trace_env(): return lower_jaxpr_to_module( parallel_grid, - grid_mapping.grid_names, + axis_names, block, mesh.cluster if mesh is not None else (), [bm.array_shape_dtype for bm in in_block_mappings], @@ -581,7 +611,7 @@ def body_fn(*refs): def lower_jaxpr_to_module( grid: Sequence[int], - grid_names: Sequence[str], + axis_names: _AxisNames, block: Sequence[int], cluster: Sequence[int], in_shapes: Sequence[jax.ShapeDtypeStruct], @@ -597,6 +627,11 @@ def lower_jaxpr_to_module( "thread_semantics", mgpu_core.ThreadSemantics.Lane ) + if len(cluster) < 3: + cluster = cluster + (1,) * (3 - len(cluster)) + else: + assert len(cluster) == 3 + if len(grid) <= 3: squashed_dims = () parallel_grid = grid + (1,) * (3 - len(grid)) @@ -614,7 +649,7 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): grouped_barriers[barrier].append(barrier_ref) module_ctx = ModuleContext( mlir.sanitize_name(debug_info.func_name), - grid_names, + axis_names, [_program_id(axis, squashed_dims) for axis in range(len(grid))], approx_math, mgpu.single_thread_predicate(per_block=False), @@ -645,7 +680,7 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): module, out_structs_gmem, _, launch_ctx, scratch_arr = ( mgpu_core._lower_as_gpu_kernel( body, - grid=parallel_grid, + grid=tuple(map(operator.mul, parallel_grid, cluster)), cluster=cluster, block=block, in_shapes=in_shapes, @@ -1605,49 +1640,68 @@ def _reduce_max_lowering_rule_wg(ctx: LoweringRuleContext, x, *, axes): return _reduce_lowering_rule_wg(kind, acc, ctx, x, axes=axes).result +def _block_id(ctx: LoweringRuleContext, dim: gpu_dialect.Dimension) -> ir.Value: + result = gpu_dialect.block_id(dim) + cluster_size = ctx.launch_ctx.cluster_size + if math.prod(cluster_size) == 1 or cluster_size[dim.value] == 1: + return result + # We scale the grid in the presence of clusters, so we need to scale the + # block ID back here. + return arith_dialect.divui(result, _as_index(cluster_size[dim.value])) + + @register_lowering_rule(lax.axis_index_p, mgpu.ThreadSemantics.Lane) def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable): - i32 = ir.IntegerType.get_signless(32) - grid_names = ctx.module_ctx.grid_names + axis_names = ctx.module_ctx.axis_names + if not axis_names or axis_name not in axis_names: + raise ValueError( + "Named axes can only refer to GPUMesh axes in Mosaic GPU kernels" + ) + + if axis_names.wg is not None and axis_name == axis_names.wg: + return mgpu.warpgroup_idx(sync=True) + + if axis_name in axis_names.cluster: + idx = axis_names.cluster.index(axis_name) + return arith_dialect.index_cast( + ir.IntegerType.get_signless(32), + gpu_dialect.cluster_block_id(gpu_dialect.Dimension(idx)), + ) + squashed_dims = ctx.module_ctx.squashed_dims if squashed_dims: - unsquashed_names = grid_names[-3:] - squashed_names = grid_names[:-3] + unsquashed_names = axis_names.grid[-2:] + squashed_names = axis_names.grid[:-2] else: # These are unused but initialized for type checkers. - unsquashed_names = () - squashed_names = () - if grid_names and axis_name in grid_names: - if axis_name == grid_names[-1]: - return mgpu.warpgroup_idx(sync=True) + unsquashed_names = squashed_names = () + + if squashed_dims: + if axis_name in unsquashed_names: + # We add 1 to the index because the first dimension is the + # squashed dimension. + # e.g. for the grid (a, b, c, d, wg) + # squashed = (a, b) Mapped to Dimension.x (0) + # unsquashed = (c, d) Mapped to Dimension.y (1) and Dimension.z (2) + idx = unsquashed_names.index(axis_name) + 1 + return arith_dialect.index_cast( + ir.IntegerType.get_signless(32), + _block_id(ctx, gpu_dialect.Dimension(idx)), + ) else: - if squashed_dims: - if axis_name in unsquashed_names: - # We add 1 to the index because the first dimension is the - # squashed dimension. - # e.g. for the grid (a, b, c, d, wg) - # squashed = (a, b) Mapped to Dimension.x (0) - # unsquashed = (c, d) Mapped to Dimension.y (1) and Dimension.z (2) - idx = unsquashed_names.index(axis_name) + 1 - return arith_dialect.index_cast( - i32, - gpu_dialect.block_id(gpu_dialect.Dimension(idx)), - ) - elif axis_name in squashed_names: - # All squashed dimensions are mapped to Dimension.x. - block_id = gpu_dialect.block_id(gpu_dialect.Dimension.x) - axis = squashed_names.index(axis_name) - return _unravel_program_id(block_id, axis, squashed_dims) - else: - if axis_name in grid_names: - idx = grid_names.index(axis_name) - return arith_dialect.index_cast( - i32, - gpu_dialect.block_id(gpu_dialect.Dimension(idx)), - ) - raise ValueError( - "Named axes can only refer to GPUMesh axes in Mosaic GPU kernels" - ) + assert axis_name in squashed_names + # All squashed dimensions are mapped to Dimension.x. + axis = squashed_names.index(axis_name) + return _unravel_program_id( + _block_id(ctx, gpu_dialect.Dimension.x), axis, squashed_dims + ) + else: + assert axis_name in axis_names.grid + idx = axis_names.grid.index(axis_name) + return arith_dialect.index_cast( + ir.IntegerType.get_signless(32), + _block_id(ctx, gpu_dialect.Dimension(idx)), + ) @register_lowering_rule(primitives.debug_print_p, mgpu.ThreadSemantics.Lane) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 6792ddfaa9a8..38335925b44d 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -2079,7 +2079,6 @@ def _(): result.shape) np.testing.assert_array_equal(result, ref) - def test_cross_wg_barrier(self): mesh = plgpu.GPUMesh(num_threads=2, axis_names=("wg",)) @@ -2100,6 +2099,34 @@ def scoped(barrier): return inner(y_init) np.testing.assert_array_equal(f(), np.repeat([0, 1], 128).reshape(2, 128)) + def test_cluster(self): + mesh = plgpu.GPUMesh(grid=(2,), cluster=(2,), axis_names=("x", "cluster")) + + @jax.jit + def f(): + @pl.run_state + def inner(ref): + @pl.core_map(mesh) + def kernel(): + block_idx = jax.lax.axis_index("x") + cluster_idx = jax.lax.axis_index("cluster") + pl.debug_print("block: {} cluster: {}", block_idx, cluster_idx) + + ref[...] = ref[...] + return inner(jnp.zeros(128, np.int32)) + + with self.capture_stdout() as output: + jax.block_until_ready(f()) + self.assertEqual( + set(output().splitlines()), + { + "block: 0 cluster: 0", + "block: 1 cluster: 0", + "block: 0 cluster: 1", + "block: 1 cluster: 1", + }, + ) + class ExamplesTest(PallasTest): From 918192fd45e74ba1793a6507be6587cf01b814e8 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Wed, 19 Mar 2025 10:35:48 -0700 Subject: [PATCH 024/483] Move sparse op GPU lowerings from jaxlib into JAX. PiperOrigin-RevId: 738454875 --- jax/_src/lax/linalg.py | 32 ++- jax/experimental/sparse/_base.py | 10 - jax/experimental/sparse/_lowerings.py | 174 ++++++++++--- jax/experimental/sparse/bcsr.py | 30 +-- jax/experimental/sparse/coo.py | 62 ++--- jax/experimental/sparse/csr.py | 59 ++--- jaxlib/gpu_sparse.py | 357 +------------------------- 7 files changed, 233 insertions(+), 491 deletions(-) diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index c674401fb80d..3e9077d0a51c 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -2511,16 +2511,30 @@ def _tridiagonal_solve_shape_rule(dl_shape, d_shape, du_shape, b_shape, **_): "equal the dimensions of the diagonal arguments.") return b_shape -def _tridiagonal_solve_gpu_lowering(lowering, ctx, dl, d, du, b): +def _tridiagonal_solve_gpu_lowering(ctx, dl, d, du, b, *, target_name_prefix): _, _, _, b_aval = ctx.avals_in - if b_aval.dtype != np.float32 and b_aval.dtype != np.float64: + *batch_dims, m, n = b_aval.shape + batch_size = math.prod(batch_dims) + + mod = gpu_sparse._cusparse if target_name_prefix == "cu" else gpu_sparse._hipsparse + assert mod is not None + opaque = mod.build_gtsv2_descriptor(batch_size, m, n, m) + if b_aval.dtype == np.float32: + buffer_size = mod.gtsv2_f32_buffer_size(m, n, m) + target_name = "sparse_gtsv2_f32_ffi" + elif b_aval.dtype == np.float64: + buffer_size = mod.gtsv2_f64_buffer_size(m, n, m) + target_name = "sparse_gtsv2_f64_ffi" + else: raise NotImplementedError( "tridiagonal_solve is only implemented for float32 and float64 on GPU.") - m, n = b_aval.shape[-2:] - b_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, b_aval.shape) - return [lowering( - dl, d, du, b, m=m, n=n, ldb=m, t=b_aval.dtype, - b_shape_vals=b_shape_vals)] + + buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8) + sub_ctx = ctx.replace(avals_out=[*ctx.avals_out, buffer_aval]) + rule = _linalg_ffi_lowering( + f"{target_name_prefix}{target_name}", operand_output_aliases={3: 0}, + batch_partitionable=False) + return rule(sub_ctx, dl, d, du, b, opaque=opaque)[:1] def _tridiagonal_solve_cpu_lowering(ctx, dl, d, du, b, **kwargs): del kwargs # unused @@ -2628,11 +2642,11 @@ def _tridiagonal_solve_jax(dl, d, du, b, **_): platform='cpu') mlir.register_lowering( tridiagonal_solve_p, - partial(_tridiagonal_solve_gpu_lowering, gpu_sparse.cuda_gtsv2), + partial(_tridiagonal_solve_gpu_lowering, target_name_prefix='cu'), platform='cuda') mlir.register_lowering( tridiagonal_solve_p, - partial(_tridiagonal_solve_gpu_lowering, gpu_sparse.rocm_gtsv2), + partial(_tridiagonal_solve_gpu_lowering, target_name_prefix='hip'), platform='rocm') mlir.register_lowering(tridiagonal_solve_p, mlir.lower_fun( _tridiagonal_solve_jax, multiple_results=False)) diff --git a/jax/experimental/sparse/_base.py b/jax/experimental/sparse/_base.py index 7739af0291f1..36d84cb0db62 100644 --- a/jax/experimental/sparse/_base.py +++ b/jax/experimental/sparse/_base.py @@ -19,18 +19,8 @@ import jax from jax._src import core -from jax._src import ffi from jax._src import util from jax._src.typing import Array -from jax._src.lib import gpu_sparse - - -if hasattr(gpu_sparse, "registrations"): - for platform, targets in gpu_sparse.registrations().items(): - for name, value, api_version in targets: - ffi.register_ffi_target( - name, value, platform=platform, api_version=api_version - ) class JAXSparse(util.StrictABC): diff --git a/jax/experimental/sparse/_lowerings.py b/jax/experimental/sparse/_lowerings.py index 6962ef78bcff..76e74d13ed69 100644 --- a/jax/experimental/sparse/_lowerings.py +++ b/jax/experimental/sparse/_lowerings.py @@ -18,13 +18,29 @@ """ from functools import partial +from typing import Any from jax._src import core from jax._src import dispatch +from jax._src import ffi from jax._src.interpreters import mlir from jax._src.lib import gpu_sparse import numpy as np +if hasattr(gpu_sparse, "registrations"): + for platform, targets in gpu_sparse.registrations().items(): + for name, value, api_version in targets: + ffi.register_ffi_target( + name, value, platform=platform, api_version=api_version + ) + +def _get_module(target_name_prefix: str) -> Any: + if target_name_prefix == "cu": + return gpu_sparse._cusparse + elif target_name_prefix == "hip": + return gpu_sparse._hipsparse + else: + raise ValueError(f"Unsupported target_name_prefix: {target_name_prefix}") SUPPORTED_DATA_DTYPES = [np.float32, np.float64, np.complex64, np.complex128] SUPPORTED_INDEX_DTYPES = [np.int32] @@ -54,27 +70,30 @@ def _coo_spmv_abstract_eval(data, row, col, x, *, transpose, shape): shape=shape[1:] if transpose else shape[:1], dtype=x.dtype) -def _coo_spmv_gpu_lowering(coo_spmv_hlo, ctx, data, row, col, x, *, transpose, shape): +def _coo_spmv_gpu_lowering(ctx, data, row, col, x, *, transpose, shape, + target_name_prefix): + rows, cols = shape data_aval, row_aval, _, x_aval = ctx.avals_in - return [coo_spmv_hlo( - data, row, col, x, - shape=shape, - transpose=transpose, - data_dtype=data_aval.dtype, - index_dtype=row_aval.dtype, - x_dtype=x_aval.dtype)] + nnz, = data_aval.shape + buffer_size, opaque = _get_module(target_name_prefix).build_coo_matvec_descriptor( + data_aval.dtype, x_aval.dtype, data_aval.dtype, row_aval.dtype, + rows, cols, nnz, transpose) + buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8) + sub_ctx = ctx.replace(avals_out=[ctx.avals_out[0], buffer_aval]) + rule = ffi.ffi_lowering(f"{target_name_prefix}sparse_coo_matvec_ffi") + return rule(sub_ctx, data, row, col, x, opaque=opaque)[:1] coo_spmv_p.def_abstract_eval(_coo_spmv_abstract_eval) dispatch.simple_impl(coo_spmv_p) if gpu_sparse.cuda_is_supported: mlir.register_lowering( coo_spmv_p, - partial(_coo_spmv_gpu_lowering, gpu_sparse.cuda_coo_matvec), + partial(_coo_spmv_gpu_lowering, target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( coo_spmv_p, - partial(_coo_spmv_gpu_lowering, gpu_sparse.rocm_coo_matvec), + partial(_coo_spmv_gpu_lowering, target_name_prefix='hip'), platform='rocm') @@ -103,27 +122,51 @@ def _coo_spmm_abstract_eval(data, row, col, x, *, transpose, shape): shape=(shape[1] if transpose else shape[0], x.shape[1]), dtype=x.dtype) -def _coo_spmm_gpu_lowering(coo_spmm_hlo, ctx, data, row, col, x, *, transpose, shape): +def _coo_spmm_gpu_lowering(ctx, data, row, col, x, *, transpose, shape, + target_name_prefix): data_aval, row_aval, _, x_aval = ctx.avals_in - return [coo_spmm_hlo( - data, row, col, x, - shape=shape, - transpose=transpose, - data_dtype=data_aval.dtype, - index_dtype=row_aval.dtype, - x_dtype=x_aval.dtype)] + nnz, = data_aval.shape + _, Ccols = x_aval.shape + + batch_count = 1 + if len(shape) == 2: + rows, cols = shape + elif len(shape) == 3: + batch_count, rows, cols = shape + nnz = nnz // batch_count + else: + raise NotImplementedError(f"Unsupported shape: {shape}") + + # TODO(tianjianlu): use batch stride to trigger different mode of batch + # computation. Currently batch_stride = 0 is not allowed because of the issue + # in cusparse https://github.com/NVIDIA/CUDALibrarySamples/issues/81#issuecomment-1205562643 + # Set batch stride to be the matrix size for now. + lhs_batch_stride = nnz + B_rows = rows if transpose else cols + rhs_batch_stride = B_rows * Ccols + + buffer_size, opaque = _get_module(target_name_prefix).build_coo_matmat_descriptor( + data_aval.dtype, x_aval.dtype, data_aval.dtype, row_aval.dtype, + rows, cols, Ccols, nnz, transpose, batch_count, lhs_batch_stride, + rhs_batch_stride) + + buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8) + sub_ctx = ctx.replace(avals_out=[ctx.avals_out[0], buffer_aval]) + rule = ffi.ffi_lowering(f"{target_name_prefix}sparse_coo_matmat_ffi") + return rule(sub_ctx, data, row, col, x, opaque=opaque)[:1] + coo_spmm_p.def_abstract_eval(_coo_spmm_abstract_eval) dispatch.simple_impl(coo_spmm_p) if gpu_sparse.cuda_is_supported: mlir.register_lowering( coo_spmm_p, - partial(_coo_spmm_gpu_lowering, gpu_sparse.cuda_coo_matmat), + partial(_coo_spmm_gpu_lowering, target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( coo_spmm_p, - partial(_coo_spmm_gpu_lowering, gpu_sparse.rocm_coo_matmat), + partial(_coo_spmm_gpu_lowering, target_name_prefix='hip'), platform='rocm') # csr_spmv_p @@ -151,30 +194,33 @@ def _csr_spmv_abstract_eval(data, indices, indptr, x, *, transpose, shape): shape=shape[1:] if transpose else shape[:1], dtype=x.dtype) -def _csr_spmv_gpu_lowering(csr_spmv_hlo, ctx, data, indices, indptr, x, *, transpose, shape): +def _csr_spmv_gpu_lowering(ctx, data, indices, indptr, x, *, transpose, shape, + target_name_prefix): + rows, cols = shape data_aval, indices_aval, _, x_aval = ctx.avals_in - return [csr_spmv_hlo( - data, indices, indptr, x, - shape=shape, - transpose=transpose, - data_dtype=data_aval.dtype, - index_dtype=indices_aval.dtype, - x_dtype=x_aval.dtype)] + nnz, = data_aval.shape + buffer_size, opaque = _get_module(target_name_prefix).build_csr_matvec_descriptor( + data_aval.dtype, x_aval.dtype, data_aval.dtype, indices_aval.dtype, + rows, cols, nnz, transpose) + buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8) + sub_ctx = ctx.replace(avals_out=[ctx.avals_out[0], buffer_aval]) + rule = ffi.ffi_lowering(f"{target_name_prefix}sparse_csr_matvec_ffi") + return rule(sub_ctx, data, indices, indptr, x, opaque=opaque)[:1] csr_spmv_p.def_abstract_eval(_csr_spmv_abstract_eval) dispatch.simple_impl(csr_spmv_p) if gpu_sparse.cuda_is_supported: mlir.register_lowering( csr_spmv_p, - partial(_csr_spmv_gpu_lowering, gpu_sparse.cuda_csr_matvec), + partial(_csr_spmv_gpu_lowering, target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( csr_spmv_p, - partial(_csr_spmv_gpu_lowering, gpu_sparse.rocm_csr_matvec), + partial(_csr_spmv_gpu_lowering, target_name_prefix='hip'), platform='rocm') - # csr_spmm_p +# csr_spmm_p # This is an internal-only primitive that calls into cusparse CSR SpMM. # This is a raw lowering that does no validation of inputs; the indices are # assumed to be lexicographically sorted, deduplicated, and in-bounds. @@ -199,25 +245,71 @@ def _csr_spmm_abstract_eval(data, indices, indptr, x, *, transpose, shape): shape=(shape[1] if transpose else shape[0], x.shape[1]), dtype=x.dtype) -def _csr_spmm_gpu_lowering(csr_spmm_hlo, ctx, data, indices, indptr, x, *, transpose, shape): +def _csr_spmm_gpu_lowering(ctx, data, indices, indptr, x, *, transpose, shape, + target_name_prefix): + rows, cols = shape data_aval, indices_aval, _, x_aval = ctx.avals_in - return [csr_spmm_hlo( - data, indices, indptr, x, - shape=shape, - transpose=transpose, - data_dtype=data_aval.dtype, - index_dtype=indices_aval.dtype, - B_dtype=x_aval.dtype)] + nnz, = data_aval.shape + _, Ccols = x_aval.shape + buffer_size, opaque = _get_module(target_name_prefix).build_csr_matmat_descriptor( + data_aval.dtype, x_aval.dtype, data_aval.dtype, indices_aval.dtype, + rows, cols, Ccols, nnz, transpose) + buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8) + sub_ctx = ctx.replace(avals_out=[ctx.avals_out[0], buffer_aval]) + rule = ffi.ffi_lowering(f"{target_name_prefix}sparse_csr_matmat_ffi") + return rule(sub_ctx, data, indices, indptr, x, opaque=opaque)[:1] csr_spmm_p.def_abstract_eval(_csr_spmm_abstract_eval) dispatch.simple_impl(csr_spmm_p) if gpu_sparse.cuda_is_supported: mlir.register_lowering( csr_spmm_p, - partial(_csr_spmm_gpu_lowering, gpu_sparse.cuda_csr_matmat), + partial(_csr_spmm_gpu_lowering, target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( csr_spmm_p, - partial(_csr_spmm_gpu_lowering, gpu_sparse.rocm_csr_matmat), + partial(_csr_spmm_gpu_lowering, target_name_prefix='hip'), platform='rocm') + +def coo_todense_gpu_lowering(ctx, data, row, col, *, shape, target_name_prefix): + data_aval, row_aval, _ = ctx.avals_in + nnz, = data_aval.shape + rows, cols = shape + buffer_size, opaque = _get_module(target_name_prefix).build_coo_todense_descriptor( + data_aval.dtype, row_aval.dtype, rows, cols, nnz) + buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8) + sub_ctx = ctx.replace(avals_out=[ctx.avals_out[0], buffer_aval]) + rule = ffi.ffi_lowering(f"{target_name_prefix}sparse_coo_todense_ffi") + return rule(sub_ctx, data, row, col, opaque=opaque)[0] + +def coo_fromdense_gpu_lowering(ctx, mat, *, nnz, index_dtype, target_name_prefix): + mat_aval, = ctx.avals_in + rows, cols = mat_aval.shape + buffer_size, opaque = _get_module(target_name_prefix).build_coo_fromdense_descriptor( + mat_aval.dtype, np.dtype(index_dtype), rows, cols, nnz) + buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8) + sub_ctx = ctx.replace(avals_out=[*ctx.avals_out, buffer_aval]) + rule = ffi.ffi_lowering(f"{target_name_prefix}sparse_coo_fromdense_ffi") + return rule(sub_ctx, mat, opaque=opaque)[:3] + +def csr_todense_gpu_lowering(ctx, data, indices, indptr, *, shape, target_name_prefix): + data_aval, indices_aval, _, = ctx.avals_in + nnz, = data_aval.shape + rows, cols = shape + buffer_size, opaque = _get_module(target_name_prefix).build_csr_todense_descriptor( + data_aval.dtype, indices_aval.dtype, rows, cols, nnz) + buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8) + sub_ctx = ctx.replace(avals_out=[ctx.avals_out[0], buffer_aval]) + rule = ffi.ffi_lowering(f"{target_name_prefix}sparse_csr_todense_ffi") + return rule(sub_ctx, data, indices, indptr, opaque=opaque)[0] + +def csr_fromdense_gpu_lowering(ctx, mat, *, nnz, index_dtype, target_name_prefix): + mat_aval, = ctx.avals_in + rows, cols = mat_aval.shape + buffer_size, opaque = _get_module(target_name_prefix).build_csr_fromdense_descriptor( + mat_aval.dtype, np.dtype(index_dtype), rows, cols, nnz) + buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8) + sub_ctx = ctx.replace(avals_out=[*ctx.avals_out, buffer_aval]) + rule = ffi.ffi_lowering(f"{target_name_prefix}sparse_csr_fromdense_ffi") + return rule(sub_ctx, mat, opaque=opaque)[:3] diff --git a/jax/experimental/sparse/bcsr.py b/jax/experimental/sparse/bcsr.py index 7fefd1572f45..dc8be2237544 100644 --- a/jax/experimental/sparse/bcsr.py +++ b/jax/experimental/sparse/bcsr.py @@ -27,6 +27,7 @@ import jax.numpy as jnp from jax import lax from jax import tree_util +from jax.experimental.sparse import _lowerings from jax.experimental.sparse._base import JAXSparse from jax.experimental.sparse import bcoo from jax.experimental.sparse.util import ( @@ -620,9 +621,9 @@ def _bcsr_correct_out_of_bound_indices(data, indices, indptr, rhs, *, shape): _bcsr_correct_out_of_bound_indices, multiple_results=True) def _bcsr_dot_general_gpu_lowering( - csr_matvec_lowering, csr_matmat_lowering, + # csr_matvec_lowering, csr_matmat_lowering, ctx, lhs_data, lhs_indices, lhs_indptr, rhs, *, dimension_numbers, - preferred_element_type, lhs_spinfo: SparseInfo): + preferred_element_type, lhs_spinfo: SparseInfo, target_name_prefix): if not config.bcoo_cusparse_lowering.value: return _bcsr_dot_general_default_lowering( @@ -674,22 +675,23 @@ def _bcsr_dot_general_gpu_lowering( lhs_data, lhs_indices = _bcsr_correct_out_of_bound_indices_lowered( ctx, lhs_data, lhs_indices, lhs_indptr, rhs, shape=lhs_spinfo.shape) + sub_ctx = ctx if rhs_aval.ndim == 1: - dot_general_fn = csr_matvec_lowering - x_dtype = 'x_dtype' + dot_general_fn = _lowerings._csr_spmv_gpu_lowering elif rhs_aval.ndim == 2: - dot_general_fn = csr_matmat_lowering - x_dtype = 'B_dtype' + dot_general_fn = _lowerings._csr_spmm_gpu_lowering if rhs_contract[0] == 1: rhs = hlo.transpose(rhs, permutation=mlir.dense_int_array([1, 0])) + *avals_in, rhs_aval = sub_ctx.avals_in + rhs_aval = core.ShapedArray( + shape=(rhs_aval.shape[1], rhs_aval.shape[0]), dtype=rhs_aval.dtype) + sub_ctx = sub_ctx.replace(avals_in=[*avals_in, rhs_aval]) else: raise ValueError(f"rhs has to be 1d or 2d; get {rhs_aval.ndim}d.") - return [dot_general_fn(lhs_data, lhs_indices, lhs_indptr, rhs, - shape=lhs_spinfo.shape, transpose=False, - data_dtype=lhs_data_aval.dtype, - index_dtype=lhs_indices_aval.dtype, - **{x_dtype: rhs_aval.dtype})] + return dot_general_fn(sub_ctx, lhs_data, lhs_indices, lhs_indptr, rhs, + shape=lhs_spinfo.shape, transpose=False, + target_name_prefix=target_name_prefix) _bcsr_dot_general_default_lowering = mlir.lower_fun( _bcsr_dot_general_impl, multiple_results=False) @@ -700,14 +702,12 @@ def _bcsr_dot_general_gpu_lowering( if gpu_sparse.cuda_is_supported: mlir.register_lowering(bcsr_dot_general_p, partial(_bcsr_dot_general_gpu_lowering, - gpu_sparse.cuda_csr_matvec, - gpu_sparse.cuda_csr_matmat), + target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering(bcsr_dot_general_p, partial(_bcsr_dot_general_gpu_lowering, - gpu_sparse.rocm_csr_matvec, - gpu_sparse.rocm_csr_matmat), + target_name_prefix='hip'), platform='rocm') diff --git a/jax/experimental/sparse/coo.py b/jax/experimental/sparse/coo.py index c65bc87235d6..014fe9128c1b 100644 --- a/jax/experimental/sparse/coo.py +++ b/jax/experimental/sparse/coo.py @@ -26,6 +26,7 @@ import jax from jax import lax from jax.interpreters import mlir +from jax.experimental.sparse import _lowerings from jax.experimental.sparse._base import JAXSparse from jax.experimental.sparse.util import _coo_extract, CuSparseEfficiencyWarning from jax import tree_util @@ -205,7 +206,7 @@ def _coo_todense_abstract_eval(data, row, col, *, spinfo): _coo_todense_lowering = mlir.lower_fun( _coo_todense_impl, multiple_results=False) -def _coo_todense_gpu_lowering(coo_todense_hlo, ctx, data, row, col, *, spinfo): +def _coo_todense_gpu_lowering(ctx, data, row, col, *, spinfo, target_name_prefix): data_aval, row_aval, _ = ctx.avals_in dtype = data_aval.dtype if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)): @@ -226,8 +227,13 @@ def _coo_todense_gpu_lowering(coo_todense_hlo, ctx, data, row, col, *, spinfo): "back to the default implementation.", CuSparseEfficiencyWarning) return _coo_todense_lowering(ctx, data, row, col, spinfo=spinfo) - result = coo_todense_hlo( - data, row, col, shape=shape, data_dtype=dtype, index_dtype=row_aval.dtype) + sub_ctx = ctx + if transpose: + out_aval, = ctx.avals_out + out_aval = core.ShapedArray(shape=out_aval.shape[::-1], dtype=out_aval.dtype) + sub_ctx = sub_ctx.replace(avals_out=[out_aval]) + result = _lowerings.coo_todense_gpu_lowering( + sub_ctx, data, row, col, shape=shape, target_name_prefix=target_name_prefix) return ( [hlo.transpose(result, mlir.dense_int_array([1, 0]))] if transpose else [result]) @@ -255,12 +261,12 @@ def _coo_todense_transpose(ct, data, row, col, *, spinfo): if gpu_sparse.cuda_is_supported: mlir.register_lowering( coo_todense_p, - partial(_coo_todense_gpu_lowering, gpu_sparse.cuda_coo_todense), + partial(_coo_todense_gpu_lowering, target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( coo_todense_p, - partial(_coo_todense_gpu_lowering, gpu_sparse.rocm_coo_todense), + partial(_coo_todense_gpu_lowering, target_name_prefix='hip'), platform='rocm') #-------------------------------------------------------------------- @@ -325,20 +331,15 @@ def _coo_fromdense_abstract_eval(mat, *, nse, index_dtype): _coo_fromdense_lowering = mlir.lower_fun( _coo_fromdense_impl, multiple_results=True) -def _coo_fromdense_gpu_lowering(coo_fromdense_hlo, ctx, mat, *, nse, - index_dtype): +def _coo_fromdense_gpu_lowering(ctx, mat, *, nse, index_dtype, target_name_prefix): dtype = ctx.avals_in[0].dtype if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)): warnings.warn(f"coo_fromdense cusparse/hipsparse lowering not available for {dtype=}. " "Falling back to default implementation.", CuSparseEfficiencyWarning) return _coo_fromdense_lowering(ctx, mat, nse=nse, index_dtype=index_dtype) - data, row, col = coo_fromdense_hlo( - mat, nnz=nse, - data_dtype=dtype, - index_dtype=np.dtype(index_dtype), - index_type=mlir.dtype_to_ir_type(np.dtype(index_dtype))) - return [data, row, col] - + return _lowerings.coo_fromdense_gpu_lowering( + ctx, mat, nnz=nse, index_dtype=index_dtype, + target_name_prefix=target_name_prefix) def _coo_fromdense_jvp(primals, tangents, *, nse, index_dtype): M, = primals @@ -373,12 +374,12 @@ def _coo_fromdense_transpose(ct, M, *, nse, index_dtype): if gpu_sparse.cuda_is_supported: mlir.register_lowering( coo_fromdense_p, - partial(_coo_fromdense_gpu_lowering, gpu_sparse.cuda_coo_fromdense), + partial(_coo_fromdense_gpu_lowering, target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( coo_fromdense_p, - partial(_coo_fromdense_gpu_lowering, gpu_sparse.rocm_coo_fromdense), + partial(_coo_fromdense_gpu_lowering, target_name_prefix='hip'), platform='rocm') #-------------------------------------------------------------------- @@ -444,8 +445,8 @@ def _coo_matvec_abstract_eval(data, row, col, v, *, spinfo, transpose): _coo_matvec_lowering = mlir.lower_fun( _coo_matvec_impl, multiple_results=False) -def _coo_matvec_gpu_lowering(coo_matvec_hlo, ctx, data, row, col, v, *, spinfo, - transpose): +def _coo_matvec_gpu_lowering(ctx, data, row, col, v, *, spinfo, transpose, + target_name_prefix): data_aval, row_aval, _, x_aval = ctx.avals_in dtype = data_aval.dtype if dtype not in [np.float32, np.float64, np.complex64, np.complex128]: @@ -466,9 +467,9 @@ def _coo_matvec_gpu_lowering(coo_matvec_hlo, ctx, data, row, col, v, *, spinfo, return _coo_matvec_lowering(ctx, data, row, col, v, spinfo=spinfo, transpose=transpose) - return [coo_matvec_hlo( - data, row, col, v, shape=shape, transpose=transpose, - index_dtype=row_aval.dtype, data_dtype=dtype, x_dtype=x_aval.dtype)] + return _lowerings._coo_spmv_gpu_lowering( + ctx, data, row, col, v, transpose=transpose, shape=shape, + target_name_prefix=target_name_prefix) def _coo_matvec_jvp_mat(data_dot, data, row, col, v, *, spinfo, transpose): @@ -497,12 +498,12 @@ def _coo_matvec_transpose(ct, data, row, col, v, *, spinfo, transpose): if gpu_sparse.cuda_is_supported: mlir.register_lowering( coo_matvec_p, - partial(_coo_matvec_gpu_lowering, gpu_sparse.cuda_coo_matvec), + partial(_coo_matvec_gpu_lowering, target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( coo_matvec_p, - partial(_coo_matvec_gpu_lowering, gpu_sparse.rocm_coo_matvec), + partial(_coo_matvec_gpu_lowering, target_name_prefix='hip'), platform='rocm') @@ -567,8 +568,8 @@ def _coo_matmat_abstract_eval(data, row, col, B, *, spinfo, transpose): _coo_matmat_lowering = mlir.lower_fun(_coo_matmat_impl, multiple_results=False) -def _coo_matmat_gpu_lowering(coo_matmat_hlo, ctx, data, row, col, B, *, spinfo, - transpose): +def _coo_matmat_gpu_lowering(ctx, data, row, col, B, *, spinfo, transpose, + target_name_prefix): data_aval, row_aval, _, B_aval = ctx.avals_in dtype = data_aval.dtype if dtype not in [np.float32, np.float64, np.complex64, np.complex128]: @@ -589,10 +590,9 @@ def _coo_matmat_gpu_lowering(coo_matmat_hlo, ctx, data, row, col, B, *, spinfo, return _coo_matmat_lowering(ctx, data, row, col, B, spinfo=spinfo, transpose=transpose) - return [coo_matmat_hlo(data, row, col, B, shape=shape, - transpose=transpose, x_dtype=B_aval.dtype, - data_dtype=data_aval.dtype, - index_dtype=row_aval.dtype)] + return _lowerings._coo_spmm_gpu_lowering( + ctx, data, row, col, B, transpose=transpose, shape=shape, + target_name_prefix=target_name_prefix) def _coo_matmat_jvp_left(data_dot, data, row, col, B, *, spinfo, transpose): @@ -618,10 +618,10 @@ def _coo_matmat_transpose(ct, data, row, col, B, *, spinfo, transpose): if gpu_sparse.cuda_is_supported: mlir.register_lowering( coo_matmat_p, - partial(_coo_matmat_gpu_lowering, gpu_sparse.cuda_coo_matmat), + partial(_coo_matmat_gpu_lowering, target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( coo_matmat_p, - partial(_coo_matmat_gpu_lowering, gpu_sparse.rocm_coo_matmat), + partial(_coo_matmat_gpu_lowering, target_name_prefix='hip'), platform='rocm') diff --git a/jax/experimental/sparse/csr.py b/jax/experimental/sparse/csr.py index 84171855b85e..cbc5bad1100b 100644 --- a/jax/experimental/sparse/csr.py +++ b/jax/experimental/sparse/csr.py @@ -23,6 +23,7 @@ import jax from jax.interpreters import mlir +from jax.experimental.sparse import _lowerings from jax.experimental.sparse._base import JAXSparse from jax.experimental.sparse.coo import _coo_matmat, _coo_matvec, _coo_todense, COOInfo from jax.experimental.sparse.util import _csr_to_coo, _csr_extract, CuSparseEfficiencyWarning @@ -249,17 +250,16 @@ def _csr_todense_abstract_eval(data, indices, indptr, *, shape): _csr_todense_lowering = mlir.lower_fun( _csr_todense_impl, multiple_results=False) -def _csr_todense_gpu_lowering(csr_todense_hlo, ctx, data, indices, indptr, *, - shape): +def _csr_todense_gpu_lowering(ctx, data, indices, indptr, *, shape, target_name_prefix): data_aval, indices_aval, _ = ctx.avals_in dtype = data_aval.dtype if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)): warnings.warn(f"csr_todense cusparse/hipsparse lowering not available for {dtype=}. " "Falling back to default implementation.", CuSparseEfficiencyWarning) return _csr_todense_lowering(ctx, data, indices, indptr, shape=shape) - return [csr_todense_hlo( - data, indices, indptr, shape=shape, data_dtype=dtype, - index_dtype=indices_aval.dtype)] + return [_lowerings.csr_todense_gpu_lowering( + ctx, data, indices, indptr, shape=shape, + target_name_prefix=target_name_prefix)] def _csr_todense_jvp(data_dot, data, indices, indptr, *, shape): @@ -284,12 +284,12 @@ def _csr_todense_transpose(ct, data, indices, indptr, *, shape): if gpu_sparse.cuda_is_supported: mlir.register_lowering( csr_todense_p, - partial(_csr_todense_gpu_lowering, gpu_sparse.cuda_csr_todense), + partial(_csr_todense_gpu_lowering, target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( csr_todense_p, - partial(_csr_todense_gpu_lowering, gpu_sparse.rocm_csr_todense), + partial(_csr_todense_gpu_lowering, target_name_prefix='hip'), platform='rocm') @@ -359,16 +359,16 @@ def _csr_fromdense_abstract_eval(mat, *, nse, index_dtype): _csr_fromdense_lowering = mlir.lower_fun(_csr_fromdense_impl, multiple_results=True) -def _csr_fromdense_gpu_lowering(csr_fromdense_hlo, ctx, mat, *, nse, index_dtype): +def _csr_fromdense_gpu_lowering(ctx, mat, *, nse, index_dtype, + target_name_prefix): dtype = ctx.avals_in[0].dtype if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)): warnings.warn(f"csr_fromdense cusparse/hipsparse lowering not available for {dtype=}. " "Falling back to default implementation.", CuSparseEfficiencyWarning) return _csr_fromdense_lowering(ctx, mat, nse=nse, index_dtype=index_dtype) - data, indices, indptr = csr_fromdense_hlo( - mat, nnz=nse, index_dtype=np.dtype(index_dtype), - data_dtype=dtype, index_type=mlir.dtype_to_ir_type(np.dtype(index_dtype))) - return [data, indices, indptr] + return _lowerings.csr_fromdense_gpu_lowering( + ctx, mat, nnz=nse, index_dtype=index_dtype, + target_name_prefix=target_name_prefix) def _csr_fromdense_jvp(primals, tangents, *, nse, index_dtype): @@ -404,12 +404,12 @@ def _csr_fromdense_transpose(ct, M, *, nse, index_dtype): if gpu_sparse.cuda_is_supported: mlir.register_lowering( csr_fromdense_p, - partial(_csr_fromdense_gpu_lowering, gpu_sparse.cuda_csr_fromdense), + partial(_csr_fromdense_gpu_lowering, target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( csr_fromdense_p, - partial(_csr_fromdense_gpu_lowering, gpu_sparse.rocm_csr_fromdense), + partial(_csr_fromdense_gpu_lowering, target_name_prefix='hip'), platform='rocm') #-------------------------------------------------------------------- @@ -470,8 +470,8 @@ def _csr_matvec_abstract_eval(data, indices, indptr, v, *, shape, transpose): _csr_matvec_lowering = mlir.lower_fun(_csr_matvec_impl, multiple_results=False) -def _csr_matvec_gpu_lowering(csr_matvec_hlo, ctx, data, indices, indptr, v, *, - shape, transpose): +def _csr_matvec_gpu_lowering(ctx, data, indices, indptr, v, *, shape, transpose, + target_name_prefix): data_aval, indices_aval, _, v_aval = ctx.avals_in dtype = data_aval.dtype if dtype not in [np.float32, np.float64, np.complex64, np.complex128]: @@ -479,10 +479,9 @@ def _csr_matvec_gpu_lowering(csr_matvec_hlo, ctx, data, indices, indptr, v, *, "Falling back to default implementation.", CuSparseEfficiencyWarning) return _csr_matvec_lowering(ctx, data, indices, indptr, v, shape=shape, transpose=transpose) - return [csr_matvec_hlo( - data, indices, indptr, v, shape=shape, transpose=transpose, - data_dtype=dtype, index_dtype=indices_aval.dtype, x_dtype=v_aval.dtype)] - + return _lowerings._csr_spmv_gpu_lowering( + ctx, data, indices, indptr, v, shape=shape, transpose=transpose, + target_name_prefix=target_name_prefix) def _csr_matvec_jvp_mat(data_dot, data, indices, indptr, v, *, shape, transpose): return _csr_matvec(data_dot, indices, indptr, v, shape=shape, transpose=transpose) @@ -511,12 +510,12 @@ def _csr_matvec_transpose(ct, data, indices, indptr, v, *, shape, transpose): if gpu_sparse.cuda_is_supported: mlir.register_lowering( csr_matvec_p, - partial(_csr_matvec_gpu_lowering, gpu_sparse.cuda_csr_matvec), + partial(_csr_matvec_gpu_lowering, target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( csr_matvec_p, - partial(_csr_matvec_gpu_lowering, gpu_sparse.rocm_csr_matvec), + partial(_csr_matvec_gpu_lowering, target_name_prefix='hip'), platform='rocm') @@ -580,8 +579,8 @@ def _csr_matmat_abstract_eval(data, indices, indptr, B, *, shape, transpose): _csr_matmat_lowering = mlir.lower_fun(_csr_matmat_impl, multiple_results=False) -def _csr_matmat_gpu_lowering(csr_matmat_hlo, ctx, data, indices, indptr, B, *, - shape, transpose): +def _csr_matmat_gpu_lowering(ctx, data, indices, indptr, B, *, shape, transpose, + target_name_prefix): data_aval, indices_aval, _, B_aval = ctx.avals_in dtype = data_aval.dtype if dtype not in [np.float32, np.float64, np.complex64, np.complex128]: @@ -589,11 +588,9 @@ def _csr_matmat_gpu_lowering(csr_matmat_hlo, ctx, data, indices, indptr, B, *, "Falling back to default implementation.", CuSparseEfficiencyWarning) return _csr_matmat_lowering(ctx, data, indices, indptr, B, shape=shape, transpose=transpose) - return [csr_matmat_hlo( - data, indices, indptr, B, shape=shape, transpose=transpose, - index_dtype=indices_aval.dtype, data_dtype=data_aval.dtype, - B_dtype=B_aval.dtype)] - + return _lowerings._csr_spmm_gpu_lowering( + ctx, data, indices, indptr, B, shape=shape, transpose=transpose, + target_name_prefix=target_name_prefix) def _csr_matmat_jvp_left(data_dot, data, indices, indptr, B, *, shape, transpose): return _csr_matmat(data_dot, indices, indptr, B, shape=shape, transpose=transpose) @@ -621,10 +618,10 @@ def _csr_matmat_transpose(ct, data, indices, indptr, B, *, shape, transpose): if gpu_sparse.cuda_is_supported: mlir.register_lowering( csr_matmat_p, - partial(_csr_matmat_gpu_lowering, gpu_sparse.cuda_csr_matmat), + partial(_csr_matmat_gpu_lowering, target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( csr_matmat_p, - partial(_csr_matmat_gpu_lowering, gpu_sparse.rocm_csr_matmat), + partial(_csr_matmat_gpu_lowering, target_name_prefix='hip'), platform='rocm') diff --git a/jaxlib/gpu_sparse.py b/jaxlib/gpu_sparse.py index d8645041c946..cc2b2ad08e55 100644 --- a/jaxlib/gpu_sparse.py +++ b/jaxlib/gpu_sparse.py @@ -11,25 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -cusparse wrappers for performing sparse matrix computations in JAX -""" -import math -from functools import partial from typing import Any -import jaxlib.mlir.ir as ir - -import numpy as np - -from .hlo_helpers import custom_call, mk_result_types_and_shapes - from .plugin_support import import_from_plugin _cusparse = import_from_plugin("cuda", "_sparse") _hipsparse = import_from_plugin("rocm", "_sparse") +cuda_is_supported = bool(_cusparse and _cusparse.sparse_supported) +rocm_is_supported = bool(_hipsparse and _hipsparse.sparse_supported) + def registrations() -> dict[str, list[tuple[str, Any, int]]]: registrations = {"CUDA": [], "ROCM": []} for platform, module in [("CUDA", _cusparse), ("ROCM", _hipsparse)]: @@ -38,346 +30,3 @@ def registrations() -> dict[str, list[tuple[str, Any, int]]]: (name, value, int(name.endswith("_ffi"))) for name, value in module.registrations().items()) return registrations # pytype: disable=bad-return-type - - -cuda_is_supported = bool(_cusparse and _cusparse.sparse_supported) -rocm_is_supported = bool(_hipsparse and _hipsparse.sparse_supported) - - -def _validate_csr_hlo(data, indices, indptr, shape): - data_type = ir.RankedTensorType(data.type) - indices_type = ir.RankedTensorType(indices.type) - indptr_type = ir.RankedTensorType(indptr.type) - - nnz, = data_type.shape - assert indices_type.shape == [nnz] - assert indptr_type.element_type == indices_type.element_type - assert indptr_type.shape == [shape[0] + 1] - return data_type.element_type, indices_type.element_type, nnz - -def _validate_coo_hlo(data, row, col): - data_type = ir.RankedTensorType(data.type) - row_type = ir.RankedTensorType(row.type) - col_type = ir.RankedTensorType(col.type) - - nnz, = data_type.shape - assert row_type.shape == [nnz] - assert col_type.element_type == row_type.element_type - assert col_type.shape == [nnz] - return data_type.element_type, row_type.element_type, nnz - - -def _csr_todense_hlo(platform, gpu_sparse, data, indices, indptr, *, shape, - data_dtype, index_dtype): - """CSR to dense matrix.""" - data_type, index_type, nnz = _validate_csr_hlo(data, indices, indptr, shape) - rows, cols = shape - - buffer_size, opaque = gpu_sparse.build_csr_todense_descriptor( - data_dtype, index_dtype, rows, cols, nnz) - - out = custom_call( - f"{platform}sparse_csr_todense_ffi", - result_types=[ - ir.RankedTensorType.get(shape, data_type), - ir.RankedTensorType.get([buffer_size], - ir.IntegerType.get_signless(8)), - ], - operands=[data, indices, indptr], - backend_config={"opaque": ir.StringAttr.get(opaque)}, - api_version=4, - operand_layouts=[[0]] * 3, - result_layouts=[[1, 0], [0]]).results - return out[0] - -cuda_csr_todense = partial(_csr_todense_hlo, "cu", _cusparse) -rocm_csr_todense = partial(_csr_todense_hlo, "hip", _hipsparse) - - -def _csr_fromdense_hlo(platform, gpu_sparse, mat, *, nnz, index_dtype, - data_dtype, index_type): - """CSR from dense matrix.""" - mat_type = ir.RankedTensorType(mat.type) - rows, cols = mat_type.shape - - buffer_size, opaque = gpu_sparse.build_csr_fromdense_descriptor( - data_dtype, index_dtype, rows, cols, nnz) - - out = custom_call( - f"{platform}sparse_csr_fromdense_ffi", - result_types=[ - ir.RankedTensorType.get([nnz], mat_type.element_type), - ir.RankedTensorType.get([nnz], index_type), - ir.RankedTensorType.get([rows + 1], index_type), - ir.RankedTensorType.get([buffer_size], - ir.IntegerType.get_signless(8)), - ], - operands=[mat], - backend_config={"opaque": ir.StringAttr.get(opaque)}, - api_version=4, - operand_layouts=[[1, 0]], - result_layouts=[[0]] * 4).results - return out[:3] - -cuda_csr_fromdense = partial(_csr_fromdense_hlo, "cu", _cusparse) -rocm_csr_fromdense = partial(_csr_fromdense_hlo, "hip", _hipsparse) - - -def _csr_matvec_hlo(platform, gpu_sparse, data, indices, indptr, x, *, shape, - transpose=False, compute_dtype=None, compute_type=None, - data_dtype, index_dtype, x_dtype): - """CSR matrix/vector multiply.""" - data_type, index_type, nnz = _validate_csr_hlo(data, indices, indptr, shape) - rows, cols = shape - - if compute_dtype is None: - compute_dtype = data_dtype - compute_type = data_type - - buffer_size, opaque = gpu_sparse.build_csr_matvec_descriptor( - data_dtype, x_dtype, compute_dtype, index_dtype, - rows, cols, nnz, transpose) - out_size = cols if transpose else rows - - out = custom_call( - f"{platform}sparse_csr_matvec_ffi", - result_types=[ - ir.RankedTensorType.get([out_size], compute_type), - ir.RankedTensorType.get([buffer_size], - ir.IntegerType.get_signless(8)), - ], - operands=[data, indices, indptr, x], - backend_config={"opaque": ir.StringAttr.get(opaque)}, - api_version=4, - operand_layouts=[[0]] * 4, - result_layouts=[[0]] * 2).results - return out[0] - -cuda_csr_matvec = partial(_csr_matvec_hlo, "cu", _cusparse) -rocm_csr_matvec = partial(_csr_matvec_hlo, "hip", _hipsparse) - - -def _csr_matmat_hlo(platform, gpu_sparse, data, indices, indptr, B, *, shape, - transpose=False, compute_dtype=None, compute_type=None, - index_dtype, data_dtype, B_dtype): - """CSR from dense matrix.""" - data_type, index_type, nnz = _validate_csr_hlo(data, indices, indptr, shape) - rows, cols = shape - B_shape = ir.RankedTensorType(B.type).shape - _, Ccols = B_shape - - if compute_dtype is None: - compute_dtype = data_dtype - compute_type = data_type - - buffer_size, opaque = gpu_sparse.build_csr_matmat_descriptor( - data_dtype, B_dtype, compute_dtype, index_dtype, - rows, cols, Ccols, nnz, transpose) - out_size = cols if transpose else rows - - out = custom_call( - f"{platform}sparse_csr_matmat_ffi", - result_types=[ - ir.RankedTensorType.get([out_size, Ccols], compute_type), - ir.RankedTensorType.get([buffer_size], - ir.IntegerType.get_signless(8)), - ], - operands=[data, indices, indptr, B], - backend_config={"opaque": ir.StringAttr.get(opaque)}, - api_version=4, - operand_layouts=[[0], [0], [0], [1, 0]], - result_layouts=[[1, 0], [0]]).results - return out[0] - -cuda_csr_matmat = partial(_csr_matmat_hlo, "cu", _cusparse) -rocm_csr_matmat = partial(_csr_matmat_hlo, "hip", _hipsparse) - - -def _coo_todense_hlo(platform, gpu_sparse, data, row, col, *, shape, - data_dtype, index_dtype): - """COO to dense matrix.""" - data_type, _, nnz = _validate_coo_hlo(data, row, col) - rows, cols = shape - - buffer_size, opaque = gpu_sparse.build_coo_todense_descriptor( - data_dtype, index_dtype, rows, cols, nnz) - - out = custom_call( - f"{platform}sparse_coo_todense_ffi", - result_types=[ - ir.RankedTensorType.get(shape, data_type), - ir.RankedTensorType.get([buffer_size], - ir.IntegerType.get_signless(8)), - ], - operands=[data, row, col], - backend_config={"opaque": ir.StringAttr.get(opaque)}, - api_version=4, - operand_layouts=[[0]] * 3, - result_layouts=[[1, 0], [0]]).results - return out[0] - -cuda_coo_todense = partial(_coo_todense_hlo, "cu", _cusparse) -rocm_coo_todense = partial(_coo_todense_hlo, "hip", _hipsparse) - - -def _coo_fromdense_hlo(platform, gpu_sparse, mat, *, nnz, data_dtype, - index_dtype, index_type): - """COO from dense matrix.""" - mat_type = ir.RankedTensorType(mat.type) - rows, cols = mat_type.shape - - buffer_size, opaque = gpu_sparse.build_coo_fromdense_descriptor( - data_dtype, index_dtype, rows, cols, nnz) - - out = custom_call( - f"{platform}sparse_coo_fromdense_ffi", - result_types=[ - ir.RankedTensorType.get([nnz], mat_type.element_type), - ir.RankedTensorType.get([nnz], index_type), - ir.RankedTensorType.get([nnz], index_type), - ir.RankedTensorType.get([buffer_size], - ir.IntegerType.get_signless(8)), - ], - operands=[mat], - backend_config={"opaque": ir.StringAttr.get(opaque)}, - api_version=4, - operand_layouts=[[1, 0]], - result_layouts=[[0]] * 4).results - return out[:3] - -cuda_coo_fromdense = partial(_coo_fromdense_hlo, "cu", _cusparse) -rocm_coo_fromdense = partial(_coo_fromdense_hlo, "hip", _hipsparse) - - -def _coo_matvec_hlo(platform, gpu_sparse, data, row, col, x, *, shape, - transpose=False, compute_dtype=None, compute_type=None, - index_dtype, data_dtype, x_dtype): - """COO matrix/vector multiply.""" - data_type, _, nnz = _validate_coo_hlo(data, row, col) - rows, cols = shape - - if compute_dtype is None: - compute_dtype = data_dtype - compute_type = data_type - - buffer_size, opaque = gpu_sparse.build_coo_matvec_descriptor( - data_dtype, x_dtype, compute_dtype, index_dtype, - rows, cols, nnz, transpose) - out_size = cols if transpose else rows - - out = custom_call( - f"{platform}sparse_coo_matvec_ffi", - result_types=[ - ir.RankedTensorType.get([out_size], compute_type), - ir.RankedTensorType.get([buffer_size], - ir.IntegerType.get_signless(8)), - ], - operands=[data, row, col, x], - backend_config={"opaque": ir.StringAttr.get(opaque)}, - api_version=4, - operand_layouts=[[0]] * 4, - result_layouts=[[0]] * 2).results - return out[0] - -cuda_coo_matvec = partial(_coo_matvec_hlo, "cu", _cusparse) -rocm_coo_matvec = partial(_coo_matvec_hlo, "hip", _hipsparse) - - -def _coo_matmat_hlo(platform, gpu_sparse, data, row, col, B, *, shape, - transpose=False, compute_dtype=None, compute_type=None, - x_dtype, data_dtype, index_dtype): - """COO from dense matrix.""" - data_type, _, nnz = _validate_coo_hlo(data, row, col) - is_batched_matmat = False - batch_count = 1 - if len(shape) == 2: - rows, cols = shape - elif len(shape) == 3: - is_batched_matmat = True - batch_count, rows, cols = shape - # Redefine nnz as nnz per batch. - nnz = nnz // batch_count - - B_shape = ir.RankedTensorType(B.type).shape - _, Ccols = B_shape - - if compute_dtype is None: - compute_dtype = data_dtype - compute_type = data_type - - # TODO(tianjianlu): use batch stride to trigger different mode of batch - # computation. Currently batch_stride = 0 is not allowed because of the issue - # in cusparse https://github.com/NVIDIA/CUDALibrarySamples/issues/81#issuecomment-1205562643 - # Set batch stride to be the matrix size for now. - lhs_batch_stride = nnz - B_rows = rows if transpose else cols - rhs_batch_stride = B_rows * Ccols - - buffer_size, opaque = gpu_sparse.build_coo_matmat_descriptor( - data_dtype, x_dtype, compute_dtype, index_dtype, - rows, cols, Ccols, nnz, transpose, batch_count, lhs_batch_stride, - rhs_batch_stride) - out_size = cols if transpose else rows - - if is_batched_matmat: - out_shape = [batch_count, out_size, Ccols] - out_layout = [2, 1, 0] - else: - out_shape = [out_size, Ccols] - out_layout = [1, 0] - - out = custom_call( - f"{platform}sparse_coo_matmat_ffi", - result_types=[ - ir.RankedTensorType.get(out_shape, compute_type), - ir.RankedTensorType.get([buffer_size], - ir.IntegerType.get_signless(8)), - ], - operands=[data, row, col, B], - backend_config={"opaque": ir.StringAttr.get(opaque)}, - api_version=4, - operand_layouts=[[0], [0], [0], [1, 0]], - result_layouts=[out_layout, [0]]).results - return out[0] - -cuda_coo_matmat = partial(_coo_matmat_hlo, "cu", _cusparse) -rocm_coo_matmat = partial(_coo_matmat_hlo, "hip", _hipsparse) - - -def _gtsv2_hlo( - platform, gpu_sparse, dl, d, du, B, *, m, n, ldb, t, b_shape_vals=None): - """Calls `cusparsegtsv2(dl, d, du, B, m, n, ldb)`.""" - assert len(b_shape_vals) >= 2 - batch_dim_vals = b_shape_vals[:-2] - batch_size = math.prod(batch_dim_vals) - num_bd = len(b_shape_vals) - 2 - f32 = (t == np.float32) - if f32: - buffer_size = gpu_sparse.gtsv2_f32_buffer_size(m, n, ldb) - else: - buffer_size = gpu_sparse.gtsv2_f64_buffer_size(m, n, ldb) - - b_layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) - d_layout = (num_bd,) + tuple(range(num_bd - 1, -1, -1)) - b_type = ir.RankedTensorType(B.type) - - shape_type_pairs = [ - (batch_dim_vals + (ldb, n), b_type.element_type), - ((buffer_size,), ir.IntegerType.get_signless(8)) - ] - result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs) - opaque = gpu_sparse.build_gtsv2_descriptor(batch_size, m, n, ldb) - out = custom_call( - f"{platform}sparse_gtsv2_" + ("f32" if f32 else "f64") + "_ffi", - result_types=result_types, - operands=[dl, d, du, B], - backend_config={"opaque": ir.StringAttr.get(opaque)}, - api_version=4, - operand_layouts=[d_layout] * 3 + [b_layout], - result_layouts=[b_layout, [0]], - operand_output_aliases={3: 0}, - result_shapes=result_shapes).results - return out[0] - -cuda_gtsv2 = partial(_gtsv2_hlo, "cu", _cusparse) -rocm_gtsv2 = partial(_gtsv2_hlo, "hip", _hipsparse) From 84ec21e03e88014a2cdaebee23075924f2054181 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 19 Mar 2025 10:42:17 -0700 Subject: [PATCH 025/483] Add sliding window support to the ragged paged attention. PiperOrigin-RevId: 738457532 --- .../pallas/ops/tpu/ragged_paged_attention.py | 23 +++++++++-- .../pallas/tpu_ragged_paged_attention_test.py | 40 ++++++++++++++++++- 2 files changed, 58 insertions(+), 5 deletions(-) diff --git a/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py b/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py index 6600d765024c..a9b61da290f7 100644 --- a/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py +++ b/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py @@ -19,7 +19,6 @@ specifications. It supports mixed prefill and decoding, enhancing throughput during inference. """ - import functools import jax from jax import lax @@ -81,6 +80,7 @@ def ref_ragged_paged_attention( num_seqs: jax.Array, # i32[1], *, sm_scale: float = 1.0, + sliding_window: int | None = None, mask_value: float = DEFAULT_MASK_VALUE, ): _, _, num_kv_heads, head_dim = k_pages.shape @@ -105,7 +105,10 @@ def ref_ragged_paged_attention( jnp.int32, attn.shape, 1 ) kv_span = jax.lax.broadcasted_iota(jnp.int32, attn.shape, 2) - attn += jnp.where(q_span < kv_span, mask_value, 0.0) + mask = q_span < kv_span + if sliding_window is not None: + mask = jnp.logical_or(mask, q_span - sliding_window >= kv_span) + attn += jnp.where(mask, mask_value, 0.0) attn = jax.nn.softmax(attn, axis=-1).astype(v.dtype) out = jnp.einsum("hqk,khd->qhd", attn, v).astype(queries.dtype) outputs.append(out) @@ -122,6 +125,7 @@ def validate_inputs_on_runtime( page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] cu_q_lens: jax.Array, # i32[max_num_seqs + 1] num_seqs, # i32[1] + sliding_window: int | None = None, ): check_inputs_shapes( q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens, num_seqs @@ -150,6 +154,8 @@ def validate_inputs_on_runtime( raise ValueError( f"{q_len=} must be less or equal to {kv_len=} at sequence {i}." ) + if sliding_window is not None and sliding_window <= 0: + raise ValueError(f"{sliding_window=} must be positive.") # Expect to run these checks during compile time. @@ -221,7 +227,8 @@ def ragged_paged_attention_kernel( m_ref, # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128] *, sm_scale: float, - mask_value: float, + sliding_window: int | None = None, + mask_value: float = DEFAULT_MASK_VALUE, ): num_q_per_blk, num_q_heads_per_blk, head_dim = q_ref.shape num_seqs = num_seqs_ref[0] @@ -373,7 +380,7 @@ def flash_attention( def masked_store(ref, val, start, end, group=1): iota = lax.broadcasted_iota(jnp.int32, ref.shape, 0) // group mask = jnp.logical_and(iota >= start, iota < end) - pl.store(ref, tuple(slice(None) for _ in ref.shape), val, mask=mask) + pl.store(ref, idx=tuple(slice(None) for _ in ref.shape), val=val, mask=mask) qk = ( jnp.einsum("nd,md->nm", q, k, preferred_element_type=jnp.float32) @@ -422,6 +429,9 @@ def init_scratch_ref(): 1, ) causal_mask = row_ids < col_ids + if sliding_window is not None: + causal_mask = jnp.logical_or(causal_mask, + row_ids - sliding_window>=col_ids) qk += jnp.where(causal_mask, mask_value, 0.0) m_curr = jnp.max(qk, axis=1, keepdims=True) s_curr = jnp.exp(qk - m_curr) @@ -601,6 +611,7 @@ def can_be_xla_fully_tiled(x, packing): "num_kv_pages_per_block", "num_queries_per_block", "vmem_limit_bytes", + "sliding_window", ], ) def ragged_paged_attention( @@ -614,6 +625,7 @@ def ragged_paged_attention( num_seqs: jax.Array, # i32[1] *, sm_scale: float = 1.0, + sliding_window: int | None = None, mask_value: float = DEFAULT_MASK_VALUE, num_kv_pages_per_block: int = 16, num_queries_per_block: int = 128, @@ -632,6 +644,7 @@ def ragged_paged_attention( kv_lens, only the first num_seqs+1 values are valid. num_seqs: the dynamic number of sequences. sm_scale: the softmax scale which will be applied to the Q@K^T. + sliding_window: the sliding window size for the attention. mask_value: mask value for causal mask. num_kv_pages_per_block: number of kv pages to be processed in one flash attention block in the pallas kernel. @@ -705,6 +718,7 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_): functools.partial( ragged_paged_attention_kernel, sm_scale=sm_scale, + sliding_window=sliding_window, mask_value=mask_value, ), grid_spec=pltpu.PrefetchScalarGridSpec( @@ -724,6 +738,7 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_): out_shape=jax.ShapeDtypeStruct(shape=q.shape, dtype=jnp.float32), name="ragged_paged_attention_kernel", ) + # TODO(jevinjiang): Use f32 acc scratch for output! So we only need # to transfer output with desired dtype back to HBM. return kernel(*scalar_prefetches, q, k_pages, v_pages).astype(q.dtype) diff --git a/tests/pallas/tpu_ragged_paged_attention_test.py b/tests/pallas/tpu_ragged_paged_attention_test.py index bffcebc5254b..80d78ec32d07 100644 --- a/tests/pallas/tpu_ragged_paged_attention_test.py +++ b/tests/pallas/tpu_ragged_paged_attention_test.py @@ -13,6 +13,7 @@ # limitations under the License. import random + from absl.testing import absltest from absl.testing import parameterized import jax @@ -50,6 +51,7 @@ def _test_ragged_paged_attention( vmem_limit_bytes=32 * 1024 * 1024, max_num_batched_tokens=512, max_num_seq=8, + sliding_window: int | None = None, ): if not jtu.is_device_tpu_at_least(version=4): self.skipTest("Expect TPUv4+") @@ -101,8 +103,10 @@ def _test_ragged_paged_attention( page_indices, cu_q_lens, num_seqs, + sliding_window=sliding_window, ) + actual_num_q_tokens = cu_q_lens[num_seqs[0]] output = ragged_paged_attention( q, k_pages, @@ -114,7 +118,8 @@ def _test_ragged_paged_attention( num_kv_pages_per_block=num_kv_pages_per_block, num_queries_per_block=num_queries_per_block, vmem_limit_bytes=vmem_limit_bytes, - )[: cu_q_lens[num_seqs[0]]] + sliding_window=sliding_window, + )[: actual_num_q_tokens] expected = ref_ragged_paged_attention( q, @@ -124,6 +129,7 @@ def _test_ragged_paged_attention( page_indices, cu_q_lens, num_seqs=num_seqs, + sliding_window=sliding_window, ) tols = { "float32": 0.15, @@ -266,6 +272,7 @@ def test_ragged_paged_attention_mixed(self, dtype): dtype=[jnp.float32, jnp.bfloat16], num_kv_pages_per_block=[4, 8], num_queries_per_block=[32, 64], + sliding_window=[None, 5, 128], ) def test_ragged_paged_attention_complex( self, @@ -274,6 +281,7 @@ def test_ragged_paged_attention_complex( dtype, num_kv_pages_per_block, num_queries_per_block, + sliding_window: int | None, ): seq_lens = [] for _ in range(num_seqs): @@ -294,8 +302,38 @@ def test_ragged_paged_attention_complex( num_pages, num_kv_pages_per_block=num_kv_pages_per_block, num_queries_per_block=num_queries_per_block, + sliding_window=sliding_window, ) + def test_ragged_paged_attention_sliding_window_should_be_positive(self): + dtype=jnp.float32 + seq_lens = [(192, 328), (128, 180), (64, 255)] + num_heads = (32, 8) + head_dim = 128 + page_size = 16 + num_pages = 1000 + + with self.assertRaisesRegex(ValueError, "must be positive"): + self._test_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + sliding_window=0, + ) + + with self.assertRaisesRegex(ValueError, "must be positive"): + self._test_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + sliding_window=-1, + ) if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 5a5415bcda4944edb0009a94f793031a8708f0d5 Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Wed, 19 Mar 2025 19:42:44 +0200 Subject: [PATCH 026/483] Rename arguments x, y of assertAllClose and friends to actual, expected. --- jax/_src/test_util.py | 81 +++++++++++++++++++++---------------------- 1 file changed, 40 insertions(+), 41 deletions(-) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 18f7efa16223..0dace13821fc 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -1343,15 +1343,15 @@ def assertDeprecationWarnsOrRaises(self, deprecation_id: str, message: str): else: return self.assertWarnsRegex(DeprecationWarning, message) - def assertArraysEqual(self, x, y, *, check_dtypes=True, err_msg='', + def assertArraysEqual(self, actual, desired, *, check_dtypes=True, err_msg='', allow_object_dtype=False, verbose=True): """Assert that x and y arrays are exactly equal.""" if check_dtypes: - self.assertDtypesMatch(x, y) - x = np.asarray(x) - y = np.asarray(y) + self.assertDtypesMatch(actual, desired) + actual = np.asarray(actual) + desired = np.asarray(desired) - if (not allow_object_dtype) and (x.dtype == object or y.dtype == object): + if (not allow_object_dtype) and (actual.dtype == object or desired.dtype == object): # See https://github.com/jax-ml/jax/issues/17867 raise TypeError( "assertArraysEqual may be poorly behaved when np.asarray casts to dtype=object. " @@ -1361,57 +1361,57 @@ def assertArraysEqual(self, x, y, *, check_dtypes=True, err_msg='', # Work around https://github.com/numpy/numpy/issues/18992 with np.errstate(over='ignore'): - np.testing.assert_array_equal(x, y, err_msg=err_msg, + np.testing.assert_array_equal(actual, desired, err_msg=err_msg, verbose=verbose) - def assertArraysAllClose(self, x, y, *, check_dtypes=True, atol=None, + def assertArraysAllClose(self, actual, desired, *, check_dtypes=True, atol=None, rtol=None, err_msg=''): - """Assert that x and y are close (up to numerical tolerances).""" - self.assertEqual(x.shape, y.shape) - atol = max(tolerance(_dtype(x), atol), tolerance(_dtype(y), atol)) - rtol = max(tolerance(_dtype(x), rtol), tolerance(_dtype(y), rtol)) + """Assert that actual and desired are close (up to numerical tolerances).""" + self.assertEqual(actual.shape, desired.shape) + atol = max(tolerance(_dtype(actual), atol), tolerance(_dtype(desired), atol)) + rtol = max(tolerance(_dtype(actual), rtol), tolerance(_dtype(desired), rtol)) - _assert_numpy_allclose(x, y, atol=atol, rtol=rtol, err_msg=err_msg) + _assert_numpy_allclose(actual, desired, atol=atol, rtol=rtol, err_msg=err_msg) if check_dtypes: - self.assertDtypesMatch(x, y) + self.assertDtypesMatch(actual, desired) - def assertDtypesMatch(self, x, y, *, canonicalize_dtypes=True): + def assertDtypesMatch(self, actual, desired, *, canonicalize_dtypes=True): if not config.enable_x64.value and canonicalize_dtypes: - self.assertEqual(_dtypes.canonicalize_dtype(_dtype(x), allow_extended_dtype=True), - _dtypes.canonicalize_dtype(_dtype(y), allow_extended_dtype=True)) + self.assertEqual(_dtypes.canonicalize_dtype(_dtype(actual), allow_extended_dtype=True), + _dtypes.canonicalize_dtype(_dtype(desired), allow_extended_dtype=True)) else: - self.assertEqual(_dtype(x), _dtype(y)) + self.assertEqual(_dtype(actual), _dtype(desired)) - def assertAllClose(self, x, y, *, check_dtypes=True, atol=None, rtol=None, + def assertAllClose(self, actual, desired, *, check_dtypes=True, atol=None, rtol=None, canonicalize_dtypes=True, err_msg=''): - """Assert that x and y, either arrays or nested tuples/lists, are close.""" - if isinstance(x, dict): - self.assertIsInstance(y, dict) - self.assertEqual(set(x.keys()), set(y.keys())) - for k in x.keys(): - self.assertAllClose(x[k], y[k], check_dtypes=check_dtypes, atol=atol, + """Assert that actual and desired, either arrays or nested tuples/lists, are close.""" + if isinstance(actual, dict): + self.assertIsInstance(desired, dict) + self.assertEqual(set(actual.keys()), set(desired.keys())) + for k in actual.keys(): + self.assertAllClose(actual[k], desired[k], check_dtypes=check_dtypes, atol=atol, rtol=rtol, canonicalize_dtypes=canonicalize_dtypes, err_msg=err_msg) - elif is_sequence(x) and not hasattr(x, '__array__'): - self.assertTrue(is_sequence(y) and not hasattr(y, '__array__')) - self.assertEqual(len(x), len(y)) - for x_elt, y_elt in zip(x, y): - self.assertAllClose(x_elt, y_elt, check_dtypes=check_dtypes, atol=atol, + elif is_sequence(actual) and not hasattr(actual, '__array__'): + self.assertTrue(is_sequence(desired) and not hasattr(desired, '__array__')) + self.assertEqual(len(actual), len(desired)) + for actual_elt, desired_elt in zip(actual, desired): + self.assertAllClose(actual_elt, desired_elt, check_dtypes=check_dtypes, atol=atol, rtol=rtol, canonicalize_dtypes=canonicalize_dtypes, err_msg=err_msg) - elif hasattr(x, '__array__') or np.isscalar(x): - self.assertTrue(hasattr(y, '__array__') or np.isscalar(y)) + elif hasattr(actual, '__array__') or np.isscalar(actual): + self.assertTrue(hasattr(desired, '__array__') or np.isscalar(desired)) if check_dtypes: - self.assertDtypesMatch(x, y, canonicalize_dtypes=canonicalize_dtypes) - x = np.asarray(x) - y = np.asarray(y) - self.assertArraysAllClose(x, y, check_dtypes=False, atol=atol, rtol=rtol, + self.assertDtypesMatch(actual, desired, canonicalize_dtypes=canonicalize_dtypes) + actual = np.asarray(actual) + desired = np.asarray(desired) + self.assertArraysAllClose(actual, desired, check_dtypes=False, atol=atol, rtol=rtol, err_msg=err_msg) - elif x == y: + elif actual == desired: return else: - raise TypeError((type(x), type(y))) + raise TypeError((type(actual), type(desired))) def assertMultiLineStrippedEqual(self, expected, what): """Asserts two strings are equal, after dedenting and stripping each line.""" @@ -1426,7 +1426,6 @@ def assertMultiLineStrippedEqual(self, expected, what): self.assertMultiLineEqual(expected_clean, what_clean, msg=f"Found\n{what}\nExpecting\n{expected}") - @contextmanager def assertNoWarnings(self): with test_warning_util.raise_on_warnings(): @@ -1496,9 +1495,9 @@ def wrapped_fun(*args): python_should_be_executing = False compiled_ans = cfun(*args) - self.assertAllClose(python_ans, monitored_ans, check_dtypes=check_dtypes, + self.assertAllClose(monitored_ans, python_ans, check_dtypes=check_dtypes, atol=atol or tol, rtol=rtol or tol) - self.assertAllClose(python_ans, compiled_ans, check_dtypes=check_dtypes, + self.assertAllClose(compiled_ans, python_ans, check_dtypes=check_dtypes, atol=atol or tol, rtol=rtol or tol) args = args_maker() @@ -1509,7 +1508,7 @@ def wrapped_fun(*args): python_should_be_executing = False compiled_ans = cfun(*args) - self.assertAllClose(python_ans, compiled_ans, check_dtypes=check_dtypes, + self.assertAllClose(compiled_ans, python_ans, check_dtypes=check_dtypes, atol=atol or tol, rtol=rtol or tol) def _CheckAgainstNumpy(self, numpy_reference_op, lax_op, args_maker, From 7a67c9bd63bcf160c1edb884930df1bfe1108496 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 19 Mar 2025 11:30:41 -0700 Subject: [PATCH 027/483] Fix lint error on main --- jax/experimental/pallas/ops/tpu/ragged_paged_attention.py | 2 +- tests/pallas/tpu_ragged_paged_attention_test.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py b/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py index a9b61da290f7..90b808282c22 100644 --- a/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py +++ b/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py @@ -431,7 +431,7 @@ def init_scratch_ref(): causal_mask = row_ids < col_ids if sliding_window is not None: causal_mask = jnp.logical_or(causal_mask, - row_ids - sliding_window>=col_ids) + row_ids - sliding_window >= col_ids) qk += jnp.where(causal_mask, mask_value, 0.0) m_curr = jnp.max(qk, axis=1, keepdims=True) s_curr = jnp.exp(qk - m_curr) diff --git a/tests/pallas/tpu_ragged_paged_attention_test.py b/tests/pallas/tpu_ragged_paged_attention_test.py index 80d78ec32d07..ba574a4ce98c 100644 --- a/tests/pallas/tpu_ragged_paged_attention_test.py +++ b/tests/pallas/tpu_ragged_paged_attention_test.py @@ -306,7 +306,7 @@ def test_ragged_paged_attention_complex( ) def test_ragged_paged_attention_sliding_window_should_be_positive(self): - dtype=jnp.float32 + dtype = jnp.float32 seq_lens = [(192, 328), (128, 180), (64, 255)] num_heads = (32, 8) head_dim = 128 From 9d534ad2cd40e11ce9d1f19c80300a35b9332c8d Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 19 Mar 2025 14:41:25 -0400 Subject: [PATCH 028/483] Update version numbers after JAX 0.5.3 release. --- CHANGELOG.md | 4 +++- jax/version.py | 2 +- setup.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9faff67cf305..9a817ce80937 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,7 +14,9 @@ Remember to align the itemized text with the first line of an item within a list When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.md. --> -## jax 0.5.3 +## Unreleased + +## jax 0.5.3 (Mar 19, 2025) * New Features diff --git a/jax/version.py b/jax/version.py index 13df5f00a11b..6ed6a5fda600 100644 --- a/jax/version.py +++ b/jax/version.py @@ -21,7 +21,7 @@ import pathlib import subprocess -_version = "0.5.3" +_version = "0.5.4" # The following line is overwritten by build scripts in distributions & # releases. Do not modify this manually, or jax/jaxlib build will fail. _release_version: str | None = None diff --git a/setup.py b/setup.py index a5c8500dc1cf..dbb7040d2d2b 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ _current_jaxlib_version = '0.5.3' # The following should be updated after each new jaxlib release. -_latest_jaxlib_version_on_pypi = '0.5.1' +_latest_jaxlib_version_on_pypi = '0.5.3' _libtpu_version = '0.0.11.*' From 4489303dfc6548878c81c4e5b3209e0aba002332 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 19 Mar 2025 12:43:36 -0700 Subject: [PATCH 029/483] Delete `ParsedPartitionSpec` and `preprocess` function and do a couple more cleanups PiperOrigin-RevId: 738503430 --- jax/_src/interpreters/pxla.py | 41 +++++++++++--- jax/_src/named_sharding.py | 103 +++------------------------------- jax/_src/pjit.py | 38 ------------- jax/_src/sharding_impls.py | 5 +- 4 files changed, 41 insertions(+), 146 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index c06eda5214ed..387f0661ae9d 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2463,14 +2463,41 @@ def cost_analysis(self) -> dict[str, float]: return xe.hlo_module_cost_analysis(backend, self.hlo().as_hlo_module()) +def get_op_sharding_from_executable( + executable) -> tuple[Sequence[xc.OpSharding], Sequence[xc.OpSharding]]: + in_op_shardings: list[xc.OpSharding] = [] + parameter_shardings_from_xla = executable.get_parameter_shardings() + if parameter_shardings_from_xla is not None: + in_op_shardings = parameter_shardings_from_xla + + out_op_shardings: list[xc.OpSharding] = [] + output_shardings_from_xla = executable.get_output_shardings() + if output_shardings_from_xla is not None: + out_op_shardings = output_shardings_from_xla + + return in_op_shardings, out_op_shardings + + +def get_pspec_from_executable( + executable, mesh: Mesh +) -> tuple[tuple[PartitionSpec, ...], tuple[PartitionSpec, ...]]: + input_op_s, output_op_s = get_op_sharding_from_executable(executable) + in_pspec: list[PartitionSpec] = [] + for s in input_op_s: + in_pspec.extend(sharding_impls.parse_flatten_op_sharding(s, mesh)) + + out_pspec: list[PartitionSpec] = [] + for s in output_op_s: + out_pspec.extend(sharding_impls.parse_flatten_op_sharding(s, mesh)) + return tuple(in_pspec), tuple(out_pspec) + + def get_out_shardings_from_executable( xla_executable, device_assignment: Sequence[xc.Device], num_out_avals: int, num_ordered_effects: int, ) -> Sequence[sharding_impls.GSPMDSharding] | None: - from jax._src import pjit - try: omk = xla_executable.get_output_memory_kinds()[0] if num_ordered_effects > 0: @@ -2486,7 +2513,7 @@ def get_out_shardings_from_executable( return [sharding_impls.GSPMDSharding.get_replicated(device_assignment, memory_kind=mk) for mk in omk] - _, out_op_shardings = pjit.get_op_sharding_from_executable(xla_executable) + _, out_op_shardings = get_op_sharding_from_executable(xla_executable) if not out_op_shardings: return None @@ -2517,14 +2544,12 @@ def _get_in_shardings_from_xla( num_ordered_effects: int ) -> Sequence[GSPMDSharding] | None: """Returns input shardings from XLA.""" - from jax._src import pjit - # When the device assignment only has 1 device, SPMD partitioner will not run. # Hence the op shardings will not be set on the `hlo_module`. if len(device_assignment) == 1: return [GSPMDSharding.get_replicated(device_assignment)] * num_in_avals - in_op_shardings, _ = pjit.get_op_sharding_from_executable(xla_executable) + in_op_shardings, _ = get_op_sharding_from_executable(xla_executable) if not in_op_shardings: return None @@ -2543,9 +2568,7 @@ def _get_in_shardings_from_xla( def _get_mesh_pspec_shardings_from_executable( xla_executable, mesh: Mesh ) -> tuple[Sequence[NamedSharding], Sequence[NamedSharding]]: - from jax._src import pjit - - in_pspec, out_pspec = pjit.get_pspec_from_executable(xla_executable, mesh) + in_pspec, out_pspec = get_pspec_from_executable(xla_executable, mesh) return ([NamedSharding(mesh, i) for i in in_pspec], [NamedSharding(mesh, o) for o in out_pspec]) diff --git a/jax/_src/named_sharding.py b/jax/_src/named_sharding.py index 5accdd880a79..3d5b2e67f169 100644 --- a/jax/_src/named_sharding.py +++ b/jax/_src/named_sharding.py @@ -21,11 +21,11 @@ from typing import Any, Union from jax._src import config -from jax._src.util import use_cpp_class, cache, use_cpp_method, tuple_insert +from jax._src.util import use_cpp_class, cache, use_cpp_method from jax._src.lib import xla_client as xc from jax._src.lib.mlir.dialects import sdy from jax._src import mesh as mesh_lib -from jax._src.partition_spec import PartitionSpec, UnconstrainedSingleton +from jax._src.partition_spec import PartitionSpec from jax._src import sharding as JSharding from jax._src import xla_bridge as xb import numpy as np @@ -198,7 +198,7 @@ def is_fully_addressable(self) -> bool: # Speed up `is_fully_addressable` since there is a high chance that the # mesh across multiple NamedSharding objects will be the same. if config.enable_empty_arrays.value: - client = self._internal_device_list[0].client + client = self._internal_device_list[0].client # type: ignore return (len(self.mesh._process_indices) == 1 and next(iter(self.mesh._process_indices)) == xb.process_index(client)) @@ -325,80 +325,6 @@ def __repr__(self): if self.replicated_axes else '') return f"SdyArraySharding([{dim_sharding_repr}]{device_id_repr}{rar})" -# TODO(yashkatariya): Remove this after jax 0.5.2 release -class ParsedPartitionSpec: - __slots__ = ('_user_spec', 'partitions') - - _user_spec: PartitionSpec | None - partitions: tuple[tuple[MeshAxisName, ...] | UnconstrainedSingleton, ...] - - def __init__(self, user_spec, partitions): - self._user_spec = user_spec - assert None not in partitions, partitions - self.partitions = tuple(partitions) - - def get_partition_spec(self) -> PartitionSpec: - if isinstance(self._user_spec, PartitionSpec): - return self._user_spec - else: - return get_single_pspec(self) - - def insert_axis_partitions(self, dim, val): - parts = self.partitions - too_short = dim - len(parts) - if too_short > 0: - parts += ((),) * too_short - new_partitions = tuple_insert(parts, dim, val) - return ParsedPartitionSpec(None, new_partitions) - - @classmethod - def from_user_input( - cls, - entry: PartitionSpec | None, - arg_name: str, - allow_unconstrained_dims: bool = False, - ) -> ParsedPartitionSpec: - if entry is None: - return cls(entry, ()) - if not isinstance(entry, PartitionSpec): - raise TypeError(f"{arg_name} are expected to be " - f"PartitionSpec instances or None, but got {entry}") - axis_specs = [] - for axis_spec in entry: - if axis_spec is None: - axis_spec = () - elif isinstance(axis_spec, (list, tuple)): - axis_spec = tuple(axis_spec) - elif axis_spec is PartitionSpec.UNCONSTRAINED: - if not allow_unconstrained_dims: - raise ValueError(f"Unconstrained dims are not allowed: {entry}") - axis_spec = PartitionSpec.UNCONSTRAINED - else: - axis_spec = (axis_spec,) - axis_specs.append(axis_spec) - new_entry = PartitionSpec( - *[tuple(e) if isinstance(e, (list, tuple)) else e for e in entry]) - return cls(new_entry, axis_specs) - - def __hash__(self): - return hash(self.partitions) - - def __eq__(self, other): - if not isinstance(other, ParsedPartitionSpec): - return False - return self.partitions == other.partitions - - def __len__(self): - return len(self.partitions) - - def __getitem__(self, i): - return self.partitions[i] - - def __iter__(self): - return iter(self.partitions) - - def __repr__(self): - return f"ParsedPartitionSpec(partitions={self.partitions})" @cache(max_size=4096, trace_context_in_key=False) def named_sharding_to_xla_hlo_sharding( @@ -491,18 +417,8 @@ def array_mapping_to_axis_resources(array_mapping: ArrayMapping): partitions.append(None) return PartitionSpec(*partitions) -get_single_pspec = lambda p: array_mapping_to_axis_resources(get_array_mapping(p)) # type: ignore - -# TODO(yashkatariya): Remove this after jax 0.5.2 release -def preprocess(mesh, spec, parsed_pspec, _manual_axes=frozenset()): - if parsed_pspec is None: - spec = PartitionSpec() if spec is None else spec - parsed_pspec = ParsedPartitionSpec.from_user_input( - spec, "NamedSharding spec", allow_unconstrained_dims=True) - _check_unique_resources(parsed_pspec, "NamedSharding spec", mesh) - _check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes) - return parsed_pspec +@cache(max_size=128, trace_context_in_key=False) def check_pspec(mesh, spec, _manual_axes=frozenset()): _check_unique_resources(spec, "NamedSharding spec", mesh) _check_mesh_resource_axis(mesh, spec, _manual_axes) @@ -517,13 +433,10 @@ def __init__(self, message, mesh, pspec): def __str__(self): return f"{self.message}" -def _check_unique_resources( - pspec: ParsedPartitionSpec | PartitionSpec, arg_name: str, mesh=None, -) -> None: +def _check_unique_resources(pspec: PartitionSpec, arg_name: str, mesh=None + ) -> None: resource_counts: dict[MeshAxisName, int] = {} duplicate = False - pspec = (pspec.get_partition_spec() if isinstance(pspec, ParsedPartitionSpec) - else pspec) for d in pspec: if d is PartitionSpec.UNCONSTRAINED or d is None: continue @@ -542,10 +455,8 @@ def _check_unique_resources( f' for {mesh_lib.show_axes(multiple_uses)}'), mesh=mesh, pspec=pspec) -@cache(max_size=128, trace_context_in_key=False) + def _check_mesh_resource_axis(mesh, pspec, _manual_axes): - pspec = (pspec.get_partition_spec() if isinstance(pspec, ParsedPartitionSpec) - else pspec) for p in pspec: if p is PartitionSpec.UNCONSTRAINED or p is None: continue diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index b6024dcdfedd..d690cd6e9c67 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2905,41 +2905,3 @@ def get_unconstrained_dims(sharding: NamedSharding): assert sharding.spec is not None return {i for i, axes in enumerate(sharding.spec) if axes is PartitionSpec.UNCONSTRAINED} - - -def get_op_sharding_from_executable( - executable) -> tuple[Sequence[xc.OpSharding], Sequence[xc.OpSharding]]: - in_op_shardings: list[xc.OpSharding] = [] - parameter_shardings_from_xla = executable.get_parameter_shardings() - if parameter_shardings_from_xla is not None: - in_op_shardings = parameter_shardings_from_xla - - out_op_shardings: list[xc.OpSharding] = [] - output_shardings_from_xla = executable.get_output_shardings() - if output_shardings_from_xla is not None: - out_op_shardings = output_shardings_from_xla - - return in_op_shardings, out_op_shardings - - -def _get_ppspec_from_executable( - executable, mesh - ) -> tuple[Sequence[PartitionSpec], Sequence[PartitionSpec]]: - input_op_shardings, output_op_sharding = get_op_sharding_from_executable( - executable - ) - in_pspec: list[PartitionSpec] = [] - for s in input_op_shardings: - in_pspec.extend(parse_flatten_op_sharding(s, mesh)) - - out_pspec: list[PartitionSpec] = [] - for s in output_op_sharding: - out_pspec.extend(parse_flatten_op_sharding(s, mesh)) - return in_pspec, out_pspec - - -def get_pspec_from_executable( - executable, mesh: pxla.Mesh -) -> tuple[tuple[PartitionSpec, ...], tuple[PartitionSpec, ...]]: - in_pspec, out_pspec = _get_ppspec_from_executable(executable, mesh) - return tuple(in_pspec), tuple(out_pspec) diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index f3295a75cf7a..0ed8568e4bcc 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -37,10 +37,9 @@ from jax._src.lib.mlir.dialects import sdy from jax._src.named_sharding import ( # noqa: F401 SdyArraySharding, SdyDimSharding, UnspecifiedValue, AUTO, - ParsedPartitionSpec, _check_unique_resources, NamedSharding, UNSPECIFIED, + _check_unique_resources, NamedSharding, UNSPECIFIED, ArrayMapping, ArrayMappingOrAutoOrUnspecified, get_array_mapping, - array_mapping_to_axis_resources, get_single_pspec, preprocess, - named_sharding_to_xla_hlo_sharding) + array_mapping_to_axis_resources, named_sharding_to_xla_hlo_sharding) from jax._src.op_shardings import ( are_op_shardings_equal, get_num_ways_dim_sharded, is_op_sharding_replicated) from jax._src.partition_spec import PartitionSpec From 362fb7ae9d1413e0eadd4ff7227b318c99700a8b Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 19 Mar 2025 13:39:18 -0700 Subject: [PATCH 030/483] Remove code to support jaxlib < 0.5.3. The new xla_extension_version is 320. PiperOrigin-RevId: 738522486 --- jax/_src/export/_export.py | 16 ++--- jax/_src/interpreters/mlir.py | 9 +-- jax/_src/lax/lax.py | 6 -- jax/_src/sharding_impls.py | 4 +- jax/_src/util.py | 62 +------------------ .../mosaic/gpu/dialect_lowering.py | 8 +-- .../mosaic/gpu/transform_inference.py | 6 +- jax/experimental/sparse/linalg.py | 6 -- tests/lax_test.py | 6 -- tests/linalg_test.py | 3 - 10 files changed, 16 insertions(+), 110 deletions(-) diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index afae3d9bcdc2..9b6a0f80930f 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -43,7 +43,7 @@ from jax._src.interpreters import mlir from jax._src.interpreters import pxla from jax._src.lib import xla_client -from jax._src.lib import xla_extension, xla_extension_version +from jax._src.lib import xla_extension from jax._src.lib.mlir import ir, passmanager from jax._src.lib.mlir.dialects import hlo from jax._src.lib.mlir.dialects import func as func_dialect @@ -674,10 +674,8 @@ def _export_lowered( # Shardy was used during lowering if we can find the Shardy mesh in the # module. Note that the mesh should have been lifted by the # `sdy-lift-inlined-meshes` pass in mlir.py. - shardy_enabled = False - if xla_extension_version >= 319: - shardy_enabled = xla_extension.sdy.lowered_with_shardy( - mlir.module_to_bytecode(mlir_module)) + shardy_enabled = xla_extension.sdy.lowered_with_shardy( + mlir.module_to_bytecode(mlir_module)) mlir_module_serialized = _module_to_bytecode(mlir_module, shardy_enabled) @@ -784,7 +782,7 @@ def _get_exported_vjp(exp_primal: Exported) -> Exported: _get_vjp=_get_exported_vjp) def _module_to_bytecode(module: ir.Module, shardy_enabled: bool) -> bytes: - if xla_extension_version >= 319 and shardy_enabled: + if shardy_enabled: mlir_str = xla_extension.sdy.sdy_round_trip_export_pipeline( mlir.module_to_bytecode(module)) else: @@ -1423,10 +1421,8 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args, ctx.module_context.shape_poly_state.uses_dim_vars = True submodule = ir.Module.parse(exported.mlir_module()) - shardy_enabled = False - if xla_extension_version >= 319: - shardy_enabled = xla_extension.sdy.lowered_with_shardy( - mlir.module_to_bytecode(submodule)) + shardy_enabled = xla_extension.sdy.lowered_with_shardy( + mlir.module_to_bytecode(submodule)) if shardy_enabled: submodule = ir.Module.parse(xla_extension.sdy.sdy_round_trip_import_shardings( mlir.module_to_bytecode(submodule))) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 1369f72ac74c..8f257b976dff 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -55,7 +55,7 @@ SdyArraySharding, SdyArrayShardingList) from jax._src.util import foreach from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension, xla_extension_version +from jax._src.lib import xla_extension from jax._src.lib.mlir import dialects, ir, passmanager from jax._src.lib.mlir.dialects import func as func_dialect, hlo from jax._src.lib.mlir import register_jax_dialects @@ -3031,11 +3031,8 @@ def refine_polymorphic_shapes(module: ir.Module) -> ir.Module: mlir_module=module_to_bytecode(module), enable_shape_assertions=True, validate_static_shapes=True) - if xla_extension_version >= 319: - refined_module_str = refine_polymorphic_shapes( - enable_shardy=config.use_shardy_partitioner.value) - else: - refined_module_str = refine_polymorphic_shapes() + refined_module_str = refine_polymorphic_shapes( + enable_shardy=config.use_shardy_partitioner.value) except Exception as e: raise ValueError( "Error refining shapes. " + diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 86a75ada63ad..388ad49ec83d 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -66,7 +66,6 @@ from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import chlo from jax._src.lib.mlir.dialects import hlo -from jax._src.lib import xla_extension_version from jax._src.sharding_impls import (PmapSharding, NamedSharding, PartitionSpec as P, canonicalize_sharding) from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DTypeLike, Shape @@ -2267,11 +2266,6 @@ def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike, case DotAlgorithmPreset.BF16_BF16_F32_X6: return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 6, False) case DotAlgorithmPreset.BF16_BF16_F32_X9: - if xla_extension_version < 320: - raise ValueError( - "The dot algorithm BF16_BF16_F32_X9 requires XLA extension " - "version >= 320." - ) return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 9, False) case DotAlgorithmPreset.TF32_TF32_F32: return hlo.DotAlgorithm.get(tf32, tf32, f32, 1, 1, 1, False) diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 0ed8568e4bcc..efa1b4cfd5b6 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -33,7 +33,6 @@ from jax._src import xla_bridge as xb from jax._src import mesh_utils from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension_version from jax._src.lib.mlir.dialects import sdy from jax._src.named_sharding import ( # noqa: F401 SdyArraySharding, SdyDimSharding, UnspecifiedValue, AUTO, @@ -881,8 +880,7 @@ def parse_flatten_op_sharding( return out elif hlo_sharding.is_replicated(): return [PartitionSpec()] - elif (xla_extension_version >= 319 and hlo_sharding.is_maximal() - and mesh.size == 1): + elif hlo_sharding.is_maximal() and mesh.size == 1: return [PartitionSpec()] elif hlo_sharding.is_tiled(): mesh_shape = mesh.shape diff --git a/jax/_src/util.py b/jax/_src/util.py index 0e28aea04b5a..d558954e881c 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -108,11 +108,7 @@ def foreach(f, *args): return None else: - # TODO(phawkins): remove after jaxlib 0.5.2 is the minimum. - if hasattr(jaxlib_utils, 'foreach'): - foreach = jaxlib_utils.foreach - else: - foreach = safe_map + foreach = jaxlib_utils.foreach def unzip2(xys: Iterable[tuple[T1, T2]] @@ -244,61 +240,8 @@ def curry(f): """ return wraps(f)(partial(partial, f)) -# TODO(phawkins): make this unconditional after jaxlib 0.5.3 is the minimum. toposort: Callable[[Iterable[Any]], list[Any]] -if hasattr(jaxlib_utils, "topological_sort"): - toposort = partial(jaxlib_utils.topological_sort, "parents") -else: - - def toposort(end_nodes): - if not end_nodes: - return [] - end_nodes = _remove_duplicates(end_nodes) - - child_counts = {} - stack = list(end_nodes) - while stack: - node = stack.pop() - if id(node) in child_counts: - child_counts[id(node)] += 1 - else: - child_counts[id(node)] = 1 - stack.extend(node.parents) - for node in end_nodes: - child_counts[id(node)] -= 1 - - sorted_nodes = [] - childless_nodes = [ - node for node in end_nodes if child_counts[id(node)] == 0 - ] - assert childless_nodes - while childless_nodes: - node = childless_nodes.pop() - sorted_nodes.append(node) - for parent in node.parents: - if child_counts[id(parent)] == 1: - childless_nodes.append(parent) - else: - child_counts[id(parent)] -= 1 - sorted_nodes = sorted_nodes[::-1] - - check_toposort(sorted_nodes) - return sorted_nodes - - def check_toposort(nodes): - visited = set() - for node in nodes: - assert all(id(parent) in visited for parent in node.parents) - visited.add(id(node)) - - def _remove_duplicates(node_list): - seen = set() - out = [] - for n in node_list: - if id(n) not in seen: - seen.add(id(n)) - out.append(n) - return out +toposort = partial(jaxlib_utils.topological_sort, "parents") def split_merge(predicate, xs): @@ -320,7 +263,6 @@ def merge(new_lhs, new_rhs): return lhs, rhs, merge - def _ignore(): return None diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index d605a2dea8f9..ae702d50ebb7 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -849,13 +849,9 @@ def _mgpu_wait_op_lowering_rule( return [] -# TODO(bchetioui): remove this once jaxlib minimum version >= 0.5.2. -SliceSMEMOp = getattr(mgpu, "SliceSMEMOp", None) - - -@_register_lowering(SliceSMEMOp) +@_register_lowering(mgpu.SliceSMEMOp) def _mgpu_slice_smem_op_lowering_rule( - ctx: LoweringContext, op: SliceSMEMOp + ctx: LoweringContext, op: mgpu.SliceSMEMOp ) -> Sequence[ir.Value]: del ctx return [_slice_smem(op.result.type, op.offset)] diff --git a/jax/experimental/mosaic/gpu/transform_inference.py b/jax/experimental/mosaic/gpu/transform_inference.py index ef2d3661674c..d285e5df188f 100644 --- a/jax/experimental/mosaic/gpu/transform_inference.py +++ b/jax/experimental/mosaic/gpu/transform_inference.py @@ -172,11 +172,9 @@ def _infer_vector_load_store_transforms( return None -# TODO(bchetioui): remove this once jaxlib minimum version >= 0.5.2. -SliceSMEMOp = getattr(mgpu, "SliceSMEMOp", None) -@partial(_add_transform_inference_rule, SliceSMEMOp) -def _infer_slice_smem_transforms(op: SliceSMEMOp) -> OptionalTransforms: +@partial(_add_transform_inference_rule, mgpu.SliceSMEMOp) +def _infer_slice_smem_transforms(op: mgpu.SliceSMEMOp) -> OptionalTransforms: transforms = None uses = cast(ir.OpResult, op.result).uses diff --git a/jax/experimental/sparse/linalg.py b/jax/experimental/sparse/linalg.py index a931b0a30dcf..b2e57caba9a6 100644 --- a/jax/experimental/sparse/linalg.py +++ b/jax/experimental/sparse/linalg.py @@ -29,7 +29,6 @@ from jax._src import core from jax._src import ffi from jax._src.interpreters import ad -from jax._src.lib import gpu_solver import numpy as np from scipy.sparse import csr_matrix, linalg @@ -534,11 +533,6 @@ def _spsolve_abstract_eval(data, indices, indptr, b, *, tol, reorder): def _spsolve_gpu_lowering(ctx, data, indices, indptr, b, *, tol, reorder): - # TODO(danfm): remove after JAX 0.5.1 release. - if hasattr(gpu_solver, "cuda_csrlsvqr"): - data_aval, _, _, _, = ctx.avals_in - return gpu_solver.cuda_csrlsvqr(data_aval.dtype, data, indices, - indptr, b, tol, reorder) return ffi.ffi_lowering("cusolver_csrlsvqr_ffi")( ctx, data, indices, indptr, b, tol=np.float64(tol), reorder=np.int32(reorder)) diff --git a/tests/lax_test.py b/tests/lax_test.py index 8764caeb2e49..f7cca2c9b48f 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -49,7 +49,6 @@ from jax._src.lax import lax as lax_internal from jax._src.util import NumpyComplexWarning, safe_zip from jax._src.tree_util import tree_map -from jax._src.lib import xla_extension_version config.parse_flags_with_absl() @@ -1128,11 +1127,6 @@ def testDotAlgorithm(self, algorithm, dtype): raise SkipTest( f"The dot algorithm '{algorithm}' is not supported on CPU.") if jtu.test_device_matches(["gpu"]): - if (algorithm == lax.DotAlgorithmPreset.BF16_BF16_F32_X9 and - xla_extension_version < 320): - raise SkipTest( - f"The dot algorithm ${algorithm} requires XLA extension version " - ">= 320.") # GPU algorithm support is a little spotty. It is checked in # xla/service/algorithm_util.cc and the logic is copied here. if algorithm in { diff --git a/tests/linalg_test.py b/tests/linalg_test.py index feab105ccbe2..60e507d84782 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -867,9 +867,6 @@ def testSVD(self, b, m, n, dtype, full_matrices, compute_uv, hermitian, algorith self.skipTest("Hermitian SVD doesn't support the algorithm parameter.") if not jtu.test_device_matches(["cpu", "gpu"]): self.skipTest("SVD algorithm selection only supported on CPU and GPU.") - # TODO(danfm): Remove this check after 0.5.2 is released. - if jtu.test_device_matches(["cpu"]) and jtu.jaxlib_version() <= (0, 5, 1): - self.skipTest("SVD algorithm selection on CPU requires a newer jaxlib version.") if jtu.test_device_matches(["cpu"]) and algorithm == lax.linalg.SvdAlgorithm.JACOBI: self.skipTest("Jacobi SVD not supported on GPU.") From 29e90a30cd7b4373ba3755fd1ddc9b2abc4b85d4 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 19 Mar 2025 13:48:01 -0700 Subject: [PATCH 031/483] Add a presubmit check to test against oldest supported numpy PiperOrigin-RevId: 738525650 --- .github/workflows/oldest_supported_numpy.yml | 60 ++++++++++++++++++++ ci/run_pytest_cpu.sh | 6 +- ci/utilities/install_wheels_locally.sh | 34 +++++------ 3 files changed, 81 insertions(+), 19 deletions(-) create mode 100644 .github/workflows/oldest_supported_numpy.yml diff --git a/.github/workflows/oldest_supported_numpy.yml b/.github/workflows/oldest_supported_numpy.yml new file mode 100644 index 000000000000..80e0cb154ecd --- /dev/null +++ b/.github/workflows/oldest_supported_numpy.yml @@ -0,0 +1,60 @@ +# CI - Oldest Supported NumPy (presubmit) +# This workflow tests the oldest supported NumPy and jaxlib versions. + +name: CI - Oldest Supported NumPy (presubmit) + +on: + pull_request: + branches: + - main + push: + branches: + - main + - 'release/**' + +# This should also be set to read-only in the project settings, but it's nice to +# document and enforce the permissions here. +permissions: + contents: read + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + # Don't cancel in-progress jobs for main/release branches. + cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }} + +jobs: + test-oldest-supported-numpy: + if: github.event.repository.fork == false + defaults: + run: + shell: bash + runs-on: "linux-x86-n2-64" + container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" +# Begin Presubmit Naming Check - name modification requires internal check to be updated + name: "CI - Oldest Supported NumPy (Python 3.10, x64=0)" +# End Presubmit Naming Check github-oldest-supported-numpy-presubmit + + env: + JAXCI_PYTHON: "python3.10" + JAXCI_ENABLE_X64: 0 + JAX_NUM_GENERATED_CASES: 5 + + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Install Python dependencies + run: | + $JAXCI_PYTHON -m uv pip install -r build/test-requirements.txt + + # Install NumPy and SciPy with the oldest supported versions + $JAXCI_PYTHON -m uv pip install numpy==1.25.2 scipy==1.11.1 + + # Install JAX using the changes in the PR + $JAXCI_PYTHON -m uv pip install -e .[minimum-jaxlib] + # Halt for testing + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@main + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: Run Pytest CPU tests + timeout-minutes: 30 + run: ./ci/run_pytest_cpu.sh \ No newline at end of file diff --git a/ci/run_pytest_cpu.sh b/ci/run_pytest_cpu.sh index 43581ef2c96c..9de29691f753 100755 --- a/ci/run_pytest_cpu.sh +++ b/ci/run_pytest_cpu.sh @@ -26,13 +26,13 @@ set -exu -o history -o allexport # Source default JAXCI environment variables. source ci/envs/default.env +# Set up the build environment. +source "ci/utilities/setup_build_environment.sh" + # Install jaxlib wheel inside the $JAXCI_OUTPUT_DIR directory on the system. echo "Installing wheels locally..." source ./ci/utilities/install_wheels_locally.sh -# Set up the build environment. -source "ci/utilities/setup_build_environment.sh" - # Print all the installed packages echo "Installed packages:" "$JAXCI_PYTHON" -m uv pip list diff --git a/ci/utilities/install_wheels_locally.sh b/ci/utilities/install_wheels_locally.sh index f98f7658ad18..64f88765bb75 100644 --- a/ci/utilities/install_wheels_locally.sh +++ b/ci/utilities/install_wheels_locally.sh @@ -30,23 +30,25 @@ for i in "${!WHEELS[@]}"; do fi done -if [[ -z "${WHEELS[@]}" ]]; then - echo "ERROR: No wheels found under $JAXCI_OUTPUT_DIR" - exit 1 -fi +if [[ -n "${WHEELS[@]}" ]]; then + echo "Installing the following wheels:" + echo "${WHEELS[@]}" -echo "Installing the following wheels:" -echo "${WHEELS[@]}" - -# Install `uv` if it's not already installed. `uv` is much faster than pip for -# installing Python packages. -if ! command -v uv >/dev/null 2>&1; then - pip install uv~=0.5.30 -fi + # Install `uv` if it's not already installed. `uv` is much faster than pip for + # installing Python packages. + if ! command -v uv >/dev/null 2>&1; then + pip install uv~=0.5.30 + fi -# On Windows, convert MSYS Linux-like paths to Windows paths. -if [[ $(uname -s) =~ "MSYS_NT" ]]; then - "$JAXCI_PYTHON" -m uv pip install $(cygpath -w "${WHEELS[@]}") + # On Windows, convert MSYS Linux-like paths to Windows paths. + if [[ $(uname -s) =~ "MSYS_NT" ]]; then + "$JAXCI_PYTHON" -m uv pip install $(cygpath -w "${WHEELS[@]}") + else + "$JAXCI_PYTHON" -m uv pip install "${WHEELS[@]}" + fi else - "$JAXCI_PYTHON" -m uv pip install "${WHEELS[@]}" + # Note that we don't exit here because the wheels may have been installed + # earlier in a different step in the CI job. + echo "INFO: No wheels found under $JAXCI_OUTPUT_DIR" + echo "INFO: Skipping local wheel installation." fi \ No newline at end of file From 945582add83b33d4796b44a57935133232a33d3e Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 19 Mar 2025 14:31:13 -0700 Subject: [PATCH 032/483] jax.numpy: add tests for __jax_array__ handling --- tests/BUILD | 9 + tests/array_extensibility_test.py | 516 ++++++++++++++++++++++++++++++ 2 files changed, 525 insertions(+) create mode 100644 tests/array_extensibility_test.py diff --git a/tests/BUILD b/tests/BUILD index 0ffa68ed8eb3..6706971dc7d1 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -80,6 +80,15 @@ jax_py_test( ] + py_deps("absl/testing"), ) +jax_py_test( + name = "array_extensibility_test", + srcs = ["array_extensibility_test.py"], + deps = [ + "//jax", + "//jax:test_util", + ] + py_deps("absl/testing"), +) + jax_multiplatform_test( name = "array_interoperability_test", srcs = ["array_interoperability_test.py"], diff --git a/tests/array_extensibility_test.py b/tests/array_extensibility_test.py new file mode 100644 index 000000000000..45c83f7967ce --- /dev/null +++ b/tests/array_extensibility_test.py @@ -0,0 +1,516 @@ +# Copyright 2018 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from absl.testing import absltest +from absl.testing import parameterized +from typing import Any, Callable, NamedTuple + +import jax +import jax.numpy as jnp +from jax.typing import ArrayLike +from jax._src import config +from jax._src import test_util as jtu + + +config.parse_flags_with_absl() + + +class JaxArrayWrapper: + """Class that provides a __jax_array__ method.""" + x: ArrayLike + + def __init__(self, x: ArrayLike): + self.x = x + + def __jax_array__(self) -> jax.Array: + return jnp.asarray(self.x) + + +class NumPyAPI(NamedTuple): + fun: Callable[..., Any] + args: list[jax.ShapeDtypeStruct] + kwargs: dict[str, Any] + + def name(self): + return self.fun.__name__ + + def make_args(self, rng): + rng = jtu.rand_default(rng) + return jax.tree.map(lambda arg: rng(arg.shape, arg.dtype), self.args) + + @classmethod + def sig(cls, fun: Callable[..., Any], *args: Any, **kwargs: Any) -> 'NumPyAPI': + return cls(fun, args, kwargs) + + +class ShapeDtype: + """Shortcut for specifying ShapeDtypeStruct.""" + def __init__(self, dtype): + self.dtype = jax.dtypes.canonicalize_dtype(dtype) + def __getitem__(self, shape) -> jax.ShapeDtypeStruct: + if isinstance(shape, int): + shape = (shape,) + return jax.ShapeDtypeStruct(shape, self.dtype) + +Bool = ShapeDtype(bool) +Int = ShapeDtype(int) +Uint8 = ShapeDtype('uint8') +Float = ShapeDtype(float) +Complex = ShapeDtype(complex) + + +# NumPy namespace objects skipped in the enumeration below, mainly because +# they are not functions or do not take arrays as positional arguments. +SKIPPED_APIS = [ + 'apply_along_axis', + 'apply_over_axes', + 'arange', + 'astype', + 'bartlett', + 'bfloat16', + 'blackman', + 'block', + 'bool', + 'bool_', + 'broadcast_shapes', + 'c_', + 'cdouble', + 'character', + 'complex128', + 'complex64', + 'complex_', + 'complexfloating', + 'csingle', + 'diag_indices', + 'double', + 'dtype', + 'e', + 'einsum', + 'einsum_path', + 'euler_gamma', + 'empty', + 'eye', + 'finfo', + 'flexible', + 'float_', + 'float16', + 'float32', + 'float4_e2m1fn', + 'float64', + 'float8_e3m4', + 'float8_e4m3', + 'float8_e4m3b11fnuz', + 'float8_e4m3fn', + 'float8_e4m3fnuz', + 'float8_e5m2', + 'float8_e5m2fnuz', + 'float8_e8m0fnu', + 'floating', + 'from_dlpack', + 'frombuffer', + 'fromfile', + 'fromfunction', + 'fromiter', + 'frompyfunc', + 'fromstring', + 'full', + 'generic', + 'geomspace', + 'get_printoptions', + 'gradient', + 'hamming', + 'hanning', + 'identity', + 'iinfo', + 'index_exp', + 'indices', + 'inexact', + 'inf', + 'int16', + 'int2', + 'int32', + 'int4', + 'int64', + 'int8', + 'int_', + 'integer', + 'isdtype', + 'issubdtype' + 'iterable' + 'kaiser' + 'kron' + 'ix_', + 'linalg', + 'linspace', + 'load', + 'logspace', + 'mask_indices', + 'mgrid', + 'nan', + 'ndarray', + 'newaxis', + 'number', + 'object_', + 'ogrid', + 'ones', + 'pi', + 'printoptions', + 'promote_types' + 'r_', + 'result_type', + 's_', + 'save', + 'savez', + 'set_printoptions', + 'signedinteger', + 'single', + 'tri', + 'tril_indices', + 'triu_indices', + 'ufunc', + 'uint', + 'uint16', + 'uint2', + 'uint32', + 'uint4', + 'uint64', + 'uint8', + 'unsignedinteger', + 'vectorize', + 'zeros', +] + +# TODO(jakevdp): commented APIs are ones which do not yet support +# __jax_array__ on inputs. We should fix these! +NUMPY_APIS = [ + NumPyAPI.sig(jnp.abs, Float[5]), + NumPyAPI.sig(jnp.absolute, Float[5]), + NumPyAPI.sig(jnp.acos, Float[5]), + NumPyAPI.sig(jnp.acosh, Float[5]), + NumPyAPI.sig(jnp.add, Float[5], Float[5]), + NumPyAPI.sig(jnp.all, Bool[5]), + NumPyAPI.sig(jnp.allclose, Float[5], Float[5]), + NumPyAPI.sig(jnp.amax, Float[5]), + NumPyAPI.sig(jnp.amin, Float[5]), + NumPyAPI.sig(jnp.angle, Float[5]), + NumPyAPI.sig(jnp.any, Float[5]), + NumPyAPI.sig(jnp.append, Float[10], Float[()]), + NumPyAPI.sig(jnp.arccos, Float[5]), + NumPyAPI.sig(jnp.arccosh, Float[5]), + NumPyAPI.sig(jnp.arcsin, Float[5]), + NumPyAPI.sig(jnp.arcsinh, Float[5]), + NumPyAPI.sig(jnp.arctan, Float[5]), + NumPyAPI.sig(jnp.arctan2, Float[5], Float[5]), + NumPyAPI.sig(jnp.arctanh, Float[5]), + NumPyAPI.sig(jnp.argmax, Float[10]), + NumPyAPI.sig(jnp.argmin, Float[10]), + NumPyAPI.sig(jnp.argpartition, Float[10], kth=5), + NumPyAPI.sig(jnp.argsort, Float[10]), + # NumPyAPI.sig(jnp.argwhere, [float], [(10,)]), + NumPyAPI.sig(jnp.around, Float[5]), + NumPyAPI.sig(jnp.array, Float[5]), + NumPyAPI.sig(jnp.array_equal, Float[5], Float[5]), + NumPyAPI.sig(jnp.array_equiv, Float[5], Float[5]), + # NumPyAPI.sig(jnp.array_repr, Float[5]), + NumPyAPI.sig(jnp.array_split, Float[9], indices_or_sections=3), + # NumPyAPI.sig(jnp.array_str, Float[5]), + NumPyAPI.sig(jnp.asarray, Float[5]), + NumPyAPI.sig(jnp.asin, Float[5]), + NumPyAPI.sig(jnp.asinh, Float[5]), + NumPyAPI.sig(jnp.atan, Float[5]), + NumPyAPI.sig(jnp.atan2, Float[5], Float[5]), + NumPyAPI.sig(jnp.atanh, Float[5]), + NumPyAPI.sig(jnp.atleast_1d, Float[5]), + NumPyAPI.sig(jnp.atleast_2d, Float[5]), + NumPyAPI.sig(jnp.atleast_3d, Float[5]), + NumPyAPI.sig(jnp.average, Float[10]), + # NumPyAPI.sig(jnp.bincount, int[10]), + NumPyAPI.sig(jnp.bitwise_and, Int[5], Int[5]), + NumPyAPI.sig(jnp.bitwise_count, Int[5]), + NumPyAPI.sig(jnp.bitwise_invert, Int[5]), + NumPyAPI.sig(jnp.bitwise_left_shift, Int[5], Int[5]), + NumPyAPI.sig(jnp.bitwise_not, Int[5]), + NumPyAPI.sig(jnp.bitwise_or, Int[5], Int[5]), + NumPyAPI.sig(jnp.bitwise_right_shift, Int[5], Int[5]), + NumPyAPI.sig(jnp.bitwise_xor, Int[5], Int[5]), + NumPyAPI.sig(jnp.broadcast_arrays, Float[5]), + # NumPyAPI.sig(jnp.broadcast_to, Float[()], shape=(10,)), + # NumPyAPI.sig(jnp.can_cast, Float[()], to='int32'), + NumPyAPI.sig(jnp.cbrt, Float[5]), + NumPyAPI.sig(jnp.ceil, Float[5]), + # NumPyAPI.sig(jnp.choose, [int, float], [(3,), (10,)]), + NumPyAPI.sig(jnp.clip, Float[5]), + # NumPyAPI.sig(jnp.column_stack, [float], [(3, 10)]), + NumPyAPI.sig(jnp.compress, Float[10], Bool[10]), + # NumPyAPI.sig(jnp.concat, [Float[5], Float[5]]), + # NumPyAPI.sig(jnp.concatenate, [Float[5], Float[5]]), + NumPyAPI.sig(jnp.conj, Float[5]), + NumPyAPI.sig(jnp.conjugate, Float[5]), + NumPyAPI.sig(jnp.convolve, Float[7], Float[3]), + NumPyAPI.sig(jnp.copy, Float[5]), + NumPyAPI.sig(jnp.copysign, Float[5], Float[5]), + NumPyAPI.sig(jnp.corrcoef, Float[7], Float[7]), + NumPyAPI.sig(jnp.correlate, Float[7], Float[3]), + NumPyAPI.sig(jnp.cos, Float[5]), + NumPyAPI.sig(jnp.cosh, Float[5]), + # NumPyAPI.sig(np.count_nonzero, [float], [(10,)]), + # NumPyAPI.sig(np.cov, [float], [(10,)]), + # NumPyAPI.sig(np.cross, [float, float], [(3,), (3,)]), + # NumPyAPI.sig(np.cumprod, [float], [(10,)]), + # NumPyAPI.sig(np.cumsum, [float], [(10,)]), + # NumPyAPI.sig(np.cumulative_prod, [float], [(10,)]), + # NumPyAPI.sig(np.cumulative_sum, [float], [(10,)]), + NumPyAPI.sig(jnp.deg2rad, Float[5]), + NumPyAPI.sig(jnp.degrees, Float[5]), + # NumPyAPI.sig(jnp.delete, Float[5], Int[()]), + NumPyAPI.sig(jnp.diag, Float[5]), + # NumPyAPI.sig(jnp.diag_indices_from, Float[5, 5]), + NumPyAPI.sig(jnp.diagflat, Float[5]), + NumPyAPI.sig(jnp.diagonal, Float[5, 5]), + NumPyAPI.sig(jnp.diff, Float[5]), + NumPyAPI.sig(jnp.digitize, Float[5], Float[5]), + NumPyAPI.sig(jnp.divide, Float[5], Float[5]), + NumPyAPI.sig(jnp.divmod, Float[5], Float[5]), + NumPyAPI.sig(jnp.dot, Float[5], Float[5]), + NumPyAPI.sig(jnp.dsplit, Float[3, 5, 6], indices_or_sections=2), + # NumPyAPI.sig(jnp.dstack, Float[3, 5]), + NumPyAPI.sig(jnp.ediff1d, Float[5]), + # NumPyAPI.sig(jnp.empty_like, Float[5]), + NumPyAPI.sig(jnp.equal, Float[5], Float[5]), + NumPyAPI.sig(jnp.exp, Float[5]), + NumPyAPI.sig(jnp.exp2, Float[5]), + # NumPyAPI.sig(jnp.expand_dims, Float[5], axis=0), + NumPyAPI.sig(jnp.expm1, Float[5]), + NumPyAPI.sig(jnp.extract, Bool[5], Float[5]), + NumPyAPI.sig(jnp.fabs, Float[5]), + NumPyAPI.sig(jnp.fft.fft, Float[5]), + NumPyAPI.sig(jnp.fft.fft2, Float[5, 5]), + NumPyAPI.sig(jnp.fft.ifft, Float[5]), + NumPyAPI.sig(jnp.fft.ifft2, Float[5, 5]), + NumPyAPI.sig(jnp.fill_diagonal, Float[5, 5], Float[()], inplace=False), + NumPyAPI.sig(jnp.fix, Float[5]), + NumPyAPI.sig(jnp.flatnonzero, Float[5]), + NumPyAPI.sig(jnp.flip, Float[5]), + NumPyAPI.sig(jnp.fliplr, Float[5, 5]), + NumPyAPI.sig(jnp.flipud, Float[5, 5]), + NumPyAPI.sig(jnp.float_power, Float[5], Float[5]), + NumPyAPI.sig(jnp.floor, Float[5]), + NumPyAPI.sig(jnp.floor_divide, Float[5], Float[5]), + NumPyAPI.sig(jnp.fmax, Float[5], Float[5]), + NumPyAPI.sig(jnp.fmin, Float[5], Float[5]), + NumPyAPI.sig(jnp.fmod, Float[5], Float[5]), + NumPyAPI.sig(jnp.frexp, Float[5]), + # NumPyAPI.sig(jnp.full_like, Float[5], Float[()]), + NumPyAPI.sig(jnp.gcd, Int[5], Int[5]), + NumPyAPI.sig(jnp.greater, Float[5], Float[5]), + NumPyAPI.sig(jnp.greater_equal, Float[5], Float[5]), + NumPyAPI.sig(jnp.heaviside, Float[5], Float[5]), + # NumPyAPI.sig(jnp.histogram, Float[5]), + NumPyAPI.sig(jnp.histogram2d, Float[5], Float[5]), + NumPyAPI.sig(jnp.histogram_bin_edges, Float[5]), + # NumPyAPI.sig(jnp.histogramdd, Float[5, 3]), + # NumPyAPI.sig(jnp.hsplit, Float[3, 5], Int[1]), + NumPyAPI.sig(jnp.hstack, (Float[5], Float[5])), + NumPyAPI.sig(jnp.hypot, Float[5], Float[5]), + NumPyAPI.sig(jnp.i0, Float[5]), + NumPyAPI.sig(jnp.imag, Complex[5]), + NumPyAPI.sig(jnp.inner, Float[5], Float[5]), + NumPyAPI.sig(jnp.insert, Float[5], Int[()], Float[2]), + NumPyAPI.sig(jnp.interp, Float[10], Float[5], Float[5]), + NumPyAPI.sig(jnp.intersect1d, Int[5], Int[5]), + NumPyAPI.sig(jnp.invert, Int[5]), + NumPyAPI.sig(jnp.isclose, Float[5], Float[5]), + NumPyAPI.sig(jnp.iscomplex, Float[5]), + NumPyAPI.sig(jnp.iscomplexobj, Complex[5]), + NumPyAPI.sig(jnp.isfinite, Float[5]), + NumPyAPI.sig(jnp.isin, Int[5], Int[10]), + NumPyAPI.sig(jnp.isinf, Float[5]), + NumPyAPI.sig(jnp.isnan, Float[5]), + # NumPyAPI.sig(jnp.isneginf, Float[5]), + # NumPyAPI.sig(jnp.isposinf, Float[5]), + NumPyAPI.sig(jnp.isreal, Float[5]), + NumPyAPI.sig(jnp.isrealobj, Float[5]), + NumPyAPI.sig(jnp.isscalar, Float[()]), + NumPyAPI.sig(jnp.lcm, Int[5], Int[5]), + NumPyAPI.sig(jnp.ldexp, Float[5], Int[5]), + NumPyAPI.sig(jnp.left_shift, Int[5], Int[5]), + NumPyAPI.sig(jnp.less, Float[5], Float[5]), + NumPyAPI.sig(jnp.less_equal, Float[5], Float[5]), + NumPyAPI.sig(jnp.lexsort, [Float[5], Float[5]]), + NumPyAPI.sig(jnp.log, Float[5]), + NumPyAPI.sig(jnp.log10, Float[5]), + NumPyAPI.sig(jnp.log1p, Float[5]), + NumPyAPI.sig(jnp.log2, Float[5]), + NumPyAPI.sig(jnp.logaddexp, Float[5], Float[5]), + NumPyAPI.sig(jnp.logaddexp2, Float[5], Float[5]), + NumPyAPI.sig(jnp.logical_and, Int[5], Int[5]), + NumPyAPI.sig(jnp.logical_not, Int[5]), + NumPyAPI.sig(jnp.logical_or, Int[5], Int[5]), + NumPyAPI.sig(jnp.logical_xor, Int[5], Int[5]), + NumPyAPI.sig(jnp.matmul, Float[5, 5], Float[5]), + # NumPyAPI.sig(jnp.matrix_transpose, Float[5, 6]), + NumPyAPI.sig(jnp.matvec, Float[5, 5], Float[5]), + NumPyAPI.sig(jnp.max, Float[5]), + NumPyAPI.sig(jnp.maximum, Float[5], Float[5]), + NumPyAPI.sig(jnp.mean, Float[5]), + NumPyAPI.sig(jnp.median, Float[5]), + NumPyAPI.sig(jnp.meshgrid, Float[5], Float[5]), + NumPyAPI.sig(jnp.min, Float[5]), + NumPyAPI.sig(jnp.minimum, Float[5], Float[5]), + NumPyAPI.sig(jnp.mod, Float[5], Float[5]), + NumPyAPI.sig(jnp.modf, Float[5]), + NumPyAPI.sig(jnp.moveaxis, Float[5, 3], source=0, destination=1), + NumPyAPI.sig(jnp.multiply, Float[5], Float[5]), + NumPyAPI.sig(jnp.nan_to_num, Float[5]), + NumPyAPI.sig(jnp.nanargmax, Float[5]), + NumPyAPI.sig(jnp.nanargmin, Float[5]), + NumPyAPI.sig(jnp.nancumprod, Float[5]), + NumPyAPI.sig(jnp.nancumsum, Float[5]), + NumPyAPI.sig(jnp.nanmax, Float[5]), + NumPyAPI.sig(jnp.nanmean, Float[5]), + NumPyAPI.sig(jnp.nanmedian, Float[5]), + NumPyAPI.sig(jnp.nanmin, Float[5]), + NumPyAPI.sig(jnp.nanpercentile, Float[5], q=75), + NumPyAPI.sig(jnp.nanprod, Float[5]), + NumPyAPI.sig(jnp.nanquantile, Float[5], q=0.75), + NumPyAPI.sig(jnp.nanstd, Float[5]), + NumPyAPI.sig(jnp.nansum, Float[5]), + NumPyAPI.sig(jnp.nanvar, Float[5]), + # NumPyAPI.sig(jnp.ndim, Float[5]), + NumPyAPI.sig(jnp.negative, Float[5]), + NumPyAPI.sig(jnp.nextafter, Float[5], Float[5]), + NumPyAPI.sig(jnp.nonzero, Float[5]), + NumPyAPI.sig(jnp.not_equal, Float[5], Float[5]), + # NumPyAPI.sig(jnp.ones_like, Float[5]), + NumPyAPI.sig(jnp.outer, Float[5], Float[5]), + NumPyAPI.sig(jnp.packbits, Int[5]), + # NumPyAPI.sig(jnp.pad, Float[5], pad_width=2), + NumPyAPI.sig(jnp.partition, Float[5], kth=3), + NumPyAPI.sig(jnp.percentile, Float[5], q=75), + NumPyAPI.sig(jnp.permute_dims, Float[3, 5], axes=(1, 0)), + NumPyAPI.sig(jnp.piecewise, Float[5], [Bool[5], Bool[5]], funclist=[jnp.sin, jnp.cos]), + NumPyAPI.sig(jnp.place, Float[5], Bool[5], Float[3], inplace=False), + NumPyAPI.sig(jnp.poly, Float[5]), + NumPyAPI.sig(jnp.polyadd, Float[5], Float[5]), + NumPyAPI.sig(jnp.polyder, Float[5]), + NumPyAPI.sig(jnp.polydiv, Float[5], Float[5]), + NumPyAPI.sig(jnp.polyfit, Float[5], Float[5], deg=2), + NumPyAPI.sig(jnp.polyint, Float[5]), + NumPyAPI.sig(jnp.polymul, Float[5], Float[5]), + NumPyAPI.sig(jnp.polysub, Float[5], Float[5]), + NumPyAPI.sig(jnp.polyval, Float[5], Float[10]), + NumPyAPI.sig(jnp.positive, Float[5]), + # NumPyAPI.sig(jnp.pow, Float[5], Float[5]), + # NumPyAPI.sig(jnp.power, Float[5], Float[5]), + NumPyAPI.sig(jnp.prod, Float[5]), + NumPyAPI.sig(jnp.ptp, Float[5]), + NumPyAPI.sig(jnp.put, Float[5], Int[()], Float[()], inplace=False), + NumPyAPI.sig(jnp.put_along_axis, Float[5], Int[1], Float[1], axis=0, inplace=False), + NumPyAPI.sig(jnp.quantile, Float[5], q=0.75), + NumPyAPI.sig(jnp.rad2deg, Float[5]), + NumPyAPI.sig(jnp.radians, Float[5]), + NumPyAPI.sig(jnp.ravel, Float[5]), + # NumPyAPI.sig(jnp.ravel_multi_index, Int[2, 5], dims=(2, 3)), + NumPyAPI.sig(jnp.real, Complex[5]), + NumPyAPI.sig(jnp.reciprocal, Float[5]), + NumPyAPI.sig(jnp.remainder, Float[5], Float[5]), + # NumPyAPI.sig(jnp.repeat, Float[5], Int[5]), + # NumPyAPI.sig(jnp.reshape, Float[6], (2, 3)), + NumPyAPI.sig(jnp.resize, Float[6], new_shape=(2, 3)), + NumPyAPI.sig(jnp.right_shift, Int[5], Int[5]), + NumPyAPI.sig(jnp.rint, Float[5]), + NumPyAPI.sig(jnp.roll, Float[5], Int[1]), + NumPyAPI.sig(jnp.rollaxis, Float[5, 4], axis=1), + NumPyAPI.sig(jnp.roots, Float[5]), + NumPyAPI.sig(jnp.rot90, Float[5, 3]), + NumPyAPI.sig(jnp.round, Float[5]), + NumPyAPI.sig(jnp.searchsorted, Float[5], Float[5]), + # NumPyAPI.sig(jnp.select, [Bool[5], Bool[5]], [Float[5], Float[5]], Float[5]), + NumPyAPI.sig(jnp.setdiff1d, Int[5], Int[5]), + NumPyAPI.sig(jnp.setxor1d, Int[5], Int[5]), + # NumPyAPI.sig(jnp.shape, Float[5, 3]), + NumPyAPI.sig(jnp.sign, Float[5]), + NumPyAPI.sig(jnp.signbit, Float[5]), + NumPyAPI.sig(jnp.sin, Float[5]), + NumPyAPI.sig(jnp.sinc, Float[5]), + NumPyAPI.sig(jnp.sinh, Float[5]), + # NumPyAPI.sig(jnp.size, Float[5]), + NumPyAPI.sig(jnp.sort, Float[5]), + NumPyAPI.sig(jnp.sort_complex, Complex[5]), + NumPyAPI.sig(jnp.spacing, Float[5]), + NumPyAPI.sig(jnp.split, Float[6], indices_or_sections=2), + NumPyAPI.sig(jnp.sqrt, Float[5]), + NumPyAPI.sig(jnp.square, Float[5]), + NumPyAPI.sig(jnp.squeeze, Float[5]), + # NumPyAPI.sig(jnp.stack, [Float[2, 3], Float[2, 3]], axis=1), + NumPyAPI.sig(jnp.std, Float[5]), + NumPyAPI.sig(jnp.subtract, Float[5], Float[5]), + NumPyAPI.sig(jnp.sum, Float[5]), + NumPyAPI.sig(jnp.swapaxes, Float[3, 5], axis1=1, axis2=0), + NumPyAPI.sig(jnp.take, Float[5], Int[2]), + NumPyAPI.sig(jnp.take_along_axis, Float[5], Int[2], axis=0), + NumPyAPI.sig(jnp.tan, Float[5]), + NumPyAPI.sig(jnp.tanh, Float[5]), + NumPyAPI.sig(jnp.tensordot, Float[2, 3, 4], Float[3, 4, 5]), + # NumPyAPI.sig(jnp.tile, Float[5], reps=(2,)), + NumPyAPI.sig(jnp.trace, Float[5, 5]), + # NumPyAPI.sig(jnp.transpose, Float[5, 6]), + NumPyAPI.sig(jnp.trapezoid, Float[5]), + NumPyAPI.sig(jnp.tril, Float[5, 6]), + # NumPyAPI.sig(jnp.tril_indices_from, Float[5, 6]), + NumPyAPI.sig(jnp.trim_zeros, Float[5]), + NumPyAPI.sig(jnp.triu, Float[5, 6]), + # NumPyAPI.sig(jnp.triu_indices_from, Float[5, 6]), + NumPyAPI.sig(jnp.true_divide, Float[5], Float[5]), + NumPyAPI.sig(jnp.trunc, Float[5]), + NumPyAPI.sig(jnp.union1d, Int[5], Int[5]), + NumPyAPI.sig(jnp.unique, Int[10]), + NumPyAPI.sig(jnp.unique_all, Int[10]), + NumPyAPI.sig(jnp.unique_counts, Int[10]), + NumPyAPI.sig(jnp.unique_inverse, Int[10]), + NumPyAPI.sig(jnp.unique_values, Int[10]), + NumPyAPI.sig(jnp.unpackbits, Uint8[8]), + NumPyAPI.sig(jnp.unravel_index, Int[5], shape=(2, 3)), + NumPyAPI.sig(jnp.unstack, Float[5]), + NumPyAPI.sig(jnp.unwrap, Float[5]), + NumPyAPI.sig(jnp.vander, Float[5]), + NumPyAPI.sig(jnp.var, Float[5]), + NumPyAPI.sig(jnp.vdot, Float[5], Float[5]), + NumPyAPI.sig(jnp.vecdot, Float[5], Float[5]), + NumPyAPI.sig(jnp.vecmat, Float[5], Float[5, 3]), + NumPyAPI.sig(jnp.vsplit, Float[6], indices_or_sections=2), + NumPyAPI.sig(jnp.vstack, [Float[5], Float[2, 5]]), + NumPyAPI.sig(jnp.where, Bool[5], Float[5], Float[5]), + # NumPyAPI.sig(jnp.zeros_like, Float[5]), +] + + +class JaxArrayTests(jtu.JaxTestCase): + @parameterized.named_parameters( + {'testcase_name': api.name(), 'api': api} for api in NUMPY_APIS) + def test_numpy_api_supports_jax_array(self, api): + fun = api.fun + args = api.make_args(self.rng()) + wrapped_args = jax.tree.map(JaxArrayWrapper, args) + kwargs = api.kwargs + + expected = fun(*args, **kwargs) + wrapped = fun(*wrapped_args, **kwargs) + + self.assertAllClose(wrapped, expected, atol=0, rtol=0) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) From 16dc0ad1dd475a5ea994f03d95127eb2a003d43b Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 19 Mar 2025 14:47:50 -0700 Subject: [PATCH 033/483] Add `jax_source_package` macros and target to generate a source package `.tar.gz`. Refactor `jax_wheel` macros, so it outputs a `.whl` file only. When the macros returns one output object only, it allows all downstream dependencies consume it easily without the need to filter the macros outputs. The previous implementation design (when `jax_wheel` returned `.tar.gz` and `.whl` files) required one of two options: either create a new target that produces `.whl` only, or to implement filename filtering in the downstream rules. With the new implementation we can just depend on `//:jax_wheel` target that produces the `.whl`. PiperOrigin-RevId: 738547491 --- BUILD.bazel | 18 ++++++- build/build.py | 3 ++ build_wheel.py | 17 +++++- jaxlib/jax.bzl | 100 +++++++++++++++++++++++++----------- jaxlib/tools/build_utils.py | 8 +-- 5 files changed, 110 insertions(+), 36 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index 33cbefd29f0b..eb43d7ec0fd8 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -15,6 +15,7 @@ load("@xla//third_party/py:python_wheel.bzl", "collect_data_files", "transitive_py_deps") load( "//jaxlib:jax.bzl", + "jax_source_package", "jax_wheel", ) @@ -67,7 +68,6 @@ py_binary( jax_wheel( name = "jax_wheel", - build_wheel_only = False, platform_independent = True, source_files = [ ":transitive_py_data", @@ -82,3 +82,19 @@ jax_wheel( wheel_binary = ":build_wheel", wheel_name = "jax", ) + +jax_source_package( + name = "jax_source_package", + source_files = [ + ":transitive_py_data", + ":transitive_py_deps", + "//jax:py.typed", + "AUTHORS", + "LICENSE", + "README.md", + "pyproject.toml", + "setup.py", + ], + source_package_binary = ":build_wheel", + source_package_name = "jax", +) diff --git a/build/build.py b/build/build.py index d38b911bb904..cdb568171b66 100755 --- a/build/build.py +++ b/build/build.py @@ -68,6 +68,7 @@ # rule as the default. WHEEL_BUILD_TARGET_DICT_NEW = { "jax": "//:jax_wheel", + "jax_source_package": "//:jax_source_package", "jaxlib": "//jaxlib/tools:jaxlib_wheel", "jax-cuda-plugin": "//jaxlib/tools:jax_cuda_plugin_wheel", "jax-cuda-pjrt": "//jaxlib/tools:jax_cuda_pjrt_wheel", @@ -661,6 +662,8 @@ async def main(): # Append the build target to the Bazel command. build_target = wheel_build_targets[wheel] wheel_build_command.append(build_target) + if args.use_new_wheel_build_rule and wheel == "jax": + wheel_build_command.append(wheel_build_targets["jax_source_package"]) if not args.use_new_wheel_build_rule: wheel_build_command.append("--") diff --git a/build_wheel.py b/build_wheel.py index f8e1595d3c3a..b4db96773527 100644 --- a/build_wheel.py +++ b/build_wheel.py @@ -47,6 +47,20 @@ parser.add_argument( "--srcs", help="source files for the wheel", action="append" ) +parser.add_argument( + "--build-wheel-only", + default=False, + help=( + "Whether to build the wheel only. Optional." + ), +) +parser.add_argument( + "--build-source-package-only", + default=False, + help=( + "Whether to build the source package only. Optional." + ), +) args = parser.parse_args() @@ -94,7 +108,8 @@ def prepare_srcs(deps: list[str], srcs_dir: str) -> None: args.output_path, package_name="jax", git_hash=args.jaxlib_git_hash, - build_wheel_only=False, + build_wheel_only=args.build_wheel_only, + build_source_package_only=args.build_source_package_only, ) finally: if tmpdir: diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 89f1545995d5..02e6b10b1de1 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -362,7 +362,7 @@ def _get_full_wheel_name( free_threaded_suffix = "t" if py_freethreaded.lower() == "yes" else "", ) -def _get_source_distribution_name(package_name, wheel_version): +def _get_source_package_name(package_name, wheel_version): return "{package_name}-{wheel_version}.tar.gz".format( package_name = package_name, wheel_version = wheel_version, @@ -394,37 +394,47 @@ def _jax_wheel_impl(ctx): no_abi = ctx.attr.no_abi platform_independent = ctx.attr.platform_independent build_wheel_only = ctx.attr.build_wheel_only + build_source_package_only = ctx.attr.build_source_package_only editable = ctx.attr.editable platform_name = ctx.attr.platform_name + + output_dir_path = "" + outputs = [] if editable: output_dir = ctx.actions.declare_directory(output_path + "/" + ctx.attr.wheel_name) - wheel_dir = output_dir.path + output_dir_path = output_dir.path outputs = [output_dir] args.add("--editable") else: - wheel_name = _get_full_wheel_name( - package_name = ctx.attr.wheel_name, - no_abi = no_abi, - platform_independent = platform_independent, - platform_name = platform_name, - cpu_name = cpu, - wheel_version = full_wheel_version, - py_freethreaded = py_freethreaded, - ) - wheel_file = ctx.actions.declare_file(output_path + - "/" + wheel_name) - wheel_dir = wheel_file.path[:wheel_file.path.rfind("/")] - outputs = [wheel_file] - if not build_wheel_only: - source_distribution_name = _get_source_distribution_name( + if build_wheel_only: + wheel_name = _get_full_wheel_name( package_name = ctx.attr.wheel_name, + no_abi = no_abi, + platform_independent = platform_independent, + platform_name = platform_name, + cpu_name = cpu, wheel_version = full_wheel_version, + py_freethreaded = py_freethreaded, ) - source_distribution_file = ctx.actions.declare_file(output_path + - "/" + source_distribution_name) - outputs.append(source_distribution_file) - - args.add("--output_path", wheel_dir) # required argument + wheel_file = ctx.actions.declare_file(output_path + + "/" + wheel_name) + output_dir_path = wheel_file.path[:wheel_file.path.rfind("/")] + outputs = [wheel_file] + if ctx.attr.wheel_name == "jax": + args.add("--build-wheel-only", "True") + if build_source_package_only: + source_package_name = _get_source_package_name( + package_name = ctx.attr.wheel_name, + wheel_version = full_wheel_version, + ) + source_package_file = ctx.actions.declare_file(output_path + + "/" + source_package_name) + output_dir_path = source_package_file.path[:source_package_file.path.rfind("/")] + outputs = [source_package_file] + if ctx.attr.wheel_name == "jax": + args.add("--build-source-package-only", "True") + + args.add("--output_path", output_dir_path) # required argument if not platform_independent: args.add("--cpu", cpu) args.add("--jaxlib_git_hash", git_hash) # required argument @@ -472,16 +482,17 @@ _jax_wheel = rule( "wheel_name": attr.string(mandatory = True), "no_abi": attr.bool(default = False), "platform_independent": attr.bool(default = False), - "build_wheel_only": attr.bool(default = True), + "build_wheel_only": attr.bool(mandatory = True, default = True), + "build_source_package_only": attr.bool(mandatory = True, default = False), "editable": attr.bool(default = False), - "cpu": attr.string(mandatory = True), - "platform_name": attr.string(mandatory = True), + "cpu": attr.string(), + "platform_name": attr.string(), "git_hash": attr.label(default = Label("//jaxlib/tools:jaxlib_git_hash")), "source_files": attr.label_list(allow_files = True), "output_path": attr.label(default = Label("//jaxlib/tools:output_path")), "enable_cuda": attr.bool(default = False), # A cuda/rocm version is required for gpu wheels; for cpu wheels, it can be an empty string. - "platform_version": attr.string(mandatory = True, default = ""), + "platform_version": attr.string(), "skip_gpu_kernels": attr.bool(default = False), "enable_rocm": attr.bool(default = False), "include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:include_cuda_libs")), @@ -498,7 +509,6 @@ def jax_wheel( wheel_name, no_abi = False, platform_independent = False, - build_wheel_only = True, editable = False, enable_cuda = False, enable_rocm = False, @@ -509,11 +519,10 @@ def jax_wheel( Common artifact attributes are grouped within a single macro. Args: - name: the name of the wheel + name: the target name wheel_binary: the binary to use to build the wheel wheel_name: the name of the wheel no_abi: whether to build a wheel without ABI - build_wheel_only: whether to build a wheel without source distribution editable: whether to build an editable wheel platform_independent: whether to build a wheel without platform tag enable_cuda: whether to build a cuda wheel @@ -522,7 +531,7 @@ def jax_wheel( source_files: the source files to include in the wheel Returns: - A directory containing the wheel + A wheel file or a wheel directory. """ _jax_wheel( name = name, @@ -530,7 +539,8 @@ def jax_wheel( wheel_name = wheel_name, no_abi = no_abi, platform_independent = platform_independent, - build_wheel_only = build_wheel_only, + build_wheel_only = True, + build_source_package_only = False, editable = editable, enable_cuda = enable_cuda, enable_rocm = enable_rocm, @@ -554,6 +564,34 @@ def jax_wheel( source_files = source_files, ) +def jax_source_package( + name, + source_package_binary, + source_package_name, + source_files = []): + """Create jax source package. + + Common artifact attributes are grouped within a single macro. + + Args: + name: the target name + source_package_binary: the binary to use to build the package + source_package_name: the name of the source package + source_files: the source files to include in the package + + Returns: + A jax source package file. + """ + _jax_wheel( + name = name, + wheel_binary = source_package_binary, + wheel_name = source_package_name, + build_source_package_only = True, + build_wheel_only = False, + platform_independent = True, + source_files = source_files, + ) + jax_test_file_visibility = [] jax_export_file_visibility = [] diff --git a/jaxlib/tools/build_utils.py b/jaxlib/tools/build_utils.py index 4c50cff16743..582a0c9f1d6f 100644 --- a/jaxlib/tools/build_utils.py +++ b/jaxlib/tools/build_utils.py @@ -65,6 +65,7 @@ def build_wheel( package_name: str, git_hash: str = "", build_wheel_only: bool = True, + build_source_package_only: bool = False, ) -> None: """Builds a wheel in `output_path` using the source tree in `sources_path`.""" env = dict(os.environ) @@ -78,7 +79,8 @@ def build_wheel( env["USERPROFILE"] = env.get("SYSTEMDRIVE", "C:") subprocess.run( [sys.executable, "-m", "build", "-n"] - + (["-w"] if build_wheel_only else []), + + (["-w"] if build_wheel_only else []) + + (["-s"] if build_source_package_only else []), check=True, cwd=sources_path, env=env, @@ -97,10 +99,10 @@ def build_wheel( sys.stderr.write(" bazel run //build:requirements.update" + f" --repo_env=HERMETIC_PYTHON_VERSION={py_version}\n\n") shutil.copy(wheel, output_path) - if not build_wheel_only: + if build_source_package_only: for dist in glob.glob(os.path.join(sources_path, "dist", "*.tar.gz")): output_file = os.path.join(output_path, os.path.basename(dist)) - sys.stderr.write(f"Output source distribution: {output_file}\n\n") + sys.stderr.write(f"Output source package: {output_file}\n\n") shutil.copy(dist, output_path) From f74711254feae1e1d0ba532ac4b3b56e388d036d Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 19 Mar 2025 15:40:40 -0700 Subject: [PATCH 034/483] Fix `lax_autodiff_test` on v5p PiperOrigin-RevId: 738565192 --- tests/lax_autodiff_test.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/lax_autodiff_test.py b/tests/lax_autodiff_test.py index a69f44f37754..aea9d2ad3dff 100644 --- a/tests/lax_autodiff_test.py +++ b/tests/lax_autodiff_test.py @@ -205,14 +205,16 @@ class LaxAutodiffTest(jtu.JaxTestCase): )) def testOpGrad(self, op, rng_factory, shapes, dtype, order, tol): rng = rng_factory(self.rng()) - if jtu.test_device_matches(["cpu"]): + if jtu.test_device_matches(["cpu", "tpu"]): if op is lax.cosh and dtype == np.complex64: - tol = 3e-1 # 2nd-order gradients are noisy on CPU + tol = 3e-1 # 2nd-order gradients are noisy on CPU and TPU if jtu.test_device_matches(["tpu"]): if op is lax.pow: raise SkipTest("pow grad imprecise on tpu") if op is lax.cos: order = 1 # 2nd-order gradient is imprecise on TPU. + if op is lax.sin: + order = 1 # 2nd-order gradient is imprecise on TPUv5p. if op is lax.log: order = 1 # 2nd-order gradient is imprecise on TPU. From 47dde87b9d734baf4b9f58f896305cdba0b9f484 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 19 Mar 2025 15:53:08 -0700 Subject: [PATCH 035/483] Use np.ones to avoid signed integer overflow at run time PiperOrigin-RevId: 738569856 --- tests/pjit_test.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 293b37a9fbc7..6cf11494988a 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -6360,8 +6360,8 @@ def f(x): def test_intermediate_einsum(self, mesh): shape1 = (8, 32, 1, 16) shape2 = (8, 32, 1, 8) - np_inp1 = np.arange(math.prod(shape1)).reshape(shape1) - np_inp2 = np.arange(math.prod(shape2)).reshape(shape2) + np_inp1 = np.ones(math.prod(shape1)).reshape(shape1) + np_inp2 = np.ones(math.prod(shape2)).reshape(shape2) s = NamedSharding(mesh, P('data')) arr1 = jax.device_put(np_inp1, s) @@ -6387,9 +6387,9 @@ def test_intermediate_einsum_auto_complete_spec(self, mesh): shape1 = (8, 32, 2*16) shape2 = (8, 32, 2, 8) shape3 = (8, 32, 2, 8) - np_inp1 = np.arange(math.prod(shape1)).reshape(shape1) - np_inp2 = np.arange(math.prod(shape2)).reshape(shape2) - np_inp3 = np.arange(math.prod(shape3)).reshape(shape3) + np_inp1 = np.ones(math.prod(shape1)).reshape(shape1) + np_inp2 = np.ones(math.prod(shape2)).reshape(shape2) + np_inp3 = np.ones(math.prod(shape3)).reshape(shape3) arr1 = jax.device_put(np_inp1, s) arr2 = jax.device_put(np_inp2, s) @@ -6436,8 +6436,8 @@ def f(condition, x, y): def test_intermediate_einsum_conflict_error(self, mesh): shape1 = (8, 32, 1, 16) shape2 = (8, 32, 1, 8) - np_inp1 = np.arange(math.prod(shape1)).reshape(shape1) - np_inp2 = np.arange(math.prod(shape2)).reshape(shape2) + np_inp1 = np.ones(math.prod(shape1)).reshape(shape1) + np_inp2 = np.ones(math.prod(shape2)).reshape(shape2) arr1 = jax.device_put( np_inp1, NamedSharding(mesh, P(None, None, None, 'data'))) From ab42a3e6382a0e2eedcc63176f36ec0d48f617bf Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Wed, 12 Mar 2025 15:09:57 +0200 Subject: [PATCH 036/483] Fix betainc edge cases and inaccuracies when a is close to zero. --- jax/_src/lax/special.py | 45 ++++++++++++++---- jax/_src/test_util.py | 2 +- tests/lax_scipy_special_functions_test.py | 56 +++++++++++++++++++---- 3 files changed, 84 insertions(+), 19 deletions(-) diff --git a/jax/_src/lax/special.py b/jax/_src/lax/special.py index b70513bc2d20..ba2687d4acd7 100644 --- a/jax/_src/lax/special.py +++ b/jax/_src/lax/special.py @@ -194,12 +194,18 @@ def nth_partial_betainc_numerator(iteration, a, b, x): iteration_is_one = eq(iteration_bcast, full_like(iteration_bcast, 1)) iteration_minus_one = iteration_bcast - full_like(iteration_bcast, 1) m = iteration_minus_one // full_like(iteration_minus_one, 2) + m_is_zero = eq(m, full_like(m, 0)) m = convert_element_type(m, dtype) one = full_like(a, 1) two = full_like(a, 2.0) # Partial numerator terms - even_numerator = -(a + m) * (a + b + m) * x / ( - (a + two * m) * (a + two * m + one)) + + # When a is close to zero and m == 0, using zero_numerator avoids + # inaccuracies when FTZ or DAZ is enabled: + zero_numerator = -(a + b) * x / (a + one) + even_numerator = select(m_is_zero, zero_numerator, + -(a + m) * (a + b + m) * x / ( + (a + two * m) * (a + two * m + one))) odd_numerator = m * (b - m) * x / ((a + two * m - one) * (a + two * m)) one_numerator = full_like(x, 1.0) numerator = select(iteration_is_even, even_numerator, odd_numerator) @@ -210,12 +216,24 @@ def nth_partial_betainc_denominator(iteration, a, b, x): return select(eq(iteration_bcast, full_like(iteration_bcast, 0)), full_like(x, 0), full_like(x, 1)) + a_is_zero = bitwise_or(eq(a, full_like(a, 0)), eq(b, full_like(b, float('inf')))) + b_is_zero = bitwise_or(eq(b, full_like(b, 0)), eq(a, full_like(a, float('inf')))) + x_is_zero = eq(x, full_like(x, 0)) + x_is_one = eq(x, full_like(x, 1)) + x_is_not_zero = bitwise_not(x_is_zero) + x_is_not_one = bitwise_not(x_is_one) + is_nan = bitwise_or(bitwise_or(_isnan(a), _isnan(b)), _isnan(x)) + + result_is_zero = bitwise_or(bitwise_and(b_is_zero, x_is_not_one), bitwise_and(a_is_zero, x_is_zero)) + result_is_one = bitwise_or(bitwise_and(a_is_zero, x_is_not_zero), bitwise_and(b_is_zero, x_is_one)) + result_is_nan = bitwise_or(bitwise_or(bitwise_or( - le(a, full_like(a, 0)), le(b, full_like(b, 0))), + lt(a, full_like(a, 0)), lt(b, full_like(b, 0))), lt(x, full_like(x, 0))), gt(x, full_like(x, 1))) + result_is_nan = bitwise_or(result_is_nan, bitwise_or(bitwise_and(a_is_zero, b_is_zero), is_nan)) - # The continued fraction will converge rapidly when x < (a+1)/(a+b+2) - # as per: http://dlmf.nist.gov/8.17.E23 + # The continued fraction will converge rapidly when x < + # (a+1)/(a+b+2) as per: http://dlmf.nist.gov/8.17.E23. # # Otherwise, we can rewrite using the symmetry relation as per: # http://dlmf.nist.gov/8.17.E4 @@ -234,10 +252,21 @@ def nth_partial_betainc_denominator(iteration, a, b, x): inputs=[a, b, x] ) - lbeta_ab = lgamma(a) + lgamma(b) - lgamma(a + b) - result = continued_fraction * exp(log(x) * a + log1p(-x) * b - lbeta_ab) / a + # For very small a and to avoid division by zero, we'll use + # a * gamma(a) = gamma(a + 1) -> 1 as a -> 0+. + very_small = (dtypes.finfo(dtype).tiny * 2).astype(dtype) + lbeta_ab_small_a = lgamma(b) - lgamma(a + b) + lbeta_ab = lgamma(a) + lbeta_ab_small_a + factor = select(lt(a, full_like(a, very_small)), + exp(log1p(-x) * b - lbeta_ab_small_a), + exp(log(x) * a + log1p(-x) * b - lbeta_ab) / a) + result = continued_fraction * factor + result = select(converges_rapidly, result, sub(full_like(result, 1), result)) + + result = select(result_is_zero, full_like(a, 0), result) + result = select(result_is_one, full_like(a, 1), result) result = select(result_is_nan, full_like(a, float('nan')), result) - return select(converges_rapidly, result, sub(full_like(result, 1), result)) + return result class IgammaMode(Enum): VALUE = 1 diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 0dc4fe641029..3a18d12e9d4b 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -1522,7 +1522,7 @@ def _CheckAgainstNumpy(self, numpy_reference_op, lax_op, args_maker, args = args_maker() lax_ans = lax_op(*args) numpy_ans = numpy_reference_op(*args) - self.assertAllClose(numpy_ans, lax_ans, check_dtypes=check_dtypes, + self.assertAllClose(lax_ans, numpy_ans, check_dtypes=check_dtypes, atol=atol or tol, rtol=rtol or tol, canonicalize_dtypes=canonicalize_dtypes) diff --git a/tests/lax_scipy_special_functions_test.py b/tests/lax_scipy_special_functions_test.py index f4e4e4f48213..4b3945a84453 100644 --- a/tests/lax_scipy_special_functions_test.py +++ b/tests/lax_scipy_special_functions_test.py @@ -288,35 +288,71 @@ def testExpiDisableJit(self): self.assertAllClose(result_jit, result_nojit) def testGammaIncBoundaryValues(self): - dtype = jax.numpy.zeros(0).dtype # default float dtype. + dtype = jax.dtypes.canonicalize_dtype(float) nan = float('nan') inf = float('inf') if jtu.parse_version(scipy.__version__) >= (1, 16): - samples_slice = slice(None) + a_samples = [0, 0, 0, 1, nan, 1, nan, 0, 1, 1, nan] + x_samples = [0, 1, 2, 0, 1, nan, nan, inf, inf, -1, inf] else: # disable samples that contradict with scipy/scipy#22441 - samples_slice = slice(None, -1) - args_maker = lambda: [np.array([0, 0, 0, 1, nan, 1, nan, 0, 1, nan][samples_slice]).astype(dtype), - np.array([0, 1, 2, 0, 1, nan, nan, inf, inf, inf][samples_slice]).astype(dtype)] + a_samples = [0, 0, 0, 1, nan, 1, nan, 0, 1, 1] + x_samples = [0, 1, 2, 0, 1, nan, nan, inf, inf, -1] + + args_maker = lambda: (np.array(a_samples, dtype=dtype), np.array(x_samples, dtype=dtype)) + rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 1e-5 self._CheckAgainstNumpy(lsp_special.gammainc, osp_special.gammainc, args_maker, rtol=rtol) self._CompileAndCheck(lsp_special.gammainc, args_maker, rtol=rtol) def testGammaIncCBoundaryValues(self): - dtype = jax.numpy.zeros(0).dtype # default float dtype. + dtype = jax.dtypes.canonicalize_dtype(float) nan = float('nan') inf = float('inf') if jtu.parse_version(scipy.__version__) >= (1, 16): - samples_slice = slice(None) + a_samples = [0, 0, 0, 1, nan, 1, nan, 0, 1, 1, nan] + x_samples = [0, 1, 2, 0, 1, nan, nan, inf, inf, -1, inf] else: # disable samples that contradict with scipy/scipy#22441 - samples_slice = slice(None, -1) - args_maker = lambda: [np.array([0, 0, 0, 1, nan, 1, nan, 0, 1, 1, nan][samples_slice]).astype(dtype), - np.array([0, 1, 2, 0, 1, nan, nan, inf, inf, -1, inf][samples_slice]).astype(dtype)] + a_samples = [0, 0, 0, 1, nan, 1, nan, 0, 1, 1] + x_samples = [0, 1, 2, 0, 1, nan, nan, inf, inf, -1] + + args_maker = lambda: (np.array(a_samples, dtype=dtype), np.array(x_samples, dtype=dtype)) + rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 1e-5 self._CheckAgainstNumpy(lsp_special.gammaincc, osp_special.gammaincc, args_maker, rtol=rtol) self._CompileAndCheck(lsp_special.gammaincc, args_maker, rtol=rtol) + def testBetaIncBoundaryValues(self): + dtype = jax.dtypes.canonicalize_dtype(float) + fi = jax.numpy.finfo(dtype) + nan = float('nan') + inf = float('inf') + tiny = fi.tiny + eps = fi.eps + if jtu.parse_version(scipy.__version__) >= (1, 16): + # TODO(pearu): enable tiny samples when a fix to scipy/scipy#22682 + # will be available + a_samples = [nan, -0.5, inf, 0, eps, 1, tiny][:-1] + b_samples = [nan, -0.5, inf, 0, eps, 1, tiny][:-1] + elif jtu.parse_version(scipy.__version__) >= (1, 12): + # disabled samples that contradict with scipy/scipy#22425 + a_samples = [nan, -0.5, 0.5] + b_samples = [nan, -0.5, 0.5] + else: + a_samples = [-0.5, 0.5] + b_samples = [-0.5, 0.5] + x_samples = [nan, -0.5, 0, 0.5, 1, 1.5] + + a_samples = np.array(a_samples, dtype=dtype) + b_samples = np.array(b_samples, dtype=dtype) + x_samples = np.array(x_samples, dtype=dtype) + + args_maker = lambda: np.meshgrid(a_samples, b_samples, x_samples) + + rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 5e-5 + self._CheckAgainstNumpy(osp_special.betainc, lsp_special.betainc, args_maker, rtol=rtol) + self._CompileAndCheck(lsp_special.betainc, args_maker, rtol=rtol) if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 2562da7026ccd930e5f0972598c7d5479175b787 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 19 Mar 2025 18:26:36 -0700 Subject: [PATCH 037/483] Expose profiler_data submodule from XLA to Jaxlib. PiperOrigin-RevId: 738613439 --- jaxlib/setup.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/jaxlib/setup.py b/jaxlib/setup.py index b3a37a25f1b2..60f17a987307 100644 --- a/jaxlib/setup.py +++ b/jaxlib/setup.py @@ -68,10 +68,10 @@ def has_ext_modules(self): url='https://github.com/jax-ml/jax', license='Apache-2.0', classifiers=[ - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - "Programming Language :: Python :: 3.13", + 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', + 'Programming Language :: Python :: 3.13', ], package_data={ 'jaxlib': [ @@ -105,7 +105,7 @@ def has_ext_modules(self): 'triton/*.so', 'include/xla/ffi/api/*.h', ], - 'jaxlib.xla_extension': ['*.pyi'], + 'jaxlib.xla_extension': ['*.pyi', 'profiler/*.pyi'], }, zip_safe=False, distclass=BinaryDistribution, From b5c467e6cf702160be29ee93084f3f9a0da2b888 Mon Sep 17 00:00:00 2001 From: carlosgmartin Date: Wed, 19 Mar 2025 23:56:24 -0400 Subject: [PATCH 038/483] Fix doc for random.categorical replace argument. --- jax/_src/random.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax/_src/random.py b/jax/_src/random.py index 094268c65825..c0663dc67f80 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -1568,8 +1568,8 @@ def categorical( shape: Optional, a tuple of nonnegative integers representing the result shape. Must be broadcast-compatible with ``np.delete(logits.shape, axis)``. The default (None) produces a result shape equal to ``np.delete(logits.shape, axis)``. - replace: If True, perform sampling without replacement. Default (False) is to - perform sampling with replacement. + replace: If True (default), perform sampling with replacement. If False, perform + sampling without replacement. Returns: A random array with int dtype and shape given by ``shape`` if ``shape`` From 258ed1b0a5bd56b797d2ca47627db539c1be81f8 Mon Sep 17 00:00:00 2001 From: Yunlong Liu Date: Thu, 20 Mar 2025 04:03:11 +0000 Subject: [PATCH 039/483] Fixes the stream annotation compute on box. --- jax/_src/interpreters/mlir.py | 5 ++++- tests/memories_test.py | 22 +++++++++++++++++----- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 1369f72ac74c..4a723ffe5227 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -2340,7 +2340,10 @@ def wrap_compute_type_in_place(ctx, op): if ctx.jaxpr_eqn_ctx is not None and ctx.jaxpr_eqn_ctx.compute_type is not None: if ctx.jaxpr_eqn_ctx.compute_type.startswith("gpu_stream:"): stream = ctx.jaxpr_eqn_ctx.compute_type.split(":")[1] - dict_attr = {"_xla_stream_annotation": ir.StringAttr.get(stream)} + dict_attr = { + "_xla_stream_annotation": ir.StringAttr.get(stream), + "inlineable": ir.StringAttr.get("false"), + } op.operation.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(dict_attr) else: dict_attr = {"_xla_compute_type": ir.StringAttr.get( diff --git a/tests/memories_test.py b/tests/memories_test.py index 0ca973c4d221..bdb88b418697 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -1670,24 +1670,36 @@ def test_stream_annotation_inside_shmap(self): arr1 = jax.device_put(np_inp, s) arr2 = jax.device_put(np_inp, s) + # Makes sure the compute wrapped here is fusible. + # This is a workaround for limitations in XLA. + # 1) Compute-on boxes contain a single instruction cannot work. + # 2) Compute-on boxes contain tiny matmul cannot work. @compute_on('gpu_stream:1') @jax.jit def g(x, y): - return x * y + return x * y + x @compute_on('gpu_stream:2') @jax.jit def h(x, y): - return x * y + return x * y + x def f(x, y): z = g(x, y) w = h(3 * x, 2 * y) return z + w - out = jax.jit(shard_map(f, mesh=mesh, in_specs=(P('x'), P('x')), - out_specs=P('x')))(arr1, arr2) - self.assertArraysEqual(out, arr1 * 7) + compiled_f = jax.jit( + shard_map(f, mesh=mesh, in_specs=(P('x'), P('x')), + out_specs=P('x'))).lower(arr1, arr2).compile( + {"xla_gpu_experimental_stream_annotation": True} + ) + compiled_text = compiled_f.as_text() + self.assertIn('call-start', compiled_text) + self.assertIn('_xla_stream_annotation="1"', compiled_text) + self.assertIn('call-start.1', compiled_f.as_text()) + self.assertIn('_xla_stream_annotation="2"', compiled_text) + self.assertArraysEqual(compiled_f(arr1, arr2), arr1 * 11) class ActivationOffloadingTest(jtu.JaxTestCase): From e0c093314d8d9a6f68953f0c340c1b01d50ce386 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 19 Mar 2025 21:24:30 -0700 Subject: [PATCH 040/483] Remove ; in code blocks of `thinking_in_jax.md` PiperOrigin-RevId: 738656531 --- docs/notebooks/thinking_in_jax.ipynb | 4 ++-- docs/notebooks/thinking_in_jax.md | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/notebooks/thinking_in_jax.ipynb b/docs/notebooks/thinking_in_jax.ipynb index 5ddcdd32e2b4..560e0500ad13 100644 --- a/docs/notebooks/thinking_in_jax.ipynb +++ b/docs/notebooks/thinking_in_jax.ipynb @@ -57,7 +57,7 @@ "\n", "x_np = np.linspace(0, 10, 1000)\n", "y_np = 2 * np.sin(x_np) * np.cos(x_np)\n", - "plt.plot(x_np, y_np);" + "plt.plot(x_np, y_np)" ] }, { @@ -91,7 +91,7 @@ "\n", "x_jnp = jnp.linspace(0, 10, 1000)\n", "y_jnp = 2 * jnp.sin(x_jnp) * jnp.cos(x_jnp)\n", - "plt.plot(x_jnp, y_jnp);" + "plt.plot(x_jnp, y_jnp)" ] }, { diff --git a/docs/notebooks/thinking_in_jax.md b/docs/notebooks/thinking_in_jax.md index 0693f6ba8579..b107f78635f6 100644 --- a/docs/notebooks/thinking_in_jax.md +++ b/docs/notebooks/thinking_in_jax.md @@ -42,7 +42,7 @@ import numpy as np x_np = np.linspace(0, 10, 1000) y_np = 2 * np.sin(x_np) * np.cos(x_np) -plt.plot(x_np, y_np); +plt.plot(x_np, y_np) ``` ```{code-cell} ipython3 @@ -53,7 +53,7 @@ import jax.numpy as jnp x_jnp = jnp.linspace(0, 10, 1000) y_jnp = 2 * jnp.sin(x_jnp) * jnp.cos(x_jnp) -plt.plot(x_jnp, y_jnp); +plt.plot(x_jnp, y_jnp) ``` +++ {"id": "kTZcsCJiuPG8"} From 4da751a97a2a7837e977ecde77cd4ba0a05cfda5 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 19 Mar 2025 21:50:36 -0700 Subject: [PATCH 041/483] Reverts e0c093314d8d9a6f68953f0c340c1b01d50ce386 PiperOrigin-RevId: 738662342 --- docs/notebooks/thinking_in_jax.ipynb | 4 ++-- docs/notebooks/thinking_in_jax.md | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/notebooks/thinking_in_jax.ipynb b/docs/notebooks/thinking_in_jax.ipynb index 560e0500ad13..5ddcdd32e2b4 100644 --- a/docs/notebooks/thinking_in_jax.ipynb +++ b/docs/notebooks/thinking_in_jax.ipynb @@ -57,7 +57,7 @@ "\n", "x_np = np.linspace(0, 10, 1000)\n", "y_np = 2 * np.sin(x_np) * np.cos(x_np)\n", - "plt.plot(x_np, y_np)" + "plt.plot(x_np, y_np);" ] }, { @@ -91,7 +91,7 @@ "\n", "x_jnp = jnp.linspace(0, 10, 1000)\n", "y_jnp = 2 * jnp.sin(x_jnp) * jnp.cos(x_jnp)\n", - "plt.plot(x_jnp, y_jnp)" + "plt.plot(x_jnp, y_jnp);" ] }, { diff --git a/docs/notebooks/thinking_in_jax.md b/docs/notebooks/thinking_in_jax.md index b107f78635f6..0693f6ba8579 100644 --- a/docs/notebooks/thinking_in_jax.md +++ b/docs/notebooks/thinking_in_jax.md @@ -42,7 +42,7 @@ import numpy as np x_np = np.linspace(0, 10, 1000) y_np = 2 * np.sin(x_np) * np.cos(x_np) -plt.plot(x_np, y_np) +plt.plot(x_np, y_np); ``` ```{code-cell} ipython3 @@ -53,7 +53,7 @@ import jax.numpy as jnp x_jnp = jnp.linspace(0, 10, 1000) y_jnp = 2 * jnp.sin(x_jnp) * jnp.cos(x_jnp) -plt.plot(x_jnp, y_jnp) +plt.plot(x_jnp, y_jnp); ``` +++ {"id": "kTZcsCJiuPG8"} From 58ba4106c33752856738bbf5f22cd16854eb2b22 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Wed, 19 Mar 2025 21:51:15 -0700 Subject: [PATCH 042/483] [mosaic_gpu] Check for dropped activity records in cupti profiler. PiperOrigin-RevId: 738662559 --- jaxlib/mosaic/gpu/mosaic_gpu_ext.cc | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc index a726acd4d662..f91018cf7287 100644 --- a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc +++ b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc @@ -195,6 +195,12 @@ void callback_complete(CUcontext context, uint32_t streamId, THROW_IF_CUPTI_ERROR(status); } } + + size_t num_dropped; + THROW_IF_CUPTI_ERROR( + cuptiActivityGetNumDroppedRecords(context, streamId, &num_dropped), + "failed to get number of dropped activity records"); + THROW_IF(num_dropped > 0, "activity records were dropped"); } NB_MODULE(_mosaic_gpu_ext, m) { From 509c65895dd3e6011e08269f2b1c61ba620ce7ed Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Wed, 19 Mar 2025 23:41:08 -0700 Subject: [PATCH 043/483] [mosaic_gpu] Make cupti finalization optional. cupti initialization / finalization is somewhat expensive. This gives us the option of avoiding repeated initialization when performing multiple cupti timings. Disable kernel activity to ensure we've restored cupti to its original state. PiperOrigin-RevId: 738685851 --- jaxlib/mosaic/gpu/mosaic_gpu_ext.cc | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc index f91018cf7287..2c7242b6e6c0 100644 --- a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc +++ b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc @@ -245,15 +245,23 @@ NB_MODULE(_mosaic_gpu_ext, m) { cuptiActivityEnable(CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL), "failed to enable tracking of kernel activity by CUPTI"); }); - m.def("_cupti_get_timings", []() { - THROW_IF_CUPTI_ERROR( - cuptiActivityFlushAll(CUPTI_ACTIVITY_FLAG_FLUSH_FORCED), - "failed to flush CUPTI activity buffers"); - THROW_IF_CUPTI_ERROR(cuptiFinalize(), "failed to detach CUPTI"); - THROW_IF_CUPTI_ERROR(cuptiUnsubscribe(profiler_state.subscriber), - "failed to unsubscribe from CUPTI"); - return profiler_state.timings; - }); + m.def( + "_cupti_get_timings", + [](bool finalize) { + THROW_IF_CUPTI_ERROR( + cuptiActivityDisable(CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL), + "failed to disable tracking of kernel activity by CUPTI"); + THROW_IF_CUPTI_ERROR( + cuptiActivityFlushAll(CUPTI_ACTIVITY_FLAG_FLUSH_FORCED), + "failed to flush CUPTI activity buffers"); + if (finalize) { + THROW_IF_CUPTI_ERROR(cuptiFinalize(), "failed to detach CUPTI"); + } + THROW_IF_CUPTI_ERROR(cuptiUnsubscribe(profiler_state.subscriber), + "failed to unsubscribe from CUPTI"); + return profiler_state.timings; + }, + nb::arg("finalize") = true); } } // namespace From 6e204171f53d08ead0238f25e647ad4a41367c5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Longeri?= Date: Wed, 19 Mar 2025 23:47:35 -0700 Subject: [PATCH 044/483] [Mosaic:TPU] Add overload to ComputeTileStrides that just takes a shape. PiperOrigin-RevId: 738687016 --- jaxlib/mosaic/dialect/tpu/util.cc | 12 ++++++------ jaxlib/mosaic/dialect/tpu/util.h | 9 ++++++++- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/util.cc b/jaxlib/mosaic/dialect/tpu/util.cc index 651cef85f740..44cc301d6f0d 100644 --- a/jaxlib/mosaic/dialect/tpu/util.cc +++ b/jaxlib/mosaic/dialect/tpu/util.cc @@ -45,18 +45,18 @@ std::ostream &operator<<(std::ostream &os, Print p) { return os; } -SmallVector ComputeTileStrides(MemRefType memref_ty, +SmallVector ComputeTileStrides(absl::Span shape, absl::Span tiling) { - SmallVector tile_strides(memref_ty.getRank()); + SmallVector tile_strides(shape.size()); int64_t stride = 1; - for (int64_t i = 0; i < memref_ty.getRank(); ++i) { - int64_t idx = memref_ty.getRank() - 1 - i; + for (int64_t i = 0; i < shape.size(); ++i) { + int64_t idx = shape.size() - 1 - i; int64_t tiling_idx = tiling.size() - 1 - i; tile_strides[idx] = stride; if (tiling_idx >= 0) { - stride *= llvm::divideCeil(memref_ty.getShape()[idx], tiling[tiling_idx]); + stride *= llvm::divideCeil(shape[idx], tiling[tiling_idx]); } else { - stride *= memref_ty.getShape()[idx]; + stride *= shape[idx]; } } return tile_strides; diff --git a/jaxlib/mosaic/dialect/tpu/util.h b/jaxlib/mosaic/dialect/tpu/util.h index 2e19cb820b5b..f9ab1b7e349d 100644 --- a/jaxlib/mosaic/dialect/tpu/util.h +++ b/jaxlib/mosaic/dialect/tpu/util.h @@ -192,8 +192,15 @@ std::string shapeToString(const T &shape) { return os.str(); } -SmallVector ComputeTileStrides(MemRefType memref_ty, +SmallVector ComputeTileStrides(absl::Span shape, absl::Span tiling); + +inline SmallVector ComputeTileStrides( + MemRefType memref_ty, absl::Span tiling) { + absl::Span shape(memref_ty.getShape().data(), + memref_ty.getShape().size()); + return ComputeTileStrides(shape, tiling); +} // Assuming MKN matmul - This function must only be called after // canonicalization passes. // From 2d43fb473001af5f3a779b7d9ce7f54bf0ae1fe6 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Thu, 20 Mar 2025 02:07:33 -0700 Subject: [PATCH 045/483] =?UTF-8?q?[Mosaic=20GPU]=C2=A0Introduce=20an=20op?= =?UTF-8?q?timization=20barrier=20op.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Also add layout inference and lowering rules for it. Its initial use case will be to fence WGMMA accumulator registers. As a result, transform inference is not immediately useful for this op, and we omit it here. PiperOrigin-RevId: 738718000 --- .../mosaic/gpu/dialect_lowering.py | 33 ++++++++++++ .../mosaic/gpu/layout_inference.py | 51 ++++++++++++++++++- jaxlib/mosaic/dialect/gpu/mosaic_gpu.td | 29 +++++++++++ tests/mosaic/gpu_layout_inference_test.py | 51 +++++++++++++++++++ 4 files changed, 162 insertions(+), 2 deletions(-) diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index ae702d50ebb7..936bba73915b 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -35,6 +35,7 @@ from jax._src.lib.mlir.dialects import nvvm from jax._src.lib.mlir.dialects import scf from jax._src.lib.mlir.dialects import vector +from jax._src.util import safe_zip from jax.experimental.mosaic.gpu import layouts as layouts_lib import numpy as np @@ -203,6 +204,38 @@ def _initialize_barrier_op_lowering_rule( barrier_base_ptr, initialize_barrier_op.barriers_ref.type), +# TODO(bchetioui): remove once minimum jaxlib >= 0.5.3. +OptimizationBarrierOp = getattr(mgpu, "OptimizationBarrierOp", None) + + +@_register_lowering(OptimizationBarrierOp) +def _optimization_barrier_op_lowering_rule( + _: LoweringContext, + op: OptimizationBarrierOp, +) -> Sequence[ir.Value]: + if not all(ir.VectorType.isinstance(operand.type) for operand in op.operands): + raise NotImplementedError( + f"Optimization barrier op {op} has non-vector operands." + ) + + fragmented_arrays = [] + for operand, layout in safe_zip(op.operands, inference_utils.in_layouts(op)): + ty = ir.VectorType(operand.type) + is_signed = False if ir.IntegerType.isinstance(ty.element_type) else None + fragmented_arrays.append( + _fragmented_array_from_ir(operand, layout, is_signed=is_signed) + ) + + lowered_fragmented_arrays = fa.optimization_barrier(*fragmented_arrays) + if isinstance(lowered_fragmented_arrays, fa.FragmentedArray): + lowered_fragmented_arrays = [lowered_fragmented_arrays] + + return [ + _fragmented_array_to_ir(arr, result.type) + for arr, result in safe_zip(lowered_fragmented_arrays, op.results) + ] + + @_register_lowering(arith.ConstantOp) def _arith_constant_op_lowering_rule( _: LoweringContext, op: arith.ConstantOp diff --git a/jax/experimental/mosaic/gpu/layout_inference.py b/jax/experimental/mosaic/gpu/layout_inference.py index 470b0d328d8e..dec75e4db1a0 100644 --- a/jax/experimental/mosaic/gpu/layout_inference.py +++ b/jax/experimental/mosaic/gpu/layout_inference.py @@ -44,7 +44,9 @@ def _add_layout_inference_rule(op: type[ir.OpView], rule: LayoutInferenceRule): - _layout_inference_rules[op.OPERATION_NAME] = rule # pytype: disable=attribute-error + if op is not None: + _layout_inference_rules[op.OPERATION_NAME] = rule # pytype: disable=attribute-error + return rule def _set_layout_attributes( @@ -192,7 +194,7 @@ def is_array(v: ir.Value) -> bool: # This is left for a future change, and currently we only do "down # propagation". layout = _choose_representative_layout(layouts) - # It is unsafe to t conclude that this op produces a splat if not all inputs + # It is unsafe to conclude that this op produces a splat if not all inputs # have been inferred: some of them might turn out not to be splats! if layouts_lib.is_splat_fragmented_layout(layout) and not all_inputs_have_layout: return None @@ -247,6 +249,51 @@ def is_array(v: ir.Value) -> bool: _add_layout_inference_rule(op, _infer_pointwise_op_layouts) +# TODO(bchetioui): remove once minimum jaxlib >= 0.5.3. +OptimizationBarrierOp = getattr(mgpu, "OptimizationBarrierOp", None) + + +@partial(_add_layout_inference_rule, OptimizationBarrierOp) +def _infer_optimization_barrier_op_layout( + op: OptimizationBarrierOp, +) -> OptionalLayouts: + def is_array(v: ir.Value) -> bool: + return ir.VectorType.isinstance(v.type) + + if inference_utils.has_in_layouts_set(op): + op_in_layouts = list(inference_utils.in_layouts(op)) + return op_in_layouts, op_in_layouts + + if inference_utils.has_out_layouts_set(op): + op_out_layouts = list(inference_utils.out_layouts(op)) + return op_out_layouts, op_out_layouts + + layouts = [None] * len(op.operands) + for i, operand in enumerate(filter(is_array, op.operands)): + layouts[i] = inference_utils.value_layout(operand) + + for i, result in enumerate(filter(is_array, op.results)): + possible_layouts = set() + for op_operand_use in cast(ir.OpResult, result).uses: + consumer = op_operand_use.owner + op_user = consumer.operands[op_operand_use.operand_number] + layout = inference_utils.in_layout_for_operand(consumer, op_user) + if layout is not None: + possible_layouts.add(layout) + if possible_layouts and layouts[i] is None: + # TODO(bchetioui): we could actually just pick any user layout here, + # and optimize later. This is fine for now. + layouts[i] = _choose_representative_layout(possible_layouts) + + # TODO(bchetioui): handle annotating layout for only certain operands. + # Otherwise, layouts may not get propagated through optimization barriers, if + # a single branch does not carry any forcing layout, which is pretty bad. + if any(layout is None for layout in layouts): + return None + + return layouts, layouts + + @partial(_add_layout_inference_rule, arith.ConstantOp) def _infer_constant_op_layout(constant_op: arith.ConstantOp) -> OptionalLayouts: if not ir.VectorType.isinstance(constant_op.result.type): diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td index 108ff952b571..f0a37084b759 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td @@ -457,4 +457,33 @@ def MosaicGPU_WGMMAOp : Op { let hasVerifier = 1; } +def MosaicGPU_OptimizationBarrierOp : Op { + let summary = "Prevents MLIR from moving operations across the barrier."; + + let arguments = (ins + Variadic:$operands + ); + let results = (outs Variadic); + + let extraClassDeclaration = [{ + static llvm::LogicalResult inferReturnTypes( + mlir::MLIRContext *, + std::optional location, + mlir::ValueRange operands, + mlir::DictionaryAttr attributes, + mlir::OpaqueProperties properties, + mlir::RegionRange regions, + llvm::SmallVectorImpl &inferredReturnTypes) { + if (operands.empty()) { + return ::mlir::emitOptionalError( + location, "expected non-empty operands"); + } + ::mlir::TypeRange operand_types = operands.getTypes(); + inferredReturnTypes.assign(operand_types.begin(), operand_types.end()); + return ::mlir::success(); + } + }]; +} + #endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_GPU_MOSAIC_GPU_TD_ diff --git a/tests/mosaic/gpu_layout_inference_test.py b/tests/mosaic/gpu_layout_inference_test.py index 36c8ff9cf47e..893e21efc6d0 100644 --- a/tests/mosaic/gpu_layout_inference_test.py +++ b/tests/mosaic/gpu_layout_inference_test.py @@ -442,6 +442,57 @@ def body(lhs, rhs): self.assertNotIn("in_layouts", f.attributes) self.assertNotIn("out_layouts", f.attributes) + def test_optimization_barrier_op_propagates_user_layouts(self): + add = optimization_barrier = None + + def body(lhs, rhs): + nonlocal add, optimization_barrier + optimization_barrier = mgpu.dialect.OptimizationBarrierOp([lhs, rhs]) + lhs, rhs = optimization_barrier.results + add = arith.AddFOp(lhs, rhs) + + with ir.InsertionPoint(self.module.body): + shape = (32, 4) + ty = ir.VectorType.get(shape, ir.BF16Type.get()) + func.FuncOp.from_py_func(ty, ty)(body) + + splat_layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape)) + add.attributes["out_layouts"] = ir.ArrayAttr.get([splat_layout]) + mgpu.infer_layout(self.module) + + self.assertSequenceEqual( + optimization_barrier.attributes["in_layouts"], + [splat_layout, splat_layout], + ) + self.assertSequenceEqual( + optimization_barrier.attributes["out_layouts"], + [splat_layout, splat_layout], + ) + + def test_optimization_barrier_op_propagates_producer_layouts(self): + add = optimization_barrier = None + + def body(lhs, rhs): + nonlocal add, optimization_barrier + add = arith.AddFOp(lhs, rhs) + optimization_barrier = mgpu.dialect.OptimizationBarrierOp([add]) + + with ir.InsertionPoint(self.module.body): + shape = (32, 4) + ty = ir.VectorType.get(shape, ir.BF16Type.get()) + func.FuncOp.from_py_func(ty, ty)(body) + + splat_layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape)) + add.attributes["out_layouts"] = ir.ArrayAttr.get([splat_layout]) + mgpu.infer_layout(self.module) + + self.assertSequenceEqual( + optimization_barrier.attributes["in_layouts"], [splat_layout] + ) + self.assertSequenceEqual( + optimization_barrier.attributes["out_layouts"], [splat_layout] + ) + if __name__ == "__main__": parameterized.absltest.main(testLoader=jtu.JaxTestLoader()) From 18326abea65df2efee72cf909d9f4ae910df4f76 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Thu, 20 Mar 2025 02:38:54 -0700 Subject: [PATCH 046/483] [mosaic_gpu] Don't time the warmup step in cupti profiler. Initializing and finalizing cupti has an overhead. PiperOrigin-RevId: 738725435 --- jax/experimental/mosaic/gpu/profiler.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/jax/experimental/mosaic/gpu/profiler.py b/jax/experimental/mosaic/gpu/profiler.py index 99fefc1adc9c..32b3edf7caf9 100644 --- a/jax/experimental/mosaic/gpu/profiler.py +++ b/jax/experimental/mosaic/gpu/profiler.py @@ -102,23 +102,21 @@ def _measure_cupti(f, aggregate): if not isinstance(f, (stages.Wrapped, stages.Compiled)): f = jax.jit(f) - def run(*args, **kwargs): + def wrapper(*args, **kwargs): + jax.block_until_ready(f(*args, **kwargs)) # Warmup. mosaic_gpu_lib._mosaic_gpu_ext._cupti_init() try: results = jax.block_until_ready(f(*args, **kwargs)) finally: timings = mosaic_gpu_lib._mosaic_gpu_ext._cupti_get_timings() - return results, timings - def wrapper(*args, **kwargs): - run(*args, **kwargs) # Warmup. - results, timings = run(*args, **kwargs) if not timings: return results, None elif aggregate: return results, sum(item[1] for item in timings) else: return results, timings + return wrapper From e2b6859e7d3e5c0c01be9013d6cb680ab647d9a4 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Thu, 20 Mar 2025 02:51:56 -0700 Subject: [PATCH 047/483] Deprecate the jaxlib.hlo_helpers submodule. jaxlib no longer includes any lowering logic, so we don't need this module anymore. Users would be better served by the APIs in JAX core like `jax.ffi` or `jax.interpreters.mlir`. This module isn't covered by JAX's compatibility policy, so no formal deprecation period is required, but there are enough users that we should keep this warning for at least one full release cycle. PiperOrigin-RevId: 738728721 --- jax/_src/lib/__init__.py | 1 - jaxlib/hlo_helpers.py | 11 +++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/jax/_src/lib/__init__.py b/jax/_src/lib/__init__.py index 7933bb769733..70dc914668cf 100644 --- a/jax/_src/lib/__init__.py +++ b/jax/_src/lib/__init__.py @@ -109,7 +109,6 @@ def _xla_gc_callback(*args): import jaxlib.gpu_sparse as gpu_sparse # pytype: disable=import-error # noqa: F401 import jaxlib.gpu_prng as gpu_prng # pytype: disable=import-error # noqa: F401 import jaxlib.gpu_linalg as gpu_linalg # pytype: disable=import-error # noqa: F401 -import jaxlib.hlo_helpers as hlo_helpers # pytype: disable=import-error # noqa: F401 # Jaxlib code is split between the Jax and the Tensorflow repositories. # Only for the internal usage of the JAX developers, we expose a version diff --git a/jaxlib/hlo_helpers.py b/jaxlib/hlo_helpers.py index 0d57a04f1aa7..11ff844ae53f 100644 --- a/jaxlib/hlo_helpers.py +++ b/jaxlib/hlo_helpers.py @@ -19,11 +19,22 @@ from collections.abc import Callable, Sequence from functools import partial from typing import Union +import warnings import jaxlib.mlir.ir as ir import jaxlib.mlir.dialects.stablehlo as hlo import numpy as np +# TODO(danfm): This module isn't covered by JAX's compatibility policy, so no +# formal deprecation period is required, but there are enough users that we +# should keep this warning for at least one full release cycle. +# Deprecation added 2025-03-19 after the release of v0.5.3. Remove this whole +# module after the release of v0.5.4 or later. +warnings.warn( + "The jaxlib.hlo_helpers submodule is deprecated. Instead, use jax.ffi if " + "possible or, for lower-level operations, jax.interpreters.mlir.", + DeprecationWarning, +) _dtype_to_ir_type_factory : dict[np.dtype, Callable[[], ir.Type]] = { np.dtype(np.bool_): partial(ir.IntegerType.get_signless, 1), From f1298ae7f11464e697ebec66e75046e90ae739e1 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Thu, 20 Mar 2025 03:01:50 -0700 Subject: [PATCH 048/483] Remove XLA FFI GPU callback handler. - In order to migrate the GPU FFI handler from the internal API intended for static linking to the external API intended for dynamic linking, we need to migrate both CPU and GPU FFI handlers at the same time. - Builds break if we include both versions of the FFI APIs. - Now that py_client_gpu sits in jaxlib, tests that run new FFI API in jaxlib against old FFI API in xla (and vice versa) for GPU targets will fail. - This change lets us update the CPU handler first in XLA and then update the GPU handler second in jaxlib. - Because the GPU handler depends on new symbols in xla, we need to land XLA changes first anyway (i.e., no point to deleting both CPU and GPU to try to land jaxlib and xla in one go). PiperOrigin-RevId: 738730955 --- jaxlib/cuda/BUILD | 12 ---- jaxlib/gpu/py_client_gpu.cc | 136 ------------------------------------ jaxlib/gpu/py_client_gpu.h | 3 - jaxlib/rocm/BUILD | 12 ---- 4 files changed, 163 deletions(-) diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index 4e74cc2dcf5b..ee32888864dd 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -668,34 +668,22 @@ cc_library( features = ["-use_header_modules"], deps = [ ":cuda_vendor", - "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", - "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@llvm-project//llvm:Support", "@nanobind", "@xla//xla:comparison_util", - "@xla//xla:shape_util", - "@xla//xla/ffi", - "@xla//xla/ffi:ffi_api", "@xla//xla/pjrt:exceptions", "@xla//xla/pjrt:host_callback", "@xla//xla/pjrt:transpose", "@xla//xla/python:callback", "@xla//xla/python:nb_numpy", - "@xla//xla/python:py_host_callback", - "@xla//xla/python:types", - "@xla//xla/python/ifrt", "@xla//xla/service:custom_call_status", "@xla//xla/service:custom_call_target_registry", "@xla//xla/service:platform_util", - "@xla//xla/tsl/concurrency:ref_count", - "@xla//xla/tsl/platform:errors", - "@xla//xla/tsl/platform:statusor", ], ) diff --git a/jaxlib/gpu/py_client_gpu.cc b/jaxlib/gpu/py_client_gpu.cc index d6faa1859eb8..cf701574959b 100644 --- a/jaxlib/gpu/py_client_gpu.cc +++ b/jaxlib/gpu/py_client_gpu.cc @@ -21,9 +21,7 @@ limitations under the License. #include #include "nanobind/nanobind.h" -#include "absl/algorithm/container.h" #include "absl/base/casts.h" -#include "absl/container/inlined_vector.h" #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -31,24 +29,15 @@ limitations under the License. #include "absl/strings/numbers.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" -#include "llvm/include/llvm/Support/Casting.h" #include "jaxlib/gpu/vendor.h" -#include "xla/ffi/ffi.h" -#include "xla/ffi/ffi_api.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/host_callback.h" #include "xla/pjrt/transpose.h" #include "xla/primitive_util.h" #include "xla/python/callback.h" -#include "xla/python/ifrt/host_callback.h" #include "xla/python/nb_numpy.h" -#include "xla/python/py_host_callback.h" -#include "xla/python/types.h" #include "xla/service/custom_call_status.h" #include "xla/service/custom_call_target_registry.h" -#include "xla/shape_util.h" -#include "xla/tsl/concurrency/ref_count.h" -#include "xla/tsl/platform/statusor.h" namespace nb = nanobind; @@ -166,130 +155,5 @@ void XlaPythonGpuCallback(gpuStream_t stream, void** buffers, XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( "xla_python_gpu_callback", &XlaPythonGpuCallback, absl::AsciiStrToUpper(JAX_GPU_PLUGIN_NAME)); - -absl::Status XlaFfiPythonGpuCallback( - gpuStream_t stream, - std::vector>* callbacks, - uint64_t index, xla::ffi::RemainingArgs args, - xla::ffi::RemainingRets rets) { - auto loaded_callback = llvm::dyn_cast_or_null( - callbacks->at(index).get()); - if (loaded_callback == nullptr) { - return absl::InternalError( - "Expected a PyCpuLoadedHostCallback, got something else."); - } - xla::CpuCallback* callback = loaded_callback->cpu_callback(); - size_t arity = args.size(); - std::vector host_input_buffers(arity); - // Copy input GPU buffers to host - for (size_t i = 0; i < arity; ++i) { - auto arg = args.get(i); - if (arg->element_type() == xla::TOKEN) { - host_input_buffers[i] = nullptr; - continue; - } - void* buf = new char[arg->size_bytes()]; - host_input_buffers[i] = buf; - // TODO(b/238441608): Use pinned memory here to speed up the transfer. - auto gpu_res = - gpuMemcpyAsync(buf, arg.value().untyped_data(), arg->size_bytes(), - gpuMemcpyDeviceToHost, stream); - CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; - } - CHECK_EQ(gpuStreamSynchronize(stream), gpuSuccess) - << "Failed to gpuStreamSynchronize"; - nb::gil_scoped_acquire gil; - nb::tuple host_input_arrays = nb::steal(PyTuple_New(arity)); - for (size_t i = 0; i < arity; ++i) { - auto arg = args.get(i); - xla::PrimitiveType ptype = arg->element_type(); - if (ptype == xla::TOKEN) { - PyTuple_SET_ITEM(host_input_arrays.ptr(), i, nb::none().inc_ref().ptr()); - } else { - nb::capsule base(host_input_buffers[i], [](void* ptr) noexcept { - delete[] static_cast(ptr); - }); - TF_ASSIGN_OR_RETURN(auto dtype, xla::PrimitiveTypeToNbDtype(ptype)); - auto array = xla::nb_numpy_ndarray(dtype, arg->dimensions(), std::nullopt, - host_input_buffers[i], base); - array.attr("flags").attr("writeable") = nb::bool_(false); - PyTuple_SET_ITEM(host_input_arrays.ptr(), i, array.inc_ref().ptr()); - } - } - - xla::EnterHostCallback(); - // TODO(dsuo): Change this to use the Python vectorcall protocol, which allows - // you to avoid constructing a tuple for the arguments. - absl::StatusOr maybe_result_tuple = - callback->FfiCall(host_input_arrays); - xla::LeaveHostCallback(); - TF_ASSIGN_OR_RETURN(auto result_tuple, maybe_result_tuple); - - std::vector temp_buffers; - for (size_t i = 0; i < rets.size(); ++i) { - auto ret = rets.get(i).value(); - auto ptype = ret->element_type(); - if (ptype == xla::TOKEN) continue; - nb::object output = - nb::borrow(PyTuple_GetItem(result_tuple.ptr(), i)); - xla::nb_numpy_ndarray array = - xla::nb_numpy_ndarray::ensure(std::move(output)); - absl::Span strides( - reinterpret_cast(array.strides()), array.ndim()); - // We expect the output to be in default numpy layout. - TF_ASSIGN_OR_RETURN(auto expected_shape, xla::ShapeUtil::MakeValidatedShape( - ptype, ret->dimensions())); - auto expected_strides = ByteStridesForShape(expected_shape); - if (strides == expected_strides) { - auto gpu_res = - gpuMemcpyAsync(ret->untyped_data(), array.data(), ret->size_bytes(), - gpuMemcpyHostToDevice, stream); - CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; - } else { - void* temp = new char[ret->size_bytes()]; - temp_buffers.push_back(temp); - xla::TransposePlan::Options options; - options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); - absl::Span dims( - reinterpret_cast(array.shape()), array.ndim()); - options.dims = dims; - absl::InlinedVector reversed_layout; - reversed_layout.resize(expected_shape.rank()); - absl::c_reverse_copy(expected_shape.layout().minor_to_major(), - reversed_layout.begin()); - options.permutation = reversed_layout; - options.input_layout = xla::TransposePlan::Striding{strides}; - TF_ASSIGN_OR_RETURN(auto plan, - callback->transpose_cache().GetOrCreate(options)); - plan->Execute(array.data(), temp); - auto gpu_res = - gpuMemcpyAsync(ret->untyped_data(), temp, ret->size_bytes(), - gpuMemcpyHostToDevice, stream); - CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; - } - } - nb::gil_scoped_release release; - CHECK_EQ(gpuStreamSynchronize(stream), gpuSuccess) - << "Failed to gpuStreamSynchronize"; - for (int i = 0; i < temp_buffers.size(); ++i) { - delete[] static_cast(temp_buffers[i]); - } - return absl::OkStatus(); -} - -XLA_FFI_DEFINE_HANDLER_SYMBOL( - kXlaFfiPythonGpuCallback, XlaFfiPythonGpuCallback, - xla::ffi::Ffi::Bind() - .Ctx>() - .Ctx>>>() - .Attr("index") - .RemainingArgs() - .RemainingRets()); -XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), - "xla_ffi_python_gpu_callback", - absl::AsciiStrToUpper(JAX_GPU_PLUGIN_NAME), - kXlaFfiPythonGpuCallback); - } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/gpu/py_client_gpu.h b/jaxlib/gpu/py_client_gpu.h index 6be2d40823dc..e9454504f5d9 100644 --- a/jaxlib/gpu/py_client_gpu.h +++ b/jaxlib/gpu/py_client_gpu.h @@ -19,7 +19,6 @@ limitations under the License. #include #include "jaxlib/gpu/vendor.h" -#include "xla/ffi/ffi.h" #include "xla/service/custom_call_status.h" namespace jax { @@ -29,8 +28,6 @@ void XlaPythonGpuCallback(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status); -XLA_FFI_DECLARE_HANDLER_SYMBOL(kXlaFfiPythonGpuCallback); - } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index 1e54d82c4f71..99df757018f3 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -566,34 +566,22 @@ cc_library( features = ["-use_header_modules"], deps = [ ":hip_vendor", - "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", - "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@llvm-project//llvm:Support", "@nanobind", "@xla//xla:comparison_util", - "@xla//xla:shape_util", - "@xla//xla/ffi", - "@xla//xla/ffi:ffi_api", "@xla//xla/pjrt:exceptions", "@xla//xla/pjrt:host_callback", "@xla//xla/pjrt:transpose", "@xla//xla/python:callback", "@xla//xla/python:nb_numpy", - "@xla//xla/python:py_host_callback", - "@xla//xla/python:types", - "@xla//xla/python/ifrt", "@xla//xla/service:custom_call_status", "@xla//xla/service:custom_call_target_registry", "@xla//xla/service:platform_util", - "@xla//xla/tsl/concurrency:ref_count", - "@xla//xla/tsl/platform:errors", - "@xla//xla/tsl/platform:statusor", ], ) From 84cc397b4eab2825b9b3479995fd700a3e17f17f Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Thu, 20 Mar 2025 04:00:30 -0700 Subject: [PATCH 049/483] [XLA:GPU][Triton] Remove sparsity code. It's unused but causes significant burden during Triton integrates. PiperOrigin-RevId: 738744625 --- tests/BUILD | 14 --- tests/sparse_nm_test.py | 209 ---------------------------------------- 2 files changed, 223 deletions(-) delete mode 100644 tests/sparse_nm_test.py diff --git a/tests/BUILD b/tests/BUILD index 0ffa68ed8eb3..b126655b0a06 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1101,20 +1101,6 @@ jax_multiplatform_test( ] + py_deps("scipy"), ) -jax_multiplatform_test( - name = "sparse_nm_test", - srcs = ["sparse_nm_test.py"], - enable_backends = [], - enable_configs = [ - "gpu_a100", - "gpu_h100", - ], - deps = [ - "//jax:experimental_sparse", - "//jax:pallas_gpu", - ], -) - jax_multiplatform_test( name = "sparsify_test", srcs = ["sparsify_test.py"], diff --git a/tests/sparse_nm_test.py b/tests/sparse_nm_test.py deleted file mode 100644 index 9ecf30eb6229..000000000000 --- a/tests/sparse_nm_test.py +++ /dev/null @@ -1,209 +0,0 @@ -# Copyright 2024 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math - -import numpy as np -from absl.testing import absltest -from absl.testing import parameterized - -import jax -import jax.numpy as jnp -from jax import dtypes -from jax._src import config -from jax._src import test_util as jtu -from jax.experimental.sparse import nm - -jax.config.parse_flags_with_absl() - - -class SpmmTest(jtu.JaxTestCase): - def setUp(self): - if not jtu.test_device_matches(["gpu"]): - self.skipTest("Only works on GPU") - if (jtu.test_device_matches(["cuda"]) and - not jtu.is_cuda_compute_capability_at_least("8.0")): - self.skipTest("Only works on GPUs with capability >= sm80") - super().setUp() - - # ----- Test different input shapes - @parameterized.product( - tile_m=(32, 128), - tile_n=(32, 128), - tile_k=(32, 128), - batch=(None, 5), - sparse_idx=(0, 1), - ) - @jtu.run_on_devices("gpu") - def test_shapes(self, tile_m, tile_n, tile_k, batch, sparse_idx): - # Build keyword arguments - kwargs = { - "dimension_numbers": (((1,), (1,)), (tuple(), tuple())), - "sparse_operand_idx": sparse_idx, - } - if batch: - kwargs["dimension_numbers"] = (((2,), (2,)), ((0,), (0,))) - - # Build input data - batch_dims = (batch,) if batch else tuple() - lhs = ( - (np.arange((batch or 1) * tile_m * tile_k) % 11) - .astype(dtypes.bfloat16) - .reshape(batch_dims + (tile_m, tile_k)) - ) - rhs = ( - (np.arange((batch or 1) * tile_n * tile_k) % 13) - .astype(dtypes.bfloat16) - .reshape(batch_dims + (tile_n, tile_k)) - ) - - # Build sparsity mask and metadata - sp = [lhs, rhs][sparse_idx] - mask = np.tile([True, False], math.prod(sp.shape) // 2).reshape(sp.shape) - sparse = sp[mask].reshape(sp.shape[:-1] + (sp.shape[-1] // 2,)) - meta = nm.nm_pack(mask) - - # Calculate sparse and dense dots - if sparse_idx == 0: - dot_sparse = nm.nm_spmm(sparse, rhs, meta, **kwargs) - dot_dense = jnp.einsum("...mk,...nk->...mn", (lhs * mask), rhs) - else: - dot_sparse = nm.nm_spmm(lhs, sparse, meta, **kwargs) - dot_dense = jnp.einsum("...mk,...nk->...mn", lhs, (rhs * mask)) - - # Verify the result - jtu.check_eq(dot_sparse, dot_dense.astype(dtypes.bfloat16)) - - # ----- Test different input types - @parameterized.product( - lhs_type=[jnp.int8, jnp.int16, jnp.float16, jnp.bfloat16], - rhs_type=[jnp.bfloat16], - output_type=[jnp.bfloat16, jnp.float32], - ) - @jtu.run_on_devices("gpu") - def test_types(self, lhs_type, rhs_type, output_type): - tile_m, tile_n, tile_k = 64, 32, 128 - - # Build input data - lhs = ( - (np.arange(tile_m * tile_k) % 17) - .astype(lhs_type) - .reshape((tile_m, tile_k)) - ) - rhs = ( - (np.arange(tile_k * tile_n) % 19) - .astype(rhs_type) - .reshape((tile_k, tile_n)) - ) - - # Build sparsity mask and metadata - mask = np.tile([True, False], tile_m * tile_k // 2).reshape(lhs.shape) - sparse = lhs[mask].reshape(tile_m, tile_k // 2) - meta = nm.nm_pack(mask) - - # Calculate sparse and dense dots - dot_sparse = nm.nm_spmm(sparse, rhs, meta, output_dtype=output_type) - dot_dense = (lhs * mask) @ rhs - - # Verify the result - jtu.check_close(dot_sparse, dot_dense.astype(output_type), rtol=0.01) - - # ----- Test validation - @jtu.run_on_devices("gpu") - def test_validate_nm_pack(self): - with self.assertRaisesRegex(TypeError, "Mask should be bool"): - nm.nm_pack(jnp.zeros(16, jnp.int8)) - with self.assertRaisesRegex( - TypeError, "Inner dimension size should be divisible by 16" - ): - nm.nm_pack(jnp.array([False] * 8)) - - @jtu.run_on_devices("gpu") - def test_validate_nm_spmm(self): - batch, tile_m, tile_n, tile_k = 2, 64, 32, 128 - lhs = jnp.zeros((batch, tile_m, tile_k // 2), dtype=jnp.bfloat16) - rhs = jnp.zeros((batch, tile_k, tile_n), dtype=jnp.bfloat16) - meta = jnp.zeros((batch, tile_m, tile_k // 16), dtype=jnp.uint16) - - if config.enable_x64.value: - with self.assertRaisesRegex(TypeError, "Unsupported lhs input type"): - nm.nm_spmm(jnp.zeros(lhs.shape, dtype=jnp.int64), rhs, meta) - with self.assertRaisesRegex(TypeError, "Unsupported rhs input type"): - nm.nm_spmm(lhs, jnp.zeros(rhs.shape, dtype=jnp.int64), meta) - with self.assertRaisesRegex(TypeError, "Unsupported output type"): - nm.nm_spmm(lhs, rhs, meta, output_dtype=jnp.int64) - - # Check dimension numbers - nm_spmm_with_dnums = lambda c, b: nm.nm_spmm( - lhs, rhs, meta, dimension_numbers=(c, b) - ) - with self.assertRaisesRegex( - TypeError, "Only single contracting dimension is supported" - ): - nm_spmm_with_dnums(((0, 2), (0, 1)), (tuple(), tuple())) - with self.assertRaisesRegex( - TypeError, "Incorrect dimension numbers for lhs" - ): - nm_spmm_with_dnums(((2,), (1,)), ((2,), (0,))) - with self.assertRaisesRegex( - TypeError, "Incorrect dimension numbers for rhs" - ): - nm_spmm_with_dnums(((2,), (1,)), ((0,), (1,))) - with self.assertRaisesRegex( - TypeError, "Only single non-contracting dimension is supported" - ): - nm_spmm_with_dnums(((2,), (1,)), (tuple(), tuple())) - with self.assertRaisesRegex( - TypeError, "Batch dimension sizes do not match" - ): - nm.nm_spmm( - lhs, - rhs.reshape(1, tile_k, tile_n * batch), - meta, - dimension_numbers=(((2,), (1,)), ((0,), (0,))), - ) - - # Check metadata - nm_spmm_with_meta = lambda m: nm.nm_spmm( - lhs, rhs, m, dimension_numbers=(((2,), (1,)), ((0,), (0,))) - ) - with self.assertRaisesRegex(TypeError, "Metadata must be uint16"): - nm_spmm_with_meta(jnp.zeros(meta.shape, dtype=jnp.uint8)) - with self.assertRaisesRegex( - TypeError, "Metadata shape must match the operand shape" - ): - nm_spmm_with_meta(meta.reshape(1, batch * tile_m, tile_k // 16)) - with self.assertRaisesRegex( - TypeError, - "Metadata must be exactly 8 times less than the contracting dimension" - " for 2:4 structured sparsity", - ): - nm_spmm_with_meta(jnp.repeat(meta, 2, axis=-1)) - with self.assertRaisesRegex( - TypeError, "Contracting dimension must be the minor one" - ): - nm.nm_spmm(lhs, rhs, meta, dimension_numbers=(((1,), (1,)), ((0,), (0,)))) - with self.assertRaisesRegex( - TypeError, "Contracting dimension sizes should have 2:4 ratio" - ): - nm.nm_spmm( - lhs, - jnp.repeat(rhs, 2, axis=1), - meta, - dimension_numbers=(((2,), (1,)), ((0,), (0,))), - ) - - -if __name__ == "__main__": - absltest.main(testLoader=jtu.JaxTestLoader()) From 1c8e60e6c299bb6cc39f5d9a0d68df327c79da10 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 20 Mar 2025 04:30:49 -0700 Subject: [PATCH 050/483] Update XLA dependency to use revision http://github.com/openxla/xla/commit/d3145a119840723c16fd27ee342729d68fddb7ef. PiperOrigin-RevId: 738751933 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index f81e3931b1dc..08c5af0c32b8 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "0d20d73f2c8f21c21b9f343c4363a76e980f032e" -XLA_SHA256 = "9df61c200b0a54b7a5c55155fa7a454e33d660e6a49239b6980f5a10305fecc5" +XLA_COMMIT = "d3145a119840723c16fd27ee342729d68fddb7ef" +XLA_SHA256 = "daf2a72e36a9358803a8156c48b32117c9699fd327fcbc37b465f1a0045bccae" def repo(): tf_http_archive( From 4d6f15f20c588fffd87ad1d610d92b636b194c5d Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 20 Mar 2025 05:17:57 -0700 Subject: [PATCH 051/483] [Mosaic GPU] Add support for slicing tiled refs with (tile aligned) dynamic base offsets PiperOrigin-RevId: 738762062 --- jax/_src/pallas/mosaic_gpu/BUILD | 1 + jax/_src/pallas/mosaic_gpu/core.py | 84 +++++++++++++++++++++++------- tests/pallas/mosaic_gpu_test.py | 28 ++++++++++ 3 files changed, 93 insertions(+), 20 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/BUILD b/jax/_src/pallas/mosaic_gpu/BUILD index e5b491aef330..ab35eebafc04 100644 --- a/jax/_src/pallas/mosaic_gpu/BUILD +++ b/jax/_src/pallas/mosaic_gpu/BUILD @@ -80,6 +80,7 @@ pytype_strict_library( "//jax:mosaic_gpu", "//jax:state_types", "//jax:tree_util", + "//jax/_src/lib", "//jax/_src/pallas", "//jaxlib/mlir:ir", ] + py_deps("numpy"), diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 5e4566ddfc9c..b1e0a683f64d 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -29,10 +29,11 @@ from jax._src import dtypes from jax._src import effects from jax._src import tree_util +from jax._src.lib.mlir.dialects import arith as arith_dialect from jax._src.pallas import core as pallas_core +from jax._src.state import discharge as state_discharge from jax._src.state import indexing from jax._src.state import types as state_types -from jax._src.state import discharge as state_discharge import jax.experimental.mosaic.gpu as mgpu import jax.numpy as jnp from jaxlib.mlir import ir @@ -135,6 +136,24 @@ def cmap_body(): return wrapper +def _is_known_divisible(value, divisor, fuel=10) -> bool: + """Returns True if the value is statically known to be divisible by the divisor.""" + if fuel < 0: + return False + if not isinstance(value.owner, ir.Operation): + return False + def_op = value.owner.opview + match def_op: + case arith_dialect.IndexCastOp(): + return _is_known_divisible(value.owner.operands[0], divisor, fuel - 1) + case arith_dialect.ConstantOp(): + return ir.IntegerAttr(def_op.value).value % divisor == 0 + case arith_dialect.MulIOp(): + return (_is_known_divisible(value.owner.operands[0], divisor, fuel // 2) or + _is_known_divisible(value.owner.operands[1], divisor, (fuel + 1)// 2)) + return False + + @dataclasses.dataclass(frozen=True) class GPUMemoryRef(pallas_core.MemoryRef): transforms: Sequence[MemoryRefTransform] = () @@ -171,7 +190,7 @@ def __call__(self, aval: jax_core.ShapedArray) -> jax_core.ShapedArray: shape=self.to_gpu_transform().transform_shape(aval.shape) ) -Index = slice | int | ir.Value +Index = mgpu.DynamicSlice | slice | int | ir.Value @dataclasses.dataclass(frozen=True) class TilingTransform(MemoryRefTransform): @@ -218,16 +237,37 @@ def untransform_index( ) -> tuple[tuple[Index, ...], state_types.Transform]: untiled_idxs = idxs[: -len(self.tiling)] tiled_idxs = idxs[-len(self.tiling) :] - idxs_after_tiling = [] + idxs_after_tiling: list[Index] = [] for idx, tile in zip(tiled_idxs, self.tiling): - if not isinstance(idx, slice): - raise NotImplementedError("Non-slice indices are not supported") - assert isinstance(idx, slice) - if idx.step is not None and idx.step != 1: - raise NotImplementedError("Strided slices unsupported") - if (idx.start is not None and idx.start % tile) or (idx.stop is not None and idx.stop % tile): - raise ValueError("Non-empty slices must be tile aligned") - idxs_after_tiling.append(slice(idx.start // tile, idx.stop // tile)) + if isinstance(idx, slice): + if idx.step is not None and idx.step != 1: + raise NotImplementedError("Strided slices unsupported") + if (idx.start is not None and idx.start % tile) or (idx.stop is not None and idx.stop % tile): + raise ValueError("Non-empty slices must be tile aligned") + idxs_after_tiling.append(slice(idx.start // tile, idx.stop // tile)) + elif isinstance(idx, mgpu.DynamicSlice): + if idx.length % tile: + raise ValueError( + f"Dynamic slice length ({idx.length}) is not divisible by the" + f" tiling ({tile})" + ) + if isinstance(idx.base, ir.Value): + if not _is_known_divisible(idx.base, tile): + raise ValueError( + "Dynamic slice base index (which is a dynamic value) cannot be" + f" statically proven to be divisible by the tiling ({tile})" + ) + new_base = arith_dialect.divui(idx.base, mgpu.c(tile, idx.base.type)) + else: + if idx.base % tile: + raise ValueError( + f"Dynamic slice base ({idx.base}) is not divisible by the" + f" tiling ({tile})" + ) + new_base = idx.base // tile + idxs_after_tiling.append(mgpu.DynamicSlice(new_base, idx.length // tile)) + else: + raise TypeError(f"Unsupported index type: {type(idx)}") return (*untiled_idxs, *idxs_after_tiling, *(slice(None) for _ in self.tiling)), self def undo_to_gpu_transform(self) -> mgpu.MemRefTransform: @@ -285,7 +325,7 @@ def untransform_index( self, idxs: tuple[Index, ...] ) -> tuple[tuple[Index, ...], state_types.Transform]: removed_dims = [ - i for i, idx in enumerate(idxs) if not isinstance(idx, slice) + i for i, idx in enumerate(idxs) if not isinstance(idx, (slice, mgpu.ds)) ] new_perm = tuple( p - sum(d < p for d in removed_dims) @@ -358,18 +398,22 @@ def untransform_index( ) -> tuple[tuple[Index, ...], state_types.Transform]: if not idxs: return idxs, self - if not all(isinstance(idx, slice) for idx in idxs[-2:]): + if not all(isinstance(idx, (slice, mgpu.ds)) for idx in idxs[-2:]): raise NotImplementedError( "Non-slice indices are not supported in 2 minormost dims" ) last_idx = idxs[-1] - assert isinstance(last_idx, slice) - if last_idx.step is not None and last_idx.step != 1: - raise NotImplementedError("Swizzled dims cannot be sliced") - if (last_idx.start is not None and last_idx.start != 0) or ( - last_idx.stop is not None and last_idx.stop != self.swizzle - ): - raise ValueError("Swizzled dims cannot be sliced") + if isinstance(last_idx, mgpu.DynamicSlice): + if last_idx.base != 0 or last_idx.length != self.swizzle: + raise ValueError("Swizzled dims cannot be sliced") + else: + assert isinstance(last_idx, slice) + if ( + (last_idx.step is not None and last_idx.step != 1) + or (last_idx.start is not None and last_idx.start != 0) + or (last_idx.stop is not None and last_idx.stop != self.swizzle) + ): + raise ValueError("Swizzled dims cannot be sliced") return idxs, self diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 38335925b44d..40e98bf05ba9 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1132,6 +1132,34 @@ def kernel(x_ref, o_ref): x = jnp.arange(256, dtype=jnp.int32) np.testing.assert_array_equal(kernel(x), jnp.broadcast_to(jnp.sum(x) * 3, [256])) + # Not testing with warpgroup semantics, because we want to enforce a layout. + def test_tile_slicing(self): + shape = (256, 128) + block_spec = plgpu.GPUBlockSpec( + transforms=( + plgpu.TilingTransform((64, 64)), + plgpu.SwizzleTransform(128), + ) + ) + @functools.partial( + pl.pallas_call, + in_specs=[block_spec], + out_specs=block_spec, + out_shape=jax.ShapeDtypeStruct((64, 64), jnp.uint16), + ) + def kernel(x_ref, o_ref): + def sum_tiles(row, acc): + row_slice = pl.ds(row * 64, 64) + for col in range(128 // 64): + acc += x_ref[row_slice, pl.ds(col * 64, 64)] + return acc + acc = plgpu.layout_cast(jnp.zeros((64, 64), jnp.uint16), plgpu.Layout.WGMMA) + o_ref[...] = _fori_loop(False, 0, 256 // 64, sum_tiles, acc) + + x = jnp.arange(math.prod(shape), dtype=jnp.uint16).reshape(shape) + y = x.reshape(256 // 64, 64, 128 // 64, 64).sum(axis=(0, 2), dtype=jnp.uint16) + np.testing.assert_array_equal(kernel(x), y) + def test_input_output_aliases(self): # Note that we're writing to the input pointer, which should alias b_ptr. def kernel(a_ref, b_ref): From 2c90fe2dea5a0ec5941d21973e33f4334d43ed0b Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 20 Mar 2025 06:12:22 -0700 Subject: [PATCH 052/483] Reorder C++ imports. PiperOrigin-RevId: 738774175 --- examples/jax_cpp/main.cc | 2 +- jaxlib/gpu/vendor.h | 2 +- jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc | 2 +- .../mlir/_mlir_libs/register_jax_dialects.cc | 3 +- jaxlib/mlir/_mlir_libs/triton_ext.cc | 2 +- jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc | 15 ++-- jaxlib/mosaic/dialect/gpu/mosaic_gpu.h | 4 +- jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc | 38 ++++---- jaxlib/mosaic/dialect/tpu/array_util.cc | 4 +- jaxlib/mosaic/dialect/tpu/array_util.h | 2 +- jaxlib/mosaic/dialect/tpu/array_util_test.cc | 2 +- .../dialect/tpu/integrations/c/tpu_dialect.cc | 6 +- jaxlib/mosaic/dialect/tpu/layout.cc | 2 +- jaxlib/mosaic/dialect/tpu/layout.h | 2 +- jaxlib/mosaic/dialect/tpu/tpu_dialect.cc | 8 +- jaxlib/mosaic/dialect/tpu/tpu_dialect.h | 7 +- jaxlib/mosaic/dialect/tpu/tpu_ops.cc | 18 ++-- .../tpu/transforms/apply_vector_layout.cc | 25 +++--- .../apply_vector_layout_extensions.h | 6 +- .../tpu/transforms/canonicalize_mosaic.cc | 33 ++++--- .../apply_vector_layout_extensions.cc | 4 +- .../infer_vector_layout_extensions.cc | 6 +- .../tpu/transforms/infer_memref_layout.cc | 2 +- .../tpu/transforms/infer_vector_layout.cc | 19 ++-- .../infer_vector_layout_extensions.h | 4 +- .../tpu/transforms/linalg_vectorization.cc | 52 +++++------ .../transforms/memory_space_specialization.cc | 10 +-- .../tpu/transforms/relayout_insertion.cc | 13 ++- jaxlib/mosaic/dialect/tpu/transforms/serde.cc | 10 +-- jaxlib/mosaic/dialect/tpu/transforms/serde.h | 10 +-- jaxlib/mosaic/dialect/tpu/util.cc | 16 ++-- jaxlib/mosaic/dialect/tpu/util.h | 10 +-- jaxlib/mosaic/dialect/tpu/vreg_util.cc | 20 ++--- jaxlib/mosaic/dialect/tpu/vreg_util.h | 12 +-- jaxlib/mosaic/dialect/tpu/vreg_util_test.cc | 28 +++--- jaxlib/mosaic/gpu/custom_call.cc | 86 +++++++++---------- jaxlib/mosaic/gpu/launch_lowering.cc | 46 +++++----- jaxlib/mosaic/gpu/passes.cc | 24 +++--- jaxlib/mosaic/gpu/serde.cc | 8 +- jaxlib/mosaic/gpu/serde.h | 12 +-- jaxlib/mosaic/gpu/target.cc | 4 +- jaxlib/mosaic/pass_boilerplate.h | 8 +- jaxlib/mosaic/serde.cc | 18 ++-- jaxlib/mosaic/serde.h | 10 +-- jaxlib/triton/triton_dialect_capi.cc | 12 +-- jaxlib/triton/triton_dialect_capi.h | 4 +- 46 files changed, 307 insertions(+), 324 deletions(-) diff --git a/examples/jax_cpp/main.cc b/examples/jax_cpp/main.cc index 0a1d3a63acfd..5d1190ff1f2c 100644 --- a/examples/jax_cpp/main.cc +++ b/examples/jax_cpp/main.cc @@ -41,7 +41,7 @@ limitations under the License. #include #include -#include "third_party/absl/status/statusor.h" +#include "absl/status/statusor.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/literal.h" diff --git a/jaxlib/gpu/vendor.h b/jaxlib/gpu/vendor.h index cadd5453107a..58a02e7c568c 100644 --- a/jaxlib/gpu/vendor.h +++ b/jaxlib/gpu/vendor.h @@ -29,7 +29,7 @@ limitations under the License. #include "third_party/gpus/cuda/include/cublas_v2.h" #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/cuda_fp8.h" -#include "third_party/gpus/cuda/include/cuda_runtime_api.h" +#include "cuda_runtime_api.h" #include "third_party/gpus/cuda/include/cufft.h" #include "third_party/gpus/cuda/include/cusolverDn.h" #include "third_party/gpus/cuda/include/cusolver_common.h" diff --git a/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc b/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc index c73084abc99d..7483d7ed1eea 100644 --- a/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc +++ b/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc @@ -16,9 +16,9 @@ limitations under the License. #include #include +#include "nanobind/nanobind.h" #include "mlir-c/IR.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" // IWYU pragma: keep -#include "nanobind/nanobind.h" #include "jaxlib/mosaic/dialect/gpu/integrations/c/attributes.h" #include "jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.h" diff --git a/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc b/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc index 9da841acc7de..64f84965b8e2 100644 --- a/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc +++ b/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc @@ -2,6 +2,7 @@ // This module is called by mlir/__init__.py during initialization. #include +#include "shardy/integrations/c/passes.h" #include "mlir-c/Dialect/Arith.h" #include "mlir-c/Dialect/Func.h" #include "mlir-c/Dialect/GPU.h" @@ -14,10 +15,8 @@ #include "mlir-c/Dialect/Vector.h" #include "mlir-c/Transforms.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" -#include "shardy/integrations/c/passes.h" #include "jaxlib/mosaic/gpu/integrations/c/passes.h" - namespace nb = nanobind; #define REGISTER_DIALECT(name) \ diff --git a/jaxlib/mlir/_mlir_libs/triton_ext.cc b/jaxlib/mlir/_mlir_libs/triton_ext.cc index 2a13c40d963f..e824d4058d7e 100644 --- a/jaxlib/mlir/_mlir_libs/triton_ext.cc +++ b/jaxlib/mlir/_mlir_libs/triton_ext.cc @@ -15,9 +15,9 @@ limitations under the License. #include +#include "nanobind/nanobind.h" #include "mlir-c/IR.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" -#include "nanobind/nanobind.h" #include "jaxlib/triton/triton_dialect_capi.h" namespace nb = nanobind; diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc index a1e7b571d20e..2358a97ba20d 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc @@ -18,6 +18,11 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep #include "llvm/Support/Casting.h" @@ -32,7 +37,9 @@ limitations under the License. #include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" // IWYU pragma: keep #include "mlir/IR/ImplicitLocOpBuilder.h" @@ -43,14 +50,6 @@ limitations under the License. #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" #include "mlir/Support/LLVM.h" -#include "absl/algorithm/container.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/Diagnostics.h" #include "tsl/platform/statusor.h" // Generated definitions. diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.h b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.h index b4f13c50bd8c..47b286aec302 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.h +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.h @@ -19,6 +19,8 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "absl/strings/string_view.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" @@ -28,8 +30,6 @@ limitations under the License. #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Support/LLVM.h" -#include "absl/status/status.h" -#include "absl/strings/string_view.h" // Generated definitions. #include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_dialect.h.inc" // IWYU pragma: keep diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc b/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc index 527aa7c7ce25..c259da3e737c 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc @@ -25,25 +25,25 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" -#include "llvm/include/llvm/ADT/ArrayRef.h" -#include "llvm/include/llvm/ADT/SmallVector.h" -#include "mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h" -#include "mlir/include/mlir/Conversion/LLVMCommon/StructBuilder.h" -#include "mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" -#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/include/mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/include/mlir/Dialect/SCF/Utils/Utils.h" -#include "mlir/include/mlir/IR/Builders.h" -#include "mlir/include/mlir/IR/BuiltinOps.h" -#include "mlir/include/mlir/IR/Diagnostics.h" -#include "mlir/include/mlir/IR/MLIRContext.h" -#include "mlir/include/mlir/IR/OwningOpRef.h" -#include "mlir/include/mlir/IR/Types.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/IR/Verifier.h" -#include "mlir/include/mlir/Interfaces/DataLayoutInterfaces.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" +#include "mlir/Conversion/LLVMCommon/StructBuilder.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" +#include "mlir/Support/LLVM.h" #include "tsl/platform/errors.h" namespace mosaic_gpu { diff --git a/jaxlib/mosaic/dialect/tpu/array_util.cc b/jaxlib/mosaic/dialect/tpu/array_util.cc index 4c1e79667c0f..f7d559fb08bc 100644 --- a/jaxlib/mosaic/dialect/tpu/array_util.cc +++ b/jaxlib/mosaic/dialect/tpu/array_util.cc @@ -19,8 +19,8 @@ limitations under the License. #include "absl/log/check.h" #include "absl/types/span.h" -#include "llvm/include/llvm/ADT/STLExtras.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "llvm/ADT/STLExtras.h" +#include "mlir/Support/LLVM.h" namespace mlir::tpu::internal { diff --git a/jaxlib/mosaic/dialect/tpu/array_util.h b/jaxlib/mosaic/dialect/tpu/array_util.h index 1b755dbf8495..ab8e98d17836 100644 --- a/jaxlib/mosaic/dialect/tpu/array_util.h +++ b/jaxlib/mosaic/dialect/tpu/array_util.h @@ -20,7 +20,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/types/span.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "mlir/Support/LLVM.h" #include "jaxlib/mosaic/dialect/tpu/util.h" #include "xla/array.h" diff --git a/jaxlib/mosaic/dialect/tpu/array_util_test.cc b/jaxlib/mosaic/dialect/tpu/array_util_test.cc index 18c2f94fa8b6..bcbf417a967b 100644 --- a/jaxlib/mosaic/dialect/tpu/array_util_test.cc +++ b/jaxlib/mosaic/dialect/tpu/array_util_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include -#include "mlir/include/mlir/Support/LLVM.h" +#include "mlir/Support/LLVM.h" #include "xla/array.h" namespace mlir::tpu { diff --git a/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc b/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc index 772e87beff71..ce7e90d45fb9 100644 --- a/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc +++ b/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc @@ -22,6 +22,8 @@ limitations under the License. #include #include +#include "absl/log/check.h" +#include "absl/log/log.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MemAlloc.h" #include "llvm/Support/raw_ostream.h" @@ -39,8 +41,6 @@ limitations under the License. #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "absl/log/check.h" -#include "absl/log/log.h" #include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h" @@ -410,7 +410,7 @@ MLIR_CAPI_EXPORTED void mlirTpuRegisterMosaicSerdePass() { mlir::tpu::registerMosaicSerdePass(); } -#include "mlir/CAPI/Pass.h" // IWYU pragma: keep +#include "mlir/CAPI/Pass.h" // IWYU pragma: keep #include "mlir/CAPI/Support.h" // IWYU pragma: keep extern "C" { diff --git a/jaxlib/mosaic/dialect/tpu/layout.cc b/jaxlib/mosaic/dialect/tpu/layout.cc index 172f2e91b41f..c54c99fc9825 100644 --- a/jaxlib/mosaic/dialect/tpu/layout.cc +++ b/jaxlib/mosaic/dialect/tpu/layout.cc @@ -27,6 +27,7 @@ limitations under the License. #include #include +#include "absl/log/check.h" #include "llvm/ADT/Hashing.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/MathExtras.h" @@ -41,7 +42,6 @@ limitations under the License. #include "mlir/IR/ValueRange.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "absl/log/check.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/dialect/tpu/util.h" diff --git a/jaxlib/mosaic/dialect/tpu/layout.h b/jaxlib/mosaic/dialect/tpu/layout.h index 2c45be62fa7d..bcfe205d58a9 100644 --- a/jaxlib/mosaic/dialect/tpu/layout.h +++ b/jaxlib/mosaic/dialect/tpu/layout.h @@ -24,6 +24,7 @@ limitations under the License. #include #include +#include "absl/log/check.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/bit.h" #include "llvm/Support/ErrorHandling.h" @@ -38,7 +39,6 @@ limitations under the License. #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "absl/log/check.h" namespace mlir::tpu { diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc index 59ca5d7a3437..73c119b70e1a 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc @@ -23,7 +23,11 @@ limitations under the License. #include #include +#include "absl/hash/hash.h" +#include "absl/log/log.h" #include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep. +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" @@ -32,10 +36,6 @@ limitations under the License. #include "mlir/IR/DialectImplementation.h" // IWYU pragma: keep. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "absl/hash/hash.h" -#include "absl/log/log.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.cc.inc" #include "jaxlib/mosaic/dialect/tpu/tpu_enums.cc.inc" #include "xla/layout.h" diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h index 0800a9e75087..cf74689dd3e6 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h @@ -24,11 +24,10 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" -#include "mlir/include/mlir/IR/BuiltinOps.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/Support/LogicalResult.h" +#include "mlir/Support/LogicalResult.h" #include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_enums.h.inc" #include "jaxlib/mosaic/dialect/tpu/transforms/serde.h" diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index c73accb09b26..b69a6ae06a7f 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -19,24 +19,24 @@ limitations under the License. #include #include +#include "absl/log/check.h" +#include "absl/strings/str_format.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/FormatVariadic.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "absl/log/check.h" -#include "absl/strings/str_format.h" -#include "mlir/include/mlir/Dialect/Math/IR/Math.h" -#include "mlir/include/mlir/IR/Builders.h" -#include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/Diagnostics.h" -#include "mlir/include/mlir/IR/IRMapping.h" -#include "mlir/include/mlir/IR/OperationSupport.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/dialect/tpu/util.h" diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 1997ffe34535..7755738a4fc7 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -13,15 +13,23 @@ #include #include +#include "absl/algorithm/container.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVectorExtras.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Compiler.h" +#include "llvm/Support/LogicalResult.h" #include "llvm/Support/MathExtras.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Traits.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" @@ -33,9 +41,11 @@ #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" #include "mlir/IR/Region.h" #include "mlir/IR/TypeRange.h" #include "mlir/IR/Types.h" @@ -45,21 +55,6 @@ #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "absl/algorithm/container.h" -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/status/status.h" -#include "absl/types/span.h" -#include "llvm/include/llvm/ADT/APInt.h" -#include "llvm/include/llvm/Support/LogicalResult.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/include/mlir/Dialect/Math/IR/Math.h" -#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/Builders.h" -#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" -#include "mlir/include/mlir/IR/OperationSupport.h" #include "jaxlib/mosaic/dialect/tpu/array_util.h" #include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h index 33c9e7421004..fded0d1dbfd7 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h @@ -3,9 +3,9 @@ #include -#include "llvm/include/llvm/ADT/StringMap.h" -#include "mlir/include/mlir/IR/Operation.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "llvm/ADT/StringMap.h" +#include "mlir/IR/Operation.h" +#include "mlir/Support/LLVM.h" #include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" diff --git a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc index 5efbdb9cb437..6f56489ab4b1 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc @@ -13,31 +13,30 @@ // NOLINTNEXTLINE(misc-include-cleaner) #include "mlir/Dialect/MemRef/IR/MemRef.h" // NOLINTNEXTLINE(misc-include-cleaner) +#include "absl/log/check.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringSet.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Region.h" +#include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "absl/log/check.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/include/mlir/Dialect/Math/IR/Math.h" -#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h" -#include "mlir/include/mlir/IR/AffineExpr.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/Block.h" -#include "mlir/include/mlir/IR/Builders.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" -#include "mlir/include/mlir/IR/OpDefinition.h" -#include "mlir/include/mlir/IR/Operation.h" -#include "mlir/include/mlir/IR/PatternMatch.h" -#include "mlir/include/mlir/IR/Region.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/Support/LLVM.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/dialect/tpu/vreg_util.h" diff --git a/jaxlib/mosaic/dialect/tpu/transforms/extensions/apply_vector_layout_extensions.cc b/jaxlib/mosaic/dialect/tpu/transforms/extensions/apply_vector_layout_extensions.cc index e7528533938f..067f8e592e30 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/extensions/apply_vector_layout_extensions.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/extensions/apply_vector_layout_extensions.cc @@ -1,7 +1,7 @@ #include "jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h" -#include "llvm/include/llvm/ADT/StringMap.h" -#include "mlir/include/mlir/IR/Operation.h" +#include "llvm/ADT/StringMap.h" +#include "mlir/IR/Operation.h" namespace mlir::tpu::extensions { diff --git a/jaxlib/mosaic/dialect/tpu/transforms/extensions/infer_vector_layout_extensions.cc b/jaxlib/mosaic/dialect/tpu/transforms/extensions/infer_vector_layout_extensions.cc index c9c4a97e6222..9dbf89724fef 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/extensions/infer_vector_layout_extensions.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/extensions/infer_vector_layout_extensions.cc @@ -3,9 +3,9 @@ #include #include -#include "mlir/include/mlir/IR/Operation.h" -#include "mlir/include/mlir/Support/LLVM.h" -#include "mlir/include/mlir/Support/LogicalResult.h" +#include "mlir/IR/Operation.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" namespace mlir::tpu::extensions { diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc index 0926f8a3c7b5..cdf48632784b 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc @@ -6,6 +6,7 @@ #include #include +#include "absl/log/check.h" #include "llvm/ADT/bit.h" #include "llvm/Support/MathExtras.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -23,7 +24,6 @@ #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "absl/log/check.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/dialect/tpu/util.h" #include "xla/layout.h" diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index 0081feba985b..00e53314e588 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -23,6 +23,9 @@ limitations under the License. #include #include +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/types/span.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVectorExtras.h" #include "llvm/Support/raw_ostream.h" @@ -32,22 +35,16 @@ limitations under the License. #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/IR/Value.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/types/span.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/include/mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" -#include "mlir/include/mlir/IR/OpDefinition.h" -#include "mlir/include/mlir/IR/Visitors.h" -#include "mlir/include/mlir/Pass/Pass.h" #include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h" diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h index d240f27fd42d..36fa2ce8113f 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h @@ -4,8 +4,8 @@ #include #include -#include "mlir/include/mlir/IR/Operation.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "mlir/IR/Operation.h" +#include "mlir/Support/LLVM.h" namespace mlir::tpu::extensions { diff --git a/jaxlib/mosaic/dialect/tpu/transforms/linalg_vectorization.cc b/jaxlib/mosaic/dialect/tpu/transforms/linalg_vectorization.cc index 949a26a4f593..0d310ff45b30 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/linalg_vectorization.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/linalg_vectorization.cc @@ -19,32 +19,32 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h" -#include "mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/include/mlir/Dialect/Math/IR/Math.h" -#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/include/mlir/Dialect/Utils/StaticValueUtils.h" -#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h" -#include "mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" -#include "mlir/include/mlir/IR/AffineMap.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/DialectRegistry.h" -#include "mlir/include/mlir/IR/Matchers.h" -#include "mlir/include/mlir/IR/Operation.h" -#include "mlir/include/mlir/IR/OperationSupport.h" -#include "mlir/include/mlir/IR/PatternMatch.h" -#include "mlir/include/mlir/IR/Types.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/Pass/Pass.h" -#include "mlir/include/mlir/Support/LLVM.h" -#include "mlir/include/mlir/Support/LogicalResult.h" -#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/Transforms/Hoisting.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" +#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" namespace mlir::tpu { diff --git a/jaxlib/mosaic/dialect/tpu/transforms/memory_space_specialization.cc b/jaxlib/mosaic/dialect/tpu/transforms/memory_space_specialization.cc index b73ea0f1250f..f78df135a45a 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/memory_space_specialization.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/memory_space_specialization.cc @@ -14,11 +14,11 @@ limitations under the License. ==============================================================================*/ #include "absl/log/check.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/Support/LLVM.h" -#include "mlir/include/mlir/Support/LogicalResult.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" namespace mlir { diff --git a/jaxlib/mosaic/dialect/tpu/transforms/relayout_insertion.cc b/jaxlib/mosaic/dialect/tpu/transforms/relayout_insertion.cc index b88504e35068..8aae7a10279a 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/relayout_insertion.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/relayout_insertion.cc @@ -3,20 +3,19 @@ #include #include +#include "absl/log/check.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/MathExtras.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" #include "mlir/IR/Visitors.h" #include "mlir/Pass/Pass.h" -#include "absl/log/check.h" -#include "llvm/include/llvm/ADT/STLExtras.h" -#include "llvm/include/llvm/Support/MathExtras.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/include/mlir/IR/Builders.h" -#include "mlir/include/mlir/IR/Diagnostics.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "mlir/Support/LLVM.h" #include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/dialect/tpu/util.h" diff --git a/jaxlib/mosaic/dialect/tpu/transforms/serde.cc b/jaxlib/mosaic/dialect/tpu/transforms/serde.cc index 0981c263d252..5f6c9bd712ff 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/serde.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/serde.cc @@ -18,19 +18,17 @@ limitations under the License. #include #include +#include "llvm/ADT/StringMap.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/Value.h" #include "mlir/IR/Visitors.h" #include "mlir/Support/LLVM.h" -#include "llvm/include/llvm/ADT/StringMap.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/OpDefinition.h" -#include "mlir/include/mlir/IR/OperationSupport.h" -#include "mlir/include/mlir/Support/LogicalResult.h" +#include "mlir/Support/LogicalResult.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/serde.h" diff --git a/jaxlib/mosaic/dialect/tpu/transforms/serde.h b/jaxlib/mosaic/dialect/tpu/transforms/serde.h index 64753a22e7be..ccb32131e519 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/serde.h +++ b/jaxlib/mosaic/dialect/tpu/transforms/serde.h @@ -4,11 +4,11 @@ #include #include -#include "llvm/include/llvm/ADT/StringRef.h" -#include "llvm/include/llvm/Support/CommandLine.h" -#include "mlir/include/mlir/Interfaces/DataLayoutInterfaces.h" -#include "mlir/include/mlir/Pass/Pass.h" -#include "mlir/include/mlir/Pass/PassRegistry.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/CommandLine.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" #include "jaxlib/mosaic/pass_boilerplate.h" namespace mlir::tpu { diff --git a/jaxlib/mosaic/dialect/tpu/util.cc b/jaxlib/mosaic/dialect/tpu/util.cc index 44cc301d6f0d..e61d9fa8d417 100644 --- a/jaxlib/mosaic/dialect/tpu/util.cc +++ b/jaxlib/mosaic/dialect/tpu/util.cc @@ -22,16 +22,16 @@ limitations under the License. #include #include -#include "llvm/Support/MathExtras.h" #include "absl/log/check.h" #include "absl/types/span.h" -#include "llvm/include/llvm/Support/raw_ostream.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/IR/ValueRange.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Support/LLVM.h" #include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" diff --git a/jaxlib/mosaic/dialect/tpu/util.h b/jaxlib/mosaic/dialect/tpu/util.h index f9ab1b7e349d..dadd71800f3e 100644 --- a/jaxlib/mosaic/dialect/tpu/util.h +++ b/jaxlib/mosaic/dialect/tpu/util.h @@ -10,19 +10,17 @@ #include #include +#include "absl/status/status.h" +#include "absl/types/span.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Location.h" +#include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "absl/status/status.h" -#include "absl/types/span.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/Diagnostics.h" -#include "mlir/include/mlir/IR/Value.h" #include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "tsl/platform/statusor.h" diff --git a/jaxlib/mosaic/dialect/tpu/vreg_util.cc b/jaxlib/mosaic/dialect/tpu/vreg_util.cc index 1f59ee13a311..72e0bf7f0caf 100644 --- a/jaxlib/mosaic/dialect/tpu/vreg_util.cc +++ b/jaxlib/mosaic/dialect/tpu/vreg_util.cc @@ -19,16 +19,16 @@ limitations under the License. #include #include "absl/log/check.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/Diagnostics.h" -#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" -#include "mlir/include/mlir/IR/Types.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/IR/ValueRange.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Support/LLVM.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/dialect/tpu/util.h" #include "xla/array.h" diff --git a/jaxlib/mosaic/dialect/tpu/vreg_util.h b/jaxlib/mosaic/dialect/tpu/vreg_util.h index 86955e128f59..8c2967e776c7 100644 --- a/jaxlib/mosaic/dialect/tpu/vreg_util.h +++ b/jaxlib/mosaic/dialect/tpu/vreg_util.h @@ -19,12 +19,12 @@ limitations under the License. #include #include -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/Builders.h" -#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" -#include "mlir/include/mlir/IR/Types.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" #include "xla/array.h" namespace mlir::tpu { diff --git a/jaxlib/mosaic/dialect/tpu/vreg_util_test.cc b/jaxlib/mosaic/dialect/tpu/vreg_util_test.cc index ea3063361e1a..8a6d437ab73c 100644 --- a/jaxlib/mosaic/dialect/tpu/vreg_util_test.cc +++ b/jaxlib/mosaic/dialect/tpu/vreg_util_test.cc @@ -21,20 +21,20 @@ limitations under the License. #include #include -#include "llvm/include/llvm/ADT/TypeSwitch.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/Builders.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/BuiltinOps.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" -#include "mlir/include/mlir/IR/MLIRContext.h" -#include "mlir/include/mlir/IR/OwningOpRef.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/Support/DebugStringHelper.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "llvm/ADT/TypeSwitch.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/DebugStringHelper.h" +#include "mlir/Support/LLVM.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" namespace mlir::tpu { diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index 402e099c8d6b..d4f4d1732b2e 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -40,49 +40,49 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" -#include "llvm/include/llvm/ADT/SmallVector.h" -#include "llvm/include/llvm/Support/CodeGen.h" -#include "llvm/include/llvm/Support/TargetSelect.h" -#include "mlir/include/mlir/Conversion/ArithToLLVM/ArithToLLVM.h" -#include "mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" -#include "mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" -#include "mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" -#include "mlir/include/mlir/Conversion/IndexToLLVM/IndexToLLVM.h" -#include "mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h" -#include "mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" -#include "mlir/include/mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" -#include "mlir/include/mlir/Conversion/Passes.h" -#include "mlir/include/mlir/Conversion/UBToLLVM/UBToLLVM.h" -#include "mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/include/mlir/Dialect/Arith/Transforms/Passes.h" -#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h" -#include "mlir/include/mlir/Dialect/GPU/Transforms/Passes.h" -#include "mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h" -#include "mlir/include/mlir/Dialect/Math/IR/Math.h" -#include "mlir/include/mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h" -#include "mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h" -#include "mlir/include/mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/include/mlir/ExecutionEngine/ExecutionEngine.h" -#include "mlir/include/mlir/ExecutionEngine/OptUtils.h" -#include "mlir/include/mlir/IR/AsmState.h" -#include "mlir/include/mlir/IR/DialectRegistry.h" -#include "mlir/include/mlir/IR/MLIRContext.h" -#include "mlir/include/mlir/Parser/Parser.h" -#include "mlir/include/mlir/Pass/PassManager.h" -#include "mlir/include/mlir/Pass/PassRegistry.h" -#include "mlir/include/mlir/Support/LLVM.h" -#include "mlir/include/mlir/Support/LogicalResult.h" -#include "mlir/include/mlir/Target/LLVM/NVVM/Target.h" -#include "mlir/include/mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" -#include "mlir/include/mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" -#include "mlir/include/mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" -#include "mlir/include/mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" -#include "mlir/include/mlir/Transforms/Passes.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/CodeGen.h" +#include "llvm/Support/TargetSelect.h" +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" +#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" +#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" +#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" +#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" +#include "mlir/Conversion/Passes.h" +#include "mlir/Conversion/UBToLLVM/UBToLLVM.h" +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Transforms/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/GPU/Transforms/Passes.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/ExecutionEngine/ExecutionEngine.h" +#include "mlir/ExecutionEngine/OptUtils.h" +#include "mlir/IR/AsmState.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Target/LLVM/NVVM/Target.h" +#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" +#include "mlir/Transforms/Passes.h" #include "jaxlib/gpu/vendor.h" #include "jaxlib/mosaic/dialect/gpu/mosaic_gpu.h" #include "jaxlib/mosaic/gpu/launch_lowering.h" diff --git a/jaxlib/mosaic/gpu/launch_lowering.cc b/jaxlib/mosaic/gpu/launch_lowering.cc index 0331d800ec50..f3f982f07481 100644 --- a/jaxlib/mosaic/gpu/launch_lowering.cc +++ b/jaxlib/mosaic/gpu/launch_lowering.cc @@ -31,29 +31,29 @@ limitations under the License. #include #include -#include "llvm/include/llvm/ADT/STLExtras.h" -#include "llvm/include/llvm/ADT/StringRef.h" -#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h" -#include "mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h" -#include "mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h" -#include "mlir/include/mlir/IR/Builders.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/BuiltinOps.h" -#include "mlir/include/mlir/IR/DialectRegistry.h" -#include "mlir/include/mlir/IR/Location.h" -#include "mlir/include/mlir/IR/SymbolTable.h" -#include "mlir/include/mlir/IR/TypeRange.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/IR/ValueRange.h" -#include "mlir/include/mlir/IR/Visitors.h" -#include "mlir/include/mlir/Interfaces/DataLayoutInterfaces.h" -#include "mlir/include/mlir/Pass/Pass.h" -#include "mlir/include/mlir/Pass/PassRegistry.h" -#include "mlir/include/mlir/Support/LLVM.h" -#include "mlir/include/mlir/Support/LogicalResult.h" -#include "mlir/include/mlir/Support/TypeID.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/TypeRange.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Support/TypeID.h" namespace mosaic { namespace gpu { diff --git a/jaxlib/mosaic/gpu/passes.cc b/jaxlib/mosaic/gpu/passes.cc index 1815e18ca927..1705405d2f32 100644 --- a/jaxlib/mosaic/gpu/passes.cc +++ b/jaxlib/mosaic/gpu/passes.cc @@ -19,18 +19,18 @@ limitations under the License. #include #include -#include "llvm/include/llvm/ADT/StringRef.h" -#include "mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h" -#include "mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h" -#include "mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h" -#include "mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/BuiltinOps.h" -#include "mlir/include/mlir/IR/SymbolTable.h" -#include "mlir/include/mlir/Pass/PassRegistry.h" -#include "mlir/include/mlir/Support/LLVM.h" -#include "mlir/include/mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/DialectConversion.h" #include "jaxlib/mosaic/pass_boilerplate.h" namespace mosaic { diff --git a/jaxlib/mosaic/gpu/serde.cc b/jaxlib/mosaic/gpu/serde.cc index f4cf846acc11..5fca1d445774 100644 --- a/jaxlib/mosaic/gpu/serde.cc +++ b/jaxlib/mosaic/gpu/serde.cc @@ -15,10 +15,10 @@ limitations under the License. #include "jaxlib/mosaic/gpu/serde.h" -#include "llvm/include/llvm/ADT/StringMap.h" -#include "llvm/include/llvm/ADT/StringRef.h" -#include "mlir/include/mlir/IR/BuiltinOps.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Support/LLVM.h" #include "jaxlib/mosaic/serde.h" namespace mosaic::gpu { diff --git a/jaxlib/mosaic/gpu/serde.h b/jaxlib/mosaic/gpu/serde.h index d1e25e3f0912..29dda33d0c5a 100644 --- a/jaxlib/mosaic/gpu/serde.h +++ b/jaxlib/mosaic/gpu/serde.h @@ -19,12 +19,12 @@ limitations under the License. #include #include -#include "llvm/include/llvm/ADT/StringRef.h" -#include "llvm/include/llvm/Support/CommandLine.h" -#include "mlir/include/mlir/IR/BuiltinOps.h" -#include "mlir/include/mlir/Interfaces/DataLayoutInterfaces.h" -#include "mlir/include/mlir/Pass/Pass.h" -#include "mlir/include/mlir/Pass/PassRegistry.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/CommandLine.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" #include "jaxlib/mosaic/pass_boilerplate.h" namespace mosaic::gpu { diff --git a/jaxlib/mosaic/gpu/target.cc b/jaxlib/mosaic/gpu/target.cc index a1a66a709cbe..a259b3dead7b 100644 --- a/jaxlib/mosaic/gpu/target.cc +++ b/jaxlib/mosaic/gpu/target.cc @@ -23,8 +23,8 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" -#include "llvm/include/llvm/MC/MCSubtargetInfo.h" -#include "llvm/include/llvm/MC/TargetRegistry.h" +#include "llvm/MC/MCSubtargetInfo.h" +#include "llvm/MC/TargetRegistry.h" namespace mosaic::gpu { diff --git a/jaxlib/mosaic/pass_boilerplate.h b/jaxlib/mosaic/pass_boilerplate.h index 546981feeef7..96d9e85a1d2d 100644 --- a/jaxlib/mosaic/pass_boilerplate.h +++ b/jaxlib/mosaic/pass_boilerplate.h @@ -18,10 +18,10 @@ limitations under the License. #include -#include "mlir/include/mlir/IR/DialectRegistry.h" -#include "mlir/include/mlir/Pass/Pass.h" -#include "mlir/include/mlir/Support/LLVM.h" -#include "mlir/include/mlir/Support/TypeID.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/TypeID.h" namespace jaxlib { namespace mlir { diff --git a/jaxlib/mosaic/serde.cc b/jaxlib/mosaic/serde.cc index 88bca44bf181..307164d91dd9 100644 --- a/jaxlib/mosaic/serde.cc +++ b/jaxlib/mosaic/serde.cc @@ -18,15 +18,15 @@ limitations under the License. #include #include -#include "llvm/include/llvm/ADT/StringMap.h" -#include "llvm/include/llvm/ADT/StringRef.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/BuiltinOps.h" -#include "mlir/include/mlir/IR/Operation.h" -#include "mlir/include/mlir/IR/OperationSupport.h" -#include "mlir/include/mlir/IR/Visitors.h" -#include "mlir/include/mlir/Interfaces/DataLayoutInterfaces.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" +#include "mlir/Support/LLVM.h" namespace jaxlib::mosaic { diff --git a/jaxlib/mosaic/serde.h b/jaxlib/mosaic/serde.h index 762d9e5dad73..fdcaf58d4a8e 100644 --- a/jaxlib/mosaic/serde.h +++ b/jaxlib/mosaic/serde.h @@ -18,11 +18,11 @@ limitations under the License. #include -#include "llvm/include/llvm/ADT/StringMap.h" -#include "llvm/include/llvm/ADT/StringRef.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/BuiltinOps.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Support/LLVM.h" namespace jaxlib::mosaic { diff --git a/jaxlib/triton/triton_dialect_capi.cc b/jaxlib/triton/triton_dialect_capi.cc index 6a46d2914f57..8781fd16d76a 100644 --- a/jaxlib/triton/triton_dialect_capi.cc +++ b/jaxlib/triton/triton_dialect_capi.cc @@ -15,12 +15,12 @@ limitations under the License. #include "jaxlib/triton/triton_dialect_capi.h" -#include "llvm/include/llvm/Support/Casting.h" -#include "mlir/include/mlir-c/IR.h" -#include "mlir/include/mlir/CAPI/IR.h" -#include "mlir/include/mlir/CAPI/Registration.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/Dialect.h" +#include "llvm/Support/Casting.h" +#include "mlir-c/IR.h" +#include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" diff --git a/jaxlib/triton/triton_dialect_capi.h b/jaxlib/triton/triton_dialect_capi.h index 8c27b5b82500..7d2a2f10404a 100644 --- a/jaxlib/triton/triton_dialect_capi.h +++ b/jaxlib/triton/triton_dialect_capi.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef JAXLIB_TRITON_TRITON_DIALECT_CAPI_H_ #define JAXLIB_TRITON_TRITON_DIALECT_CAPI_H_ -#include "mlir/include/mlir-c/IR.h" -#include "mlir/include/mlir-c/Support.h" +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" #ifdef __cplusplus extern "C" { From 7fa7db7a9fd874fd4561f66fdf1fdecb0611432e Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 20 Mar 2025 07:39:10 -0700 Subject: [PATCH 053/483] Update XLA dependency to use revision http://github.com/openxla/xla/commit/6e2c9b024cec7dca4b2e1b07cc89373574c9c5af. PiperOrigin-RevId: 738795997 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 08c5af0c32b8..b20048193bd1 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "d3145a119840723c16fd27ee342729d68fddb7ef" -XLA_SHA256 = "daf2a72e36a9358803a8156c48b32117c9699fd327fcbc37b465f1a0045bccae" +XLA_COMMIT = "6e2c9b024cec7dca4b2e1b07cc89373574c9c5af" +XLA_SHA256 = "387917467d6f6e8358d54ba2b89f3fef14a00e62d8b0a096bf07acc8186444d4" def repo(): tf_http_archive( From c098b363fb032bbf812eceef679141e5261380bd Mon Sep 17 00:00:00 2001 From: Tom Natan Date: Thu, 20 Mar 2025 08:01:08 -0700 Subject: [PATCH 054/483] [JAX Shardy] Unskip stream annotation test when shardy is enabled, since the underlying issue is now resolved. PiperOrigin-RevId: 738802372 --- tests/memories_test.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/memories_test.py b/tests/memories_test.py index fdd654e2186d..adc45dbdb0c1 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -1664,8 +1664,6 @@ class StreamAnnotationTest(jtu.JaxTestCase): def test_stream_annotation_inside_shmap(self): if not jtu.test_device_matches(["gpu"]): self.skipTest("Stream annotation is only supported on GPU.") - if config.use_shardy_partitioner.value: - self.skipTest("Doesn't work with shardy") mesh = jtu.create_mesh((2,), ('x',)) s = NamedSharding(mesh, P('x')) From 8bbd738df1d77b998241b36a110eb5545cf4d2f3 Mon Sep 17 00:00:00 2001 From: Tom Natan Date: Thu, 20 Mar 2025 08:36:38 -0700 Subject: [PATCH 055/483] [JAX Shardy] #sdy Unskip another test that is now passing PiperOrigin-RevId: 738814411 --- tests/memories_test.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/memories_test.py b/tests/memories_test.py index adc45dbdb0c1..570b0c375834 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -756,9 +756,6 @@ def init(): def test_compute_no_inputs_host_replicated(self): if xb.backend_xla_version() is not None and xb.backend_xla_version() < 3: self.skipTest("This test requires an xla_version >= 3.") - if config.use_shardy_partitioner.value: - self.skipTest("XLA failure due to b/370786664 and b/366411266. " - "Enable when fixed.") mesh = jtu.create_mesh((4,), ('data')) tpu_sharding = NamedSharding(mesh, P('data')) From dad1b41f7bbe9f4d4bc39c261358df3f0823c84d Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 20 Mar 2025 08:57:12 -0700 Subject: [PATCH 056/483] Reverts 2562da7026ccd930e5f0972598c7d5479175b787 PiperOrigin-RevId: 738820673 --- jaxlib/setup.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/jaxlib/setup.py b/jaxlib/setup.py index 60f17a987307..b3a37a25f1b2 100644 --- a/jaxlib/setup.py +++ b/jaxlib/setup.py @@ -68,10 +68,10 @@ def has_ext_modules(self): url='https://github.com/jax-ml/jax', license='Apache-2.0', classifiers=[ - 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: 3.11', - 'Programming Language :: Python :: 3.12', - 'Programming Language :: Python :: 3.13', + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", ], package_data={ 'jaxlib': [ @@ -105,7 +105,7 @@ def has_ext_modules(self): 'triton/*.so', 'include/xla/ffi/api/*.h', ], - 'jaxlib.xla_extension': ['*.pyi', 'profiler/*.pyi'], + 'jaxlib.xla_extension': ['*.pyi'], }, zip_safe=False, distclass=BinaryDistribution, From 1ec0585361b9306d02cd01c07053f196750afe59 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 20 Mar 2025 09:06:27 -0700 Subject: [PATCH 057/483] Fix process_allgather of global jax.Arrays with shardy PiperOrigin-RevId: 738823617 --- jax/experimental/multihost_utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/jax/experimental/multihost_utils.py b/jax/experimental/multihost_utils.py index 2bde1fbeadc4..7be349f0fc8f 100644 --- a/jax/experimental/multihost_utils.py +++ b/jax/experimental/multihost_utils.py @@ -99,8 +99,11 @@ def _identity_fn(x): def _handle_array_process_allgather(inp, tiled): if isinstance(inp, array.ArrayImpl) and not inp.is_fully_addressable: - reps = sharding_impls.GSPMDSharding.get_replicated( - inp.sharding._device_assignment) + if isinstance(inp.sharding, sharding_impls.NamedSharding): + reps = inp.sharding.with_spec(P()) + else: + reps = sharding_impls.GSPMDSharding.get_replicated( + inp.sharding._device_assignment, memory_kind=inp.sharding.memory_kind) out = jax.jit(_identity_fn, out_shardings=reps)(inp) else: # All inputs here will be fully addressable. From 1b37613e4b9385ef1da473cbe3b12d1ac82dd833 Mon Sep 17 00:00:00 2001 From: Michael Hudgins Date: Thu, 20 Mar 2025 19:01:20 +0000 Subject: [PATCH 058/483] Introduce optional CUDA presubmit for additional hardware config --- .github/workflows/bazel_optional_cuda.yml | 65 +++++++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 .github/workflows/bazel_optional_cuda.yml diff --git a/.github/workflows/bazel_optional_cuda.yml b/.github/workflows/bazel_optional_cuda.yml new file mode 100644 index 000000000000..71936aeb9ae8 --- /dev/null +++ b/.github/workflows/bazel_optional_cuda.yml @@ -0,0 +1,65 @@ +name: CI - Bazel Optional CUDA tests +on: + workflow_dispatch: + inputs: + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: choice + required: true + default: 'no' + options: + - 'yes' + - 'no' + pull_request: + branches: + - main + types: [ labeled, synchronize ] + schedule: + - cron: "0 */2 * * *" # Run once every 2 hours +permissions: + contents: read +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + # Don't cancel in-progress jobs for main/release branches. + cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }} +jobs: + run_tests: + if: ${{ github.event.repository.fork == false && (github.event_name == 'schedule' || contains(github.event.pull_request.labels.*.name, 'CI Optional GPU Presubmit')) }} + runs-on: ${{ matrix.runner }} + container: 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.8-cudnn9.8:latest' + strategy: + matrix: + # Optional gpus to run against + runner: ["linux-x86-a4-224-b200-1gpu"] + name: "Bazel single accelerator CUDA tests (${{ matrix.runner }})" +# End Presubmit Naming Check github-cuda-presubmits + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@main + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: Run Bazel CUDA Tests + run: | + nvidia-smi + bazel test --config=ci_linux_x86_64_cuda \ + --config=resultstore \ + --config=rbe_cache \ + --repo_env=HERMETIC_CUDA_VERSION="12.8.0" \ + --repo_env=HERMETIC_CUDNN_VERSION="9.8.0" \ + --repo_env=HERMETIC_PYTHON_VERSION="3.13" \ + --test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \ + --run_under "$(pwd)/build/parallel_accelerator_execute.sh" \ + --test_output=errors \ + --test_env=JAX_ACCELERATOR_COUNT=1 \ + --test_env=JAX_TESTS_PER_ACCELERATOR=32 \ + --local_test_jobs=32 \ + --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow \ + --test_tag_filters=-multiaccelerator \ + --test_env=TF_CPP_MIN_LOG_LEVEL=0 \ + --test_env=JAX_SKIP_SLOW_TESTS=true \ + --action_env=JAX_ENABLE_X64="1" \ + --action_env=NCCL_DEBUG=WARN \ + --color=yes \ + //tests:gpu_tests //tests:backend_independent_tests \ + //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests \ No newline at end of file From 549c6694513d778f93ae63f282d8990c520d0a27 Mon Sep 17 00:00:00 2001 From: shuw Date: Thu, 20 Mar 2025 19:28:32 +0000 Subject: [PATCH 059/483] Straight-through estimator for nvfp4 --- jax/_src/cudnn/scaled_matmul_stablehlo.py | 27 +++++++++++++++++------ tests/scaled_matmul_stablehlo_test.py | 24 ++++++++++++++++++++ 2 files changed, 44 insertions(+), 7 deletions(-) diff --git a/jax/_src/cudnn/scaled_matmul_stablehlo.py b/jax/_src/cudnn/scaled_matmul_stablehlo.py index 1a8dee293082..ddcde6a95b26 100644 --- a/jax/_src/cudnn/scaled_matmul_stablehlo.py +++ b/jax/_src/cudnn/scaled_matmul_stablehlo.py @@ -489,18 +489,24 @@ def quantize(x, config): assert config.scale_type == jnp.float8_e8m0fnu scales_q = cast_to_e8m0_with_rounding_up(scales) scaled_x = x / e8m0_to_dtype(scales_q, scales.dtype) + clipped_x = jnp.clip(scaled_x, -MAX, MAX) + x_q = clipped_x.astype(config.data_type) elif config.mode == "nvfp4": assert config.scale_type == jnp.float8_e4m3fn assert config.global_scale.dtype == jnp.float32 - - scales = scales / config.global_scale - scales_q = jax.lax.optimization_barrier(scales.astype(jnp.float8_e4m3fn)) - scaled_x = x / (scales_q.astype(jnp.float32) * - config.global_scale).astype(x.dtype) + SCALE_MAX = jnp.finfo(config.scale_type).max.astype(x.dtype) + + prev_amax = config.global_scale * (MAX * SCALE_MAX) + scales_q = jnp.clip( + (amax / prev_amax) * SCALE_MAX, 0, SCALE_MAX + ).astype(config.scale_type) + x_q = jnp.where( + amax <= prev_amax, + (x * MAX) / amax, + jnp.clip((x * MAX) / prev_amax, -MAX, MAX), + ).astype(config.data_type) else: raise ValueError(f"Unrecognized mode: {config.mode}.") - clipped_x = jnp.clip(scaled_x, -MAX, MAX) - x_q = clipped_x.astype(config.data_type) x_q = x_q.reshape(x_shape) # shape = (B, M, K) scales_q = jnp.reshape(scales_q, scales_q.shape[:-1]).view( @@ -639,6 +645,13 @@ def scaled_dot_bwd(dimension_numbers, preferred_element_type, configs, res, g): } grad_lhs = scaled_dot_general_transpose_lhs(*args, **lhs_kw_args) grad_rhs = scaled_dot_general_transpose_rhs(*args, **rhs_kw_args) + + if configs[2].mode == "nvfp4": + assert rhs.dtype == lhs.dtype + MAX = jnp.finfo(configs[0].data_type).max.astype(lhs.dtype) + SCALE_MAX = jnp.finfo(configs[0].scale_type).max.astype(lhs.dtype) + grad_lhs = jnp.where(jnp.abs(lhs) <= configs[0].global_scale * MAX * SCALE_MAX, grad_lhs, 0) + grad_rhs = jnp.where(jnp.abs(rhs) <= configs[1].global_scale * MAX * SCALE_MAX, grad_rhs, 0) return (grad_lhs, grad_rhs) diff --git a/tests/scaled_matmul_stablehlo_test.py b/tests/scaled_matmul_stablehlo_test.py index 141839a19a08..224f6b6204e5 100644 --- a/tests/scaled_matmul_stablehlo_test.py +++ b/tests/scaled_matmul_stablehlo_test.py @@ -194,6 +194,11 @@ def generate_nvfp4_quantized_tensors(dot_config, output_type): amax_a = jnp.max(jnp.abs(a)).astype(jnp.float32) amax_b = jnp.max(jnp.abs(b)).astype(jnp.float32) + # To emulate calibrated amax + amax_sf = 0.9 + amax_a *= amax_sf + amax_b *= amax_sf + # Update global scales data_max = jnp.finfo(block_scale_configs_nvfp4[0].data_type).max.astype( jnp.float32 @@ -567,8 +572,27 @@ def fwd(a, b, is_ref=False, use_normalized=False): out_ref, _ = j_train_fwd_ref(a_dq, b_dq) self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=1e-2) + def _grad_clip(amax, x, grad): + return jnp.where(jnp.abs(x) <= amax, grad, 0) + + data_max = jnp.finfo(jnp.float4_e2m1fn).max.astype(output_type) + scale_max = jnp.finfo(jnp.float8_e4m3fn).max.astype(output_type) + prev_amax_a = a_gs * data_max * scale_max + prev_amax_b = b_gs * data_max * scale_max + + x_grad_ref = _grad_clip(prev_amax_a, a_raw, x_grad_ref) + w_grad_ref = _grad_clip(prev_amax_b, b_raw, w_grad_ref) self.assertArraysAllClose(x_grad, x_grad_ref, rtol=1e-2, atol=1e1) self.assertArraysAllClose(w_grad, w_grad_ref, rtol=1e-2, atol=1e1) + # Verify straight_through_estimator + self.assertArraysEqual( + jnp.where(jnp.abs(a_raw) > prev_amax_a, x_grad, 0), + jnp.zeros_like(x_grad) + ) + self.assertArraysEqual( + jnp.where(jnp.abs(b_raw) > prev_amax_b, w_grad, 0), + jnp.zeros_like(w_grad) + ) else: j_inference = jax.jit(fwd) j_inference_ref = jax.jit(partial(fwd, is_ref=True, use_normalized=True)) From 59e480db99ea221c21efc566d4fe7f51ffebadf8 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Thu, 20 Mar 2025 12:51:04 -0700 Subject: [PATCH 060/483] [Mosaic GPU] Skip Mosaic GPU tests if jax_pallas_use_mosaic_gpu flag is not set. PiperOrigin-RevId: 738906466 --- jax/_src/pallas/pallas_call.py | 2 +- tests/pallas/BUILD | 1 - tests/pallas/mosaic_gpu_test.py | 4 ++++ tests/pallas/ops_test.py | 3 ++- 4 files changed, 7 insertions(+), 3 deletions(-) diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index d0b74b2e5148..fbe3d23c6c27 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -1206,7 +1206,7 @@ def _trace_kernel_to_jaxpr( return jaxpr, tuple(consts) -_PALLAS_USE_MOSAIC_GPU = config.bool_flag( +_PALLAS_USE_MOSAIC_GPU = config.bool_state( "jax_pallas_use_mosaic_gpu", default=config.bool_env("JAX_PALLAS_USE_MOSAIC_GPU", False), help=( diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 987a3aa9d50a..8ec4eea7aa1f 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -215,7 +215,6 @@ jax_multiplatform_test( "gpu_h100", ], env = { - "JAX_PALLAS_USE_MOSAIC_GPU": "1", "JAX_PALLAS_VERBOSE_ERRORS": "0", }, deps = [ diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 40e98bf05ba9..27017e4eb740 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -26,6 +26,7 @@ from jax import lax from jax._src import test_util as jtu from jax._src.pallas.mosaic_gpu import pipeline as mgpu_pipeline +from jax._src.pallas import pallas_call from jax.experimental import pallas as pl from jax.experimental.pallas import mosaic_gpu as plgpu import jax.numpy as jnp @@ -59,6 +60,9 @@ class PallasTest(jtu.JaxTestCase): def setUp(self): if not jtu.is_cuda_compute_capability_at_least("9.0"): self.skipTest("Only works on a GPU with capability >= sm90") + context_stack = contextlib.ExitStack() + context_stack.enter_context(pallas_call._PALLAS_USE_MOSAIC_GPU(True)) + self.addCleanup(context_stack.close) super().setUp() diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 0fc375bf64a1..38426747d85d 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -30,6 +30,7 @@ from jax._src import linear_util as lu from jax._src import state from jax._src import test_util as jtu +from jax._src.pallas import pallas_call from jax.experimental import pallas as pl from jax.interpreters import partial_eval as pe import jax.numpy as jnp @@ -61,7 +62,7 @@ jax.config.parse_flags_with_absl() jtu.setup_hypothesis(max_examples=50) -use_mosaic_gpu = jax.config.read("jax_pallas_use_mosaic_gpu") +use_mosaic_gpu = pallas_call._PALLAS_USE_MOSAIC_GPU.value intx = dtypes.canonicalize_dtype(jnp.int64) floatx = dtypes.canonicalize_dtype(jnp.float64) From 412f1d35dca4efa655a291f3d4b3f8f3d6b4546d Mon Sep 17 00:00:00 2001 From: Gleb Pobudzey Date: Fri, 7 Feb 2025 07:12:19 +0000 Subject: [PATCH 061/483] Adding sharding support to dynamic masks --- .../splash_attention_kernel.py | 25 +- .../splash_attention_mask_info.py | 110 ++++++--- tests/pallas/BUILD | 16 ++ ...pu_splash_attention_kernel_sharded_test.py | 223 ++++++++++++++++++ .../tpu_splash_attention_kernel_test.py | 14 +- .../pallas/tpu_splash_attention_mask_test.py | 4 +- 6 files changed, 352 insertions(+), 40 deletions(-) create mode 100644 tests/pallas/tpu_splash_attention_kernel_sharded_test.py diff --git a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py index 4b6e4a41c43b..d0fb6f2f9670 100644 --- a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py +++ b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py @@ -2293,6 +2293,26 @@ def _splash_attention( mask_function: MaskFunctionType | None, interpret: bool, ) -> SplashCustomReturnType: + """ + For dynamic masks, `partial_mask_blocks` has shape (head_count, q_blocks, kv_blocks, block_q, block_kv). + This shape allows sharding across both head count and query sequence dimensions. + + Note: The leading dimensions (head_count, q_blocks, kv_blocks) must be + collapsed into a single dimension before being passed to the kernel. + """ + def _collapse_partial_mask_blocks(mask_info: mask_info_lib.MaskInfo | None): + if mask_info is None or mask_info.partial_mask_blocks is None: + return mask_info + + return mask_info._replace( + partial_mask_blocks=mask_info.partial_mask_blocks.reshape( + -1, *mask_info.partial_mask_blocks.shape[-2:] + ) + ) + + fwd_mask_info = _collapse_partial_mask_blocks(fwd_mask_info) + dq_mask_info = _collapse_partial_mask_blocks(dq_mask_info) + dkv_mask_info = _collapse_partial_mask_blocks(dkv_mask_info) return _splash_attention_custom( fwd_mask_info, dq_mask_info, @@ -2352,13 +2372,16 @@ def manual_sharding_spec(self, sharding: jax.sharding.NamedSharding): spec = sharding.spec assert len(spec) == 2 replicated = jax.sharding.PartitionSpec() + partial_mask_blocks_spec = ( + spec if self.fwd_mask_info.is_dynamic_mask else replicated + ) # Shard q_sequence over the sequence dimension only. q_sequence_spec = jax.sharding.PartitionSpec(spec[1]) mask_info_specs = mask_info_lib.MaskInfo( # pytype: disable=wrong-arg-types data_next=spec if self.fwd_mask_info.data_next is not None else None, mask_next=spec if self.fwd_mask_info.mask_next is not None else None, block_mask=spec if self.fwd_mask_info.block_mask is not None else None, - partial_mask_blocks=replicated + partial_mask_blocks=partial_mask_blocks_spec if self.fwd_mask_info.partial_mask_blocks is not None else None, q_sequence=q_sequence_spec diff --git a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py index 65081e79c0cf..9c79fbbf7e09 100644 --- a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py +++ b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py @@ -67,6 +67,10 @@ class MaskInfo(NamedTuple): q_sequence: A i32[q_sequence_length] NumPy array. When using causal masking, this contains the list of indices that correspond to q tokens. For plain causal this is just np.arange(q_sequence_length). + is_dynamic_mask: A bool indicating whether the mask is dynamic or static. + When True, the leading dimensions of `partial_mask_blocks` (num_heads, + q_blocks, kv_blocks) are not collapsed, allowing us to shard it along + those dimensions. """ data_next: np.ndarray | jax.Array | None @@ -74,6 +78,7 @@ class MaskInfo(NamedTuple): block_mask: np.ndarray | jax.Array | None partial_mask_blocks: np.ndarray | jax.Array | None q_sequence: np.ndarray | None + is_dynamic_mask: bool = None def _downcast_to_small_type(array: np.ndarray) -> np.ndarray: @@ -168,7 +173,7 @@ def __eq__(self, other: object) -> bool: def _get_mask_info_for_shard( output_shape: tuple[int, int, int], has_mask_next: bool, - mask: mask_lib.MultiHeadMask, + mask: mask_lib.MultiHeadMask | jax.Array, block_shape: tuple[int, int], coords_to_partial_mask_block_index: dict[tuple[int, int, int], int], masks_per_head_shard: int, @@ -338,7 +343,8 @@ def _process_dynamic_mask( launched. q_seq_shards: Number of Q sequence shards of the mesh in which the kernel is launched. - shrink_grid: Whether or not we should apply the grid shrinking optimization. This is currently ignored. + shrink_grid: Whether or not we should apply the grid shrinking optimization. + This is currently ignored. Returns: `MaskInfo`, a sparse representation of the dense mask. @@ -349,11 +355,6 @@ def _process_dynamic_mask( """ del shrink_grid - - # TODO(pobudzey): Properly support sharding. - if head_shards != 1 or q_seq_shards != 1: - raise ValueError('Dynamic mask processing does not support sharding.') - if len(mask.shape) != 3: raise ValueError(f'Expected a 3-dim mask, instead got: {mask.shape}.') @@ -370,6 +371,18 @@ def _process_dynamic_mask( if kv_mod != 0: raise ValueError(f'{kv_block_size=} should divide {kv_seq_len=}.') + q_seq_len_per_shard, mod = divmod(q_seq_len, q_seq_shards) + if mod != 0: + raise ValueError(f'{q_seq_shards=} should divide {q_seq_len=}.') + + q_blocks_per_shard, mod = divmod(q_seq_len_per_shard, q_block_size) + if mod != 0: + raise ValueError(f'{q_block_size=} should divide {q_seq_len_per_shard=}.') + + heads_per_shard, mod = divmod(head_count, head_shards) + if mod != 0: + raise ValueError(f'{head_shards=} should divide {head_count=}.') + block_mask_shape = ( head_count, q_blocks_count, @@ -398,26 +411,66 @@ def _process_dynamic_mask( block_mask = jnp.where(is_full_mask, 2, block_mask) block_mask = jnp.where(is_empty_mask, 0, block_mask) - # TODO(pobudzey): Return the next valid mask index instead of 0 for a more efficient pipeline. - mask_next = jnp.where( - jnp.logical_or(is_empty_mask, is_full_mask), - 0, - jnp.arange(math.prod(block_mask_shape), dtype=np.int32).reshape( - block_mask_shape - ), - ) + q_sequence_axis = 1 + head_axis = 0 - # data_next stores the index of the next non-empty data block in the sequence. - # The indices of empty blocks are set to 0 to avoid copying extra data when - # pipeling. - if is_dkv: - data_next = jnp.arange(q_blocks_count, dtype=np.int32)[None, :, None] - else: - data_next = jnp.arange(kv_blocks_count, dtype=np.int32)[None, None, :] - data_next = jnp.broadcast_to(data_next, block_mask_shape) - data_next = jnp.where(is_empty_mask, 0, data_next) + # Each iteration of the loop processes a slice of the mask info + # tensors of this shape: + mask_info_slice_shape = (heads_per_shard, q_blocks_per_shard, kv_blocks_count) + + # Collect mask_info shards along the head dimension, concatentate (or + # broadcast) them after the loop. + data_next_per_head_list, mask_next_per_head_list = [], [] + for head_shard in range(head_shards): + head_start = head_shard * heads_per_shard + mask_head_slice = slice(head_start, head_start + heads_per_shard) + + # Collect mask_info shards along the q_sequence dimension, concatenate them + # after the loop. + data_next_sequence_slices, mask_next_sequence_slices = [], [] + for q_seq_len_shard in range(q_seq_shards): + q_seq_len_start = q_seq_len_shard * q_blocks_per_shard + blocked_q_seq_len_slice = slice( + q_seq_len_start, q_seq_len_start + q_blocks_per_shard + ) + local_block_mask = block_mask[mask_head_slice, blocked_q_seq_len_slice] + + mask_next_slice = jnp.arange( + math.prod(mask_info_slice_shape), dtype=np.int32 + ).reshape(mask_info_slice_shape) + mask_next_slice = jnp.where(local_block_mask == 1, mask_next_slice, 0) + + # data_next stores the index of the next non-empty data block in the sequence. + # The indices of empty blocks are set to 0 to avoid copying extra data when + # pipeling. + if is_dkv: + data_next_slice = jnp.arange(q_blocks_per_shard, dtype=np.int32)[ + None, :, None + ] + else: + data_next_slice = jnp.arange(kv_blocks_count, dtype=np.int32)[ + None, None, : + ] + data_next_slice = jnp.broadcast_to(data_next_slice, mask_info_slice_shape) + data_next_slice = jnp.where(local_block_mask == 0, 0, data_next_slice) + + data_next_sequence_slices.append(data_next_slice) + mask_next_sequence_slices.append(mask_next_slice) + + # Concatenate the sequence shards. + data_next_per_head = jnp.concatenate( + data_next_sequence_slices, axis=q_sequence_axis + ) + data_next_per_head_list.append(data_next_per_head) + mask_next_per_head = jnp.concatenate( + mask_next_sequence_slices, axis=q_sequence_axis + ) + mask_next_per_head_list.append(mask_next_per_head) + + # Concatenate (or broadcast) the head shards. + data_next = jnp.concatenate(data_next_per_head_list, axis=head_axis) + mask_next = jnp.concatenate(mask_next_per_head_list, axis=head_axis) - partial_mask_blocks = partial_mask_blocks.reshape(-1, *block_shape) if is_dkv: partial_mask_blocks = partial_mask_blocks.swapaxes(-1, -2) @@ -438,9 +491,11 @@ def _downcast(array: jax.Array, max_value: int) -> jax.Array: if downcast_smem_data: block_mask = block_mask.astype(np.int8) # values are in the range [0, 1, 2] data_next = _downcast( - data_next, q_blocks_count if is_dkv else kv_blocks_count + data_next, q_blocks_per_shard if is_dkv else kv_blocks_count + ) + mask_next = _downcast( + mask_next, heads_per_shard * q_blocks_per_shard * kv_blocks_count ) - mask_next = _downcast(mask_next, math.prod(block_mask_shape)) return ( MaskInfo( @@ -449,6 +504,7 @@ def _downcast(array: jax.Array, max_value: int) -> jax.Array: block_mask=block_mask, partial_mask_blocks=partial_mask_blocks, q_sequence=None, + is_dynamic_mask=True, ), None, ) diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 8ec4eea7aa1f..34af5e16a9b6 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -540,6 +540,22 @@ jax_multiplatform_test( ] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"), ) +jax_multiplatform_test( + name = "tpu_splash_attention_kernel_sharded_test", + srcs = ["tpu_splash_attention_kernel_sharded_test.py"], + enable_configs = [ + "tpu_v5e_4x2", + "tpu_v5p_2x2", + ], + shard_count = 5, + deps = [ + "//jax:extend", + "//jax:pallas_tpu", + "//jax:pallas_tpu_ops", + ], +) + + # This test doesn't need a TPU; it only tests numpy-using helpers. jax_py_test( name = "tpu_splash_attention_mask_test", diff --git a/tests/pallas/tpu_splash_attention_kernel_sharded_test.py b/tests/pallas/tpu_splash_attention_kernel_sharded_test.py new file mode 100644 index 000000000000..db14b44938e9 --- /dev/null +++ b/tests/pallas/tpu_splash_attention_kernel_sharded_test.py @@ -0,0 +1,223 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for partitioning splash_attention.""" + +import functools +import math +from absl.testing import absltest, parameterized +import jax +from jax import random +from jax._src import test_util as jtu +from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel as splash +from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask as mask_lib +from jax.experimental.shard_map import shard_map +import jax.numpy as jnp +from jax.sharding import PartitionSpec +import numpy as np + +partial = functools.partial + +jax.config.parse_flags_with_absl() + + +@jtu.with_config(jax_traceback_filtering="off") +class PallasBaseTest(jtu.JaxTestCase): + INTERPRET = False + + def setUp(self): + super().setUp() + if not jtu.is_device_tpu(): + self.skipTest("Test requires TPU.") + + if len(jax.devices()) < 4: + self.skipTest("This test requires at least 4 devices.") + + def _assert_allclose(self, x, y, **kwargs): + if x.dtype == np.dtype(jnp.bfloat16): + x = x.astype(np.float32) + if y.dtype == np.dtype(jnp.bfloat16): + y = y.astype(np.float32) + self.assertEqual(x.dtype, y.dtype) + self.assertTupleEqual(x.shape, y.shape) + np.testing.assert_allclose(x, y, **kwargs) + + +def generate_mask(shape, num_heads, seed) -> np.ndarray: + assert num_heads >= 2 + assert shape > (64, 64) + + masks = [ + mask_lib.make_causal_mask(shape), + mask_lib.make_local_attention_mask(shape, window_size=(64, 64)), + ] + masks += [mask_lib.make_random_mask(shape, 0.8, seed)] * (num_heads - 2) + return np.stack(masks, axis=0) + + +class SplashAttentionShardingTest(PallasBaseTest): + + @parameterized.product( + topology=[(1, 1), (2, 1), (2, 2), (1, 2), (1, 4), (4, 1)], + num_heads=[2, 4, 16], + dtype=[jnp.bfloat16], + is_dynamic_mask=[False, True], + ) + def test_dynamic_mask_manual_partitioning_mha( + self, topology, num_heads, dtype, is_dynamic_mask + ): + k1, k2, k3 = random.split(random.key(0), 3) + seq_len = 1024 + head_dim = 128 + + head_shards, q_seq_shards = topology + num_devices = math.prod(topology) + + if head_shards > num_heads: + self.skipTest( + f"This test requires {num_heads} heads, but has only" + f" {head_shards} head shards available." + ) + + if len(jax.devices()) < num_devices: + self.skipTest( + f"This test requires {num_devices} devices, but has only" + f" {len(jax.devices())} devices available." + ) + + q = random.uniform(k1, (num_heads, seq_len, head_dim), dtype=dtype) + k = random.uniform(k2, (num_heads, seq_len, head_dim), dtype=dtype) + v = random.uniform(k3, (num_heads, seq_len, head_dim), dtype=dtype) + + mask = generate_mask((seq_len, seq_len), num_heads, seed=0) + if is_dynamic_mask: + mask = jnp.array(mask) + + devices = np.asarray(jax.devices()[:num_devices]).reshape( + head_shards, q_seq_shards + ) + + mesh = jax.sharding.Mesh(devices, ("heads", "q_seq")) + q_spec = PartitionSpec( + "heads" if head_shards > 1 else None, + "q_seq" if q_seq_shards > 1 else None, + ) + kv_spec = PartitionSpec("heads" if head_shards > 1 else None, None) + kernel = splash.make_splash_mha( + mask, head_shards=head_shards, q_seq_shards=q_seq_shards + ) + kernel_spec = kernel.manual_sharding_spec( + jax.sharding.NamedSharding(mesh, q_spec) + ) + + @partial( + shard_map, + mesh=mesh, + in_specs=( + kernel_spec, + q_spec, + kv_spec, + kv_spec, + ), + out_specs=q_spec, + check_rep=False, + ) + def f(kernel, q, k, v): + return kernel(q, k, v) + + out = f(kernel, q, k, v) + out_ref = jax.vmap(splash.attention_reference)(mask, q, k, v, None) + self._assert_allclose(out, out_ref, rtol=3e-3, atol=3e-3) + + @parameterized.product( + topology=[(1, 1), (2, 1), (2, 2), (1, 2), (1, 4), (4, 1)], + num_heads=[2, 4], + dtype=[jnp.bfloat16], + is_dynamic_mask=[False, True], + ) + def test_dynamic_mask_manual_partitioning_mha_bwd( + self, topology, num_heads, dtype, is_dynamic_mask + ): + assert num_heads % 2 == 0 + k1, k2, k3, k4 = random.split(random.key(0), 4) + seq_len = 1024 + head_dim = 128 + + head_shards, q_seq_shards = topology + num_devices = math.prod(topology) + + if head_shards > num_heads: + self.skipTest( + f"This test requires {num_heads} heads, but has only" + f" {head_shards} head shards available." + ) + + q = random.uniform(k1, (num_heads, seq_len, head_dim), dtype=dtype) + k = random.uniform(k2, (num_heads, seq_len, head_dim), dtype=dtype) + v = random.uniform(k3, (num_heads, seq_len, head_dim), dtype=dtype) + + mask = generate_mask((seq_len, seq_len), num_heads, seed=0) + if is_dynamic_mask: + mask = jnp.array(mask) + + devices = np.asarray(jax.devices()[:num_devices]).reshape( + head_shards, q_seq_shards + ) + + mesh = jax.sharding.Mesh(devices, ("heads", "q_seq")) + q_spec = PartitionSpec( + "heads" if head_shards > 1 else None, + "q_seq" if q_seq_shards > 1 else None, + ) + kv_spec = PartitionSpec("heads" if head_shards > 1 else None, None) + + kernel = splash.make_splash_mha( + mask, head_shards=head_shards, q_seq_shards=q_seq_shards + ) + kernel_spec = kernel.manual_sharding_spec( + jax.sharding.NamedSharding(mesh, q_spec) + ) + + @partial( + shard_map, + mesh=mesh, + in_specs=( + kernel_spec, + q_spec, + kv_spec, + kv_spec, + ), + out_specs=q_spec, + check_rep=False, + ) + def f(kernel, q, k, v): + return kernel(q, k, v) + + f_ref = jax.vmap(splash.attention_reference) + + out, out_vjp = jax.vjp(f, kernel, q, k, v) + out_ref, out_vjp_ref = jax.vjp(f_ref, mask, q, k, v, None) + self._assert_allclose(out, out_ref, rtol=3e-3, atol=3e-3) + + do = random.uniform(k4, out.shape, dtype=out.dtype) + _, dq, dk, dv = out_vjp(do) + _, dq_ref, dk_ref, dv_ref, _ = out_vjp_ref(do.astype(jnp.float32)) + + self.assertAllClose(dq, dq_ref, atol=5e-2) + self.assertAllClose(dk, dk_ref, atol=5e-2) + self.assertAllClose(dv, dv_ref, atol=5e-2) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/tpu_splash_attention_kernel_test.py b/tests/pallas/tpu_splash_attention_kernel_test.py index dfe0bcc0da3b..240a9c91c02d 100644 --- a/tests/pallas/tpu_splash_attention_kernel_test.py +++ b/tests/pallas/tpu_splash_attention_kernel_test.py @@ -303,14 +303,6 @@ def attn_logits_soft_cap_strategy() -> hps.SearchStrategy[float | None]: return hps.one_of(hps.just(None), hps.floats(min_value=1.0, max_value=50.0)) -def to_dynamic_mask(mask: mask_lib.MultiHeadMask) -> jax.Array: - q_seq_len, kv_seq_len = mask.masks[0].shape - full_mask_slice = (slice(0, q_seq_len), slice(0, kv_seq_len)) - dynamic_mask = jnp.stack([m[full_mask_slice] for m in mask.masks], axis=0) - - return dynamic_mask - - @jtu.with_config(jax_traceback_filtering="off") class PallasBaseTest(jtu.JaxTestCase): INTERPRET = False @@ -384,7 +376,7 @@ def test_splash_attention(self, is_mqa, is_segmented, is_dynamic_mask, data): masks = data.draw(mha_mask_strategy(q_seq_len, kv_seq_len, num_q_heads)) mask = mask_lib.MultiHeadMask(tuple(m.get_mask() for m in masks)) if is_dynamic_mask: - mask = to_dynamic_mask(mask) + mask = jnp.array(mask[:, :, :]) block_sizes = data.draw(block_sizes_strategy(q_seq_len, kv_seq_len)) if is_mqa: @@ -460,7 +452,7 @@ def test_splash_attention_fwd( masks = data.draw(mha_mask_strategy(q_seq_len, kv_seq_len, num_q_heads)) mask = mask_lib.MultiHeadMask(tuple(m.get_mask() for m in masks)) if is_dynamic_mask: - mask = to_dynamic_mask(mask) + mask = jnp.array(mask[:, :, :]) block_sizes = data.draw(block_sizes_strategy(q_seq_len, kv_seq_len)) if is_mqa: attn_ref = splash.make_masked_mqa_reference(mask) @@ -628,7 +620,7 @@ def test_splash_attention_bwd( masks = data.draw(mha_mask_strategy(q_seq_len, kv_seq_len, num_q_heads)) mask = mask_lib.MultiHeadMask(tuple(m.get_mask() for m in masks)) if use_dynamic_mask: - mask = to_dynamic_mask(mask) + mask = jnp.array(mask[:, :, :]) block_sizes = data.draw( block_sizes_strategy(q_seq_len, kv_seq_len, include_bwd_blocks=True, use_fused_bwd_kernel=use_fused_bwd_kernel) diff --git a/tests/pallas/tpu_splash_attention_mask_test.py b/tests/pallas/tpu_splash_attention_mask_test.py index f39b4d839340..5379eb10990f 100644 --- a/tests/pallas/tpu_splash_attention_mask_test.py +++ b/tests/pallas/tpu_splash_attention_mask_test.py @@ -2166,7 +2166,9 @@ def test_dynamic_mask(self, is_dkv: bool): self.assertArraysEqual(mask_info.block_mask, _expected_block_mask) self.assertArraysEqual( - mask_info.partial_mask_blocks, + mask_info.partial_mask_blocks.reshape( + -1, *mask_info.partial_mask_blocks.shape[-2:] + ), _expected_partial_mask_blocks, ) self.assertArraysEqual(mask_info.mask_next, _expected_mask_next) From ea7fa29be73f322eed727a59f1dcbf8cb7ac7170 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 20 Mar 2025 13:23:09 -0700 Subject: [PATCH 062/483] Allow `tuple(arrays)` as an input to `make_array_from_single_device_arrays`. Fixes https://github.com/jax-ml/jax/issues/27303 PiperOrigin-RevId: 738917340 --- jax/_src/array.py | 7 ++++--- tests/array_test.py | 12 ++++++++++++ 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/jax/_src/array.py b/jax/_src/array.py index e49963ccda9c..ee196026887d 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -1024,7 +1024,7 @@ def make_array_from_single_device_arrays( shape : Shape of the output ``jax.Array``. This conveys information already included with ``sharding`` and ``arrays`` and serves as a double check. sharding: Sharding: A global Sharding instance which describes how the output jax.Array is laid out across devices. - arrays: Sequence of ``jax.Array``\s that are each single device addressable. ``len(arrays)`` + arrays: `list` or `tuple` of ``jax.Array``\s that are each single device addressable. ``len(arrays)`` must equal ``len(sharding.addressable_devices)`` and the shape of each array must be the same. For multiprocess code, each process will call with a different ``arrays`` argument that corresponds to that processes' data. These arrays are commonly created via ``jax.device_put``. @@ -1071,14 +1071,15 @@ def make_array_from_single_device_arrays( if dtypes.issubdtype(aval.dtype, dtypes.extended): return aval.dtype._rules.make_sharded_array(aval, sharding, arrays, committed=True) + arrays = list(arrays) if isinstance(arrays, tuple) else arrays # TODO(phawkins): ideally the cast() could be checked. try: return ArrayImpl(aval, sharding, cast(Sequence[ArrayImpl], arrays), committed=True) except TypeError: - if not isinstance(arrays, Sequence): + if not isinstance(arrays, list): raise TypeError("jax.make_array_from_single_device_arrays `arrays` " - "argument must be a Sequence (list or tuple), but got " + "argument must be a list or tuple, but got " f"{type(arrays)}.") if any(isinstance(arr, core.Tracer) for arr in arrays): raise ValueError( diff --git a/tests/array_test.py b/tests/array_test.py index cc8990828ded..6100283cc032 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -1301,6 +1301,18 @@ def f(x): with self.assertRaisesRegex(TypeError, msg): jax.jit(f)(x) + def test_make_array_from_single_device_arrays_tuple(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + shape = (8, 8) + s = jax.sharding.NamedSharding(mesh, P('x', 'y')) + inp_data = np.arange(math.prod(shape)).reshape(shape) + + arrays = tuple( + jax.device_put(inp_data[index], d) + for d, index in s.addressable_devices_indices_map(shape).items()) + + jax.make_array_from_single_device_arrays(shape, s, arrays) # doesn't crash + def test_make_array_from_single_device_arrays_bad_inputs(self): x = jnp.arange(10) mesh = jtu.create_mesh((2,), ('x',)) From d0b71fa1ceb11e9fbf89a8d0e4f6be47b80ab382 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Thu, 20 Mar 2025 14:04:48 -0700 Subject: [PATCH 063/483] [Mosaic GPU] Add preliminary TMEM allocation support for Pallas/Mosaic GPU. PiperOrigin-RevId: 738932990 --- jax/_src/pallas/mosaic_gpu/core.py | 3 + jax/_src/pallas/mosaic_gpu/lowering.py | 108 ++++++++++++++++++++++--- jax/experimental/mosaic/gpu/core.py | 2 + jax/experimental/pallas/mosaic_gpu.py | 2 + tests/pallas/mosaic_gpu_test.py | 29 +++++++ 5 files changed, 134 insertions(+), 10 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index b1e0a683f64d..1e4a9de1830c 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -101,6 +101,8 @@ class GPUMemorySpace(enum.Enum): GMEM = "gmem" #: Shared memory. SMEM = "smem" + #: Tensor memory. + TMEM = "tmem" #: Registers. REGS = "regs" @@ -452,6 +454,7 @@ def to_block_mapping( GMEM = GPUMemorySpace.GMEM SMEM = GPUMemorySpace.SMEM +TMEM = GPUMemorySpace.TMEM REGS = GPUMemorySpace.REGS diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 1fae91773178..ef4c80cb4649 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -59,6 +59,7 @@ from jax.experimental.mosaic.gpu import core as mgpu_core from jax.experimental.mosaic.gpu import profiler as mgpu_profiler from jax.experimental.mosaic.gpu import utils as mgpu_utils +from jax.experimental.mosaic.gpu import tcgen05 import jax.numpy as jnp import numpy as np @@ -100,6 +101,7 @@ def arrival_multiplier(self) -> int: @dataclasses.dataclass(kw_only=True, frozen=True) class Resources: smem_scratch_bytes: int = 0 + tmem_scratch_cols: int = 0 barrier_counts: collections.Counter[mgpu.Barrier] = dataclasses.field( default_factory=collections.Counter ) @@ -110,6 +112,12 @@ def __post_init__(self): "smem_scratch_bytes", _align_to(self.smem_scratch_bytes, _SMEM_ALIGNMENT), ) + object.__setattr__( + self, + "tmem_scratch_cols", + # TMEM must be allocated in 128x8 chunks. + _align_to(self.tmem_scratch_cols, 8), + ) @property def barriers(self) -> Sequence[mgpu.Barrier]: @@ -122,6 +130,7 @@ def __add__(self, other: Resources) -> Resources: # we will allocate two barriers, even though one would be enough. return Resources( smem_scratch_bytes=self.smem_scratch_bytes + other.smem_scratch_bytes, + tmem_scratch_cols=self.tmem_scratch_cols + other.tmem_scratch_cols, barrier_counts=self.barrier_counts + other.barrier_counts, ) @@ -130,6 +139,9 @@ def __or__(self, other: Resources) -> Resources: smem_scratch_bytes=max( self.smem_scratch_bytes, other.smem_scratch_bytes ), + tmem_scratch_cols=max( + self.tmem_scratch_cols, other.tmem_scratch_cols + ), barrier_counts=self.barrier_counts | other.barrier_counts, ) @@ -218,10 +230,26 @@ def _run_scoped_resource_estimator( ) ]) ) - else: + elif aval.memory_space == gpu_core.TMEM: + if aval.dtype.itemsize != 4: + raise ValueError("TMEM only supports 32-bit types.") + if len(aval.shape) != 2: + raise ValueError("TMEM allocations must be 2D.") + if aval.shape[0] % tcgen05.TMEM_ROWS != 0: + raise ValueError("TMEM shape[0] must be a multiple of 128.") + if aval.shape[1] % 8 != 0: + raise ValueError("TMEM shape[1] must be a multiple of 8.") + rs += Resources(tmem_scratch_cols=aval.shape[1]) + elif aval.memory_space == gpu_core.SMEM: rs += Resources( smem_scratch_bytes=math.prod(aval.shape) * aval.dtype.itemsize ) + elif aval.memory_space == gpu_core.REGS: + # Don't need to allocate anything. + pass + else: + raise NotImplementedError( + f"Unsupported memory space: {aval.memory_space}") return rs + _estimate_resources(ctx, jaxpr) @@ -267,6 +295,9 @@ class ModuleContext: single_wg_lane_predicate: ir.Value smem_requested_bytes: int smem_used_bytes: int + tmem_requested_cols: int + tmem_used_cols: int + tmem_base_ptr: ir.Value runtime_barriers: MutableMapping[ mgpu.Barrier, MutableSequence[mgpu.BarrierRef] ] @@ -286,6 +317,27 @@ def reserve_barrier(self, barrier: mgpu.Barrier) -> mgpu.BarrierRef: raise RuntimeError(f"Barrier {barrier} is already reserved") return available.pop() + @contextlib.contextmanager + def alloc_tmem( + self, + struct: jax.ShapeDtypeStruct, + layout: tcgen05.TMEMLayout | None = None + ) -> ir.Value: + if self.tmem_used_cols > 0: + raise NotImplementedError( + "Multiple TMEM allocations are not implemented.") + if layout is None: + layout = tcgen05._infer_tmem_layout(struct.shape, collective=False) + cols_used = np.prod(struct.shape) // tcgen05.TMEM_ROWS + self.tmem_used_cols += cols_used + off = self.tmem_base_ptr + tmem_ref = tcgen05.TMEMRef(address=off, + shape=struct.shape, + dtype=mgpu_utils.dtype_to_ir_type(struct.dtype), + layout=layout) + yield tmem_ref + self.tmem_used_cols -= cols_used + # TODO(cperivol): Only return the shapes and figure out the sizes when freeing. @contextlib.contextmanager def scratch_view( @@ -642,11 +694,15 @@ def lower_jaxpr_to_module( parallel_grid = (math.prod(grid[:-2]), *grid[-2:]) def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): - *buffers_gmem, (runtime_smem, runtime_barriers) = buffers + *buffers_gmem, (runtime_smem, runtime_barriers, runtime_tmem) = buffers grouped_barriers = collections.defaultdict(list) for barrier, barrier_ref in zip(rs.barriers, runtime_barriers): grouped_barriers[barrier].append(barrier_ref) + if runtime_tmem is not None: + tmem_cols = math.prod(runtime_tmem.shape) // tcgen05.TMEM_ROWS + else: + tmem_cols = 0 module_ctx = ModuleContext( mlir.sanitize_name(debug_info.func_name), axis_names, @@ -655,6 +711,9 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): mgpu.single_thread_predicate(per_block=False), smem_requested_bytes=math.prod(ir.MemRefType(runtime_smem.type).shape), smem_used_bytes=0, + tmem_requested_cols=tmem_cols, + tmem_used_cols=0, + tmem_base_ptr=runtime_tmem.address if runtime_tmem else None, runtime_barriers=grouped_barriers, name_stack=source_info_util.NameStack(), traceback_caches=mlir.TracebackCaches(), @@ -671,6 +730,18 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): smem_scratch_bytes = params.get("smem_scratch_bytes") if smem_scratch_bytes is None: smem_scratch_bytes = rs.smem_scratch_bytes + tmem_scratch_cols = rs.tmem_scratch_cols + + scratch_buffers = [ + jax.ShapeDtypeStruct(shape=[smem_scratch_bytes], dtype=np.int8), + rs.barriers, + ] + if tmem_scratch_cols > 0: + scratch_buffers.append( + mgpu.TMEM(shape=[tcgen05.TMEM_ROWS, tmem_scratch_cols], dtype=np.int32), + ) + else: + scratch_buffers.append(None) prof_ctx = prof_spec = None if prof_space := params.get("profile_space", 0): @@ -685,10 +756,7 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): block=block, in_shapes=in_shapes, out_shape=out_shapes, - smem_scratch_shape=( - jax.ShapeDtypeStruct(shape=[smem_scratch_bytes], dtype=np.int8), - rs.barriers, - ), + smem_scratch_shape=scratch_buffers, module_name=mlir.sanitize_name(debug_info.func_name), prof_spec=prof_spec, ) @@ -990,14 +1058,26 @@ def _ndindexer_indices(indexer: indexing.NDIndexer) -> tuple[gpu_core.Index, ... @register_lowering_rule(sp.get_p, mgpu.ThreadSemantics.Lane) -def _get_lowering_rule(ctx: LoweringRuleContext, x_smem, *leaves, tree): - if not isinstance(x_smem, ir.Value) and ir.MemRefType.isinstance(x_smem): - raise TypeError(f"Can only load from references (got {x_smem}).") +def _get_lowering_rule(ctx: LoweringRuleContext, x_ref, *leaves, tree): + if isinstance(x_ref, tcgen05.TMEMRef): + transforms = jax.tree.unflatten(tree, leaves) + if len(transforms) != 1 or not isinstance( + transforms[0], indexing.NDIndexer): + raise NotImplementedError( + "Only a single indexing transform is supported for TMEM refs.") + indexer = cast(indexing.NDIndexer, transforms[0]) + if not gpu_core.is_trivial_index(indexer.indices, x_ref.shape): + raise NotImplementedError( + "Only trivial indexing is supported for TMEM refs.") + return x_ref[:] + + if not isinstance(x_ref, ir.Value) and ir.MemRefType.isinstance(x_ref): + raise TypeError(f"Can only load from references (got {x_ref}).") x_aval = ctx.avals_in[0] transforms = jax.tree.unflatten(tree, leaves) - x_smem, transforms = _handle_reshaping(x_smem, transforms) + x_smem, transforms = _handle_reshaping(x_ref, transforms) x_smem, transforms = _handle_indexing(x_smem, transforms) match transforms: @@ -1784,6 +1864,14 @@ def _run_scoped_lowering_rule( ) input_refs.append(input_ref) should_discharge.append(False) + elif aval.memory_space == gpu_core.TMEM: + input_ref = alloc_stack.enter_context( + ctx.module_ctx.alloc_tmem( + jax.ShapeDtypeStruct(shape=aval.shape, dtype=aval.dtype), + ) + ) + input_refs.append(input_ref) + should_discharge.append(False) else: raise ValueError(f"Can't convert to ref: {aval}") diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index b255893e2e2e..f5331eb1b56a 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -307,6 +307,8 @@ def _smem_tree_size(smem_buffers: ShapeTree) -> int: raise NotImplementedError("Misaligned barrier allocation") size += num_barriers * utils.MBARRIER_BYTES case TMEM(_): + # TODO(justinfu): This can trigger misaligned barrier allocations + # if TMEM is requested before barriers b/c it's not divisible by 8. size += 4 # i32 takes up 4 bytes case _: size += _count_buffer_bytes(l) diff --git a/jax/experimental/pallas/mosaic_gpu.py b/jax/experimental/pallas/mosaic_gpu.py index 631b4f720984..aab58d092190 100644 --- a/jax/experimental/pallas/mosaic_gpu.py +++ b/jax/experimental/pallas/mosaic_gpu.py @@ -51,3 +51,5 @@ GMEM = GPUMemorySpace.GMEM #: Alias of :data:`jax.experimental.pallas.mosaic_gpu.GPUMemorySpace.SMEM`. SMEM = GPUMemorySpace.SMEM +#: Alias of :data:`jax.experimental.pallas.mosaic_gpu.GPUMemorySpace.TMEM`. +TMEM = GPUMemorySpace.TMEM diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 27017e4eb740..94c2620f7ae6 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -83,6 +83,13 @@ def setUp(self): super().setUp() +class PallasSm100ATest(PallasTest, jtu.CudaArchSpecificTest): + + def setUp(self): + self.skip_unless_sm100a() + super().setUp() + + class PallasCallTest(PallasTest): @parameterized.product( @@ -1531,6 +1538,28 @@ def scope(acc_ref): np.testing.assert_allclose(res, a @ b, rtol=1e-3) +class PallasCallSm100ATest(PallasSm100ATest): + + def test_tmem_alloc(self): + mesh = plgpu.GPUMesh(num_threads=1, axis_names=("x")) + @pl.run_state + def inner(y_ref): + @pl.core_map(mesh) + def _(): + def scope(tmem_ref, smem_ref): + # Issue a write so the TMEM load is not DCE'd. + smem_ref[...] = tmem_ref[...] + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(smem_ref, y_ref) + plgpu.wait_smem_to_gmem(0) + pl.run_scoped(scope, + plgpu.TMEM((128, 128), jnp.float32), + plgpu.SMEM((128, 128), jnp.float32)) + y_init = jnp.zeros((128, 128), np.float32) + # Test that this runs without errors. + jax.block_until_ready(inner(y_init)) + + class PipelineTest(PallasTest): def test_pipeline_mode(self): From 55b55e6b1b64c24c3dd87274427594fc56e8f6e0 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 10 Mar 2025 17:00:44 +0100 Subject: [PATCH 064/483] Enable multi-threading in Jax Context with shared thread pool --- jax/_src/interpreters/mlir.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 47d3fc52de26..f96f07be4149 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -593,6 +593,15 @@ def module_to_bytecode(module: ir.Module) -> bytes: # Translation rules +# Create one global thread pool that can be shared between multiple ir.Contexts +# and enabling multi-threading +# TODO: remove this check after jaxlib 0.5.4 +if hasattr(ir, "ThreadPool"): + global_thread_pool = ir.ThreadPool() +else: + global_thread_pool = None + + class JaxIrContext(ir.Context): def __init__(self, *args, **kwargs): # Note: we're very intentionally *not* calling the __init__() of our @@ -607,12 +616,16 @@ def make_ir_context() -> ir.Context: context.append_dialect_registry(upstream_dialects) context.load_all_available_dialects() - # If threading is enabled, each MLIR context will keep alive a thread pool. - # Since we cache MLIR modules (and hence contexts), this means we might keep - # several threads alive for each cache entry. This is a terrible idea. However - # we don't do any heavy computation on MLIR modules from Python anyway, so we - # just disable threading. - context.enable_multithreading(False) + # TODO: remove this check after v0.5.4 jaxlib + if global_thread_pool is not None: + context.set_thread_pool(global_thread_pool) + else: + # If threading is enabled, each MLIR context will keep alive a thread pool. + # Since we cache MLIR modules (and hence contexts), this means we might keep + # several threads alive for each cache entry. This is a terrible idea. However + # we don't do any heavy computation on MLIR modules from Python anyway, so we + # just disable threading. + context.enable_multithreading(False) # TODO(bartchr): Once JAX is released with SDY, remove the if. if dialects.sdy: dialects.sdy.register_dialect(context) From a8fb0e01f8d083fff337d3c26375bb1b77344a99 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 20 Mar 2025 15:22:31 -0700 Subject: [PATCH 065/483] [sharding_in_types] Fix a dynamic_slice bug where in the transpose, `DUS`'s operand was not sharded properly PiperOrigin-RevId: 738959282 --- jax/_src/lax/slicing.py | 3 ++- tests/pjit_test.py | 24 ++++++++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index c26de99c7374..b4a8817fbb8d 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -1476,7 +1476,8 @@ def _dynamic_slice_transpose_rule(t, operand, *start_indices, slice_sizes): if type(t) is ad_util.Zero: return [ad_util.Zero(operand.aval)] + [None] * len(start_indices) else: - zeros = lax.full(operand_shape, 0, operand_dtype) + zeros = lax.full(operand_shape, 0, operand_dtype, + sharding=operand.aval.sharding) return ([dynamic_update_slice_p.bind(zeros, t, *start_indices)] + [None] * len(start_indices)) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 6cf11494988a..6a1a73fe4301 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7090,6 +7090,30 @@ def f(x): self.assertEqual(out.shape, expected_shape) self.assertEqual(out.sharding, NamedSharding(mesh, expected_spec)) + @jtu.with_user_mesh((2,), ('x',)) + def test_dynamic_slice(self, mesh): + np_inp = np.arange(16., dtype=np.float32) + s = NamedSharding(mesh, P('x')) + arr = jax.device_put(np_inp, s) + + @jax.jit + def f(x): + y = lax.dynamic_slice_in_dim(x, jnp.array(1, dtype=np.int32), 2) + self.assertEqual(y.aval.sharding.spec, P('x')) + return y + + out = f(arr) + self.assertEqual(out.sharding, s) + + def g(x): + return jnp.sum(f(x)) + + out = jax.jit(jax.grad(g))(arr) + self.assertEqual(out.sharding, arr.sharding) + + out = jax.grad(g)(arr) + self.assertEqual(out.sharding, arr.sharding) + def test_auto_axes_computation_follows_data_error(self): mesh = jtu.create_mesh((2,), ('x',), axis_types=(AxisType.Explicit,)) s = NamedSharding(mesh, P('x')) From 1fe24ca7552576d59649f335cf5878c069d180f8 Mon Sep 17 00:00:00 2001 From: shuw Date: Thu, 20 Mar 2025 23:26:21 +0000 Subject: [PATCH 066/483] Improve based on review 1 --- jax/_src/cudnn/scaled_matmul_stablehlo.py | 15 +++-- tests/scaled_matmul_stablehlo_test.py | 75 +++++++++++++++++++---- 2 files changed, 71 insertions(+), 19 deletions(-) diff --git a/jax/_src/cudnn/scaled_matmul_stablehlo.py b/jax/_src/cudnn/scaled_matmul_stablehlo.py index ddcde6a95b26..b1d353e7bcd1 100644 --- a/jax/_src/cudnn/scaled_matmul_stablehlo.py +++ b/jax/_src/cudnn/scaled_matmul_stablehlo.py @@ -489,8 +489,6 @@ def quantize(x, config): assert config.scale_type == jnp.float8_e8m0fnu scales_q = cast_to_e8m0_with_rounding_up(scales) scaled_x = x / e8m0_to_dtype(scales_q, scales.dtype) - clipped_x = jnp.clip(scaled_x, -MAX, MAX) - x_q = clipped_x.astype(config.data_type) elif config.mode == "nvfp4": assert config.scale_type == jnp.float8_e4m3fn assert config.global_scale.dtype == jnp.float32 @@ -499,15 +497,15 @@ def quantize(x, config): prev_amax = config.global_scale * (MAX * SCALE_MAX) scales_q = jnp.clip( (amax / prev_amax) * SCALE_MAX, 0, SCALE_MAX - ).astype(config.scale_type) - x_q = jnp.where( - amax <= prev_amax, - (x * MAX) / amax, - jnp.clip((x * MAX) / prev_amax, -MAX, MAX), - ).astype(config.data_type) + ) + scaled_x = x / scales_q + scales_q = scales_q.astype(config.scale_type) else: raise ValueError(f"Unrecognized mode: {config.mode}.") + clipped_x = jnp.clip(scaled_x, -MAX, MAX) + x_q = clipped_x.astype(config.data_type) + x_q = x_q.reshape(x_shape) # shape = (B, M, K) scales_q = jnp.reshape(scales_q, scales_q.shape[:-1]).view( config.scale_type @@ -652,6 +650,7 @@ def scaled_dot_bwd(dimension_numbers, preferred_element_type, configs, res, g): SCALE_MAX = jnp.finfo(configs[0].scale_type).max.astype(lhs.dtype) grad_lhs = jnp.where(jnp.abs(lhs) <= configs[0].global_scale * MAX * SCALE_MAX, grad_lhs, 0) grad_rhs = jnp.where(jnp.abs(rhs) <= configs[1].global_scale * MAX * SCALE_MAX, grad_rhs, 0) + return (grad_lhs, grad_rhs) diff --git a/tests/scaled_matmul_stablehlo_test.py b/tests/scaled_matmul_stablehlo_test.py index 224f6b6204e5..b53ffcd5b977 100644 --- a/tests/scaled_matmul_stablehlo_test.py +++ b/tests/scaled_matmul_stablehlo_test.py @@ -174,7 +174,7 @@ def update_global_scale(config, new_global_scale): config.global_scale = new_global_scale return config -def generate_nvfp4_quantized_tensors(dot_config, output_type): +def generate_nvfp4_quantized_tensors(dot_config, output_type, enable_grad_clip=False): k1, k2 = jax.random.split(jax.random.key(0), 2) a_shape, b_shape, dimension_numbers = dot_config @@ -195,7 +195,7 @@ def generate_nvfp4_quantized_tensors(dot_config, output_type): amax_b = jnp.max(jnp.abs(b)).astype(jnp.float32) # To emulate calibrated amax - amax_sf = 0.9 + amax_sf = 0.9 if enable_grad_clip else 1.0 amax_a *= amax_sf amax_b *= amax_sf @@ -513,6 +513,68 @@ def fn(a): self.assertArraysAllClose(out_q, out_q_ref, rtol=1e-5, atol=1e-5) self.assertArraysAllClose(scale, scale_ref, rtol=1e-5, atol=1e-5) + @jtu.sample_product( + enable_grad_clip=[True, False], + configs=[ + # a_shape, b_shape, dimension_numbers + ((1, 128, 128), (1, 128, 128), (([2], [2]), ([0], [0]))), + ((30, 64), (100, 64), (([1], [1]), ([], []))), + ] + ) + @jtu.run_on_devices("cuda") + def test_nvfp4_gradient_clip(self, enable_grad_clip, configs): + output_type = jnp.float32 + (a_raw, b_raw), (a_dq, b_dq), _, block_scale_configs = ( + generate_nvfp4_quantized_tensors(configs, output_type, enable_grad_clip) + ) + a_gs = block_scale_configs[0].global_scale + b_gs = block_scale_configs[1].global_scale + dimension_numbers = configs[2] + + scaled_dot_general = partial( + scaled_dot_general_wrapper, + configs=block_scale_configs + ) + + def fwd(a, b, use_normalized=False): + y = scaled_dot_general( + a, b, dimension_numbers, + preferred_element_type=output_type + ) + return jnp.sum(y) + + j_train = jax.jit(jax.value_and_grad(fwd, argnums=[0, 1])) + _, (x_grad, w_grad) = j_train(a_raw, b_raw) + + data_max = jnp.finfo(jnp.float4_e2m1fn).max.astype(output_type) + scale_max = jnp.finfo(jnp.float8_e4m3fn).max.astype(output_type) + prev_amax_a = a_gs * data_max * scale_max + prev_amax_b = b_gs * data_max * scale_max + + # Use a large value to ensure no clipping + threshold_a = prev_amax_a if enable_grad_clip else 1e9 + threshold_b = prev_amax_b if enable_grad_clip else 1e9 + + # Verify gradients are clipped to 0 where |input| > global_scale * MAX * SCALE_MAX + self.assertArraysEqual( + jnp.where(jnp.abs(a_raw) > threshold_a, x_grad, 0), + jnp.zeros_like(x_grad), + ) + self.assertArraysEqual( + jnp.where(jnp.abs(b_raw) > threshold_b, w_grad, 0), + jnp.zeros_like(w_grad), + ) + if enable_grad_clip: + # Verify gradients are preserved where |input| <= global_scale * MAX * SCALE_MAX + self.assertArraysEqual( + jnp.where(jnp.abs(a_raw) <= prev_amax_a, x_grad, 0), + x_grad, + ) + self.assertArraysEqual( + jnp.where(jnp.abs(b_raw) <= prev_amax_b, w_grad, 0), + w_grad, + ) + @jtu.sample_product( configs=[ # a_shape, b_shape, dimension_numbers, is_training @@ -584,15 +646,6 @@ def _grad_clip(amax, x, grad): w_grad_ref = _grad_clip(prev_amax_b, b_raw, w_grad_ref) self.assertArraysAllClose(x_grad, x_grad_ref, rtol=1e-2, atol=1e1) self.assertArraysAllClose(w_grad, w_grad_ref, rtol=1e-2, atol=1e1) - # Verify straight_through_estimator - self.assertArraysEqual( - jnp.where(jnp.abs(a_raw) > prev_amax_a, x_grad, 0), - jnp.zeros_like(x_grad) - ) - self.assertArraysEqual( - jnp.where(jnp.abs(b_raw) > prev_amax_b, w_grad, 0), - jnp.zeros_like(w_grad) - ) else: j_inference = jax.jit(fwd) j_inference_ref = jax.jit(partial(fwd, is_ref=True, use_normalized=True)) From 0eb430c128cfe9448971441ec2fef61e13548592 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 20 Mar 2025 23:26:37 +0100 Subject: [PATCH 067/483] Increased test timeout in TSAN CI Description: - Increased test timeout in TSAN CI - Skip tests: testMishGrad and testSquareplusGrad --- .github/workflows/tsan.yaml | 2 +- tests/nn_test.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/tsan.yaml b/.github/workflows/tsan.yaml index 7d93707e4e92..6c97b7347ceb 100644 --- a/.github/workflows/tsan.yaml +++ b/.github/workflows/tsan.yaml @@ -210,7 +210,7 @@ jobs: --test_env=JAX_TEST_NUM_THREADS=8 \ --test_output=errors \ --local_test_jobs=32 \ - --test_timeout=600 \ + --test_timeout=1800 \ --config=resultstore \ --config=rbe_cache \ //tests:cpu_tests diff --git a/tests/nn_test.py b/tests/nn_test.py index ed016ec349ef..1a1670444ef8 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -422,6 +422,7 @@ def testSparseplusAndSparseSigmoid(self): jax.grad(nn.sparse_plus)(-2.), nn.sparse_sigmoid(-2.), check_dtypes=False) + @jtu.skip_on_flag("jax_skip_slow_tests", True) def testSquareplusGrad(self): check_grads(nn.squareplus, (1e-8,), order=4, rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None) @@ -442,6 +443,7 @@ def testSquareplusGradNan(self): def testSquareplusZero(self, dtype): self.assertEqual(dtype(1), nn.squareplus(dtype(0), dtype(4))) + @jtu.skip_on_flag("jax_skip_slow_tests", True) def testMishGrad(self): check_grads(nn.mish, (1e-8,), order=4, rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None) From 4b7ead4d02f866077f11dcfcca7507533a441bcc Mon Sep 17 00:00:00 2001 From: shuw Date: Wed, 19 Feb 2025 22:51:35 +0000 Subject: [PATCH 068/483] Bump ml_dtypes>=0.5.0 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index dbb7040d2d2b..bdaeb624bf38 100644 --- a/setup.py +++ b/setup.py @@ -55,7 +55,7 @@ def load_version_module(pkg_path): python_requires='>=3.10', install_requires=[ f'jaxlib >={_minimum_jaxlib_version}, <={_jax_version}', - 'ml_dtypes>=0.4.0', + 'ml_dtypes>=0.5.0', 'numpy>=1.25', "numpy>=1.26.0; python_version>='3.12'", 'opt_einsum', From c7d6b653cea5c4144bed123d5f9ff1fa4b668e73 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 20 Mar 2025 22:18:18 -0700 Subject: [PATCH 069/483] [sharding_in_types] Add `core.ShardingTypeError` as a new Exception that are sharding-in-types specific errors should raise. This is so that we can catch this exception in backward_pass/vmap and add extra message to inform users that this is a potential JAX bug. They should file an issue on the repo. Currently, we only raise `ShardingTypeError` in one place, but we can expand to all other places in follow up changes. This change sets the machinery up. Previous error: ``` jax._src.core.ShardingTypeError: dynamic_update_slice update sharding must be equal to operand sharding, got update sharding float32[2@x]({Explicit: ('x',)}) for operand sharding float32[16]({}). ``` New error: ``` jax._src.core.ShardingTypeError: dynamic_update_slice update sharding must be equal to operand sharding, got update sharding float32[2@x]({Explicit: ('x',)}) for operand sharding float32[16]({}). This is a potential JAX bug. Please file an issue at https://github.com/jax-ml/jax/issues ``` The new added message of `This is a potential JAX bug...` is important because this error is raised in the backward pass which is 100% a JAX bug given that forward pass did not error. PiperOrigin-RevId: 739053305 --- jax/_src/core.py | 4 ++++ jax/_src/interpreters/ad.py | 6 ++++++ jax/_src/lax/slicing.py | 2 +- jax/_src/lax/utils.py | 2 +- 4 files changed, 12 insertions(+), 2 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index 36ce2f004ed4..243ffc871042 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1786,6 +1786,10 @@ def _invalid_shape_error(shape: Shape, context: str=""): return TypeError(msg) +class ShardingTypeError(Exception): + pass + + # TODO(dougalm): Cast scalar, numpy arrays, etc to jax arrays so that values # passed to primitives are always have avals, etc i.e. they are canonical. def canonicalize_value(val): diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index ddf96af6a010..e47e518a11f2 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -407,6 +407,12 @@ def write_primal(v, val): try: cts_out = get_primitive_transpose(eqn.primitive)( cts_in, *invals, **eqn.params) + except core.ShardingTypeError as e: + extra_msg = ("This is a potential JAX bug. Please file an issue at" + " https://github.com/jax-ml/jax/issues") + if extra_msg in str(e): + raise + raise core.ShardingTypeError(f"{str(e)}\n{extra_msg}") except (FloatingPointError, ZeroDivisionError) as e: msg = "When differentiating the code at the top of the callstack:" if msg not in e.args[0]: diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index b4a8817fbb8d..b3a0a8e2d0c1 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -1607,7 +1607,7 @@ def _dynamic_update_slice_shape_rule(operand, update, *start_indices): def _dynamic_update_slice_sharding_rule(operand, update, *start_indices): if operand.sharding != update.sharding: - raise TypeError( + raise core.ShardingTypeError( "dynamic_update_slice update sharding must be equal to operand" " sharding, got update sharding" f" {update.str_short(mesh_axis_types=True)} for operand sharding" diff --git a/jax/_src/lax/utils.py b/jax/_src/lax/utils.py index f39d925ac2ad..9fc9ba16a604 100644 --- a/jax/_src/lax/utils.py +++ b/jax/_src/lax/utils.py @@ -74,7 +74,7 @@ def call_sharding_rule(prim, rule, num_out, *avals, **kwargs): s = NamedSharding(aval_mesh, P()) return s if num_out is None else [s] * num_out if rule is None: - raise ValueError( + raise core.ShardingTypeError( f'sharding rule for {prim.name} is not implemented. Please file a' ' bug at https://github.com/jax-ml/jax/issues. You can work around' ' this error by dropping that operation into full auto sharding' From 7953e6d0f88c068e0af9f38c0dd0c0c3ce05688a Mon Sep 17 00:00:00 2001 From: Zac Mustin Date: Fri, 21 Mar 2025 00:35:54 -0700 Subject: [PATCH 070/483] Add tests for varying `{batch, feature}_group_count`s for roofline `conv`. We'll need to use batch/feature when calculating flops, so it'll help reduce the size of the "calculating-flops" change if we can include them in our tests now. PiperOrigin-RevId: 739081930 --- tests/roofline_test.py | 88 ++++++++++++++++++++++++++++-------------- 1 file changed, 58 insertions(+), 30 deletions(-) diff --git a/tests/roofline_test.py b/tests/roofline_test.py index 564b4a9a1f9e..f94f5a328c46 100644 --- a/tests/roofline_test.py +++ b/tests/roofline_test.py @@ -572,40 +572,63 @@ def test_dot_general(self): result.unfused_hbm_bytes, self._bytes_per_word * (3 * 7 + 7 * 5 + 3 * 5) ) - def get_conv_output_dim(self, i, k, pad_low, pad_high, stride): + def get_conv_output_dim(self, i, k, pad_low, pad_high, stride) -> int: return jnp.floor((i - k + pad_low + pad_high) / stride) + 1 - @jtu.parameterized.named_parameters( - dict( - testcase_name="simple", - window_strides=(1, 1), - padding=((0, 0), (0, 0)), - ), - dict( - testcase_name="padding", - window_strides=(1, 1), - padding=((1, 2), (3, 4)), - ), - dict( - testcase_name="window_strides", - window_strides=(2, 2), - padding=((0, 0), (0, 0)), - ), - dict( - testcase_name="window_strides_and_padding", - window_strides=(3, 3), - padding=((1, 2), (3, 4)), - ), + def get_conv_num_output_channels( + self, batch_group_count: int, feature_group_count: int + ) -> int: + if batch_group_count > 1: + return batch_group_count + elif feature_group_count > 1: + return feature_group_count + else: + return 1 + + @jtu.parameterized.product( + window_strides=[(1, 1), (2, 2)], + padding=[((0, 0), (0, 0)), ((1, 2), (3, 4))], + # batch must be divisible by batch_group_count, so we only include factors + # of batch_group_count. + batch=[6, 12], + batch_group_count=[1, 3], + # num_input_channels must be divisible by feature_group_count, so we only + # include factors of feature_group_count. + num_input_channels=[6, 12], + feature_group_count=[1, 3], ) def test_conv_general_dilated_unfused_hbm_bytes( - self, window_strides: Sequence[int, int], padding: Sequence[int, int] + self, + window_strides: Sequence[int, int], + padding: Sequence[int, int], + batch: int, + batch_group_count: int, + num_input_channels: int, + feature_group_count: int, ): + if batch_group_count > 1 and feature_group_count > 1: + self.skipTest( + "batch_group_count and feature_group_count cannot both be > 1" + ) + + num_output_channels = self.get_conv_num_output_channels( + batch_group_count, feature_group_count + ) + + num_input_features = int(num_input_channels / feature_group_count) iw, ih = 100, 200 kw, kh = 7, 7 - input_data = jnp.zeros((1, 1, iw, ih), dtype=int) - kernel_data = jnp.ones((1, 1, kw, kh), dtype=int) + input_data = jnp.zeros((batch, num_input_channels, iw, ih), dtype=int) + kernel_data = jnp.ones( + (num_output_channels, num_input_features, kw, kh), dtype=int + ) conv = lambda a, b: lax.conv_general_dilated( - lhs=a, rhs=b, window_strides=window_strides, padding=padding + lhs=a, + rhs=b, + window_strides=window_strides, + padding=padding, + batch_group_count=batch_group_count, + feature_group_count=feature_group_count, ) _, result = roofline.roofline( @@ -615,8 +638,8 @@ def test_conv_general_dilated_unfused_hbm_bytes( out_specs=P(), )(input_data, kernel_data) - expected_input_size = 1 * 1 * iw * ih - expected_kernel_size = 1 * 1 * kw * kh + expected_input_size = batch * num_input_channels * iw * ih + expected_kernel_size = num_output_channels * num_input_features * kw * kh ow = self.get_conv_output_dim( iw, kw, padding[0][0], padding[0][1], window_strides[0] @@ -624,7 +647,10 @@ def test_conv_general_dilated_unfused_hbm_bytes( oh = self.get_conv_output_dim( ih, kh, padding[1][0], padding[1][1], window_strides[1] ) - expected_output_size = 1 * 1 * ow * oh + expected_output_shape = jnp.array( + (batch / batch_group_count, num_output_channels, ow, oh) + ) + expected_output_size = jnp.prod((expected_output_shape)) # Bytes accessed is sum of inputs and output. expected_unfused_hbm_bytes = self._bytes_per_word * ( expected_input_size + expected_kernel_size + expected_output_size @@ -642,7 +668,9 @@ def test_conv_general_dilated_unfused_hbm_bytes( padding="SAME_LOWER", ), ) - def test_conv_general_dilated_padding_string_unfused_hbm_bytes(self, padding: str): + def test_conv_general_dilated_padding_string_unfused_hbm_bytes( + self, padding: str + ): input_data = jnp.zeros((1, 1, 10, 20), dtype=int) kernel_data = jnp.ones((1, 1, 3, 3), dtype=int) conv = lambda a, b: lax.conv_general_dilated( From ad21b62bfec5560d4c612ed3c8412eb2d240468b Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 21 Mar 2025 02:41:42 -0700 Subject: [PATCH 071/483] [AutoPGLE] Prevent an AutoPGLE to run if user launched an external profiler. PiperOrigin-RevId: 739109278 --- tests/pgle_test.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/tests/pgle_test.py b/tests/pgle_test.py index 7f9ea598d51b..7dabd809d95e 100644 --- a/tests/pgle_test.py +++ b/tests/pgle_test.py @@ -65,7 +65,11 @@ def testPGLEProfilerGetFDOProfile(self): jax.jit, in_shardings=NamedSharding(mesh, PartitionSpec('x')), out_shardings=NamedSharding(mesh, PartitionSpec('x')), - compiler_options={'xla_gpu_enable_latency_hiding_scheduler': 'True'}, + compiler_options={ + 'xla_gpu_enable_latency_hiding_scheduler': 'True', + # Make sure that matmul is not emitted as Triton GEMM. + 'xla_gpu_enable_triton_gemm': 'False', + }, ) def f(x, y): return x @ y @@ -93,6 +97,8 @@ def testPGLEProfilerGetFDOProfileLarge(self): compiler_options = { 'xla_gpu_enable_latency_hiding_scheduler': 'True', + # Make sure that matmul is not emitted as Triton GEMM. + 'xla_gpu_enable_triton_gemm': 'False', } # TODO(b/37664749): Remove this flag once the bug is fixed. compiler_options['xla_gpu_enable_command_buffer'] = '' @@ -321,7 +327,11 @@ def testPassingFDOProfile(self): jax.jit, in_shardings=NamedSharding(mesh, PartitionSpec('x')), out_shardings=NamedSharding(mesh, PartitionSpec('x')), - compiler_options={'xla_gpu_enable_latency_hiding_scheduler': 'True'}, + compiler_options={ + 'xla_gpu_enable_latency_hiding_scheduler': 'True', + # Make sure that matmul is not emitted as Triton GEMM. + 'xla_gpu_enable_triton_gemm': 'False', + }, ) def f(x, y): return x @ y From 5fef4cff7a37d0bb4d7004741189880b357699a2 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 21 Mar 2025 03:41:01 -0700 Subject: [PATCH 072/483] Update XLA dependency to use revision http://github.com/openxla/xla/commit/469329ec36be093fd71d29e4518402300e04aeec. PiperOrigin-RevId: 739121877 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index b20048193bd1..00f985cdf352 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "6e2c9b024cec7dca4b2e1b07cc89373574c9c5af" -XLA_SHA256 = "387917467d6f6e8358d54ba2b89f3fef14a00e62d8b0a096bf07acc8186444d4" +XLA_COMMIT = "469329ec36be093fd71d29e4518402300e04aeec" +XLA_SHA256 = "9de006d7b51c36057898c81111fa9723b59f024eced067572fe5f6b1df63abdd" def repo(): tf_http_archive( From be5713309521d5cf0d2252b9c8f1d38ab50952d1 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 21 Mar 2025 05:18:03 -0700 Subject: [PATCH 073/483] Delay the unflattening in `jnp.array` PiperOrigin-RevId: 739143346 --- jax/_src/numpy/lax_numpy.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 96efc48062e1..16355695792d 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -49,15 +49,16 @@ from jax._src.lax.lax import (PrecisionLike,_array_copy, _sort_le_comparator, _sort_lt_comparator) from jax._src.lib import xla_client as xc -from jax._src.numpy.array_creation import (empty, empty_like, full, - ones, ones_like, zeros, zeros_like) from jax._src.numpy import indexing from jax._src.numpy import reductions from jax._src.numpy import tensor_contractions from jax._src.numpy import ufuncs from jax._src.numpy import util +from jax._src.numpy.array_creation import (empty, empty_like, full, + ones, ones_like, zeros, zeros_like) from jax._src.numpy.sorting import argsort, sort from jax._src.numpy.vectorize import vectorize +from jax._src.sharding_impls import SingleDeviceSharding from jax._src.typing import ( Array, ArrayLike, DType, DTypeLike, DeprecatedArg, DimSize, Shape ) @@ -65,8 +66,7 @@ NumpyComplexWarning, canonicalize_axis as _canonicalize_axis, ceil_of_ratio, safe_zip, set_module, unzip2) from jax.sharding import Sharding -from jax._src.sharding_impls import SingleDeviceSharding -from jax.tree_util import tree_leaves, tree_map +from jax.tree_util import tree_flatten, tree_map import numpy as np export = set_module('jax.numpy') @@ -5504,9 +5504,7 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, object = xc._xla.cuda_array_interface_to_buffer( cai=cai, gpu_backend=backend, device_id=device_id) - object = tree_map(lambda leaf: leaf.__jax_array__() - if hasattr(leaf, "__jax_array__") else leaf, object) - leaves = tree_leaves(object, is_leaf=lambda x: x is None) + leaves, treedef = tree_flatten(object, is_leaf=lambda x: x is None) if any(leaf is None for leaf in leaves): # Added Nov 16 2023 if deprecations.is_accelerated("jax-numpy-array-none"): @@ -5515,7 +5513,13 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, "None encountered in jnp.array(); this is currently treated as NaN. " "In the future this will result in an error.", FutureWarning, stacklevel=2) - leaves = tree_leaves(object) + leaves, treedef = tree_flatten(object) + leaves = [ + leaf + if (leaf_jax_array := getattr(leaf, "__jax_array__", None)) is None + else leaf_jax_array() + for leaf in leaves + ] if dtype is None: # Use lattice_result_type rather than result_type to avoid canonicalization. # Otherwise, weakly-typed inputs would have their dtypes canonicalized. @@ -5530,8 +5534,8 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, if not weak_type: dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True) # type: ignore[assignment] + object = treedef.unflatten(leaves) out: ArrayLike - if all(not isinstance(leaf, Array) for leaf in leaves): # TODO(jakevdp): falling back to numpy here fails to overflow for lists # containing large integers; see discussion in From 7f0f185abd84b9b704d64f89c7fce0236b7c3403 Mon Sep 17 00:00:00 2001 From: Arno Eigenwillig Date: Fri, 21 Mar 2025 05:56:03 -0700 Subject: [PATCH 074/483] In JEP-12049, fix link to EAFP in the Python glossary: the anchor became mixed-case as of Python 3.10. PiperOrigin-RevId: 739150752 --- docs/jep/12049-type-annotations.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/jep/12049-type-annotations.md b/docs/jep/12049-type-annotations.md index 7a20958c5cab..5ed760dd6c5c 100644 --- a/docs/jep/12049-type-annotations.md +++ b/docs/jep/12049-type-annotations.md @@ -122,7 +122,7 @@ All told, the array-type-granularity challenge is less of an issue than the othe ### Challenge 5: imprecise APIs inherited from NumPy A large part of JAX’s user-facing API is inherited from NumPy within the {mod}`jax.numpy` submodule. -NumPy’s API was developed years before static type checking became part of the Python language, and follows Python’s historic recommendations to use a [duck-typing](https://docs.python.org/3/glossary.html#term-duck-typing)/[EAFP](https://docs.python.org/3/glossary.html#term-eafp) coding style, in which strict type-checking at runtime is discouraged. As a concrete example of this, consider the {func}`numpy.tile` function, which is defined like this: +NumPy’s API was developed years before static type checking became part of the Python language, and follows Python’s historic recommendations to use a [duck-typing](https://docs.python.org/3/glossary.html#term-duck-typing)/[EAFP](https://docs.python.org/3/glossary.html#term-EAFP) coding style, in which strict type-checking at runtime is discouraged. As a concrete example of this, consider the {func}`numpy.tile` function, which is defined like this: ```python def tile(A, reps): From a93035f6250672230675290af82a829f0b0dd862 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 21 Mar 2025 06:25:37 -0700 Subject: [PATCH 075/483] Migrate xla_client and its Python tests out of XLA into JAX. This change copies targets into jaxlib, and a subsequent change will delete them from XLA. We separate these into two phases because we cannot atomically change both JAX and XLA. Future changes will migrate more of the C++ pieces of XLA:Python. PiperOrigin-RevId: 739158120 --- jax/_src/lib/BUILD | 3 +- jax/_src/lib/__init__.py | 22 +- jaxlib/BUILD | 4 +- jaxlib/jax.bzl | 3 + jaxlib/xla/BUILD | 162 + jaxlib/xla/config_test.py | 71 + jaxlib/xla/custom_calls_testlib.cc | 128 + jaxlib/xla/jax_jit_test.py | 47 + jaxlib/xla/pytree_test.py | 144 + jaxlib/xla/weakref_lru_cache_test.py | 257 ++ jaxlib/xla/xla_client.py | 1044 +++++ jaxlib/xla/xla_client.pyi | 322 ++ .../xla_client_backend_independent_test.py | 195 + jaxlib/xla/xla_client_test.py | 3714 +++++++++++++++++ pyproject.toml | 5 +- 15 files changed, 6106 insertions(+), 15 deletions(-) create mode 100644 jaxlib/xla/BUILD create mode 100644 jaxlib/xla/config_test.py create mode 100644 jaxlib/xla/custom_calls_testlib.cc create mode 100644 jaxlib/xla/jax_jit_test.py create mode 100644 jaxlib/xla/pytree_test.py create mode 100644 jaxlib/xla/weakref_lru_cache_test.py create mode 100644 jaxlib/xla/xla_client.py create mode 100644 jaxlib/xla/xla_client.pyi create mode 100644 jaxlib/xla/xla_client_backend_independent_test.py create mode 100644 jaxlib/xla/xla_client_test.py diff --git a/jax/_src/lib/BUILD b/jax/_src/lib/BUILD index 1fcbd4b6b7ef..1f4f41132e9e 100644 --- a/jax/_src/lib/BUILD +++ b/jax/_src/lib/BUILD @@ -44,6 +44,7 @@ py_library_providing_imports_info( "//jaxlib/mosaic/python:tpu_dialect", "//jaxlib:cpu_feature_guard", "//jaxlib:utils", + "//jaxlib/xla:xla_client", "//jaxlib/triton", "//jaxlib/mlir/_mlir_libs:register_jax_dialects", "//jaxlib/mlir:arithmetic_dialect", @@ -60,6 +61,6 @@ py_library_providing_imports_info( "//jaxlib/mlir:sparse_tensor_dialect", "//jaxlib/mlir:stablehlo_dialect", "//jaxlib/mlir:vector_dialect", - # xla_client + # xla_extension ]), ) diff --git a/jax/_src/lib/__init__.py b/jax/_src/lib/__init__.py index 70dc914668cf..be551449aa17 100644 --- a/jax/_src/lib/__init__.py +++ b/jax/_src/lib/__init__.py @@ -40,7 +40,7 @@ raise ImportError(msg) from err -# Checks the jaxlib version before importing anything else from jaxlib. +# Checks the jaxlib version before importing anything else. # Returns the jaxlib version string. def check_jaxlib_version(jax_version: str, jaxlib_version: str, minimum_jaxlib_version: str) -> tuple[int, ...]: @@ -77,20 +77,23 @@ def _parse_version(v: str) -> tuple[int, ...]: jaxlib_version=jaxlib.version.__version__, minimum_jaxlib_version=jax.version._minimum_jaxlib_version) -# Before importing any C compiled modules from jaxlib, first import the CPU +# Before importing any C compiled modules, first import the CPU # feature guard module to verify that jaxlib was compiled in a way that only # uses instructions that are present on this machine. import jaxlib.cpu_feature_guard as cpu_feature_guard cpu_feature_guard.check_cpu_features() -import jaxlib.utils as utils # noqa: F401 -import jaxlib.xla_client as xla_client import jaxlib.lapack as lapack # noqa: F401 +import jaxlib.utils as utils # noqa: F401 +import jaxlib.xla_extension as xla_extension # noqa: F401 +from jaxlib.xla_extension import guard_lib as guard_lib # noqa: F401 +from jaxlib.xla_extension import jax_jit as jax_jit # noqa: F401 +from jaxlib.xla_extension import pmap_lib as pmap_lib # noqa: F401 +from jaxlib.xla_extension import pytree as pytree # noqa: F401 +import jaxlib.xla_client as xla_client # noqa: F401 + +from jaxlib.xla_extension import Device as Device # noqa: F401 -xla_extension = xla_client._xla -pytree = xla_client._xla.pytree -jax_jit = xla_client._xla.jax_jit -pmap_lib = xla_client._xla.pmap_lib # XLA garbage collection: see https://github.com/jax-ml/jax/issues/14882 def _xla_gc_callback(*args): @@ -167,6 +170,3 @@ def _try_cuda_nvcc_import() -> str | None: return None cuda_path = _cuda_path() - -guard_lib = xla_client._xla.guard_lib -Device = xla_client._xla.Device diff --git a/jaxlib/BUILD b/jaxlib/BUILD index faf52a702386..2397639fddf2 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -81,7 +81,7 @@ py_library_providing_imports_info( "//jaxlib/mlir:vector_dialect", "//jaxlib/mosaic", "//jaxlib/triton", - "@xla//xla/python:xla_extension", + "//jaxlib/xla:xla_client", ], ) @@ -94,7 +94,7 @@ symlink_files( symlink_files( name = "xla_client", - srcs = ["@xla//xla/python:xla_client"], + srcs = ["//jaxlib/xla:xla_client"], dst = ".", flatten = True, ) diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 02e6b10b1de1..4403915154bc 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -132,6 +132,9 @@ def pytype_strict_library(name, pytype_srcs = [], **kwargs): new_kwargs = {k: v for k, v in kwargs.items() if k != "data"} native.py_library(name = name, data = data, **new_kwargs) +py_strict_library = native.py_library +py_strict_test = native.py_test + def py_library_providing_imports_info(*, name, lib_rule = native.py_library, pytype_srcs = [], **kwargs): data = pytype_srcs + (kwargs["data"] if "data" in kwargs else []) new_kwargs = {k: v for k, v in kwargs.items() if k != "data"} diff --git a/jaxlib/xla/BUILD b/jaxlib/xla/BUILD new file mode 100644 index 000000000000..41152d642fc8 --- /dev/null +++ b/jaxlib/xla/BUILD @@ -0,0 +1,162 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load( + "//jaxlib:jax.bzl", + "nanobind_extension", + "py_deps", + "py_strict_library", + "py_strict_test", + "pytype_strict_library", +) + +licenses(["notice"]) + +package( + default_applicable_licenses = [], + default_visibility = ["//jax:internal"], +) + +package_group( + name = "xla_python", + includes = [ + "//jax:internal", + ], +) + +pytype_strict_library( + name = "xla_client", + srcs = ["xla_client.py"], + pytype_srcs = ["xla_client.pyi"], + visibility = [":xla_python"], + deps = py_deps([ + "numpy", + "ml_dtypes", + ]) + ["@xla//xla/python:xla_extension"], +) + +py_strict_test( + name = "xla_client_backend_independent_test", + srcs = ["xla_client_backend_independent_test.py"], + deps = [ + ":xla_client", + ] + py_deps([ + "absl/testing", + "numpy", + "portpicker", + ]), +) + +py_strict_library( + name = "xla_client_test", + testonly = 1, + srcs = ["xla_client_test.py"], + visibility = [":xla_python"], + deps = [ + ":xla_client", + "//jax", + "//jax:test_util", + "//jaxlib", + ] + py_deps([ + "absl/flags", + "absl/logging", + "absl/testing", + "ml_dtypes", + "numpy", + ]), +) + +nanobind_extension( + name = "custom_calls_testlib", + testonly = 1, + srcs = ["custom_calls_testlib.cc"], + deps = [ + "@com_google_absl//absl/status", + "@nanobind", + "@xla//xla/ffi/api:c_api", + "@xla//xla/ffi/api:ffi", + ], +) + +py_strict_test( + name = "xla_client_test_cpu", + srcs = ["xla_client_test.py"], + args = ["--backend=cpu"], + env = { + "XLA_FLAGS": "--xla_force_host_platform_device_count=4", + }, + main = "xla_client_test.py", + deps = [ + ":custom_calls_testlib", + ":xla_client", + "//jax", + "//jax:test_util", + "//jaxlib", + ] + py_deps([ + "absl/flags", + "absl/logging", + "absl/testing", + "ml_dtypes", + "numpy", + ]), +) + +py_strict_test( + name = "weakref_lru_cache_test", + srcs = ["weakref_lru_cache_test.py"], + deps = [ + ":xla_client", + ] + py_deps([ + "absl/flags", + "absl/logging", + "absl/testing", + ]), +) + +py_strict_test( + name = "pytree_test", + srcs = ["pytree_test.py"], + deps = [ + ":xla_client", + ] + py_deps([ + "absl/flags", + "absl/logging", + "absl/testing", + ]), +) + +py_strict_test( + name = "config_test", + srcs = ["config_test.py"], + deps = [ + ":xla_client", + ] + py_deps([ + "absl/flags", + "absl/logging", + "absl/testing", + ]), +) + +py_strict_test( + name = "jax_jit_test", + srcs = ["jax_jit_test.py"], + deps = [ + ":xla_client", + ] + py_deps([ + "absl/flags", + "absl/logging", + "absl/testing", + "numpy", + ]), +) diff --git a/jaxlib/xla/config_test.py b/jaxlib/xla/config_test.py new file mode 100644 index 000000000000..8701a37acd1d --- /dev/null +++ b/jaxlib/xla/config_test.py @@ -0,0 +1,71 @@ +# Copyright 2024 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import threading + +from absl.testing import absltest + +from jax.jaxlib.xla import xla_client + +config = xla_client._xla.config + + +class ConfigTest(absltest.TestCase): + + def testBasic(self): + c = config.Config(1) + self.assertEqual(c.value, 1) + self.assertEqual(c.get_global(), 1) + self.assertEqual(c.get_local(), config.unset) + + c.set_global(2) + self.assertEqual(c.value, 2) + self.assertEqual(c.get_global(), 2) + self.assertEqual(c.get_local(), config.unset) + + c.set_local(3) + self.assertEqual(c.value, 3) + self.assertEqual(c.get_global(), 2) + self.assertEqual(c.get_local(), 3) + + c.set_global(4) + self.assertEqual(c.value, 3) + self.assertEqual(c.get_global(), 4) + self.assertEqual(c.get_local(), 3) + + c.set_local(config.unset) + self.assertEqual(c.value, 4) + self.assertEqual(c.get_global(), 4) + self.assertEqual(c.get_local(), config.unset) + + def testThreading(self): + c = config.Config(1) + + def Body(): + for i in range(100): + c.set_local(i) + self.assertEqual(c.get_local(), i) + self.assertEqual(c.get_global(), 1) + self.assertEqual(c.value, i) + + threads = [threading.Thread(target=Body) for _ in range(4)] + for t in threads: + t.start() + for t in threads: + t.join() + + +if __name__ == "__main__": + absltest.main() diff --git a/jaxlib/xla/custom_calls_testlib.cc b/jaxlib/xla/custom_calls_testlib.cc new file mode 100644 index 000000000000..d06105fce76f --- /dev/null +++ b/jaxlib/xla/custom_calls_testlib.cc @@ -0,0 +1,128 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "third_party/nanobind/include/nanobind/nanobind.h" +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/api/ffi.h" + +namespace xla::ffi { +namespace nb = ::nanobind; + +// Implement custom calls as static functions with XLA FFI types in the function +// signature that gives access to the arguments and results buffers together +// with their types and dimensions. See `ffi/api/ffi_test.cc` for more XLA FFI +// examples and features (e.g. binding attributes, custom user-defined structs +// and arbitrary execution context). + +static Error AlwaysFail(Result) { + return Error(XLA_FFI_Error_Code_INTERNAL, "Failed intentionally"); +} + +static Error AlwaysSucceed(Result) { return Error::Success(); } + +static Error Subtract(BufferR0 a, BufferR0 b, + Result> out) { + *out->typed_data() = *a.typed_data() - *b.typed_data(); + return Error::Success(); +} + +static Error SubtractCst(BufferR0 a, + Result> out, float cst) { + *out->typed_data() = *a.typed_data() - cst; + return Error::Success(); +} + +// Define XLA FFI handlers from the implementations defined above using explicit +// XLA FFI binding API to describe type signatures of custom calls. + +XLA_FFI_DEFINE_HANDLER(kAlwaysFail, AlwaysFail, Ffi::Bind().Ret()); + +XLA_FFI_DEFINE_HANDLER(kAlwaysSucceed, AlwaysSucceed, + Ffi::Bind().Ret()); + +XLA_FFI_DEFINE_HANDLER(kSubtract, Subtract, + Ffi::Bind() + .Arg>() + .Arg>() + .Ret>()); + +XLA_FFI_DEFINE_HANDLER(kSubtractCst, SubtractCst, + Ffi::Bind() + .Arg>() + .Ret>() + .Attr("cst")); + +// XLA FFI calls can also be stateful. +struct TestFfiState { + static TypeId id; + explicit TestFfiState(int32_t value) : value(value) {} + int32_t value; +}; +TypeId TestFfiState::id = {}; + +static ErrorOr> StateInstantiate() { + return std::make_unique(42); +} + +static Error StateExecute(TestFfiState* state, + Result> out) { + *out->typed_data() = state->value; + return Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER(kStateInstantiate, StateInstantiate, + Ffi::BindInstantiate()); +XLA_FFI_DEFINE_HANDLER( + kStateExecute, StateExecute, + Ffi::Bind().Ctx>().Ret>()); + +template +static auto BindFunction(T* fn) { + return nb::capsule(reinterpret_cast(fn)); +} + +template +static auto BindTypeId(T* typeId) { + return nb::capsule(reinterpret_cast(typeId)); +} + +// Custom calls registration library that exports function pointers to XLA FFI +// handlers to the python users. +NB_MODULE(custom_calls_testlib, m) { + m.def("registrations", []() { + nb::dict dict; + dict["always_fail"] = BindFunction(kAlwaysFail); + dict["always_succeed"] = BindFunction(kAlwaysSucceed); + dict["subtract_f32"] = BindFunction(kSubtract); + dict["subtract_f32_cst"] = BindFunction(kSubtractCst); + + nb::dict bundle; + bundle["instantiate"] = BindFunction(kStateInstantiate); + bundle["execute"] = BindFunction(kStateExecute); + dict["stateful"] = bundle; + + return dict; + }); + m.def("type_ids", []() { + nb::dict type_ids; + type_ids["test_ffi_state"] = BindTypeId(&TestFfiState::id); + return type_ids; + }); +} + +} // namespace xla::ffi diff --git a/jaxlib/xla/jax_jit_test.py b/jaxlib/xla/jax_jit_test.py new file mode 100644 index 000000000000..a090bc8dfadd --- /dev/null +++ b/jaxlib/xla/jax_jit_test.py @@ -0,0 +1,47 @@ +# Copyright 2024 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for jax_jit helper functions.""" + +from absl.testing import absltest + +from jax.jaxlib.xla import xla_client + +jax_jit = xla_client._xla.jax_jit +pytree = xla_client._xla.pytree + +pytree_registry = pytree.default_registry() + + +class JaxJitTest(absltest.TestCase): + + def testParseArguments(self): + sig, args = jax_jit.parse_arguments( + positional_args=[1, 2, 3], + keyword_args=[4, 5], + kwnames=("a", "b"), + static_argnums=[0, 2], + static_argnames=["a"], + pytree_registry=pytree_registry, + ) + self.assertEqual(args, [2, 5]) + self.assertEqual(sig.static_args, [1, 3, 4]) + self.assertEqual(sig.static_arg_names, ["a"]) + _, leaf = pytree_registry.flatten(0) + self.assertEqual(sig.dynamic_arg_names, ["b"]) + self.assertEqual(sig.dynamic_arg_treedefs, [leaf, leaf]) + + +if __name__ == "__main__": + absltest.main() diff --git a/jaxlib/xla/pytree_test.py b/jaxlib/xla/pytree_test.py new file mode 100644 index 000000000000..b5ac7dd5b4d2 --- /dev/null +++ b/jaxlib/xla/pytree_test.py @@ -0,0 +1,144 @@ +# Copyright 2023 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import collections +import dataclasses +import gc + +from absl.testing import absltest + +from jax.jaxlib.xla import xla_client + +pytree = xla_client._xla.pytree + + +ExampleType = collections.namedtuple("ExampleType", "field0 field1") + +registry = pytree.PyTreeRegistry() + + +class ExampleType2: + + def __init__(self, field0, field1): + self.field0 = field0 + self.field1 = field1 + + def to_iterable(self): + return [self.field0, self.field1], (None,) + + +def from_iterable(state, values): + del state + return ExampleType2(field0=values[0], field1=values[1]) + + +registry.register_node(ExampleType2, ExampleType2.to_iterable, from_iterable) + + +@dataclasses.dataclass +class Custom: + a: int + b: str + + +registry.register_dataclass_node(Custom, ["a"], ["b"]) + + +class PyTreeTest(absltest.TestCase): + + def roundtrip(self, example): + original = registry.flatten(example)[1] + self.assertEqual( + pytree.PyTreeDef.deserialize_using_proto( + registry, original.serialize_using_proto() + ), + original, + ) + + def testSerializeDeserializeNoPickle(self): + o = object() + self.roundtrip(({"a": o, "b": o}, [o, (o, o), None])) + + def testSerializeWithFallback(self): + o = object() + with self.assertRaises(ValueError): + self.roundtrip({"a": ExampleType(field0=o, field1=o)}) + + def testRegisteredType(self): + o = object() + with self.assertRaises(ValueError): + self.roundtrip({"a": ExampleType2(field0=o, field1=o)}) + + def roundtrip_node_data(self, example): + original = registry.flatten(example)[1] + restored = pytree.PyTreeDef.make_from_node_data_and_children( + registry, original.node_data(), original.children() + ) + self.assertEqual(restored, original) + + def testRoundtripNodeData(self): + o = object() + self.roundtrip_node_data([o, o, o]) + self.roundtrip_node_data((o, o, o)) + self.roundtrip_node_data({"a": o, "b": o}) + self.roundtrip_node_data({22: o, 88: o}) + self.roundtrip_node_data(None) + self.roundtrip_node_data(o) + self.roundtrip_node_data(ExampleType(field0=o, field1=o)) + self.roundtrip_node_data(ExampleType2(field0=o, field1=o)) + + def testCompose(self): + x = registry.flatten(0)[1] + y = registry.flatten((0, 0))[1] + self.assertEqual((x.compose(y)).num_leaves, 2) + + def testDataclassMakeFromNodeData(self): + c = Custom(1, "a") + c_leafs, c_tree = registry.flatten(c) + c_tree2 = c_tree.make_from_node_data_and_children( + registry, c_tree.node_data(), c_tree.children() + ) + self.assertEqual(c_tree2.unflatten(c_leafs), c) + self.assertEqual(str(c_tree2), str(c_tree)) + + def testTpTraverse(self): + self.assertContainsSubset( + [ + pytree.PyTreeRegistry, + ExampleType2, + ExampleType2.to_iterable, + from_iterable, + ], + gc.get_referents(registry), + ) + k1 = "k1" + k2 = "k2" + + t = ExampleType("a", "b") + _, treedef = registry.flatten([1, {k1: 2, k2: t}, 5, t]) + + self.assertContainsSubset( + [ + pytree.PyTreeDef, + registry, + k1, + k2, + ExampleType, + ], + gc.get_referents(treedef), + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/jaxlib/xla/weakref_lru_cache_test.py b/jaxlib/xla/weakref_lru_cache_test.py new file mode 100644 index 000000000000..6ac3bfd71075 --- /dev/null +++ b/jaxlib/xla/weakref_lru_cache_test.py @@ -0,0 +1,257 @@ +# Copyright 2023 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import gc +import threading +import time +import weakref + +from absl.testing import absltest + +from jax.jaxlib.xla import xla_client + + +class WeakrefLRUCacheTest(absltest.TestCase): + + def testMultiThreaded(self): + insert_evs = [threading.Event() for _ in range(2)] + insert_evs_i = 0 + + class WRKey: + pass + + class ClashingKey: + + def __eq__(self, other): + return False + + def __hash__(self): + return 333 # induce maximal caching problems. + + class GilReleasingCacheKey: + + def __eq__(self, other): + nonlocal insert_evs_i + if isinstance(other, GilReleasingCacheKey) and insert_evs_i < len( + insert_evs + ): + insert_evs[insert_evs_i].set() + insert_evs_i += 1 + time.sleep(0.01) + return False + + def __hash__(self): + return 333 # induce maximal caching problems. + + def CacheFn(obj, gil_releasing_cache_key): + del obj + del gil_releasing_cache_key + return None + + cache = xla_client.weakref_lru_cache(lambda: None, CacheFn, 2048) + + wrkey = WRKey() + + def Body(): + for insert_ev in insert_evs: + insert_ev.wait() + for _ in range(20): + cache(wrkey, ClashingKey()) + + t = threading.Thread(target=Body) + t.start() + for _ in range(3): + cache(wrkey, GilReleasingCacheKey()) + t.join() + + def testAnotherMultiThreaded(self): + num_workers = 5 + barrier = threading.Barrier(num_workers) + cache = xla_client.weakref_lru_cache(lambda: None, lambda x, y: y, 2048) + + class WRKey: + pass + + def WorkerAddToCache(): + barrier.wait() + wrkey = WRKey() + for i in range(10): + cache(wrkey, i) + + def WorkerCleanCache(): + barrier.wait() + for _ in range(10): + cache.cache_clear() + + workers = [ + threading.Thread(target=WorkerAddToCache) + for _ in range(num_workers - 1) + ] + [threading.Thread(target=WorkerCleanCache)] + + for t in workers: + t.start() + + for t in workers: + t.join() + + def testKwargsDictOrder(self): + miss_id = 0 + + class WRKey: + pass + + def CacheFn(obj, kwkey1, kwkey2): + del obj, kwkey1, kwkey2 + nonlocal miss_id + miss_id += 1 + return miss_id + + cache = xla_client.weakref_lru_cache(lambda: None, CacheFn, 4) + + wrkey = WRKey() + + self.assertEqual(cache(wrkey, kwkey1="a", kwkey2="b"), 1) + self.assertEqual(cache(wrkey, kwkey1="b", kwkey2="a"), 2) + self.assertEqual(cache(wrkey, kwkey2="b", kwkey1="a"), 1) + + def testGetKeys(self): + def CacheFn(obj, arg): + del obj + return arg + "extra" + + cache = xla_client.weakref_lru_cache(lambda: None, CacheFn, 4) + + class WRKey: + pass + + wrkey = WRKey() + + self.assertEmpty(cache.cache_keys()) + cache(wrkey, "arg1") + cache(wrkey, "arg2") + self.assertLen(cache.cache_keys(), 2) + + def testNonWeakreferenceableKey(self): + class NonWRKey: + __slots__ = () + + non_wr_key = NonWRKey() + with self.assertRaises(TypeError): + weakref.ref(non_wr_key) + + cache = xla_client.weakref_lru_cache(lambda: None, lambda x: 2048) + for _ in range(100): + with self.assertRaises(TypeError): + cache(non_wr_key) + + def testCrashingKey(self): + class WRKey: + pass + + class CrashingKey: + # A key that raises exceptions if eq or hash is called. + + def __eq__(self, other): + raise ValueError("eq") + + def __hash__(self): + raise ValueError("hash") + + cache = xla_client.weakref_lru_cache(lambda: None, lambda x, y: y, 2048) + wrkey = WRKey() + with self.assertRaises(ValueError): + for _ in range(100): + cache(wrkey, CrashingKey()) + + def testPrintingStats(self): + class WRKey: + pass + + cache = xla_client.weakref_lru_cache(lambda: None, lambda x, y: y, 2048) + wrkey = WRKey() + for i in range(10): + cache(wrkey, i) + for i in range(5): + cache(wrkey, i) + + self.assertEqual( + repr(cache.cache_info()), + "WeakrefLRUCache(hits=5, misses=10, maxsize=2048, currsize=10)", + ) + + def testGCKeys(self): + class WRKey: + + def __init__(self, x): + self.x = x + + def __eq__(self, other): + return self.x == other.x + + def __hash__(self): + return hash(self.x) + + cache = xla_client.weakref_lru_cache(lambda: None, lambda x, y: y, 2048) + keys = [WRKey(i) for i in range(10)] + for i in range(10): + cache(keys[i], i) + + # Delete some keys, to exercise the weakref callback behavior. + del keys[::2] + + for key in keys: + cache(key, 7) + + def testTpTraverse(self): + class WRKey: + pass + + def CacheContextFn(): + return None + + def CallFn(x, y, *args, **kwargs): + del x, args, kwargs + return y + + cache = xla_client.weakref_lru_cache(CacheContextFn, CallFn, 2048) + + keys = [WRKey() for _ in range(10)] + values = [str(i) for i in range(10)] + args = [str(i) for i in range(10)] + kwargs = {"a": "b"} + + for key, value in zip(keys, values): + cache(key, value, *args, **kwargs) + + expected_refs = ( + [ + CacheContextFn, + CallFn, + xla_client._xla.WeakrefLRUCache, + kwargs, + ] + + [weakref.getweakrefs(key)[0] for key in keys] + + values + + args + ) + + # Can't use assertContainsSubset because it doesn't support kwargs since + # dicts aren't hashable. + for ref in expected_refs: + self.assertIn(ref, gc.get_referents(cache)) + + +if __name__ == "__main__": + absltest.main() diff --git a/jaxlib/xla/xla_client.py b/jaxlib/xla/xla_client.py new file mode 100644 index 000000000000..b6c5707d05dd --- /dev/null +++ b/jaxlib/xla/xla_client.py @@ -0,0 +1,1044 @@ +# Copyright 2017 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""An XLA client in Python.""" + +from __future__ import annotations + +import atexit +from collections.abc import Mapping, Sequence +import contextlib +import enum # pylint: disable=g-bad-import-order +import gzip +import inspect +import logging +import os +import threading +from typing import Any, Protocol, Union + +import ml_dtypes +import numpy as np + +from jaxlib import xla_extension as _xla + +# Note this module does *not* depend on any Python protocol buffers. The XLA +# Python bindings are currently packaged both as part of jaxlib and as part +# of TensorFlow. If we use protocol buffers here, then importing both jaxlib +# and TensorFlow may fail with duplicate protocol buffer message definitions. + +# Most functions are snake_case for consistency with other modules, some +# method names are CamelCase for consistency with XLA. +# pylint: disable=invalid-name + +# Pylint has false positives for type annotations. +# pylint: disable=invalid-sequence-index + +ifrt_programs = _xla.ifrt_programs +ops = _xla.ops +profiler = _xla.profiler + +# Just an internal arbitrary increasing number to help with backward-compatible +# changes. In JAX, reference this via jax._src.lib.xla_extension_version. +_version = 320 + +# Version number for MLIR:Python components. +mlir_api_version = 58 + +xla_platform_names = { + 'cpu': 'Host', + 'gpu': 'CUDA', +} + +logger = logging.getLogger(__name__) + +_NameValueMapping = Mapping[str, Union[str, int, list[int], float, bool]] + + +def make_cpu_client( + asynchronous=True, + distributed_client=None, + node_id=0, + num_nodes=1, + collectives=None, + num_devices=None, +) -> ...: + register_custom_call_handler('cpu', _xla.register_custom_call_target) + register_custom_type_id_handler('cpu', _xla.register_custom_type_id) + return _xla.get_tfrt_cpu_client( + asynchronous=asynchronous, + distributed_client=distributed_client, + node_id=node_id, + num_nodes=num_nodes, + collectives=collectives, + num_devices=num_devices, + ) + + +def make_gpu_client( + distributed_client=None, + node_id=0, + num_nodes=1, + platform_name=None, + allowed_devices=None, + mock=False, + mock_gpu_topology=None, +): + """Returns a GPU client. BFC allocator is used by default.""" + options = generate_pjrt_gpu_plugin_options() + allocator = options['allocator'] + config = _xla.GpuAllocatorConfig() + if allocator == 'default': + config.kind = _xla.GpuAllocatorConfig.Kind.DEFAULT + if allocator == 'platform': + config.kind = _xla.GpuAllocatorConfig.Kind.PLATFORM + if allocator == 'bfc': + config.kind = _xla.GpuAllocatorConfig.Kind.BFC + if allocator == 'cuda_async': + config.kind = _xla.GpuAllocatorConfig.Kind.CUDA_ASYNC + if 'memory_fraction' in options: + config.memory_fraction = options['memory_fraction'] + if 'preallocate' in options: + config.preallocate = options['preallocate'] + if 'collective_memory_size' in options: + config.collective_memory_size = options['collective_memory_size'] + register_custom_call_handler('CUDA', _xla.register_custom_call_target) + register_custom_call_handler('ROCM', _xla.register_custom_call_target) + register_custom_type_id_handler('CUDA', _xla.register_custom_type_id) + register_custom_type_id_handler('ROCM', _xla.register_custom_type_id) + + return _xla.get_gpu_client( + asynchronous=True, + allocator_config=config, + distributed_client=distributed_client, + node_id=node_id, + num_nodes=num_nodes, + platform_name=platform_name, + allowed_devices=allowed_devices, + mock=mock, + mock_gpu_topology=mock_gpu_topology, + ) + + +def make_tfrt_tpu_c_api_client(options: _NameValueMapping | None = None): + assert pjrt_plugin_loaded('tpu') + if not pjrt_plugin_initialized('tpu'): + initialize_pjrt_plugin('tpu') + if options is None: + options = {} + return _xla.get_c_api_client('tpu', options) + + +DeviceTopology = _xla.DeviceTopology +get_topology_for_devices = _xla.get_topology_for_devices + + +def make_tfrt_tpu_c_api_device_topology( + topology_name: str = '', **kwargs +) -> DeviceTopology: + """Creates a PJRT C API TopologyDescription.""" + return _xla.get_default_c_api_topology('tpu', topology_name, dict(**kwargs)) + + +def make_c_api_device_topology( + c_api: Any, topology_name: str = '', **kwargs +) -> DeviceTopology: + """Creates a PJRT C API TopologyDescription.""" + return _xla.get_c_api_topology(c_api, topology_name, dict(**kwargs)) + + +def pjrt_plugin_loaded(plugin_name: str) -> bool: + return _xla.pjrt_plugin_loaded(plugin_name) + + +def load_pjrt_plugin_dynamically(plugin_name: str, library_path: str) -> Any: + return _xla.load_pjrt_plugin(plugin_name, library_path, c_api=None) + + +def load_pjrt_plugin_with_c_api(plugin_name: str, c_api: Any) -> None: + return _xla.load_pjrt_plugin(plugin_name, None, c_api) + + +def pjrt_plugin_initialized(plugin_name: str) -> bool: + return _xla.pjrt_plugin_initialized(plugin_name) + + +def initialize_pjrt_plugin(plugin_name: str) -> None: + """Initializes a PJRT plugin. + + The plugin needs to be loaded first (through load_pjrt_plugin_dynamically or + static linking) before this method is called. + Args: + plugin_name: the name of the PJRT plugin. + """ + _xla.initialize_pjrt_plugin(plugin_name) + + +def make_c_api_client( + plugin_name: str, + options: _NameValueMapping | None = None, + distributed_client: _xla.DistributedRuntimeClient | None = None, +): + """Creates a PJRT C API client for a PJRT plugin. + + It is required that load_pjrt_plugin_dynamically is called once with the same + plugin_name before this method is called. + + Args: + plugin_name: the name of the PJRT plugin. + options: extra platform-specific options. + distributed_client: distributed client. + + Returns: + A PJRT C API client for plugin_name. + """ + if options is None: + options = {} + return _xla.get_c_api_client(plugin_name, options, distributed_client) + + +def make_tpu_client( + library_path: str | None = None, options: _NameValueMapping | None = None +): + """Returns a TPU client. Defaults to allowing 32 in-flight computations.""" + if not pjrt_plugin_loaded('tpu'): + c_api = load_pjrt_plugin_dynamically('tpu', library_path or 'libtpu.so') + profiler.register_plugin_profiler(c_api) + return make_tfrt_tpu_c_api_client(options) + + +def generate_pjrt_gpu_plugin_options() -> _NameValueMapping: + """Generates the PjRt GPU plugin options. + + Returns: + A dictionary of plugin options. + """ + + options = {} + options['platform_name'] = 'cuda' + allocator = os.getenv('XLA_PYTHON_CLIENT_ALLOCATOR', 'default').lower() + memory_fraction = os.getenv('XLA_CLIENT_MEM_FRACTION', '') + deprecated_memory_fraction = os.getenv('XLA_PYTHON_CLIENT_MEM_FRACTION', '') + if deprecated_memory_fraction: + if memory_fraction: + raise ValueError( + 'XLA_CLIENT_MEM_FRACTION is specified together ' + 'with XLA_PYTHON_CLIENT_MEM_FRACTION. ' + 'Remove the latter one, it is deprecated.' + ) + else: + memory_fraction = deprecated_memory_fraction + preallocate = os.getenv('XLA_PYTHON_CLIENT_PREALLOCATE', '') + collective_memory_size = os.getenv( + 'XLA_PYTHON_CLIENT_COLLECTIVE_MEM_SIZE_MB', '' + ) + if allocator not in ('default', 'platform', 'bfc', 'cuda_async'): + raise ValueError( + 'XLA_PYTHON_CLIENT_ALLOCATOR env var must be "default", "platform", ' + '"bfc", or "cuda_async", got "%s"' % allocator + ) + options['allocator'] = allocator + if memory_fraction: + options['memory_fraction'] = float(memory_fraction) + if preallocate: + options['preallocate'] = preallocate not in ('false', 'False', '0') + if collective_memory_size: + options['collective_memory_size'] = int(collective_memory_size) * (1 << 20) + return options + + +class OpMetadata: + """Python representation of a xla.OpMetadata protobuf.""" + + __slots__ = ('op_type', 'op_name', 'source_file', 'source_line') + + def __init__(self, op_type='', op_name='', source_file='', source_line=0): + self.op_type = op_type + self.op_name = op_name + self.source_file = source_file + self.source_line = source_line + + +def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): + """Helper for use in source mapping that returns an OpMetadata object.""" + full_filename, lineno = inspect.stack()[skip_frames][1:3] + filename = os.path.basename(full_filename) + return OpMetadata( + op_type=op_type, op_name=op_name, source_file=filename, source_line=lineno + ) + + +PrimitiveType = _xla.PrimitiveType + +bfloat16 = ml_dtypes.bfloat16 +# TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. +# Also, it would be better to conditionally import these based on whether they +# are in the current version of ml_dtypes. +# float4_e2m1fn = ml_dtypes.float4_e2m1fn +# float8_e3m4 = ml_dtypes.float8_e3m4 +# float8_e4m3 = ml_dtypes.float8_e4m3 +# float8_e8m0fnu = ml_dtypes.float8_e8m0fnu +float8_e4m3fn = ml_dtypes.float8_e4m3fn +float8_e4m3b11fnuz = ml_dtypes.float8_e4m3b11fnuz +float8_e4m3fnuz = ml_dtypes.float8_e4m3fnuz +float8_e5m2 = ml_dtypes.float8_e5m2 +float8_e5m2fnuz = ml_dtypes.float8_e5m2fnuz + +XLA_ELEMENT_TYPE_TO_DTYPE = { + PrimitiveType.PRED: np.dtype('bool'), + PrimitiveType.S4: np.dtype('int4'), + PrimitiveType.S8: np.dtype('int8'), + PrimitiveType.S16: np.dtype('int16'), + PrimitiveType.S32: np.dtype('int32'), + PrimitiveType.S64: np.dtype('int64'), + PrimitiveType.U4: np.dtype('uint4'), + PrimitiveType.U8: np.dtype('uint8'), + PrimitiveType.U16: np.dtype('uint16'), + PrimitiveType.U32: np.dtype('uint32'), + PrimitiveType.U64: np.dtype('uint64'), + # TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. + # PrimitiveType.F4E2M1FN: np.dtype(float4_e2m1fn), + # PrimitiveType.F8E3M4: np.dtype(float8_e3m4), + # PrimitiveType.F8E4M3: np.dtype(float8_e4m3), + # PrimitiveType.F8E8M0FNU: np.dtype(float8_e8m0fnu), + PrimitiveType.F8E4M3FN: np.dtype(float8_e4m3fn), + PrimitiveType.F8E4M3B11FNUZ: np.dtype(float8_e4m3b11fnuz), + PrimitiveType.F8E5M2: np.dtype(float8_e5m2), + PrimitiveType.F8E4M3FNUZ: np.dtype(float8_e4m3fnuz), + PrimitiveType.F8E5M2FNUZ: np.dtype(float8_e5m2fnuz), + PrimitiveType.BF16: np.dtype(bfloat16), + PrimitiveType.F16: np.dtype('float16'), + PrimitiveType.F32: np.dtype('float32'), + PrimitiveType.F64: np.dtype('float64'), + PrimitiveType.C64: np.dtype('complex64'), + PrimitiveType.C128: np.dtype('complex128'), + PrimitiveType.TUPLE: np.dtype(np.object_), + PrimitiveType.TOKEN: np.dtype(np.object_), +} + +# Note the conversion on the key. Numpy has a known issue wherein dtype hashing +# doesn't work as expected (https://github.com/numpy/numpy/issues/7242). Thus, +# when keying by dtype in this dict, we use the string form of dtypes. +DTYPE_TO_XLA_ELEMENT_TYPE = { + str(dt): et for et, dt in XLA_ELEMENT_TYPE_TO_DTYPE.items() +} + + +def dtype_to_etype(dtype): + """Convenience function for reading DTYPE_TO_XLA_ELEMENT_TYPE.""" + return DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))] + + +Shape = _xla.Shape +Shape.__doc__ = """ +A Shape is an object defined in C++ that duck types like the following class: + +class Shape: + '''Represents an XLA shape. + + A shape is either an array shape, having rank-many integer + dimensions and an element type (represented by a Numpy dtype), or it + is a tuple shape, having a shape for every tuple component: + + type shape = + TupleShape of shape list + | ArrayShape of { dimensions: int list; element_type: dtype } + ''' + + @staticmethod + def tuple_shape(tuple_shapes) -> Shape: + "Construct a tuple shape." + + @staticmethod + def array_shape(element_type, dimensions, minor_to_major=None) -> Shape: + + @staticmethod + def from_pyval(pyval) -> Shape: + "Returns a Shape that describes a tuple-tree of Numpy arrays." + + def __init__(self, str) -> Shape: + "Parses a shape string." + def __eq__(self, other: Shape) -> bool: + def __ne__(self, other: Shape) -> bool: + def __hash__(self): + def __repr__(self): + def is_tuple(self) -> bool: + def is_array(self) -> bool: + def tuple_shapes(self) -> [Shape]: + def numpy_dtype(self) -> np.dtype: + "Like element_type(), but returns dtype('O') for a tuple shape." + def xla_element_type(self) -> PrimitiveType: + def element_type(self) -> np.dtype: + def dimensions(self) -> (int, int, ...): + def rank(self) -> int: + def with_major_to_minor_layout_if_absent(self) -> Shape: + "Returns a copy with missing layouts set to major-to-minor." + + def to_serialized_proto(self) -> bytes: + "Returns 'shape' as a serialized proto." +""" + +ProgramShape = _xla.ProgramShape +ProgramShape.__doc__ = """ +A ProgramShape is a C++ object that duck types like the following class. + +class ProgramShape: + def __init__(self, parameter_shapes, result_shape): + def parameter_shapes(self) -> [Shape]: + def result_shape(self) -> Shape: + def __repr__(self): +""" + +ShapeIndex = _xla.ShapeIndex +ShapeIndex.__doc__ = """ +A Shape is an object defined in C++ that duck types like the following class: + +class ShapeIndex: + '''Represents an XLA ShapeIndex. + + An index for specifying a particular nested subshape within a shape. Used in + ShapeUtil::GetSubshape and other interfaces. ShapeIndex defines a path through + the Shape tree where each element of ShapeIndex indexes into a tuple (or + nested tuple) within the shape. For a non-nested tuple, an index has a single + element. + ''' + + def __init__(self, List[int]) -> ShapeIndex: + def __eq__(self, other: Shape) -> bool: + def __ne__(self, other: Shape) -> bool: + def __hash__(self): + def __repr__(self): +""" + + +def shape_from_pyval(pyval, layout: Sequence[int] | None = None): + """Returns a Shape that describes a tuple-tree of Numpy arrays.""" + + def convert(pyval): + if isinstance(pyval, tuple): + if layout is not None: + raise NotImplementedError( + 'shape_from_pyval does not support layouts for tuple shapes' + ) + return Shape.tuple_shape(tuple(convert(elt) for elt in pyval)) + else: + return Shape.array_shape(pyval.dtype, np.shape(pyval), layout) + + return convert(pyval) + + +DeviceAssignment = _xla.DeviceAssignment +DeviceAssignment.__doc__ = """ +A DeviceAssignment is a C++ object with the following signature. + +def create(assignment): + '''Builds a device assignment. + + Args: + assignment: a 2D numpy array of device ordinal integers, indexed by + [replica][computation_in_replica]. + Returns: + A device assignment. + ''' + +def replica_count(): + '''Returns the number of replicas.''' +def computation_count(): + '''Returns the number of computations per replica.''' +""" + +Device = _xla.Device +CompileOptions = _xla.CompileOptions + +HostBufferSemantics = _xla.HostBufferSemantics + +# An Executable is a C++ class that duck types with the following API: +# class Executable: +# def local_devices(self) -> [Device]: +# def execute(self, arguments : [Buffer]) -> Buffer: +# """Execute on one replica with Buffer arguments and return value.""" +# +# def size_of_generated_code_in_bytes(self) -> int: +# """Return generated binary size, or -1 if not known.""" +# +# def execute_sharded_on_local_devices(self, arguments: [[Buffer]]) +# -> [Buffer]: +# """Execute on many replicas with Buffer arguments and return value. +# +# Args: +# arguments: A sequence of sequences of Buffers. The i'th element of each +# sequence comprises the arguments for execution on the i'th local +# device. +# +# Returns: +# A list of the computation's outputs as a list of Buffers for each +# device. +# """ +# +# There are different implementations of Executable for different backends. + + +class PaddingType(enum.Enum): + VALID = 1 + SAME = 2 + + +def window_padding_type_to_pad_values( + padding_type, lhs_dims, rhs_dims, window_strides +): + """Maps PaddingType or string to pad values (list of pairs of ints).""" + if not isinstance(padding_type, (str, PaddingType)): + msg = 'padding_type must be str or PaddingType, got {}.' + raise TypeError(msg.format(type(padding_type))) + + if isinstance(padding_type, str): + if padding_type.upper() == 'VALID': + padding_type = PaddingType.VALID + elif padding_type.upper() == 'SAME': + padding_type = PaddingType.SAME + else: + msg = 'Unknown padding type string: expected "VALID" or "SAME", got {}.' + raise ValueError(msg.format(padding_type)) + + if padding_type == PaddingType.VALID: + return [(0, 0)] * len(window_strides) + elif padding_type == PaddingType.SAME: + out_shape = np.ceil(np.true_divide(lhs_dims, window_strides)).astype(int) + pad_sizes = [ + max((out_size - 1) * stride + filter_size - in_size, 0) + for out_size, stride, filter_size, in_size in zip( + out_shape, window_strides, rhs_dims, lhs_dims + ) + ] + return [(pad_size // 2, pad_size - pad_size // 2) for pad_size in pad_sizes] + else: + msg = 'Unexpected PaddingType value: {}' + raise ValueError(msg.format(padding_type)) + + +XlaBuilder = _xla.XlaBuilder +XlaComputation = _xla.XlaComputation +XlaOp = _xla.XlaOp +FftType = _xla.FftType +Client = _xla.Client +Memory = _xla.Memory +ArrayImpl = _xla.ArrayImpl +LoadedExecutable = _xla.LoadedExecutable +DeviceList = _xla.DeviceList +OpSharding = _xla.OpSharding +HloSharding = _xla.HloSharding +Sharding = _xla.Sharding +NamedSharding = _xla.NamedSharding +SingleDeviceSharding = _xla.SingleDeviceSharding +PmapSharding = _xla.PmapSharding +GSPMDSharding = _xla.GSPMDSharding +PjRtLayout = _xla.PjRtLayout +AutotuneCacheMode = _xla.AutotuneCacheMode +ResultAccuracyMode = _xla.ResultAccuracy_Mode + + +def LoadedExecutable_execute(self, arguments, device=None): + del device + results = self.execute_sharded(arguments) + return [x[0] for x in results.disassemble_into_single_device_arrays()] + + +def LoadedExecutable_execute_with_token(self, arguments, device=None): + del device + results = self.execute_sharded(arguments, with_tokens=True) + return ( + [x[0] for x in results.disassemble_into_single_device_arrays()], + results.consume_token().get_token(0), + ) + + +LoadedExecutable.execute = LoadedExecutable_execute +LoadedExecutable.execute_with_token = LoadedExecutable_execute_with_token + + +class CustomCallTargetTraits(enum.IntFlag): + DEFAULT = 0 + # Calls to custom call are safe to trace into the command buffer. It means + # that calls to custom call always launch exactly the same device operations + # (can depend on attribute values) that can be captured and then replayed. + # + # Supported only for custom calls implemented with XLA FFI. + COMMAND_BUFFER_COMPATIBLE = 1 + + +class CustomCallHandler(Protocol): + + def __call__( + self, + name: str, + fn: Any, + platform: str, + /, + api_version: int = ..., + traits: CustomCallTargetTraits = ..., + ) -> None: + ... + + +_custom_callback_handler: dict[str, CustomCallHandler] = {} +# Key is xla_platform_name, value is (function_name, function, api_version) +_custom_callback: dict[ + str, list[tuple[str, Any, int, CustomCallTargetTraits]] +] = {} +_custom_callback_lock = threading.Lock() + + +def register_custom_call_target( + name: str, + fn: Any, + platform: str = 'cpu', + api_version: int = 0, + traits: CustomCallTargetTraits = CustomCallTargetTraits.DEFAULT, +) -> None: + """Registers a custom call target. + + Args: + name: bytes containing the name of the function. + fn: a PyCapsule object containing the function pointer. + platform: the target platform. + api_version: the XLA FFI version to use. Supported versions are: 0 for the + untyped FFI and 1 for the typed FFI. + traits: custom call traits corresponding to XLA FFI handler traits. + """ + # To support AMD GPUs, we need to have xla_platform_names["gpu"] == "ROCM" + # Since that is hardcoded to CUDA, we are using the following as workaround. + xla_platform_name = xla_platform_names.get(platform, platform) + with _custom_callback_lock: + if xla_platform_name in _custom_callback_handler: + _custom_callback_handler[xla_platform_name]( + name, fn, xla_platform_name, api_version, traits + ) + else: + _custom_callback.setdefault(xla_platform_name, []).append( + (name, fn, api_version, traits) + ) + + +def register_custom_call_handler( + platform: str, handler: CustomCallHandler +) -> None: + """Registers a custom handler and use it to register existing custom calls. + + If a custom call handler for the platform already exist, calling this method + is a no-op and it will not register a new handler. + + Args: + platform: the target platform. + handler: the function to register a custom call. + """ + xla_platform_name = xla_platform_names.get(platform, platform) + with _custom_callback_lock: + if xla_platform_name in _custom_callback_handler: + logger.debug( + 'Custom call handler for %s is already register. Will not register a' + ' new one', + xla_platform_name, + ) + return + _custom_callback_handler[xla_platform_name] = handler + if xla_platform_name in _custom_callback: + for name, fn, api_version, traits in _custom_callback[xla_platform_name]: + handler(name, fn, xla_platform_name, api_version, traits) + del _custom_callback[xla_platform_name] + + +class CustomTypeIdHandler(Protocol): + + def __call__(self, name: str, capsule: Any) -> None: + ... + + +_custom_type_id_handler: dict[str, CustomTypeIdHandler] = {} +_custom_type_id: dict[str, Any] = {} +_custom_type_id_lock = threading.Lock() + + +def register_custom_type_id( + type_name: str, + type_id: Any, + platform: str = 'cpu', +) -> None: + """Register a custom type id for use with the FFI. + + Args: + type_name: a unique name for the type. + type_id: a PyCapsule object containing a pointer to the ``ffi::TypeId``. + platform: the target platform. + """ + xla_platform_name = xla_platform_names.get(platform, platform) + with _custom_type_id_lock: + if xla_platform_name in _custom_type_id_handler: + _custom_type_id_handler[xla_platform_name](type_name, type_id) + else: + _custom_type_id.setdefault(xla_platform_name, []).append( + (type_name, type_id) + ) + + +def register_custom_type_id_handler( + platform: str, handler: CustomTypeIdHandler +) -> None: + """Register a custom type id handler and use it to register existing type ids. + + If a custom type id handler for the platform already exist, calling this + method is a no-op and it will not register a new handler. + + Args: + platform: the target platform. + handler: the function to register a custom type id. + """ + xla_platform_name = xla_platform_names.get(platform, platform) + with _custom_callback_lock: + if xla_platform_name in _custom_type_id_handler: + logger.debug( + 'Custom type id handler for %s is already register. Will not ' + 'register a new one', + xla_platform_name, + ) + return + _custom_type_id_handler[xla_platform_name] = handler + if xla_platform_name in _custom_type_id: + for name, capsule in _custom_type_id[xla_platform_name]: + handler(name, capsule) + del _custom_type_id[xla_platform_name] + + +register_custom_call_partitioner = _xla.register_custom_call_partitioner +encode_inspect_sharding_callback = _xla.encode_inspect_sharding_callback +hlo_sharding_util = _xla.hlo_sharding_util +register_custom_call_as_batch_partitionable = ( + _xla.register_custom_call_as_batch_partitionable +) + + +class PaddingConfigDimension: + """Python representation of a xla.PaddingConfigDimension protobuf.""" + + __slots__ = ('edge_padding_low', 'edge_padding_high', 'interior_padding') + + edge_padding_low: int + edge_padding_high: int + interior_padding: int + + def __init__(self): + self.edge_padding_low = 0 + self.edge_padding_high = 0 + self.interior_padding = 0 + + +class PaddingConfig: + """Python representation of a xla.PaddingConfig protobuf.""" + + __slots__ = ('dimensions',) + + def __init__(self): + self.dimensions = [] + + +def make_padding_config( + padding_config: Union[PaddingConfig, Sequence[tuple[int, int, int]]] +) -> PaddingConfig: + """Create PaddingConfig proto from list of triples of integers. + + Args: + padding_config: either a PaddingConfig or a list of integer triples + (edge_padding_low, edge_padding_high, interior_padding) representing the + configuration of the padding operation. + + Returns: + A `PaddingConfig` object. + """ + if not isinstance(padding_config, PaddingConfig): + triples = padding_config + padding_config = PaddingConfig() + for lo, hi, interior in triples: + dimension = PaddingConfigDimension() + dimension.edge_padding_low = lo + dimension.edge_padding_high = hi + dimension.interior_padding = interior + padding_config.dimensions.append(dimension) + return padding_config + + +class DotDimensionNumbers: + """Python representation of a xla.DotDimensionNumbers protobuf.""" + + __slots__ = ( + 'lhs_contracting_dimensions', + 'rhs_contracting_dimensions', + 'lhs_batch_dimensions', + 'rhs_batch_dimensions', + ) + + def __init__(self): + self.lhs_contracting_dimensions = [] + self.rhs_contracting_dimensions = [] + self.lhs_batch_dimensions = [] + self.rhs_batch_dimensions = [] + + +def make_dot_dimension_numbers( + dimension_numbers: Union[ + DotDimensionNumbers, + tuple[tuple[list[int], list[int]], tuple[list[int], list[int]]], + ] +) -> DotDimensionNumbers: + """Builds a DotDimensionNumbers object from a specification. + + Args: + dimension_numbers: either a `DotDimensionNumbers` or a nested tuple + `((lhs_contract, rhs_contract), (lhs_batch, rhs_batch))` of lists of + integers representing the dimensions to treat as contracting dimensions + and batch dimensions on each input operand. + + Returns: + A `DotDimensionNumbers` object. + """ + if isinstance(dimension_numbers, (list, tuple)): + (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers + dot_dims_proto = DotDimensionNumbers() + dot_dims_proto.lhs_contracting_dimensions.extend(lhs_contract) + dot_dims_proto.rhs_contracting_dimensions.extend(rhs_contract) + dot_dims_proto.lhs_batch_dimensions.extend(lhs_batch) + dot_dims_proto.rhs_batch_dimensions.extend(rhs_batch) + return dot_dims_proto + else: + return dimension_numbers + + +class ConvolutionDimensionNumbers: + """Python representation of a xla.ConvolutionDimensionNumbers protobuf.""" + + __slots__ = ( + 'input_batch_dimension', + 'input_feature_dimension', + 'input_spatial_dimensions', + 'kernel_input_feature_dimension', + 'kernel_output_feature_dimension', + 'kernel_spatial_dimensions', + 'output_batch_dimension', + 'output_feature_dimension', + 'output_spatial_dimensions', + ) + + def __init__(self): + self.input_batch_dimension = 0 + self.input_feature_dimension = 0 + self.input_spatial_dimensions = [] + self.kernel_input_feature_dimension = 0 + self.kernel_output_feature_dimension = 0 + self.kernel_spatial_dimensions = [] + self.output_batch_dimension = 0 + self.output_feature_dimension = 0 + self.output_spatial_dimensions = [] + + +def make_convolution_dimension_numbers( + dimension_numbers: Union[ + None, ConvolutionDimensionNumbers, tuple[str, str, str] + ], + num_spatial_dimensions: int, +) -> ConvolutionDimensionNumbers: + """Builds a ConvolutionDimensionNumbers object from a specification. + + Args: + dimension_numbers: optional, either a ConvolutionDimensionNumbers object or + a tuple (lhs_spec, rhs_spec, out_spec). Each element is a string of length + N+2 identifying by position: (1) batch dimensions in lhs, rhs, and the + output with the character 'N', (2) feature dimensions in lhs and the + output with the character 'C', (3) input and output feature dimensions in + rhs with the characters 'I' and 'O' respectively, and (4) spatial + dimension correspondences between lhs, rhs, and the output using any + distinct characters. For example, to indicate dimension numbers consistent + with the Conv operation with two spatial dimensions, one could use + ('NCHW', 'OIHW', 'NCHW'). As another example, to indicate dimension + numbers consistent with the TensorFlow Conv2D operation, one could use + ('NHWC', 'HWIO', 'NHWC'). When using the latter form of convolution + dimension specification, window strides are associated with spatial + dimension character labels according to the order in which the labels + appear in the rhs_spec string, so that window_strides[0] is matched with + the dimension corresponding to the first character appearing in rhs_spec + that is not 'I' or 'O'. By default, use the same dimension numbering as + Conv and ConvWithGeneralPadding. + num_spatial_dimensions: the number of spatial dimensions. + + Returns: + A `ConvolutionDimensionNumbers` object. + """ + if dimension_numbers is None: + nd = num_spatial_dimensions + dimension_numbers = ConvolutionDimensionNumbers() + dimension_numbers.input_batch_dimension = 0 + dimension_numbers.input_feature_dimension = 1 + dimension_numbers.output_batch_dimension = 0 + dimension_numbers.output_feature_dimension = 1 + dimension_numbers.kernel_output_feature_dimension = 0 + dimension_numbers.kernel_input_feature_dimension = 1 + dimension_numbers.input_spatial_dimensions.extend(range(2, 2 + nd)) + dimension_numbers.kernel_spatial_dimensions.extend(range(2, 2 + nd)) + dimension_numbers.output_spatial_dimensions.extend(range(2, 2 + nd)) + elif isinstance(dimension_numbers, tuple): + lhs_spec, rhs_spec, out_spec = dimension_numbers + dimension_numbers = ConvolutionDimensionNumbers() + + dimension_numbers.input_batch_dimension = lhs_spec.index('N') + dimension_numbers.input_feature_dimension = lhs_spec.index('C') + dimension_numbers.output_batch_dimension = out_spec.index('N') + dimension_numbers.output_feature_dimension = out_spec.index('C') + dimension_numbers.kernel_output_feature_dimension = rhs_spec.index('O') + dimension_numbers.kernel_input_feature_dimension = rhs_spec.index('I') + + dimension_numbers.kernel_spatial_dimensions.extend( + i for i, c in enumerate(rhs_spec) if c not in {'I', 'O'} + ) + dimension_numbers.input_spatial_dimensions.extend( + sorted( + (i for i, c in enumerate(lhs_spec) if c not in {'N', 'C'}), + key=lambda i: rhs_spec.index(lhs_spec[i]), + ) + ) + dimension_numbers.output_spatial_dimensions.extend( + sorted( + (i for i, c in enumerate(out_spec) if c not in {'N', 'C'}), + key=lambda i: rhs_spec.index(out_spec[i]), + ) + ) + return dimension_numbers + + +class PrecisionConfig: + """Python representation of a xla.PrecisionConfig protobuf.""" + + __slots__ = ('operand_precision',) + + Precision = _xla.PrecisionConfig_Precision + + def __init__(self): + self.operand_precision = [] + + +class ResultAccuracy: + """Python representation of a xla.ResultAccuracy protobuf.""" + + __slots__ = ('mode', 'atol', 'rtol', 'ulps') + + def __init__(self): + self.mode = _xla.ResultAccuracy_Mode.DEFAULT + self.atol = 0.0 + self.rtol = 0.0 + self.ulps = 0 + + +class GatherDimensionNumbers: + """Python representation of a xla.GatherDimensionNumbers protobuf.""" + + __slots__ = ( + 'offset_dims', + 'collapsed_slice_dims', + 'start_index_map', + 'index_vector_dim', + ) + + def __init__(self): + self.offset_dims = [] + self.collapsed_slice_dims = [] + self.start_index_map = [] + self.index_vector_dim = 0 + + +class ScatterDimensionNumbers: + """Python representation of a xla.ScatterDimensionNumbers protobuf.""" + + __slots__ = ( + 'update_window_dims', + 'inserted_window_dims', + 'scatter_dims_to_operand_dims', + 'index_vector_dim', + ) + + def __init__(self): + self.update_window_dims = [] + self.inserted_window_dims = [] + self.scatter_dims_to_operand_dims = [] + self.index_vector_dim = 0 + + +class ReplicaGroup: + """Python representation of a xla.ReplicaGroup protobuf.""" + + __slots__ = ('replica_ids',) + + def __init__(self): + self.replica_ids = [] + + +def _make_replica_group_proto(replica_group): + replica_group_proto = ReplicaGroup() + replica_group_proto.replica_ids.extend(replica_group) + return replica_group_proto + + +def make_replica_groups(replica_groups): + if replica_groups is None: + replica_groups_protos = [] # special value for XLA API + else: + replica_groups = list(replica_groups) + replica_groups_protos = [ + _make_replica_group_proto(group) for group in replica_groups + ] + return replica_groups_protos + + +Traceback = _xla.Traceback +Frame = _xla.Frame + + +@contextlib.contextmanager +def tracebacks(enabled=True): + """Context manager that enables or disables traceback collection.""" + saved = Traceback.enabled + Traceback.enabled = enabled + try: + yield + finally: + Traceback.enabled = saved + + +def heap_profile(client: Client) -> bytes: + """Returns a gzipped pprof protocol buffer containing a heap profile.""" + return gzip.compress(client.heap_profile()) + + +XlaRuntimeError = _xla.XlaRuntimeError + +# Perform one last garbage collection of deferred Python references. This is +# mostly to keep ASAN happy. +atexit.register(_xla.collect_garbage) + +weakref_lru_cache = _xla.weakref_lru_cache +array_result_handler = _xla.array_result_handler +batched_copy_array_to_devices_with_sharding = ( + _xla.batched_copy_array_to_devices_with_sharding +) +batched_device_put = _xla.batched_device_put +reorder_shards = _xla.reorder_shards +batched_block_until_ready = _xla.batched_block_until_ready +check_and_canonicalize_memory_kind = _xla.check_and_canonicalize_memory_kind +Layout = _xla.Layout +custom_call_targets = _xla.custom_call_targets +ArrayCopySemantics = _xla.ArrayCopySemantics diff --git a/jaxlib/xla/xla_client.pyi b/jaxlib/xla/xla_client.pyi new file mode 100644 index 000000000000..234af8f7b87d --- /dev/null +++ b/jaxlib/xla/xla_client.pyi @@ -0,0 +1,322 @@ +# Copyright 2021 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import annotations + +from collections.abc import Callable, Mapping, Sequence +import enum +from typing import Any, Union + +import numpy + +from jaxlib import xla_extension as _xla +from jaxlib.xla_extension import ArrayImpl as ArrayImpl +from jaxlib.xla_extension import AutotuneCacheMode as AutotuneCacheMode +from jaxlib.xla_extension import Client as Client +from jaxlib.xla_extension import CompileOptions as CompileOptions +from jaxlib.xla_extension import Device as Device +from jaxlib.xla_extension import DeviceAssignment as DeviceAssignment +from jaxlib.xla_extension import DeviceList as DeviceList +from jaxlib.xla_extension import DeviceTopology as DeviceTopology +from jaxlib.xla_extension import DistributedRuntimeClient as DistributedRuntimeClient +from jaxlib.xla_extension import FftType as FftType +from jaxlib.xla_extension import Frame as Frame +from jaxlib.xla_extension import GSPMDSharding as GSPMDSharding +from jaxlib.xla_extension import HloSharding as HloSharding +from jaxlib.xla_extension import HostBufferSemantics as HostBufferSemantics +from jaxlib.xla_extension import ifrt_programs as ifrt_programs +from jaxlib.xla_extension import Layout as Layout +from jaxlib.xla_extension import LoadedExecutable as LoadedExecutable +from jaxlib.xla_extension import Memory as Memory +from jaxlib.xla_extension import NamedSharding as NamedSharding +from jaxlib.xla_extension import ops as ops +from jaxlib.xla_extension import OpSharding as OpSharding +from jaxlib.xla_extension import PjRtLayout as PjRtLayout +from jaxlib.xla_extension import PmapSharding as PmapSharding +from jaxlib.xla_extension import PrimitiveType as PrimitiveType +from jaxlib.xla_extension import ArrayCopySemantics as ArrayCopySemantics +from jaxlib.xla_extension import profiler as profiler +from jaxlib.xla_extension import Shape as Shape +from jaxlib.xla_extension import Sharding as Sharding +from jaxlib.xla_extension import SingleDeviceSharding as SingleDeviceSharding +from jaxlib.xla_extension import Traceback as Traceback +from jaxlib.xla_extension import XlaBuilder as XlaBuilder +from jaxlib.xla_extension import XlaComputation as XlaComputation +from jaxlib.xla_extension import XlaOp as XlaOp + +_version: int + +mlir_api_version: int + +bfloat16: type[numpy.generic] +# TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. +# float4_e2m1fn: type[numpy.generic] +# float8_e3m4: type[numpy.generic] +# float8_e4m3: type[numpy.generic] +# float8_e8m0fnu: type[numpy.generic] +float8_e4m3fn: type[numpy.generic] +float8_e4m3b11fnuz: type[numpy.generic] +float8_e4m3fnuz: type[numpy.generic] +float8_e5m2: type[numpy.generic] +float8_e5m2fnuz: type[numpy.generic] +XLA_ELEMENT_TYPE_TO_DTYPE: dict[PrimitiveType, numpy.dtype] + +_NameValueMapping = Mapping[str, Union[str, int, list[int], float, bool]] + +def dtype_to_etype(dtype: numpy.dtype) -> PrimitiveType: + ... + +def shape_from_pyval(pyval: Any, layout: Sequence[int] | None = None) -> Any: ... + +def heap_profile(client: Client) -> bytes: + ... + +XlaRuntimeError = _xla.XlaRuntimeError + +def make_cpu_client( + asynchronous: bool = ..., + distributed_client: DistributedRuntimeClient | None = ..., + node_id: int = ..., + num_nodes: int = ..., + collectives: _xla.CpuCollectives | None = ..., + num_devices: int | None = ..., +) -> Client: + ... + +def make_gpu_client( + distributed_client: DistributedRuntimeClient | None = ..., + node_id: int = ..., + num_nodes: int = ..., + platform_name: str | None = ..., + allowed_devices: set[int] | None = ..., + mock: bool | None = ..., + mock_gpu_topology: str | None = ..., +) -> Client: + ... + +def make_tfrt_tpu_c_api_client(options: _NameValueMapping | None = None) -> Client: + ... + +def make_tfrt_tpu_c_api_device_topology( + topology_name: str | None = None, **kwargs +) -> DeviceTopology: + ... + +def make_c_api_device_topology(c_api: Any, topology_name: str = '', **kwargs) -> DeviceTopology: + ... + +def get_topology_for_devices(devices: list[Device]) -> DeviceTopology: + ... + +def make_tpu_client( + library_path: str | None, options: _NameValueMapping | None = None +) -> Client: + ... + +def make_c_api_client( + plugin_name: str, + options: _NameValueMapping | None = None, + distributed_client: DistributedRuntimeClient | None = None, +) -> Client: + ... + +def pjrt_plugin_loaded(plugin_name: str) -> bool: + ... + +def load_pjrt_plugin_dynamically(plugin_name: str, library_path: str) -> Any: + ... + +def load_pjrt_plugin_with_c_api(plugin_name: str, c_api: Any) -> None: + ... + +def pjrt_plugin_initialized(plugin_name: str) -> bool: + ... + +def initialize_pjrt_plugin(plugin_name: str) -> None: + ... + +def generate_pjrt_gpu_plugin_options() -> _NameValueMapping: + ... + +class OpMetadata: + + def __init__( + self, + op_type: str | None = ..., + op_name: str | None = ..., + source_file: str | None = ..., + source_line: int | None = ..., + ): + ... + op_type: str | None + op_name: str | None + source_file: str | None + source_line: int | None + +class PaddingConfigDimension: + edge_padding_low: int + edge_padding_high: int + interior_padding: int + +class PaddingConfig: + dimensions: list[PaddingConfigDimension] + +def make_padding_config( + padding_config: Union[PaddingConfig, Sequence[tuple[int, int, int]]], +) -> PaddingConfig: + ... + +class PaddingType(enum.Enum): + VALID = 1 + SAME = 2 + +class DotDimensionNumbers: + lhs_contracting_dimensions: list[int] + rhs_contracting_dimensions: list[int] + lhs_batch_dimensions: list[int] + rhs_batch_dimensions: list[int] + +def make_dot_dimension_numbers( + dimension_numbers: Union[ + DotDimensionNumbers, + tuple[tuple[list[int], list[int]], tuple[list[int], list[int]]], + ], +) -> DotDimensionNumbers: + ... + +class ConvolutionDimensionNumbers: + input_batch_dimension: int + input_feature_dimension: int + input_spatial_dimensions: list[int] + kernel_input_feature_dimension: int + kernel_output_feature_dimension: int + kernel_spatial_dimensions: list[int] + output_batch_dimension: int + output_feature_dimension: int + output_spatial_dimensions: list[int] + +def make_convolution_dimension_numbers( + dimension_numbers: Union[ + None, ConvolutionDimensionNumbers, tuple[str, str, str] + ], + num_spatial_dimensions: int, +) -> ConvolutionDimensionNumbers: + ... + +class PrecisionConfig: + Precision = _xla.PrecisionConfig_Precision + operand_precision: list[_xla.PrecisionConfig_Precision] + +class ResultAccuracy: + mode: _xla.ResultAccuracy_Mode + atol: float + rtol: float + ulps: int + +class GatherDimensionNumbers: + offset_dims: list[int] + collapsed_slice_dims: list[int] + start_index_map: list[int] + index_vector_dim: int + operand_batching_dims: list[int] + start_indices_batching_dims: list[int] + +class ScatterDimensionNumbers: + update_window_dims: list[int] + inserted_window_dims: list[int] + scatter_dims_to_operand_dims: list[int] + index_vector_dim: int + input_batching_dims: list[int] + scatter_indices_batching_dims: list[int] + +class ReplicaGroup: + replica_ids: list[int] + +def make_replica_groups( + replica_groups: Sequence[Sequence[int]] | None, +) -> list[ReplicaGroup]: + ... + +def weakref_lru_cache(cache_context_fn: Callable, call: Callable, maxsize=...) -> _xla.WeakrefLRUCache: + ... + +def batched_copy_array_to_devices_with_sharding( + arrays: Sequence[ArrayImpl], + devices: Sequence[list[Device]], + sharding: Sequence[Any], + array_copy_semantics: Sequence[ArrayCopySemantics], +) -> Sequence[ArrayImpl]: ... + +def batched_device_put( + aval: Any, + sharding: Any, + shards: Sequence[Any], + devices: list[Device], + committed: bool = ..., + force_copy: bool = ..., + host_buffer_semantics: Any = ..., +) -> ArrayImpl: ... + +def reorder_shards( + x: ArrayImpl, + dst_sharding: Any, + array_copy_semantics: ArrayCopySemantics, +) -> ArrayImpl: ... + +def batched_block_until_ready(x: Sequence[ArrayImpl]) -> None: ... + +def check_and_canonicalize_memory_kind( + memory_kind: str | None, device_list: DeviceList +) -> str | None: ... + +def array_result_handler( + aval: Any, + sharding: Any, + committed: bool, + _skip_checks: bool = ...) -> Callable: + ... + +class CustomCallTargetTraits(enum.IntFlag): + DEFAULT = 0 + COMMAND_BUFFER_COMPATIBLE = 1 + +def register_custom_call_target( + name: str, + fn: Any, + platform: str = ..., + api_version: int = ..., + traits: CustomCallTargetTraits = ..., +) -> None: ... + +def register_custom_call_handler( + xla_platform_name: str, handler: Any +) -> None: ... + +def custom_call_targets(platform: str) -> dict[str, Any]: ... + +def register_custom_type_id( + type_name: str, + type_id: Any, + platform: str = ..., +) -> None: ... + +def register_custom_type_id_handler(platform: str, handler: Any) -> None: ... + +def encode_inspect_sharding_callback(handler: Any) -> bytes: ... + +register_custom_call_partitioner = _xla.register_custom_call_partitioner +register_custom_call_as_batch_partitionable = ( + _xla.register_custom_call_as_batch_partitionable +) diff --git a/jaxlib/xla/xla_client_backend_independent_test.py b/jaxlib/xla/xla_client_backend_independent_test.py new file mode 100644 index 000000000000..ee1c33feb40c --- /dev/null +++ b/jaxlib/xla/xla_client_backend_independent_test.py @@ -0,0 +1,195 @@ +# Copyright 2017 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Backend-independent tests for the Python XLA client.""" + +import unittest + +from absl.testing import absltest +import numpy as np + +from jax.jaxlib.xla import xla_client + +# pylint: disable=g-import-not-at-top +try: + import portpicker +except ImportError: + portpicker = None +# pylint: enable=g-import-not-at-top + +ops = xla_client.ops + + +class ShapeTest(absltest.TestCase): + + def testInvalidShapes(self): + with self.assertRaisesRegex(xla_client.XlaRuntimeError, "invalid shape"): + xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, [-2, 4]) + + with self.assertRaisesRegex( + RuntimeError, "layout minor_to_major field contains 1 element.*"): + xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, [2, 4], [3]) + + with self.assertRaisesRegex( + RuntimeError, "layout minor_to_major field has out-of-bounds value.*"): + xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, [2, 4], + [1, -1]) + + +class ComputationPrinting(absltest.TestCase): + + def ExampleComputation(self): + builder = xla_client.XlaBuilder("acomputation") + p0 = ops.Parameter(builder, 0, xla_client.shape_from_pyval(np.float32(0))) + p1 = ops.Parameter(builder, 1, + xla_client.shape_from_pyval(np.zeros((4,), np.float32))) + x = ops.Mul(p0, p1) + ops.Add(x, x) + return builder.build() + + def testComputationToHloText(self): + computation = self.ExampleComputation() + hlo_text = computation.as_hlo_text() + self.assertTrue(hlo_text.startswith("HloModule acomputation")) + + def testComputationToHloGraph(self): + computation = self.ExampleComputation() + hlo_dot_graph = computation.as_hlo_dot_graph() + self.assertTrue(hlo_dot_graph.startswith("digraph ")) + + def testHloModuleToHloText(self): + computation = self.ExampleComputation() + hlo_text = computation.as_hlo_module().to_string() + self.assertTrue(hlo_text.startswith("HloModule acomputation")) + + def testHloModuleFromText(self): + hlo_module_text = """HloModule test + add { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT add = f32[] add(x, y) + } + ENTRY entry { + p0 = f32[2,3] parameter(0) + start = f32[2,3] all-reduce-start(p0), to_apply=add + ROOT done = f32[2,3] all-reduce-done(start) + }""" + hlo_module = xla_client._xla.hlo_module_from_text(hlo_module_text) + hlo_text = hlo_module.to_string() + self.assertTrue(hlo_text.startswith("HloModule test")) + + def testHloModuleToHloGraph(self): + computation = self.ExampleComputation() + hlo_dot_graph = xla_client._xla.hlo_module_to_dot_graph( + computation.as_hlo_module()) + self.assertTrue(hlo_dot_graph.startswith("digraph ")) + + +class ComputationHashTest(absltest.TestCase): + + def testHash(self): + builder0 = xla_client.XlaBuilder("computation0") + p0 = ops.Parameter(builder0, 0, xla_client.shape_from_pyval(np.float32(0))) + p1 = ops.Parameter(builder0, 1, + xla_client.shape_from_pyval(np.zeros((4,), np.float32))) + ops.Mul(p0, p1) + computation0 = builder0.build() + + builder1 = xla_client.XlaBuilder("computation1") + p0 = ops.Parameter(builder1, 0, xla_client.shape_from_pyval(np.float32(0))) + p1 = ops.Parameter(builder1, 1, + xla_client.shape_from_pyval(np.zeros((4,), np.float32))) + ops.Mul(p0, p1) + computation1 = builder1.build() + + self.assertEqual(computation0.hash(), computation1.hash()) + + +class AliasTest(absltest.TestCase): + + def testSetUpAlias(self): + c = xla_client.XlaBuilder(self.id()) + p1 = ops.Parameter( + c, 0, + xla_client.shape_from_pyval(np.array( + 1.0, np.float32)).with_major_to_minor_layout_if_absent()) + p2 = ops.Parameter( + c, 1, + xla_client.shape_from_pyval(np.array( + 1.0, np.float32)).with_major_to_minor_layout_if_absent()) + out = ops.Add(p1, p2) + c.setup_alias([], 0, []) + c.build(out) + + +class ProfilerTest(absltest.TestCase): + + def testTraceMe(self): + # TODO(phawkins): These tests just check that the TraceMe context manager + # acts like a context manager and doesn't explode. Ideally we'd check that + # the profiler saw the traceme too. + with xla_client.profiler.TraceMe("test1"): + pass + with xla_client.profiler.TraceMe("test2", foo=123): + pass + with self.assertRaises(ValueError): + with xla_client.profiler.TraceMe("test3"): + raise ValueError("test") + + @unittest.skipIf(portpicker is None, "Test requires portpicker") + def testStartServer(self): + port = portpicker.pick_unused_port() + server = xla_client.profiler.start_server(port) + del server + + +class HloModuleGroupTest(absltest.TestCase): + + def testHloModuleGroup(self): + builder0 = xla_client.XlaBuilder("computation0") + p0 = ops.Parameter(builder0, 0, xla_client.shape_from_pyval(np.float32(0))) + p1 = ops.Parameter(builder0, 1, + xla_client.shape_from_pyval(np.zeros((4,), np.float32))) + root = ops.Mul(p0, p1) + computation0 = builder0.build(root) + + m = computation0.get_hlo_module() + mg_name = "test_module_group" + mg = xla_client._xla.HloModuleGroup(mg_name, [m]) + self.assertEqual(mg.name, mg_name) + + modules = mg.to_modules() + self.assertLen(modules, 1) + self.assertEqual(m.to_string(), modules[0].to_string()) + + +class RunHloPassTest(absltest.TestCase): + + def testHloDCE(self): + b = xla_client.XlaBuilder("acomputation") + p0 = ops.Parameter(b, 0, xla_client.shape_from_pyval(np.float32(0))) + p1 = ops.Parameter(b, 1, + xla_client.shape_from_pyval(np.zeros((4,), np.float32))) + root = ops.Mul(p0, p1) + + # Dead instructions + p2 = ops.Parameter(b, 2, xla_client.shape_from_pyval(np.float32(0))) + ops.Add(p2, p2) + + hlo_module = b.build(root).get_hlo_module() + self.assertTrue(xla_client._xla.HloDCE().run(hlo_module)) + + +if __name__ == "__main__": + absltest.main() diff --git a/jaxlib/xla/xla_client_test.py b/jaxlib/xla/xla_client_test.py new file mode 100644 index 000000000000..e228905637cb --- /dev/null +++ b/jaxlib/xla/xla_client_test.py @@ -0,0 +1,3714 @@ +# Copyright 2017 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Backend-dependent tests for the Python XLA client.""" + +import collections +import functools +import itertools +import re +import threading +import traceback +from typing import Sequence +import unittest + +from absl import flags +from absl import logging +from absl.testing import absltest +from absl.testing import parameterized +import ml_dtypes +import numpy as np + +from jax.jaxlib.xla import xla_client +import jax +import jax._src.test_util + +# pylint: disable=g-import-not-at-top +try: + from jax.jaxlib.xla import custom_calls_testlib +except ImportError: + custom_calls_testlib = None + +xla_client._xla.jax_jit.set_thread_local_state_initialization_callback( + lambda: None +) + +bfloat16 = ml_dtypes.bfloat16 +float4_e2m1fn = ml_dtypes.float4_e2m1fn +float8_e3m4 = ml_dtypes.float8_e3m4 +float8_e4m3 = ml_dtypes.float8_e4m3 +float8_e8m0fnu = ml_dtypes.float8_e8m0fnu +float8_e4m3fn = ml_dtypes.float8_e4m3fn +float8_e4m3fnuz = ml_dtypes.float8_e4m3fnuz +float8_e4m3b11fnuz = ml_dtypes.float8_e4m3b11fnuz +float8_e5m2 = ml_dtypes.float8_e5m2 +float8_e5m2fnuz = ml_dtypes.float8_e5m2fnuz +ops = xla_client.ops +xla_computation_to_mlir_module = ( + xla_client._xla.mlir.xla_computation_to_mlir_module) + + +def execute_with_python_values(executable, arguments, backend): # pylint: disable=invalid-name + """Execute on one replica with Python values as arguments and output.""" + + def put(arg): # pylint: disable=invalid-name + return backend.buffer_from_pyval(arg, device=executable.local_devices()[0]) + + arguments = [put(arg) for arg in arguments] + outputs = executable.execute(arguments) + return [np.asarray(x) for x in outputs] + + +# pylint: disable=invalid-name +def jax_array_convert_to_array(self, dtype=None, copy=None): + del copy + out, _ = self._single_device_array_to_np_array_did_copy() + if dtype is not None: + out = out.astype(dtype) + return out + + +def jax_array_device(self): + return self._sharding._device + + +def jax_array_copy_to_host_async(self): + self._copy_single_device_array_to_host_async() + + +Array = xla_client.ArrayImpl +Array.__array__ = jax_array_convert_to_array +Array.copy_to_host_async = jax_array_copy_to_host_async +Array.device = jax_array_device +xla_client.SingleDeviceSharding.device_set = property( + lambda self: {self._device} +) +# pylint: enable=invalid-name + + +FLAGS = flags.FLAGS + +# We choose to ignore pylint's complaints about complex comprehensions, which we +# use widely for parameterizing tests. +# pylint: disable=g-complex-comprehension + +_CUSTOM_CALLS_REGISTERED = False + + +# XLA' alignment is 16 bytes at the moment, but it should match what Eigen +# supports, and that can go up to 128 bytes on hardware with HVX. +_XLA_CPU_MAX_ALIGNMENT = 128 + + +# Minimum possible alignment for XLA. +_XLA_CPU_MIN_ALIGNMENT = 16 + + +# Return a copy of `x` with the given alignment. Does nothing if `x` is already +# aligned. We do this manually, because numpy doesn't support custom alignment +# value. +def _Aligned(x, alignment=_XLA_CPU_MAX_ALIGNMENT): + if (x.ctypes.data % alignment) == 0: + return x + + # Create temporary buffer with extra space for alignment. + assert alignment % x.itemsize == 0 + extra = alignment // x.itemsize + buf = np.empty(x.size + extra, dtype=x.dtype) + + # Create a view of the temporary buffer with such an offset, that the result + # buffer is aligned. + offset = (-buf.ctypes.data % alignment) // x.itemsize + result = buf[offset : offset + x.size].reshape(x.shape) + + # Copy the data to the result buffer and return it. + np.copyto(result, x) + return result + + +# Return an unaligned copy of `x`. The result buffer's memory address is +# guaranteed to not be aligned to `alignment`. This function is useful for +# testing failiures. +def _Unaligned(x, alignment=_XLA_CPU_MIN_ALIGNMENT): + if (x.ctypes.data % alignment) != 0: + return x + + # Create temporary buffer with extra space. + assert (x.itemsize % alignment) != 0 + offset = 1 + buf = np.empty(x.size + offset, dtype=x.dtype) + + if (buf.ctypes.data % alignment) != 0: + # If the temporary buffer is already unaligned, return it. + result = buf + else: + # Otherwise, create a view of the temporary buffer with an offset. + result = buf[offset : offset + x.size].reshape(x.shape) + assert (result.ctypes.data % alignment) != 0 + + # Copy the data to the result buffer and return it. + np.copyto(result, x) + return result + + +def TestFactory(xla_backend, + cloud_tpu=False, + tfrt_tpu=False, + pjrt_c_api=False, + pathways=False, + pathways_ifrt=False): + tests = [] + + int_dtypes = [np.int32, np.int64, np.uint32, np.uint64] + # TODO(phawkins): test np.float16, where supported. + float_dtypes = [bfloat16, np.float32, np.float64] + complex_dtypes = [np.complex64, np.complex128] + standard_dtypes = int_dtypes + float_dtypes + complex_dtypes + [np.bool_] + # TODO(zhangqiaorjc): test fp8 types when XLA support is complete. + # standard_dtypes is only used for BufferProtocolTest so we only test fp8 + # round trip tests. + fp8_dtypes = [float8_e4m3b11fnuz, float8_e4m3fn, float8_e5m2] + standard_dtypes += fp8_dtypes + # TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. + # standard_dtypes += [float4_e2m1fn, float8_e3m4, float8_e4m3, float8_e8m0fnu] + dlpack_dtypes = int_dtypes + float_dtypes + [np.bool_] + complex_dtypes + + class ComputationTest(parameterized.TestCase): + """Base class for running an XLA Computation through the local client.""" + + def setUp(self): + super(ComputationTest, self).setUp() + self.backend = xla_backend() + + global _CUSTOM_CALLS_REGISTERED + if self.backend.platform == "cpu" and not _CUSTOM_CALLS_REGISTERED: + for name, fn in custom_calls_testlib.registrations().items(): + xla_client.register_custom_call_target( + name, fn, platform="cpu", api_version=1 + ) + for name, val in custom_calls_testlib.type_ids().items(): + xla_client.register_custom_type_id(name, val, platform="cpu") + _CUSTOM_CALLS_REGISTERED = True + + def _NewComputation(self, name=None): + if name is None: + name = self.id() + return xla_client.XlaBuilder(name) + + def _Execute(self, c, arguments): + compiled_c = self.backend.compile( + xla_computation_to_mlir_module(c.build())) + return execute_with_python_values( + compiled_c, arguments, backend=self.backend) + + def _ExecuteAndAssertWith(self, assert_func, c, arguments, expected): + assert expected is not None + results = self._Execute(c, arguments) + self.assertLen(results, len(expected)) + for result, e in zip(results, expected): + # Numpy's comparison methods are a bit too lenient by treating inputs as + # "array-like", meaning that scalar 4 will be happily compared equal to + # [[4]]. We'd like to be more strict so assert shapes as well. + self.assertEqual(np.asanyarray(result).shape, np.asanyarray(e).shape) + assert_func(result, e) + + def _ExecuteAndCompareExact(self, c, arguments=(), expected=None): + self._ExecuteAndAssertWith(np.testing.assert_equal, c, arguments, + expected) + + def _ExecuteAndCompareClose(self, + c, + arguments=(), + expected=None, + rtol=1e-4, + atol=0): + self._ExecuteAndAssertWith( + functools.partial(np.testing.assert_allclose, rtol=rtol, atol=atol), + c, arguments, expected) + + def NumpyArrayF32(*args, **kwargs): + """Convenience wrapper to create Numpy arrays with a np.float32 dtype.""" + return np.array(*args, dtype=np.float32, **kwargs) + + def NumpyArrayF64(*args, **kwargs): + """Convenience wrapper to create Numpy arrays with a np.float64 dtype.""" + return np.array(*args, dtype=np.float64, **kwargs) + + def NumpyArrayS32(*args, **kwargs): + """Convenience wrapper to create Numpy arrays with a np.int32 dtype.""" + return np.array(*args, dtype=np.int32, **kwargs) + + def NumpyArrayBool(*args, **kwargs): + """Convenience wrapper to create Numpy arrays with a np.bool_ dtype.""" + return np.array(*args, dtype=np.bool_, **kwargs) + + class ComputationPrinting(absltest.TestCase): + + def setUp(self): + super(ComputationPrinting, self).setUp() + self.backend = xla_backend() + + def ExampleComputation(self): + builder = xla_client.XlaBuilder("acomputation") + p0 = ops.Parameter(builder, 0, xla_client.shape_from_pyval(np.float32(0))) + p1 = ops.Parameter( + builder, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32))) + x = ops.Mul(p0, p1) + ops.Add(x, x) + return builder.build() + + @unittest.skipIf(cloud_tpu or pathways, "not implemented") + def testCompiledHloModuleToHloText(self): + computation = self.ExampleComputation() + executable = self.backend.compile( + xla_computation_to_mlir_module(computation)) + hlo_modules = executable.hlo_modules() + self.assertLen(hlo_modules, 1) + hlo_text = hlo_modules[0].to_string() + self.assertTrue(hlo_text.startswith("HloModule acomputation")) + self.assertIn("fusion", hlo_text) + + @unittest.skipIf(cloud_tpu or pathways, "not implemented") + def testCompiledHloModuleAsSerializedProto(self): + computation = self.ExampleComputation() + executable = self.backend.compile( + xla_computation_to_mlir_module(computation)) + hlo_modules = executable.hlo_modules() + self.assertLen(hlo_modules, 1) + hlo_text = hlo_modules[0].to_string() + proto = hlo_modules[0].as_serialized_hlo_module_proto() + hlo_module_roundtrip = xla_client.XlaComputation(proto).get_hlo_module() + hlo_text_roundtrip = hlo_module_roundtrip.to_string() + self.assertEqual(hlo_text, hlo_text_roundtrip) + + @unittest.skipIf(cloud_tpu or pathways, "not implemented") + def testStableComputationSerialization(self): + # Ideally we would test identical computations produced in different + # processes. For now we have this limited smoke test. + computation = self.ExampleComputation() + ref = computation.as_serialized_hlo_module_proto() + for _ in range(10): + self.assertEqual(computation.as_serialized_hlo_module_proto(), ref) + + # TODO(b/261771737): some version of this should work with pjrt_c_api=True + @unittest.skipIf(cloud_tpu or pathways or pathways_ifrt or pjrt_c_api, + "not implemented") + def testFlopEstimate(self): + computation = self.ExampleComputation() + properties = xla_client._xla.hlo_module_cost_analysis( + self.backend, computation.as_hlo_module()) + self.assertEqual(properties["flops"], 8.0) + + def testFingerprint(self): + computation = self.ExampleComputation() + executable = self.backend.compile( + xla_computation_to_mlir_module(computation)) + fingerprint = executable.fingerprint + if ( + self.backend.platform == "tpu" + or self.backend.platform == "gpu" + or self.backend.platform == "cpu" + ) and not (cloud_tpu or pathways or pathways_ifrt): + logging.info("fingerprint: %s", fingerprint) + self.assertNotEmpty(fingerprint) + else: + self.assertIsNone(fingerprint) + + tests.append(ComputationPrinting) + + class ComputationsWithConstantsTest(ComputationTest): + """Tests focusing on Constant ops.""" + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in int_dtypes + float_dtypes) + def testConstantScalarSum(self, dtype): + c = self._NewComputation() + ops.Add(ops.Constant(c, dtype(1.11)), ops.Constant(c, dtype(3.14))) + self._ExecuteAndCompareClose(c, expected=[dtype(1.11) + dtype(3.14)]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testConstantVectorMul(self, dtype): + c = self._NewComputation() + ops.Mul( + ops.Constant(c, np.array([2.5, 3.3, -1.2, 0.7], dtype)), + ops.Constant(c, np.array([-1.2, 2, -2, -3], dtype))) + self._ExecuteAndCompareClose( + c, expected=[[-3, 6.6, 2.4, -2.1]], rtol=3e-3) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testConstantVectorScalarDiv(self, dtype): + c = self._NewComputation() + ops.Div( + ops.Constant(c, np.array([1.5, 2.5, 3.0, -10.8], dtype=dtype)), + ops.Constant(c, dtype(2.0))) + self._ExecuteAndCompareClose( + c, expected=[[0.75, 1.25, 1.5, -5.4]], rtol=2e-3) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testConstantVectorScalarPow(self, dtype): + c = self._NewComputation() + ops.Pow( + ops.Constant(c, np.array([1.5, 2.5, 3.0], dtype=dtype)), + ops.Constant(c, dtype(2.))) + self._ExecuteAndCompareClose(c, expected=[[2.25, 6.25, 9.]]) + + def testIota(self): + c = self._NewComputation() + ops.Iota(c, xla_client.PrimitiveType.F32, 10) + self._ExecuteAndCompareExact( + c, expected=[np.arange(10, dtype=np.float32)]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in int_dtypes) + def testBroadcastedIota(self, dtype): + c = self._NewComputation() + shape = xla_client.Shape.array_shape( + xla_client.dtype_to_etype(dtype), (2, 3)) + ops.Iota(c, shape, 1) + expected = np.array([[0, 1, 2], [0, 1, 2]], dtype=dtype) + self._ExecuteAndCompareExact(c, expected=[expected]) + + def testBooleanAnd(self): + c = self._NewComputation() + ops.And( + ops.Constant(c, NumpyArrayBool([True, False, True, False])), + ops.Constant(c, NumpyArrayBool([True, True, False, False]))) + self._ExecuteAndCompareExact(c, expected=[[True, False, False, False]]) + + def testBooleanOr(self): + c = self._NewComputation() + ops.Or( + ops.Constant(c, NumpyArrayBool([True, False, True, False])), + ops.Constant(c, NumpyArrayBool([True, True, False, False]))) + self._ExecuteAndCompareExact(c, expected=[[True, True, True, False]]) + + def testBooleanXor(self): + c = self._NewComputation() + ops.Xor( + ops.Constant(c, NumpyArrayBool([True, False, True, False])), + ops.Constant(c, NumpyArrayBool([True, True, False, False]))) + self._ExecuteAndCompareExact(c, expected=[[False, True, True, False]]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testSum2D(self, dtype): + c = self._NewComputation() + ops.Add( + ops.Constant(c, np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype)), + ops.Constant(c, np.array([[1, -1, 1], [-1, 1, -1]], dtype=dtype))) + self._ExecuteAndCompareClose(c, expected=[[[2, 1, 4], [3, 6, 5]]]) + + def testShiftLeft(self): + c = self._NewComputation() + ops.ShiftLeft( + ops.Constant(c, NumpyArrayS32([3])), + ops.Constant(c, NumpyArrayS32([2]))) + self._ExecuteAndCompareClose(c, expected=[[12]]) + + def testShiftRightArithmetic(self): + c = self._NewComputation() + ops.ShiftRightArithmetic( + ops.Constant(c, NumpyArrayS32([-2])), + ops.Constant(c, NumpyArrayS32([1]))) + self._ExecuteAndCompareClose(c, expected=[[-1]]) + + def testShiftRightLogical(self): + c = self._NewComputation() + ops.ShiftRightLogical( + ops.Constant(c, NumpyArrayS32([-1])), + ops.Constant(c, NumpyArrayS32([1]))) + self._ExecuteAndCompareClose(c, expected=[[2**31 - 1]]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testSum2DWith1DBroadcastDim0(self, dtype): + # sum of a 2D array with a 1D array where the latter is replicated across + # dimension 0 to match the former's shape. + c = self._NewComputation() + ops.Add( + ops.Constant(c, + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], + dtype=dtype)), + ops.Constant(c, np.array([10, 20, 30], dtype=dtype)), + broadcast_dimensions=(0,)) + self._ExecuteAndCompareClose( + c, expected=[[[11, 12, 13], [24, 25, 26], [37, 38, 39]]]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testSum2DWith1DBroadcastDim1(self, dtype): + # sum of a 2D array with a 1D array where the latter is replicated across + # dimension 1 to match the former's shape. + c = self._NewComputation() + ops.Add( + ops.Constant(c, + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], + dtype=dtype)), + ops.Constant(c, np.array([10, 20, 30], dtype=dtype)), + broadcast_dimensions=(1,)) + self._ExecuteAndCompareClose( + c, expected=[[[11, 22, 33], [14, 25, 36], [17, 28, 39]]]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testConstantAxpy(self, dtype): + c = self._NewComputation() + ops.Add( + ops.Mul( + ops.Constant(c, dtype(2)), + ops.Constant(c, np.array([2.2, 3.3, 4.4, 5.5], dtype=dtype))), + ops.Constant(c, np.array([100, -100, 200, -200], dtype))) + self._ExecuteAndCompareClose( + c, expected=[[104.4, -93.4, 208.8, -189]], rtol=2e-3) + + def testCustomCall(self): + if self.backend.platform != "cpu": + self.skipTest("Test requires cpu platform") + c = self._NewComputation() + ops.CustomCallWithLayout( + c, + b"subtract_f32", + operands=[ + ops.Constant(c, np.float32(1.25)), + ops.Constant(c, np.float32(0.5)) + ], + shape_with_layout=xla_client.Shape.array_shape( + np.dtype(np.float32), (), ()), + operand_shapes_with_layout=[ + xla_client.Shape.array_shape(np.dtype(np.float32), (), ()), + xla_client.Shape.array_shape(np.dtype(np.float32), (), ()), + ], + api_version=xla_client.ops.CustomCallApiVersion + .API_VERSION_TYPED_FFI) + self._ExecuteAndCompareClose(c, expected=[0.75]) + + def testCustomCallWithUnifiedApiUnknownTarget(self): + if self.backend.platform != "cpu": + self.skipTest("Test requires cpu platform") + c = self._NewComputation() + + ops.CustomCallWithLayout( + c, + b"not_existing", + operands=[], + shape_with_layout=xla_client.Shape.array_shape( + np.dtype(np.float32), (), () + ), + operand_shapes_with_layout=[], + api_version=xla_client.ops.CustomCallApiVersion + .API_VERSION_STATUS_RETURNING_UNIFIED, + ) + with self.assertRaisesRegex( + xla_client.XlaRuntimeError, expected_regex="NOT_FOUND" + ): + self._Execute(c, arguments=()) + + def testCustomCallTypedFfiUnknownTarget(self): + if self.backend.platform != "cpu": + self.skipTest("Test requires cpu platform") + c = self._NewComputation() + + ops.CustomCallWithLayout( + c, + b"not_existing", + operands=[], + shape_with_layout=xla_client.Shape.array_shape( + np.dtype(np.float32), (), () + ), + operand_shapes_with_layout=[], + api_version=xla_client.ops.CustomCallApiVersion.API_VERSION_TYPED_FFI, + ) + with self.assertRaises(xla_client.XlaRuntimeError): + self._Execute(c, arguments=()) + + def testCustomCallTypedFfiAlwaysFail(self): + if self.backend.platform != "cpu": + self.skipTest("Test requires cpu platform") + c = self._NewComputation() + + ops.CustomCallWithLayout( + c, + b"always_fail", + operands=[], + shape_with_layout=xla_client.Shape.array_shape( + np.dtype(np.float32), (), () + ), + operand_shapes_with_layout=[], + api_version=xla_client.ops.CustomCallApiVersion.API_VERSION_TYPED_FFI, + ) + + with self.assertRaisesRegex( + Exception, expected_regex="Failed intentionally" + ): + self._Execute(c, arguments=()) + + def testCustomCallTypedFfiAlwaysSucceed(self): + if self.backend.platform != "cpu": + self.skipTest("Test requires cpu platform") + c = self._NewComputation() + + ops.CustomCallWithLayout( + c, + b"always_succeed", + operands=[], + shape_with_layout=xla_client.Shape.array_shape( + np.dtype(np.float32), (), () + ), + operand_shapes_with_layout=[], + api_version=xla_client.ops.CustomCallApiVersion.API_VERSION_TYPED_FFI, + ) + + self._Execute(c, arguments=()) + + def testCustomCallTypedFfiSubtract(self): + if self.backend.platform != "cpu": + self.skipTest("Test requires cpu platform") + c = self._NewComputation() + + ops.CustomCallWithLayout( + c, + b"subtract_f32_cst", + operands=[ops.Constant(c, np.float32(1.25))], + shape_with_layout=xla_client.Shape.array_shape( + np.dtype(np.float32), (), () + ), + operand_shapes_with_layout=[ + xla_client.Shape.array_shape(np.dtype(np.float32), (), ()), + ], + opaque=b"{cst = 3.0 : f32}", + api_version=xla_client.ops.CustomCallApiVersion.API_VERSION_TYPED_FFI, + ) + self._ExecuteAndCompareClose(c, expected=[-1.75]) + + def testStatefulCustomCall(self): + if self.backend.platform != "cpu": + self.skipTest("Test requires cpu platform") + c = self._NewComputation() + ops.CustomCallWithLayout( + c, + b"stateful", + operands=[], + shape_with_layout=xla_client.Shape.array_shape( + np.dtype(np.int32), (), ()), + operand_shapes_with_layout=[], + api_version=xla_client.ops.CustomCallApiVersion + .API_VERSION_TYPED_FFI) + self._ExecuteAndCompareClose(c, expected=[42]) + + def testCustomCallLookup(self): + if self.backend.platform != "cpu": + self.skipTest("Test requires cpu platform") + if xla_client._version < 241: + self.skipTest("Test requires jaxlib version 241") + + self.assertTrue(_CUSTOM_CALLS_REGISTERED) + xla_client.make_cpu_client() + self.assertContainsSubset( + list(custom_calls_testlib.registrations().keys()), + xla_client.custom_call_targets("Host").keys(), + ) + + tests.append(ComputationsWithConstantsTest) + + class ComputationFromProtoTest(absltest.TestCase): + """Test computation execution from HLO proto.""" + + def setUp(self): + super(ComputationFromProtoTest, self).setUp() + self.backend = xla_backend() + + def testExecuteFromProto(self): + # Build the HLO proto + b = xla_client.XlaBuilder("computation") + ops.Add(ops.Constant(b, np.int32(1)), ops.Constant(b, np.int32(2))) + serialized_proto = b.build().as_serialized_hlo_module_proto() + + # Load and execute the proto + c = xla_client.XlaComputation(serialized_proto) + m = xla_computation_to_mlir_module(c) + ans, = execute_with_python_values( + self.backend.compile(m), (), backend=self.backend) + np.testing.assert_equal(ans, np.int32(3)) + + tests.append(ComputationFromProtoTest) + + class ParametersTest(ComputationTest): + """Tests focusing on Parameter ops and argument-passing.""" + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in int_dtypes) + def testScalarTimesVector(self, dtype): + c = self._NewComputation() + arg0 = np.array(3, dtype=dtype) + if np.issubdtype(dtype, np.unsignedinteger): + arg1 = np.array([10, 15, 2, 7], dtype=dtype) + else: + arg1 = np.array([10, 15, -2, 7], dtype=dtype) + p0 = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg0)) + p1 = ops.Parameter(c, 1, xla_client.shape_from_pyval(arg1)) + ops.Mul(p0, p1) + self._ExecuteAndCompareExact( + c, arguments=[arg0, arg1], expected=[arg0 * arg1]) + + # TODO(phawkins): test comparison harness doesn't support bfloat16 + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes if dtype != bfloat16) + def testScalarMinusVectorExplicitNumbering(self, dtype): + # Use explicit numbering and pass parameter_num first. Sub is used since + # it's not commutative and can help catch parameter reversal within the + # computation. + c = self._NewComputation() + arg0 = np.array(2.0, dtype=dtype) + arg1 = np.array([-2.3, 3.3, -4.3, 5.3], dtype=dtype) + p1 = ops.Parameter(c, 1, xla_client.shape_from_pyval(arg1)) + p0 = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg0)) + ops.Sub(p1, p0) + self._ExecuteAndCompareClose( + c, arguments=[arg0, arg1], expected=[arg1 - arg0]) + + tests.append(ParametersTest) + + class LayoutsTest(ComputationTest): + """Tests related to getting and setting on-device memory layouts.""" + + def _minor_to_major(self, layout: xla_client.PjRtLayout): # pylint: disable=invalid-name + m2m_str = re.search("{([0-9,]*)", str(layout)).group(1) + if not m2m_str: + return () + return tuple(int(x) for x in m2m_str.split(",")) + + @unittest.skipIf(pathways, "not implemented") + def testGetArgumentLayouts(self): + # Create computation with a few parameters. + c = self._NewComputation() + param_count = 0 + + def MakeArg(shape, dtype): + nonlocal param_count + shape = xla_client.Shape.array_shape(np.dtype(dtype), shape) + param = ops.Parameter(c, param_count, shape) + param_count += 1 + return param + + p0 = MakeArg((2, 3, 4), np.float32) + MakeArg((3, 2), np.int32) + MakeArg((), np.float64) + + ops.Add(p0, ops.Constant(c, np.ones((2, 3, 4), np.float32))) + executable = self.backend.compile( + xla_computation_to_mlir_module(c.build())) + + # Test that compiled executable returns plausible layouts. + layouts: Sequence[xla_client.Layout] = executable.get_parameter_layouts() + self.assertLen(layouts, 3) + self.assertLen(self._minor_to_major(layouts[0]), 3) + self.assertLen(self._minor_to_major(layouts[1]), 2) + self.assertEmpty(self._minor_to_major(layouts[2])) + + @unittest.skipIf(pathways, "not implemented") + def testGetArgumentLayoutsTupled(self): + # Generated with: + # jax.jit(lambda x, y, z: (x, y, z))(np.ones((1024, 8, 128)), + # np.int32(42), + # np.ones(10)) + module_str = """ +module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, + mhlo.num_replicas = 1 : i32} { + func.func public @main( + %arg0: tensor<1024x8x128xf32> {mhlo.sharding = "{replicated}"}, + %arg1: tensor {mhlo.sharding = "{replicated}"}, + %arg2: tensor<10xf32> {mhlo.sharding = "{replicated}"}) + -> (tensor<1024x8x128xf32> {jax.result_info = "[0]"}, + tensor {jax.result_info = "[1]"}, + tensor<10xf32> {jax.result_info = "[2]"}) { + return %arg0, %arg1, %arg2 : tensor<1024x8x128xf32>, tensor, tensor<10xf32> + } +} +""" + options = xla_client.CompileOptions() + # 'parameter_is_tupled_arguments' causes MLIR untupled arguments to get + # turned into HLO tupled arguments. + options.parameter_is_tupled_arguments = True + executable = self.backend.compile(module_str, compile_options=options) + + # Test that compiled executable returns plausible layouts. + layouts: Sequence[xla_client.Layout] = executable.get_parameter_layouts() + self.assertLen(layouts, 3) + self.assertLen(self._minor_to_major(layouts[0]), 3) + self.assertEmpty(self._minor_to_major(layouts[1])) + self.assertLen(self._minor_to_major(layouts[2]), 1) + + @unittest.skipIf(pathways, "not implemented") + def testGetOutputLayouts(self): + # Generated with jax.jit(lambda: (np.ones((1024, 128)), np.int32(42), + # np.ones(10)))() + module_str = """ +module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, + mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<1024x128xf32> {jax.result_info = "[0]"}, + tensor {jax.result_info = "[1]"}, + tensor<10xf32> {jax.result_info = "[2]"}) { + %0 = stablehlo.constant dense<1.000000e+00> : tensor<1024x128xf32> + %1 = stablehlo.constant dense<1.000000e+00> : tensor<10xf32> + %2 = stablehlo.constant dense<42> : tensor + return %0, %2, %1 : tensor<1024x128xf32>, tensor, tensor<10xf32> + } +} +""" + executable = self.backend.compile(module_str) + + # Test that compiled executable returns plausible layouts. + layouts: Sequence[xla_client.Layout] = executable.get_output_layouts() + self.assertLen(layouts, 3) + self.assertLen(self._minor_to_major(layouts[0]), 2) + self.assertEmpty(self._minor_to_major(layouts[1])) + self.assertLen(self._minor_to_major(layouts[2]), 1) + + @unittest.skipIf(pathways, "not implemented") + def testSetArgumentLayouts(self): + # TODO(b/309682374): implement on CPU and GPU + if self.backend.platform != "tpu": + raise self.skipTest("mhlo.layout_mode only implemented on TPU") + + # Hand-edited version of: + # jax.jit(lambda x, y, z: (x, y, z))(np.ones((1024, 8, 128)), + # np.int32(42), + # np.ones(10)) + module_str = """ +module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, + mhlo.num_replicas = 1 : i32} { + func.func public @main( + %arg0: tensor<1024x8x128xf32> {mhlo.sharding = "{replicated}", + mhlo.layout_mode = "{0,1,2}"}, + %arg1: tensor {mhlo.sharding = "{replicated}", + mhlo.layout_mode = "{}"}, + %arg2: tensor<10xf32> {mhlo.sharding = "{replicated}", + mhlo.layout_mode = "{0}"}) + -> (tensor<1024x8x128xf32> {jax.result_info = "[0]"}, + tensor {jax.result_info = "[1]"}, + tensor<10xf32> {jax.result_info = "[2]"}) { + return %arg0, %arg1, %arg2 : tensor<1024x8x128xf32>, tensor, tensor<10xf32> + } +} + """ + executable = self.backend.compile(module_str) + + # Check input layouts. + input_layouts = executable.get_parameter_layouts() + self.assertLen(input_layouts, 3) + self.assertEqual(self._minor_to_major(input_layouts[0]), (0, 1, 2)) + self.assertEqual(self._minor_to_major(input_layouts[1]), ()) + self.assertEqual(self._minor_to_major(input_layouts[2]), (0,)) + + # Compile a version with default arg0 layout so we can make sure we + # actually set it above. + default_executable = self.backend.compile( + module_str.replace('"{0,1,2}"', '"default"') + ) + self.assertNotEqual( + self._minor_to_major(input_layouts[0]), + self._minor_to_major(default_executable.get_parameter_layouts()[0]), + ) + + @unittest.skipIf(pathways or pathways_ifrt, "not implemented") + def testSetArgumentLayoutsLegacy(self): + """Tests setting the arg layouts with compile_options (deprecated). + + New code should use the mhlo.layout_mode string attr on parameters. + """ + # Create computation with custom input layouts. + c = self._NewComputation() + param_count = 0 + + def MakeArg(shape, dtype, layout): + nonlocal param_count + arr = np.arange(np.prod(shape), dtype=dtype).reshape(shape) + param = ops.Parameter(c, param_count, + xla_client.shape_from_pyval(arr, layout)) + param_count += 1 + shape = xla_client.Shape.array_shape(np.dtype(dtype), shape, layout) + return arr, param, shape + + arg0, p0, shape0 = MakeArg((2, 3, 4), np.float32, (1, 2, 0)) + arg1, p1, shape1 = MakeArg((3, 2), np.int32, (0, 1)) + arg2, p2, shape2 = MakeArg((), np.float64, ()) + + ops.Tuple(c, [ + ops.Add(p0, ops.Constant(c, np.ones(arg0.shape, arg0.dtype))), + ops.Add(p1, ops.Constant(c, np.ones(arg1.shape, arg1.dtype))), + ops.Add(p2, ops.Constant(c, np.ones(arg2.shape, arg2.dtype))), + ]) + + # We also need to set the input layouts in the compile options. + options = xla_client.CompileOptions() + options.argument_layouts = [shape0, shape1, shape2] + executable = self.backend.compile( + xla_computation_to_mlir_module(c.build()), compile_options=options) + + # Test that compiled executable has expected layouts. + expected_layouts: Sequence[xla_client.Shape] = [shape0, shape1, shape2] + actual_layouts: Sequence[xla_client.Layout] = ( + executable.get_parameter_layouts()) + self.assertEqual(len(actual_layouts), len(expected_layouts)) + for actual, expected in zip(actual_layouts, expected_layouts): + self.assertEqual( + self._minor_to_major(actual), + expected.layout().minor_to_major(), + ) + + @unittest.skipIf(pathways, "not implemented") + def testSetOutputLayouts(self): + # TODO(b/309682374): implement on CPU and GPU + if self.backend.platform != "tpu": + raise self.skipTest("mhlo.layout_mode only implemented on TPU") + + # Hand-edited version of: + # jax.jit(lambda x, y, z: (x, y, z))(np.ones((1024, 8, 128)), + # np.int32(42), + # np.ones(10)) + module_str = """ +module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, + mhlo.num_replicas = 1 : i32} { + func.func public @main( + %arg0: tensor<1024x8x128xf32> {mhlo.sharding = "{replicated}"}, + %arg1: tensor {mhlo.sharding = "{replicated}"}, + %arg2: tensor<10xf32> {mhlo.sharding = "{replicated}"}) + -> (tensor<1024x8x128xf32> {jax.result_info = "[0]", + mhlo.layout_mode = "{0,1,2}"}, + tensor {jax.result_info = "[1]", + mhlo.layout_mode = "{}"}, + tensor<10xf32> {jax.result_info = "[2]", + mhlo.layout_mode = "{0}"}) { + return %arg0, %arg1, %arg2 : tensor<1024x8x128xf32>, tensor, tensor<10xf32> + } +} + """ + executable = self.backend.compile(module_str) + + # Check output layouts. + output_layouts = executable.get_output_layouts() + self.assertLen(output_layouts, 3) + self.assertEqual(self._minor_to_major(output_layouts[0]), (0, 1, 2)) + self.assertEqual(self._minor_to_major(output_layouts[1]), ()) + self.assertEqual(self._minor_to_major(output_layouts[2]), (0,)) + + # Compile a version with default first output layout so we can make sure + # we actually set it above. + default_executable = self.backend.compile( + module_str.replace('"{0,1,2}"', '"default"') + ) + self.assertNotEqual( + self._minor_to_major(output_layouts[0]), + self._minor_to_major(default_executable.get_output_layouts()[0]), + ) + + @unittest.skipIf(pathways, "not implemented") + def SetLayoutsSharded(self): + # TODO(b/309682374): implement on CPU and GPU + if self.backend.platform != "tpu": + raise self.skipTest("mhlo.layout_mode only implemented on TPU") + + # Hand-edited version of: + # sharding = PositionalSharding(mesh_utils.create_device_mesh((8,))) + # x = jax.device_put(np.ones((1024, 128)), sharding.reshape(4, 2)) + # jax.jit(lambda x, y: x + y, out_shardings=sharding)(x, 1.) + # + # This also lightly tests mixed default + user-specified input layouts. + module_str = """ +module @jit__lambda_ attributes {mhlo.num_partitions = 8 : i32, + mhlo.num_replicas = 1 : i32} { + func.func public @main( + %arg0: tensor<1024x128xf32> {mhlo.sharding = "{devices=[4,2]0,1,2,3,4,5,6,7}", + mhlo.layout_mode = "{0,1}"}, + %arg1: tensor {mhlo.sharding = "{replicated}"}) + -> (tensor<1024x128xf32> {jax.result_info = "", + mhlo.sharding = "{devices=[4,2]0,1,2,3,4,5,6,7}", + mhlo.layout_mode = "{0,1}"}) { + %0 = stablehlo.convert %arg1 : tensor + %1 = stablehlo.broadcast_in_dim %0, dims = [] : (tensor) -> tensor<1024x128xf32> + %2 = stablehlo.add %arg0, %1 : tensor<1024x128xf32> + return %2 : tensor<1024x128xf32> + } +} + """ + executable = self.backend.compile(module_str) + + # Check input layouts. + input_layouts = executable.get_parameter_layouts() + self.assertLen(input_layouts, 2) + self.assertEqual(self._minor_to_major(input_layouts[0]), (0, 1)) + self.assertEqual(self._minor_to_major(input_layouts[1]), ()) + + # Check output layout. + output_layouts = executable.get_output_layouts() + self.assertLen(output_layouts, 1) + self.assertEqual(self._minor_to_major(input_layouts[0]), (0, 1)) + + # Compile a version with default layouts so we can make sure we actually + # set it above. + default_executable = self.backend.compile( + module_str.replace('"{0,1}"', '"default"') + ) + self.assertNotEqual( + self._minor_to_major(input_layouts[0]), + self._minor_to_major(default_executable.get_parameter_layouts()[0]), + ) + self.assertNotEqual( + self._minor_to_major(output_layouts[0]), + self._minor_to_major(default_executable.get_output_layouts()[0]), + ) + + @unittest.skipIf(pathways, "not implemented") + def testAutoArgumentLayouts(self): + # TODO(b/309682374): implement on CPU and GPU + if self.backend.platform != "tpu": + raise self.skipTest("mhlo.layout_mode only implemented on TPU") + + # Hand-edited version of: + # jax.numpy.einsum("...a,ahd->...hd", ...) + module_str = """ +module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, + mhlo.num_replicas = 1 : i32} { + func.func public @main( + %arg0: tensor<1024x1024xf32> {mhlo.sharding = "{replicated}", + mhlo.layout_mode = "auto"}, + %arg1: tensor<1024x8x128xf32> {mhlo.sharding = "{replicated}", + mhlo.layout_mode = "auto"}) + -> (tensor<1024x8x128xf32> {jax.result_info = ""}) { + %0 = stablehlo.dot_general %arg0, %arg1, + contracting_dims = [1] x [0], + precision = [DEFAULT, DEFAULT] : (tensor<1024x1024xf32>, + tensor<1024x8x128xf32>) + -> tensor<1024x8x128xf32> + return %0 : tensor<1024x8x128xf32> + } +} +""" + executable = self.backend.compile(module_str) + + # Check input layouts. + input_layouts = executable.get_parameter_layouts() + self.assertEqual(self._minor_to_major(input_layouts[0]), (1, 0)) + self.assertEqual(self._minor_to_major(input_layouts[1]), (2, 0, 1)) + + # Compile a version with default layouts so we can make sure the compiler + # is actually choosing above. + default_executable = self.backend.compile( + module_str.replace('"auto"', '"default"') + ) + # We expect the compiler to choose a non-default layout for the second + # (1024,8,128) argument. + self.assertNotEqual( + self._minor_to_major(input_layouts[1]), + self._minor_to_major(default_executable.get_parameter_layouts()[1]), + ) + + @unittest.skipIf(pathways, "not implemented") + def testAutoOutputLayouts(self): + # TODO(b/309682374): implement on CPU and GPU + if self.backend.platform != "tpu": + raise self.skipTest("mhlo.layout_mode only implemented on TPU") + + # Generated with jax.numpy.einsum("...a,ahd->...hd", ...) + module_str = """ +module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, + mhlo.num_replicas = 1 : i32} { + func.func public @main( + %arg0: tensor<1024x1024xf32> {mhlo.sharding = "{replicated}"}, + %arg1: tensor<1024x8x128xf32> {mhlo.sharding = "{replicated}"}) + -> (tensor<1024x8x128xf32> {jax.result_info = "", + mhlo.layout_mode = "auto"}) { + %0 = stablehlo.dot_general %arg0, %arg1, + contracting_dims = [1] x [0], + precision = [DEFAULT, DEFAULT] : (tensor<1024x1024xf32>, + tensor<1024x8x128xf32>) + -> tensor<1024x8x128xf32> + return %0 : tensor<1024x8x128xf32> + } +} +""" + executable = self.backend.compile(module_str) + + # Check output layout + output_layout, = executable.get_output_layouts() + self.assertEqual(self._minor_to_major(output_layout), (2, 0, 1)) + + # Compile a version with default layouts so we can make sure the compiler + # is actually choosing above. + default_executable = self.backend.compile( + module_str.replace('"auto"', '"default"') + ) + # We expect the compiler to choose a non-default output layout. + self.assertNotEqual( + self._minor_to_major(output_layout), + self._minor_to_major(default_executable.get_output_layouts()[0]), + ) + + tests.append(LayoutsTest) + + class BufferTest(ComputationTest): + """Tests focusing on execution with Buffers.""" + + def testConstantSum(self): + c = self._NewComputation() + ops.Add( + ops.Constant(c, np.float32(1.11)), ops.Constant(c, np.float32(3.14))) + self._ExecuteAndCompareClose(c, expected=[4.25]) + + def testOneParameterSum(self): + c = self._NewComputation() + ops.Add( + ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0.))), + ops.Constant(c, np.float32(3.14))) + self._ExecuteAndCompareClose( + c, arguments=[NumpyArrayF32(1.11)], expected=[4.25]) + + def testTwoParameterSum(self): + c = self._NewComputation() + ops.Add( + ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0.))), + ops.Parameter(c, 1, xla_client.shape_from_pyval(NumpyArrayF32(0.)))) + self._ExecuteAndCompareClose( + c, + arguments=[NumpyArrayF32(1.11), + NumpyArrayF32(3.14)], + expected=[4.25]) + + @unittest.skipIf(cloud_tpu or pathways, "not implemented") + def testCannotCallWithDeletedBuffers(self): + c = self._NewComputation() + ops.Add( + ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0.))), + ops.Constant(c, np.float32(3.14))) + arg = NumpyArrayF32(1.11) + compiled_c = self.backend.compile( + xla_computation_to_mlir_module(c.build())) + arg_buffer = self.backend.buffer_from_pyval(arg) + arg_buffer.delete() + with self.assertRaises(xla_client.XlaRuntimeError): + compiled_c.execute([arg_buffer]) + + def testXlaShapeIndex(self): + a = xla_client.ShapeIndex((1, 2)) + b = xla_client.ShapeIndex((1, 2)) + c = xla_client.ShapeIndex((2, 3)) + self.assertEqual(a, b) + self.assertNotEqual(b, c) + + def testLayout(self): + f32 = xla_client.PrimitiveType.F32 + a = xla_client.Shape.array_shape(f32, (2, 3), (0, 1)).layout() + b = xla_client.Shape.array_shape(f32, (2, 3), (0, 1)).layout() + c = xla_client.Shape.array_shape(f32, (2, 3), (1, 0)).layout() + self.assertEqual(a.minor_to_major(), (0, 1)) + self.assertEqual(b.minor_to_major(), (0, 1)) + self.assertEqual(c.minor_to_major(), (1, 0)) + self.assertEqual(a, b) + self.assertNotEqual(a, c) + self.assertNotEqual(b, c) + self.assertEqual(hash(a), hash(b)) + self.assertNotEqual(hash(a), hash(c)) + self.assertNotEqual(hash(b), hash(c)) + + def testBlockUntilReadyWorks(self): + arg = np.array([[1., 2.]], np.float32) + arg_buffer = self.backend.buffer_from_pyval(arg) + arg_buffer.block_until_ready() + # This test merely checks that nothing goes awry when we call + # block_until_ready(); it's difficult to test anything else. + + def testBlockUntilReadyRaisesOnDeletedBuffer(self): + arg = np.array([[1., 2.]], np.float32) + buffer = self.backend.buffer_from_pyval(arg) + buffer.delete() + with self.assertRaisesRegex( + RuntimeError, + re.escape( + "BlockHostUntilReady() called on deleted or donated buffer")): + buffer.block_until_ready() + + @unittest.skipIf(pathways_ifrt, "not implemented") + def testOnDeviceSizeInBytes(self): + if not isinstance(self.backend, xla_client.Client): + self.skipTest("TPU Driver doesn't support OnDeviceSizeInBytes.") + arg0 = np.array([]) + arg1 = np.array([[0., 1., 2.]], np.float32) + arg2 = np.array([[3., 4., 5.]], bfloat16) + arg0_buffer = self.backend.buffer_from_pyval(arg0) + arg1_buffer = self.backend.buffer_from_pyval(arg1) + arg2_buffer = self.backend.buffer_from_pyval(arg2) + self.assertEqual(arg0_buffer.on_device_size_in_bytes(), 0) + # OnDeviceSizeInBytes varies depending on the platform. Confirm there's + # a reasonable value. + self.assertGreater(arg1_buffer.on_device_size_in_bytes(), 0) + self.assertGreater(arg2_buffer.on_device_size_in_bytes(), 0) + + def testLiveBuffers(self): + if not isinstance(self.backend, xla_client.Client): + self.skipTest("TPU Driver doesn't support LiveBuffers().") + self.assertEmpty(self.backend.live_buffers()) + arg0 = np.array([]) + arg1 = np.array([[0., 1., 2.]], np.float32) + arg2 = np.array([[3., 4., 5.]], bfloat16) + arg0_buffer = self.backend.buffer_from_pyval(arg0) + arg1_buffer = self.backend.buffer_from_pyval(arg1) + arg2_buffer = self.backend.buffer_from_pyval(arg2) + self.assertLen(self.backend.live_buffers(), 3) + self.assertIs(self.backend.live_buffers()[0], arg2_buffer) + self.assertIs(self.backend.live_buffers()[1], arg1_buffer) + self.assertIs(self.backend.live_buffers()[2], arg0_buffer) + + arg1_buffer.delete() + self.assertLen(self.backend.live_buffers(), 2) + self.assertIs(self.backend.live_buffers()[0], arg2_buffer) + self.assertIs(self.backend.live_buffers()[1], arg0_buffer) + + arg0_buffer.delete() + arg2_buffer.delete() + self.assertEmpty(self.backend.live_buffers()) + + def testCopyToHost(self): + arg0 = np.array([[1., 2.]], np.float32) + arg1 = np.array([[3., 4.]], np.float32) + arg0_buffer = self.backend.buffer_from_pyval(arg0) + arg1_buffer = self.backend.buffer_from_pyval(arg1) + # Prefetch two buffers using copy_to_host_async, and then retrieve their + # values using np.asarray(). + arg0_buffer.copy_to_host_async() + arg0_buffer.copy_to_host_async() # Duplicate calls don't do anything. + arg1_buffer.copy_to_host_async() + np.testing.assert_equal(arg0, np.asarray(arg0_buffer)) + np.testing.assert_equal(arg1, np.asarray(arg1_buffer)) + # copy_to_host_async does nothing after np.asarray() is called. + arg0_buffer.copy_to_host_async() + np.testing.assert_equal(arg0, np.asarray(arg0_buffer)) + + def testDevice(self): + x = np.arange(8, dtype=np.int32) + for device in self.backend.local_devices(): + buf = self.backend.buffer_from_pyval(x, device=device) + self.assertEqual(buf.device(), device) + np.testing.assert_equal(x, np.asarray(buf)) + + def testStandardTypes(self): + for dtype in standard_dtypes: + if dtype == np.complex128: + continue + # float8_e4m3b11fnuz not supported on some TPU backends. + if ( + dtype in [float8_e5m2fnuz, float8_e4m3fnuz, float8_e4m3b11fnuz] + and self.backend.platform == "tpu" + ): + if self.backend.platform_version.find("TPU") == -1: + continue + arr = self.backend.buffer_from_pyval(np.array([0, 1], dtype)) + arr = np.asarray(arr) + self.assertEqual(dtype, type(arr[0])) + + @unittest.skipIf(pathways_ifrt, "not implemented") + def testUnsafeBufferPointer(self): + if not isinstance(self.backend, xla_client.Client): + self.skipTest("TPU Driver doesn't support UnsafeBufferPointer().") + arg0 = np.array([]) + arg1 = np.array([[0., 1., 2.]], np.float32) + arg2 = np.array([[3., 4., 5.]], bfloat16) + arg0_buffer = self.backend.buffer_from_pyval(arg0) + arg1_buffer = self.backend.buffer_from_pyval(arg1) + arg2_buffer = self.backend.buffer_from_pyval(arg2) + self.assertGreaterEqual(arg0_buffer.unsafe_buffer_pointer(), 0) + self.assertGreaterEqual(arg1_buffer.unsafe_buffer_pointer(), 0) + self.assertGreaterEqual(arg2_buffer.unsafe_buffer_pointer(), 0) + + @unittest.skipIf(cloud_tpu or pathways or pathways_ifrt, "not implemented") + def testClone(self): + x = np.array([[3., 4., 5.]], np.float32) + y = self.backend.buffer_from_pyval(x) + z = y.clone() + self.assertNotEqual(id(x), id(y)) + np.testing.assert_array_equal(np.asarray(y), np.asarray(z)) + self.assertEqual(y.unsafe_buffer_pointer(), z.unsafe_buffer_pointer()) + + tests.append(BufferTest) + + class SingleOpTest(ComputationTest): + """Tests for single ops. + + The goal here is smoke testing - to exercise the most basic functionality of + single XLA ops. As minimal as possible number of additional ops are added + around the op being tested. + """ + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testConcatenate(self, dtype): + c = self._NewComputation() + args = ( + ops.Constant(c, np.array([1.0, 2.0, 3.0], dtype=dtype)), + ops.Constant(c, np.array([4.0, 5.0, 6.0], dtype=dtype)), + ) + ops.ConcatInDim(c, args, dimension=0) + self._ExecuteAndCompareExact( + c, expected=[np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=dtype)]) + + # pyformat: disable + @parameterized.named_parameters({ + "testcase_name": "_{}_{}".format(src_dtype.__name__, + dst_dtype.__name__), + "src_dtype": src_dtype, + "dst_dtype": dst_dtype, + } for src_dtype, dst_dtype in itertools.permutations( + [np.bool_, np.int32, np.int64, np.float32, np.float64], 2)) + # pyformat: enable + def testConvertElementType(self, src_dtype, dst_dtype): + if ((src_dtype in [np.int64, np.float64] or + dst_dtype in [np.int64, np.float64]) and + self.backend.platform == "tpu"): + self.skipTest("TPU doesn't support float64") + c = self._NewComputation() + x = np.array([0, 1, 0, 0, 1], dtype=src_dtype) + ops.ConvertElementType( + ops.Constant(c, x), xla_client.dtype_to_etype(dst_dtype)) + + result = execute_with_python_values( + self.backend.compile(xla_computation_to_mlir_module(c.build())), (), + backend=self.backend) + self.assertLen(result, 1) + expected = np.array(x, dtype=dst_dtype) + + self.assertEqual(result[0].shape, expected.shape) + self.assertEqual(result[0].dtype, expected.dtype) + np.testing.assert_equal(result[0], expected) + + # pyformat: disable + @parameterized.named_parameters( + { + "testcase_name": "_{}_{}".format(src_dtype.__name__, + dst_dtype.__name__), + "src_dtype": src_dtype, + "dst_dtype": dst_dtype, + } + for dtypes in [[np.int32, np.float32], [np.int64, np.float64]] + for src_dtype, dst_dtype in itertools.permutations(dtypes, 2)) + # pyformat: enable + def testBitcastConvertType(self, src_dtype, dst_dtype): + if (np.float64 in (src_dtype, dst_dtype) and + self.backend.platform == "tpu"): + self.skipTest("TPU doesn't support float64") + c = self._NewComputation() + x = np.array([0, 1, 0, 0, 1], dtype=src_dtype) + ops.BitcastConvertType( + ops.Constant(c, x), xla_client.dtype_to_etype(dst_dtype)) + + result = execute_with_python_values( + self.backend.compile(xla_computation_to_mlir_module(c.build())), (), + backend=self.backend) + self.assertLen(result, 1) + expected = x.view(dst_dtype) + + self.assertEqual(result[0].shape, expected.shape) + self.assertEqual(result[0].dtype, expected.dtype) + np.testing.assert_equal(result[0], expected) + + # TODO(b/123523486) implement AllToAll on CPU + def DISABLED_testAllToAllOneReplica(self): + samples = [ + NumpyArrayF32([97.0]), + NumpyArrayF32([64.0, 117.0]), + NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]), + ] + for lhs in samples[:1]: + c = self._NewComputation() + ops.AllToAll(ops.Constant(c, lhs), 0, 0) + self._ExecuteAndCompareExact(c, expected=[lhs]) + + def testCrossReplicaSumOneReplica(self): + samples = [ + NumpyArrayF32(42.0), + NumpyArrayF32([97.0]), + NumpyArrayF32([64.0, 117.0]), + NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]), + ] + for lhs in samples: + c = self._NewComputation() + ops.CrossReplicaSum(ops.Constant(c, lhs)) + self._ExecuteAndCompareExact(c, expected=[lhs]) + + def testReplicaId(self): + c = self._NewComputation() + _ = ops.ReplicaId(c) + self._ExecuteAndCompareExact(c, expected=[0]) + + def testCrossReplicaSumOneReplicaWithSingletonGroup(self): + samples = [ + NumpyArrayF32(42.0), + NumpyArrayF32([97.0]), + NumpyArrayF32([64.0, 117.0]), + NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]), + ] + for lhs in samples: + c = self._NewComputation() + ops.CrossReplicaSum( + ops.Constant(c, lhs), xla_client.make_replica_groups([[0]])) + self._ExecuteAndCompareExact(c, expected=[lhs]) + + # TODO(phawkins): np.dot implementation doesn't support bfloat16 + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes if dtype != bfloat16) + def testDotMatrixVector(self, dtype): + c = self._NewComputation() + lhs = np.array([[2.0, 3.0], [4.0, 5.0]], dtype=dtype) + rhs = np.array([[10.0], [20.0]], dtype=dtype) + ops.Dot(ops.Constant(c, lhs), ops.Constant(c, rhs)) + self._ExecuteAndCompareClose(c, expected=[np.dot(lhs, rhs)]) + + # TODO(phawkins): np.dot implementation doesn't support bfloat16 + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes if dtype != bfloat16) + def testDotMatrixMatrix(self, dtype): + c = self._NewComputation() + lhs = np.array([[2.0, 3.0], [4.0, 5.0]], dtype=dtype) + rhs = np.array([[10.0, 20.0], [100.0, 200.0]], dtype=dtype) + ops.Dot(ops.Constant(c, lhs), ops.Constant(c, rhs)) + self._ExecuteAndCompareClose(c, expected=[np.dot(lhs, rhs)]) + + def testDotGeneral(self): + c = self._NewComputation() + rng = np.random.RandomState(0) + lhs = NumpyArrayF32(rng.randn(10, 3, 4)) + rhs = NumpyArrayF32(rng.randn(10, 4, 5)) + dimension_numbers = xla_client.make_dot_dimension_numbers( + (([2], [1]), ([0], [0]))) + ops.DotGeneral( + ops.Constant(c, lhs), ops.Constant(c, rhs), dimension_numbers) + self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=1e-6) + + def testDotGeneralWithDotDimensionNumbersProto(self): + c = self._NewComputation() + rng = np.random.RandomState(0) + lhs = NumpyArrayF32(rng.randn(10, 3, 4)) + rhs = NumpyArrayF32(rng.randn(10, 4, 5)) + + dimension_numbers = xla_client.DotDimensionNumbers() + dimension_numbers.lhs_contracting_dimensions.append(2) + dimension_numbers.rhs_contracting_dimensions.append(1) + dimension_numbers.lhs_batch_dimensions.append(0) + dimension_numbers.rhs_batch_dimensions.append(0) + + ops.DotGeneral( + ops.Constant(c, lhs), ops.Constant(c, rhs), dimension_numbers) + self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=1e-6) + + def testDotGeneralWithPrecisionConfig(self): + c = self._NewComputation() + rng = np.random.RandomState(0) + lhs = NumpyArrayF32(rng.randn(10, 3, 4)) + rhs = NumpyArrayF32(rng.randn(10, 4, 5)) + dimension_numbers = xla_client.make_dot_dimension_numbers( + (([2], [1]), ([0], [0]))) + config = xla_client.PrecisionConfig() + config.operand_precision.append(config.Precision.HIGH) + config.operand_precision.append(config.Precision.HIGHEST) + ops.DotGeneral( + ops.Constant(c, lhs), + ops.Constant(c, rhs), + dimension_numbers, + precision_config=config) + self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=1e-6) + + def testConvGeneralDilatedF32(self): + c = self._NewComputation() + a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") + lhs = a(1, 1, 2, 3) + rhs = a(1, 1, 1, 2) * 10 + strides = [1, 1] + pads = [(1, 0), (0, 1)] + lhs_dilation = (2, 1) + rhs_dilation = (1, 1) + dimension_numbers = xla_client.make_convolution_dimension_numbers( + ("NCHW", "OIHW", "NCHW"), 2) + ops.ConvGeneralDilated( + ops.Constant(c, lhs), ops.Constant(c, rhs), strides, pads, + lhs_dilation, rhs_dilation, dimension_numbers) + result = np.array([[[ + [0., 0., 0.], + [10., 20., 0.], + [0., 0., 0.], + [40., 50., 0.], + ]]]) + self._ExecuteAndCompareClose(c, expected=[result]) + + def testConvGeneralDilatedF32WithPrecisionConfig(self): + c = self._NewComputation() + a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") + lhs = a(1, 1, 2, 3) + rhs = a(1, 1, 1, 2) * 10 + strides = [1, 1] + pads = [(1, 0), (0, 1)] + lhs_dilation = (2, 1) + rhs_dilation = (1, 1) + dimension_numbers = xla_client.make_convolution_dimension_numbers( + ("NCHW", "OIHW", "NCHW"), 2) + config = xla_client.PrecisionConfig() + config.operand_precision.append(config.Precision.HIGHEST) + config.operand_precision.append(config.Precision.DEFAULT) + ops.ConvGeneralDilated( + ops.Constant(c, lhs), + ops.Constant(c, rhs), + strides, + pads, + lhs_dilation, + rhs_dilation, + dimension_numbers, + precision_config=config) + result = np.array([[[ + [0., 0., 0.], + [10., 20., 0.], + [0., 0., 0.], + [40., 50., 0.], + ]]]) + self._ExecuteAndCompareClose(c, expected=[result]) + + def testConvGeneralDilatedPermutedF32(self): + c = self._NewComputation() + a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") + lhs = a(1, 1, 2, 3) + rhs = a(1, 1, 1, 2) * 10 + strides = [1, 1] + pads = [(1, 0), (0, 1)] + lhs_dilation = (2, 1) + rhs_dilation = (1, 1) + + dimension_numbers = xla_client.make_convolution_dimension_numbers( + ("NHWC", "OIHW", "CWNH"), 2) + ops.ConvGeneralDilated( + ops.Constant(c, np.transpose(lhs, + (0, 2, 3, 1))), ops.Constant(c, rhs), + strides, pads, lhs_dilation, rhs_dilation, dimension_numbers) + result = np.array([[[[0., 0., 0.], [10., 20., 0.], [0., 0., 0.], + [40., 50., 0.]]]]) + self._ExecuteAndCompareClose( + c, expected=[np.transpose(result, (1, 3, 0, 2))]) + + def testConvGeneralDilatedGroupedConvolutionF32(self): + c = self._NewComputation() + a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") + lhs = a(1, 2, 2, 3) + rhs = a(2, 1, 1, 2) * 10 + strides = [1, 1] + pads = [(1, 0), (0, 1)] + lhs_dilation = (2, 1) + rhs_dilation = (1, 1) + dimension_numbers = xla_client.make_convolution_dimension_numbers( + ("NCHW", "OIHW", "NCHW"), 2) + feature_group_count = 2 + ops.ConvGeneralDilated( + ops.Constant(c, lhs), ops.Constant(c, rhs), strides, pads, + lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count) + result = np.array([[[ + [0., 0., 0.], + [10., 20., 0.], + [0., 0., 0.], + [40., 50., 0.], + ], [ + [0., 0., 0.], + [330., 380., 160.], + [0., 0., 0.], + [480., 530., 220.], + ]]]) + self._ExecuteAndCompareClose(c, expected=[result]) + + def testConvGeneralDilatedWindowReversalF32(self): + c = self._NewComputation() + a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") + lhs = a(1, 1, 2, 3) + rhs = a(1, 1, 1, 2) * 10 + strides = [1, 1] + pads = [(1, 0), (0, 1)] + lhs_dilation = (2, 1) + rhs_dilation = (1, 1) + window_reversal = [False, True] + dimension_numbers = xla_client.make_convolution_dimension_numbers( + ("NCHW", "OIHW", "NCHW"), 2) + ops.ConvGeneralDilated( + ops.Constant(c, lhs), + ops.Constant(c, rhs), + strides, + pads, + lhs_dilation, + rhs_dilation, + dimension_numbers, + window_reversal=window_reversal) + result = np.array([[[ + [0., 0., 0.], + [0., 10., 20.], + [0., 0., 0.], + [30., 40., 50.], + ]]]) + self._ExecuteAndCompareClose(c, expected=[result]) + + def testBooleanNot(self): + c = self._NewComputation() + arr = NumpyArrayBool([True, False, True]) + ops.Not(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[~arr]) + + def testPopulationCount(self): + c = self._NewComputation() + arr = NumpyArrayS32([3, 0, 1]) + ops.PopulationCount(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[np.array([2, 0, 1])]) + + def testCountLeadingZeros(self): + c = self._NewComputation() + arr = NumpyArrayS32([0x7FFF, 0x12345678]) + ops.Clz(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[[17, 3]]) + + def testExp(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + ops.Exp(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[np.exp(arr)]) + + def testExpWithResultAccuracy(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + accuracy = xla_client.ResultAccuracy() + accuracy.mode = xla_client.ResultAccuracyMode.DEFAULT + ops.Exp(ops.Constant(c, arr), accuracy) + self._ExecuteAndCompareClose(c, expected=[np.exp(arr)]) + + def testExpm1(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + ops.Expm1(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[np.expm1(arr)]) + + def testExpm1WithResultAccuracy(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + accuracy = xla_client.ResultAccuracy() + accuracy.mode = xla_client.ResultAccuracyMode.DEFAULT + ops.Expm1(ops.Constant(c, arr), accuracy) + self._ExecuteAndCompareClose(c, expected=[np.expm1(arr)]) + + def testRound(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + ops.Round(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[np.round(arr)]) + + def testLog(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + ops.Log(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[np.log(arr)]) + + def testLog1p(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + ops.Log1p(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[np.log1p(arr)]) + + def testNeg(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + ops.Neg(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[-arr]) + + def testFloor(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + ops.Floor(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[np.floor(arr)]) + + def testCeil(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + ops.Ceil(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[np.ceil(arr)]) + + def testAbs(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, -12.1, 2.4, -1.]) + ops.Abs(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[np.abs(arr)]) + + def testTanF32(self): + c = self._NewComputation() + arr = NumpyArrayF32([-0.2, 3.3, 12.1, 0.1, 0.0001]) + ops.Tan(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[np.tan(arr)]) + + def testTanhF32(self): + c = self._NewComputation() + arr = NumpyArrayF32([-0.2, 3.3, 12.1, 0.1, 0.0001]) + ops.Tanh(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[np.tanh(arr)]) + + def testTanhF64(self): + if self.backend.platform == "tpu": + self.skipTest("TPU doesn't support 64bit tanh") + c = self._NewComputation() + arr = NumpyArrayF64([-0.2, 3.3, 12.1, 0.1, 0.0001]) + ops.Tanh(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[np.tanh(arr)], rtol=1e-12) + + def testTranspose(self): + + def _TransposeAndTest(array, permutation): + c = self._NewComputation() + ops.Transpose(ops.Constant(c, array), permutation) + expected = np.transpose(array, permutation) + self._ExecuteAndCompareClose(c, expected=[expected]) + + _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [0, 1]) + _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [1, 0]) + _TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [0, 1]) + _TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [1, 0]) + + arr = np.random.RandomState(0).randn(2, 3, 4).astype(np.float32) + for permutation in itertools.permutations(range(arr.ndim)): + _TransposeAndTest(arr, permutation) + _TransposeAndTest(np.asfortranarray(arr), permutation) + + def testEq(self): + c = self._NewComputation() + ops.Eq( + ops.Constant(c, NumpyArrayS32([1, 2, 3, 4])), + ops.Constant(c, NumpyArrayS32([4, 2, 3, 1]))) + self._ExecuteAndCompareExact(c, expected=[[False, True, True, False]]) + + def testNe(self): + c = self._NewComputation() + ops.Ne( + ops.Constant(c, NumpyArrayS32([1, 2, 3, 4])), + ops.Constant(c, NumpyArrayS32([4, 2, 3, 1]))) + self._ExecuteAndCompareExact(c, expected=[[True, False, False, True]]) + + ops.Ne( + ops.Constant(c, NumpyArrayF32([-2.0, 0.0, + float("nan"), + float("nan")])), + ops.Constant(c, NumpyArrayF32([2.0, -0.0, 1.0, + float("nan")]))) + self._ExecuteAndAssertWith( + np.testing.assert_allclose, + c, (), + expected=[[True, False, True, True]]) + + def testGt(self): + c = self._NewComputation() + ops.Gt( + ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 9])), + ops.Constant(c, NumpyArrayS32([1, 0, 2, 7, 12]))) + self._ExecuteAndCompareExact( + c, expected=[[False, True, True, False, False]]) + + def testGe(self): + c = self._NewComputation() + ops.Ge( + ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 9])), + ops.Constant(c, NumpyArrayS32([1, 0, 2, 7, 12]))) + self._ExecuteAndCompareExact( + c, expected=[[True, True, True, False, False]]) + + def testLt(self): + c = self._NewComputation() + ops.Lt( + ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 9])), + ops.Constant(c, NumpyArrayS32([1, 0, 2, 7, 12]))) + self._ExecuteAndCompareExact( + c, expected=[[False, False, False, True, True]]) + + def testLe(self): + c = self._NewComputation() + ops.Le( + ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 9])), + ops.Constant(c, NumpyArrayS32([1, 0, 2, 7, 12]))) + self._ExecuteAndCompareExact( + c, expected=[[True, False, False, True, True]]) + + def testMax(self): + c = self._NewComputation() + ops.Max( + ops.Constant(c, NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])), + ops.Constant(c, NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0]))) + self._ExecuteAndCompareExact(c, expected=[[1.0, 2.0, 3.0, 7.0, 12.0]]) + + def testMaxExplicitBroadcastDim0(self): + c = self._NewComputation() + ops.Max( + ops.Constant(c, NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + ops.Constant(c, NumpyArrayF32([3, 4, 5])), + broadcast_dimensions=(0,)) + self._ExecuteAndCompareExact( + c, expected=[[[3, 3, 3], [4, 5, 6], [7, 8, 9]]]) + + def testMaxExplicitBroadcastDim1(self): + c = self._NewComputation() + ops.Max( + ops.Constant(c, NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + ops.Constant(c, NumpyArrayF32([3, 4, 5])), + broadcast_dimensions=(1,)) + self._ExecuteAndCompareExact( + c, expected=[[[3, 4, 5], [4, 5, 6], [7, 8, 9]]]) + + def testMin(self): + c = self._NewComputation() + ops.Min( + ops.Constant(c, NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])), + ops.Constant(c, NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0]))) + self._ExecuteAndCompareExact(c, expected=[[1.0, 0.0, 2.0, 4.0, 9.0]]) + + def testPad(self): + c = self._NewComputation() + ops.Pad( + ops.Constant(c, NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])), + ops.Constant(c, NumpyArrayF32(0.0)), + xla_client.make_padding_config([(1, 2, 1), (0, 1, 0)])) + self._ExecuteAndCompareClose( + c, + expected=[[[0.0, 0.0, 0.0], [1.0, 2.0, 0.0], [0.0, 0.0, 0.0], + [3.0, 4.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]) + + def testPadWithPaddingConfig(self): + c = self._NewComputation() + padding_config = xla_client.PaddingConfig() + for lo, hi, interior in [(1, 2, 1), (0, 1, 0)]: + dimension = xla_client.PaddingConfigDimension() + dimension.edge_padding_low = lo + dimension.edge_padding_high = hi + dimension.interior_padding = interior + padding_config.dimensions.append(dimension) + ops.Pad( + ops.Constant(c, NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])), + ops.Constant(c, NumpyArrayF32(0.0)), padding_config) + self._ExecuteAndCompareClose( + c, + expected=[[[0.0, 0.0, 0.0], [1.0, 2.0, 0.0], [0.0, 0.0, 0.0], + [3.0, 4.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]) + + def testReshape(self): + c = self._NewComputation() + ops.Reshape( + ops.Constant(c, NumpyArrayS32([[1, 2], [3, 4], [5, 6]])), + new_sizes=[2, 3]) + self._ExecuteAndCompareExact(c, expected=[[[1, 2, 3], [4, 5, 6]]]) + + def testCollapse(self): + c = self._NewComputation() + ops.Collapse( + ops.Constant(c, NumpyArrayS32([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])), + dimensions=[1, 2]) + self._ExecuteAndCompareExact(c, expected=[[[1, 2, 3, 4], [5, 6, 7, 8]]]) + + def testRev(self): + c = self._NewComputation() + ops.Rev( + ops.Constant(c, NumpyArrayS32([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])), + dimensions=[0, 2]) + self._ExecuteAndCompareExact( + c, expected=[[[[6, 5], [8, 7]], [[2, 1], [4, 3]]]]) + + def testReducePrecision(self): + c = self._NewComputation() + ops.ReducePrecision( + ops.Constant(c, NumpyArrayF32([float.fromhex("0x1.32fffep-3")])), + exponent_bits=8, + mantissa_bits=7) + self._ExecuteAndCompareClose(c, expected=[[float.fromhex("0x1.32p-3")]]) + + def testClampF32(self): + c = self._NewComputation() + ops.Clamp( + ops.Constant(c, NumpyArrayF32(-1)), + ops.Constant(c, NumpyArrayF32([-2, -1, 0, 1, 2, 3])), + ops.Constant(c, NumpyArrayF32(2))) + self._ExecuteAndCompareExact(c, expected=[[-1, -1, 0, 1, 2, 2]]) + + def testClampS32(self): + c = self._NewComputation() + ops.Clamp( + ops.Constant(c, NumpyArrayS32(-1)), + ops.Constant(c, NumpyArrayS32([-2, -1, 0, 1, 2, 3])), + ops.Constant(c, NumpyArrayS32(2))) + self._ExecuteAndCompareExact(c, expected=[[-1, -1, 0, 1, 2, 2]]) + + def testSelect(self): + c = self._NewComputation() + ops.Select( + ops.Constant(c, NumpyArrayBool([True, False, False, True, False])), + ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 5])), + ops.Constant(c, NumpyArrayS32([-1, -2, -3, -4, -5]))) + self._ExecuteAndCompareExact(c, expected=[[1, -2, -3, 4, -5]]) + + def testSlice(self): + c = self._NewComputation() + ops.Slice( + ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + [1, 0], [3, 2], [1, 1]) + self._ExecuteAndCompareExact(c, expected=[[[4, 5], [7, 8]]]) + + def testSliceInDim(self): + c = self._NewComputation() + ops.SliceInDim( + ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + start_index=1, + limit_index=2, + stride=1, + dimno=1) + self._ExecuteAndCompareExact(c, expected=[[[2], [5], [8]]]) + ops.SliceInDim( + ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + start_index=0, + limit_index=3, + stride=2, + dimno=0) + self._ExecuteAndCompareExact(c, expected=[[[1, 2, 3], [7, 8, 9]]]) + + def testDynamicSlice(self): + c = self._NewComputation() + ops.DynamicSlice( + ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), [ + ops.Constant(c, NumpyArrayS32(1)), + ops.Constant(c, NumpyArrayS32(0)) + ], [2, 2]) + self._ExecuteAndCompareExact(c, expected=[[[4, 5], [7, 8]]]) + + def testDynamicUpdateSlice(self): + c = self._NewComputation() + ops.DynamicUpdateSlice( + ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + ops.Constant(c, NumpyArrayS32([[1, 2], [3, 4]])), [ + ops.Constant(c, NumpyArrayS32(1)), + ops.Constant(c, NumpyArrayS32(1)) + ]) + self._ExecuteAndCompareExact( + c, expected=[[[1, 2, 3], [4, 1, 2], [7, 3, 4]]]) + + def testTuple(self): + c = self._NewComputation() + ops.Tuple(c, [ + ops.Constant(c, np.int32(42)), + ops.Constant(c, NumpyArrayF32([1.0, 2.0])), + ops.Constant(c, NumpyArrayBool([True, False, False, True])) + ]) + result = execute_with_python_values( + self.backend.compile(xla_computation_to_mlir_module(c.build())), (), + backend=self.backend) + self.assertLen(result, 3) + np.testing.assert_equal(result[0], 42) + np.testing.assert_allclose(result[1], [1.0, 2.0]) + np.testing.assert_equal(result[2], [True, False, False, True]) + + def testGetTupleElement(self): + c = self._NewComputation() + ops.GetTupleElement( + ops.Tuple(c, [ + ops.Constant(c, np.int32(42)), + ops.Constant(c, NumpyArrayF32([1.0, 2.0])), + ops.Constant(c, NumpyArrayBool([True, False, False, True])) + ]), 1) + self._ExecuteAndCompareClose(c, expected=[[1.0, 2.0]]) + + def testBroadcast(self): + c = self._NewComputation() + ops.Broadcast( + ops.Constant(c, NumpyArrayS32([10, 20, 30, 40])), sizes=(3,)) + self._ExecuteAndCompareExact( + c, expected=[[[10, 20, 30, 40], [10, 20, 30, 40], [10, 20, 30, 40]]]) + + def testBroadcastInDim(self): + c = self._NewComputation() + ops.BroadcastInDim(ops.Constant(c, NumpyArrayS32([1, 2])), [2, 2], [0]) + self._ExecuteAndCompareExact(c, expected=[[[1, 1], [2, 2]]]) + ops.BroadcastInDim(ops.Constant(c, NumpyArrayS32([1, 2])), [2, 2], [1]) + self._ExecuteAndCompareExact(c, expected=[[[1, 2], [1, 2]]]) + + def testRngNormal(self): + shape = (2, 3) + c = self._NewComputation() + ops.RngNormal( + ops.Constant(c, NumpyArrayF32(0.)), + ops.Constant(c, NumpyArrayF32(1.)), + shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, + shape)) + result = execute_with_python_values( + self.backend.compile(xla_computation_to_mlir_module(c.build())), (), + backend=self.backend) + # since the result is random, we just check shape and uniqueness + self.assertLen(result, 1) + self.assertEqual(result[0].shape, shape) + self.assertLen(np.unique(result[0]), np.prod(shape)) + + def testRngUniformF32(self): + lo, hi = 2., 4. + shape = (2, 3) + c = self._NewComputation() + ops.RngUniform( + ops.Constant(c, NumpyArrayF32(lo)), + ops.Constant(c, NumpyArrayF32(hi)), + shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, + shape)) + result = execute_with_python_values( + self.backend.compile(xla_computation_to_mlir_module(c.build())), (), + backend=self.backend) + # since the result is random, we just check shape, uniqueness, and range + self.assertLen(result, 1) + self.assertEqual(result[0].shape, shape) + self.assertLen(np.unique(result[0]), np.prod(shape)) + self.assertTrue(np.all(lo <= result[0])) + self.assertTrue(np.all(result[0] < hi)) + + def testRngUniformS32(self): + lo, hi = 2, 4 + shape = (2, 3) + c = self._NewComputation() + ops.RngUniform( + ops.Constant(c, NumpyArrayS32(lo)), + ops.Constant(c, NumpyArrayS32(hi)), + shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.S32, + shape)) + result = execute_with_python_values( + self.backend.compile(xla_computation_to_mlir_module(c.build())), (), + backend=self.backend) + # since the result is random, we just check shape, integrality, and range + self.assertLen(result, 1) + self.assertEqual(result[0].shape, shape) + self.assertEqual(result[0].dtype, np.int32) + self.assertTrue(np.all(lo <= result[0])) + self.assertTrue(np.all(result[0] < hi)) + + def testCholesky(self): + l = np.array([[4, 0, 0, 0], [6, 5, 0, 0], [2, 14, 16, 0], [3, 6, 1, 4]], + dtype=np.float32) + c = self._NewComputation() + ops.Cholesky(ops.Constant(c, np.tril(np.dot(l, l.T)))) + self._ExecuteAndCompareClose(c, expected=[l], rtol=1e-4) + + def testSort(self): + keys = np.array([[2, 4, 1, 3], [3, 1, 4, 2]], dtype=np.float32) + c = self._NewComputation() + ops.Sort(c, [ops.Constant(c, keys)], is_stable=True) + self._ExecuteAndCompareClose( + c, + expected=[np.array([[1, 2, 3, 4], [1, 2, 3, 4]], dtype=np.float32)]) + + def testSortKeyVal(self): + keys = np.array([[2, 4, 1, 3], [3, 1, 4, 2]], dtype=np.float32) + values = np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.int32) + c = self._NewComputation() + ops.Sort(c, (ops.Constant(c, keys), ops.Constant(c, values)), dimension=0) + result = execute_with_python_values( + self.backend.compile(xla_computation_to_mlir_module(c.build())), (), + backend=self.backend) + self.assertLen(result, 2) + np.testing.assert_allclose(result[0], [[2, 1, 1, 2], [3, 4, 4, 3]]) + np.testing.assert_equal(result[1], [[0, 5, 2, 7], [4, 1, 6, 3]]) + + def testSortCustomComparator(self): + b = self._NewComputation("comparator") + p0 = ops.Parameter(b, 0, xla_client.shape_from_pyval(NumpyArrayF32(0))) + q0 = ops.Parameter(b, 1, xla_client.shape_from_pyval(NumpyArrayF32(0))) + p1 = ops.Parameter(b, 2, xla_client.shape_from_pyval(NumpyArrayS32(0))) + q1 = ops.Parameter(b, 3, xla_client.shape_from_pyval(NumpyArrayS32(0))) + ops.Or(ops.Lt(p0, q0), ops.And(ops.Eq(p0, q0), ops.Gt(p1, q1))) + comparator = b.build() + + keys = np.array([[2, 3, 1, 3], [3, 1, 2, 2]], dtype=np.float32) + values = np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.int32) + c = self._NewComputation() + ops.Sort( + c, (ops.Constant(c, keys), ops.Constant(c, values)), + dimension=1, + comparator=comparator) + result = execute_with_python_values( + self.backend.compile(xla_computation_to_mlir_module(c.build())), (), + backend=self.backend) + self.assertLen(result, 2) + np.testing.assert_allclose(result[0], [[1, 2, 3, 3], [1, 2, 2, 3]]) + np.testing.assert_equal(result[1], [[2, 0, 3, 1], [5, 7, 6, 4]]) + + def testQR(self): + a = np.array([[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], + [10, 63, 166, 310]], + dtype=np.float32) + c = self._NewComputation() + ops.Tuple(c, ops.QR(ops.Constant(c, a), full_matrices=True)) + q, r = self._Execute(c, ()) + np.testing.assert_allclose(np.dot(q, r), a, rtol=1e-4) + + def testEigh(self): + a = np.array([[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], + [10, 63, 166, 310]], + dtype=np.float32) + a = (a + a.T) / 2 + + c = self._NewComputation() + ops.Tuple(c, ops.Eigh(ops.Constant(c, a), lower=True)) + # TODO(b/129396575): Turn this test back on when it passes without + # fastmath. + # v, w = self._Execute(c, ()) + # self.assertLess(np.linalg.norm(np.dot(a, v) - w * v), 1e-3) + + def testSVD(self): + a = np.array([[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], + [10, 63, 166, 310]], + dtype=np.float32) + c = self._NewComputation() + ops.Tuple(c, ops.SVD(ops.Constant(c, a))) + u, d, v = self._Execute(c, ()) + self.assertLess(np.linalg.norm(a - np.matmul(u * d, v.T)), 1e-3) + + def testTriangularSolve(self): + a_vals = np.array( + [[2, 0, 0, 0], [3, 6, 0, 0], [4, 7, 9, 0], [5, 8, 10, 11]], + dtype=np.float32) + b_vals = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], + dtype=np.float32) + + c = self._NewComputation() + ops.TriangularSolve( + ops.Constant(c, a_vals), + ops.Constant(c, b_vals), + left_side=False, + lower=True, + transpose_a=ops.TriangularSolveOptions_Transpose.TRANSPOSE, + unit_diagonal=False) + self._ExecuteAndCompareClose( + c, + expected=[ + np.array([ + [0.5, 0.08333334, 0.04629629, 0.03367003], + [2.5, -0.25, -0.1388889, -0.1010101], + [4.5, -0.58333331, -0.32407406, -0.23569024], + ], + dtype=np.float32) + ], + rtol=1e-4) + + def testApproxTopK(self): + if self.backend.platform != "tpu": + self.skipTest("ApproxTopK is only supported on TPU") + k = 10 + qy_size = 256 + db_size = 3000 + feature = 128 + recall_target = 0.95 + b = self._NewComputation() + p0 = ops.Parameter(b, 0, xla_client.shape_from_pyval(NumpyArrayF32(0))) + q0 = ops.Parameter(b, 1, xla_client.shape_from_pyval(NumpyArrayF32(0))) + ops.Parameter(b, 2, xla_client.shape_from_pyval(NumpyArrayS32(0))) + ops.Parameter(b, 3, xla_client.shape_from_pyval(NumpyArrayS32(0))) + ops.Gt(p0, q0) + comparator = b.build() + qy_shape = [qy_size, feature] + db_shape = [feature, db_size] + rng = np.random.RandomState(0) + qy_arg = rng.randn(*qy_shape).astype(np.float32) + db_arg = rng.randn(*db_shape).astype(np.float32) + b = self._NewComputation() + qy = ops.Parameter(b, 0, xla_client.shape_from_pyval(qy_arg)) + db = ops.Parameter(b, 1, xla_client.shape_from_pyval(db_arg)) + scores = ops.Dot(qy, db) + iota = ops.Iota( + b, + xla_client.Shape.array_shape(xla_client.PrimitiveType.S32, + (qy_size, db_size)), 1) + init_val = ops.Constant(b, np.float32(-1)) + init_arg = ops.Constant(b, np.int32(-1)) + ground_truth = ops.TopK(scores, k=k) + approx_topk = ops.ApproxTopK( + b, [scores, iota], [init_val, init_arg], + top_k=k, + reduction_dim=1, + comparator=comparator, + recall_target=recall_target) + ops.Tuple(b, [ + ops.GetTupleElement(ground_truth, 1), + ops.GetTupleElement(approx_topk, 1) + ]) + results = self._Execute(b, [qy_arg, db_arg]) + ground_truth_docids = [set(x) for x in results[0]] + hits = sum( + len([x for x in approx_topk_per_q if x in ground_truth_docids[q]]) + for q, approx_topk_per_q in enumerate(results[1]) + ) + self.assertGreater(hits / (qy_size * k), recall_target) + + def testIsConstant(self): + c = self._NewComputation() + a = ops.Constant(c, np.int32(3)) + b = ops.Constant(c, np.int32(1)) + x = ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayS32(0))) + const_expr = ops.Sub(b, a) + non_const_expr = ops.Mul(const_expr, x) + self.assertTrue(c.is_constant(const_expr)) + self.assertFalse(c.is_constant(non_const_expr)) + + def testGather(self): + a = np.arange(9).astype(np.int32).reshape((3, 3)) + indices = np.array([[[0, 2], [2, 1]], [[1, 2], [2, 0]]], dtype=np.int32) + dnums = xla_client.GatherDimensionNumbers() + dnums.offset_dims.append(1) + dnums.offset_dims.append(2) + dnums.start_index_map.append(0) + dnums.start_index_map.append(1) + dnums.index_vector_dim = 2 + c = self._NewComputation() + ops.Gather( + ops.Constant(c, a), + ops.Constant(c, indices), + dnums, + slice_sizes=[1, 1]) + g, = self._Execute(c, ()) + expected = np.array([[[[2, 7]]], [[[5, 6]]]], dtype=np.int32) + np.testing.assert_allclose(g, expected, rtol=1e-4) + + def testAllGather(self): + a = np.arange(9).astype(np.int32).reshape((3, 3)) + c = self._NewComputation() + ops.AllGather( + operand=ops.Constant(c, a), + all_gather_dimension=0, + shard_count=1, + replica_groups=xla_client.make_replica_groups([[0]]), + use_global_device_ids=False) + [g] = self._Execute(c, ()) + np.testing.assert_equal(g, a) + + def testFft(self): + if self.backend.platform == "tpu": + self.skipTest("TPU only supports 1D FFT") + shape = [2, 3, 4, 5] + rng = np.random.RandomState(0) + a = rng.randn(*shape) + 1.0j * rng.randn(*shape) + a = a.astype(np.complex64) + # FFT + c = self._NewComputation() + ops.Fft(ops.Constant(c, a), xla_client.FftType.FFT, shape[-3:]) + self._ExecuteAndCompareClose( + c, expected=[np.fft.fftn(a, axes=(1, 2, 3))], rtol=1e-4) + # IFFT + c = self._NewComputation() + ops.Fft(ops.Constant(c, a), xla_client.FftType.IFFT, shape[-3:]) + self._ExecuteAndCompareClose( + c, expected=[np.fft.ifftn(a, axes=(1, 2, 3))], rtol=1e-4) + # RFFT + b = rng.randn(*shape).astype(np.float32) + c = self._NewComputation() + ops.Fft(ops.Constant(c, b), xla_client.FftType.RFFT, shape[-3:]) + self._ExecuteAndCompareClose( + c, expected=[np.fft.rfftn(b, axes=(1, 2, 3))], rtol=1e-4) + # IRFFT + c = self._NewComputation() + ops.Fft(ops.Constant(c, a), xla_client.FftType.IRFFT, [3, 4, 8]) + self._ExecuteAndCompareClose( + c, expected=[np.fft.irfftn(a, axes=(1, 2, 3))], rtol=2e-4 + ) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes + fp8_dtypes) + def testNextAfter(self, dtype): + if dtype == np.float64 and self.backend.platform == "tpu": + self.skipTest("TPU doesn't support float64") + if dtype == bfloat16 and self.backend.platform == "tpu": + self.skipTest("b/371119032: Test fails on TPUs with bfloat16") + finfo = ml_dtypes.finfo(dtype) + eps = finfo.eps + c = self._NewComputation() + # Each row is (value, direction, expected), where + # 'nextafter(value, direction)' should be 'expected'. + data = np.array( + [ + [1, 2, 1 + finfo.eps], + [2, 1, 2 - eps], + [-0., 1, finfo.smallest_subnormal], + [0., -1, -finfo.smallest_subnormal], + [-finfo.smallest_subnormal, 1, -0.], + [finfo.smallest_subnormal, 1, 2 * finfo.smallest_subnormal], + [finfo.smallest_subnormal, -1, 0], + ], + dtype=dtype, + ) + + ops.NextAfter(ops.Constant(c, data[:, 0]), ops.Constant(c, data[:, 1])) + out, = self._Execute(c, ()) + np.testing.assert_equal(out, data[:, 2]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testRegularizedIncompleteBeta(self, dtype): + x = np.array([0.53787335, 0.24015466, 0.47494545, 0.13567594, 0.95114538], + dtype=dtype) + a = np.array([0.00753073, 0.34813385, 0.30485708, 1.29298632, 0.51472606], + dtype=dtype) + b = np.array([0.55688389, 0.59794214, 0.42661022, 1.59748339, 0.95047677], + dtype=dtype) + c = self._NewComputation() + ops.RegularizedIncompleteBeta( + ops.Constant(c, a), ops.Constant(c, b), ops.Constant(c, x)) + expected = np.array( + [0.98923271, 0.48575411, 0.57952568, 0.12579775, 0.96989155]) + self._ExecuteAndCompareClose(c, expected=[expected], rtol=2e-2) + + tests.append(SingleOpTest) + + class EmbeddedComputationsTest(ComputationTest): + """Tests for XLA graphs with embedded computations (such as maps).""" + + def _CreateConstantComputation(self, in_dtype, out_dtype): + """Computation (A) -> B that returns a constant 1 for any input.""" + c = self._NewComputation("constant_{}_{}_one".format( + in_dtype.__name__, out_dtype.__name__)) + ops.Parameter( + c, 0, + xla_client.shape_from_pyval(np.array( + 0, dtype=in_dtype)).with_major_to_minor_layout_if_absent()) + ops.Constant(c, out_dtype(1)) + return c.build() + + def _CreateMulBy2Computation(self, dtype): + """Computation (dtype) -> dtype that multiplies its parameter by 2.""" + c = self._NewComputation("mul_f32_by2") + ops.Mul( + ops.Parameter( + c, 0, + xla_client.shape_from_pyval(np.array( + 0, dtype=dtype)).with_major_to_minor_layout_if_absent()), + ops.Constant(c, dtype(2.0))) + return c.build() + + def _CreateMulF32ByParamComputation(self): + """Computation (f32) -> f32 that multiplies one parameter by the other.""" + c = self._NewComputation("mul_f32_by_param") + ops.Mul( + ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0))), + ops.Parameter(c, 1, xla_client.shape_from_pyval(NumpyArrayF32(0)))) + return c.build() + + def _CreateBinaryAddComputation(self, dtype): + """Computation (dtype, dtype) -> dtype that adds its two parameters.""" + c = self._NewComputation("add_param0_by_param1") + shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype)) + shape = shape.with_major_to_minor_layout_if_absent() + ops.Add(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape)) + return c.build() + + def _CreateBinaryGeComputation(self, dtype): + """Computation (dtype, dtype) -> bool that tests param0 >= param1.""" + c = self._NewComputation("param0_lt_param1") + shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype)) + shape = shape.with_major_to_minor_layout_if_absent() + ops.Ge(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape)) + return c.build() + + def _MakeSample3DArray(self, dtype): + return np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]], + [[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]], + dtype=dtype) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testCall(self, dtype): + c = self._NewComputation() + ops.Call( + c, + self._CreateMulBy2Computation(dtype), + operands=(ops.Constant(c, dtype(5.0)),)) + self._ExecuteAndCompareClose(c, expected=[10.0]) + + @parameterized.named_parameters({ + "testcase_name": "_{}_{}".format(in_dtype.__name__, out_dtype.__name__), + "in_dtype": in_dtype, + "out_dtype": out_dtype, + } for in_dtype, out_dtype in [[np.float32, np.int32]]) + def testMapEachElementToConstant(self, in_dtype, out_dtype): + c = self._NewComputation() + ops.Map(c, + [ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=in_dtype))], + self._CreateConstantComputation(in_dtype, out_dtype), [0]) + self._ExecuteAndCompareExact(c, expected=[[1, 1, 1, 1]]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testMapMulBy2(self, dtype): + if dtype == np.float64 and self.backend.platform == "tpu": + self.skipTest("TPU doesn't support float64") + c = self._NewComputation() + ops.Map(c, [ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype))], + self._CreateMulBy2Computation(dtype), [0]) + self._ExecuteAndCompareClose(c, expected=[[2.0, 4.0, 6.0, 8.0]]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testSimpleMapChain(self, dtype): + if dtype == np.float64 and self.backend.platform == "tpu": + self.skipTest("TPU doesn't support float64") + # Chains a map of constant-out with a map of mul-by-2 + c = self._NewComputation() + const = ops.Map( + c, [ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype))], + self._CreateConstantComputation(dtype, dtype), [0]) + ops.Map(c, [const], self._CreateMulBy2Computation(dtype), [0]) + self._ExecuteAndCompareClose(c, expected=[[2.0, 2.0, 2.0, 2.0]]) + + # TODO(b/154752816): bfloat16 crashes in evaluator. + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes if dtype != bfloat16) + def testDivVectorsWithMap(self, dtype): + + def DivComputation(): + c = self._NewComputation("div_param0_by_param1") + shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype)) + ops.Div(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape)) + return c.build() + + c = self._NewComputation() + ops.Map(c, (ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype)), + ops.Constant(c, np.array([5.0, 5.0, 4.0, 4.0], dtype=dtype))), + DivComputation(), [0]) + self._ExecuteAndCompareClose( + c, expected=[[0.2, 0.4, 0.75, 1.0]], rtol=1e-3) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testSelectAndScatter(self, dtype): + if dtype == np.float64 and self.backend.platform == "tpu": + self.skipTest("TPU doesn't support float64") + c = self._NewComputation() + operand = ops.Constant( + c, np.array([[1., 2., 6.], [4., 5., 3.]], dtype=dtype)) + window_dimensions = (2, 1) + window_strides = (1, 2) + padding = xla_client.window_padding_type_to_pad_values( + xla_client.PaddingType.VALID, + c.get_shape(operand).dimensions(), window_dimensions, window_strides) + ops.SelectAndScatterWithGeneralPadding( + operand, + select=self._CreateBinaryGeComputation(dtype), + window_dimensions=window_dimensions, + window_strides=window_strides, + padding=padding, + source=ops.Constant(c, np.array([[0.1, 0.2]], dtype=dtype)), + init_value=ops.Constant(c, np.array(1, dtype=dtype)), + scatter=self._CreateBinaryAddComputation(dtype)) + self._ExecuteAndCompareClose( + c, expected=[[[1., 1., 1.2], [1.1, 1., 1.]]], rtol=5e-3) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testReduce1DtoScalar(self, dtype): + c = self._NewComputation() + ops.Reduce( + c, + operands=[ + ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype)) + ], + init_values=[ops.Constant(c, dtype(0))], + computation=self._CreateBinaryAddComputation(dtype), + dimensions_to_reduce=[0]) + self._ExecuteAndCompareClose(c, expected=[10]) + + # TODO(phawkins): test comparison harness doesn't support bfloat16 + @unittest.skipIf(pjrt_c_api, "b/264473047: hangs") + @parameterized.named_parameters({ + "testcase_name": "_{}_dim{}".format(dtype.__name__, dim), + "dtype": dtype, + "dim": dim, + } for dtype in float_dtypes if dtype != bfloat16 for dim in range(2)) + def testReduce2DTo1D(self, dtype, dim): + input_array = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype) + c = self._NewComputation() + ops.Reduce( + c, + operands=[ops.Constant(c, input_array)], + init_values=[ops.Constant(c, dtype(0))], + computation=self._CreateBinaryAddComputation(dtype), + dimensions_to_reduce=[dim]) + self._ExecuteAndCompareClose(c, expected=[np.sum(input_array, axis=dim)]) + + @unittest.skipIf(pjrt_c_api, "b/264473047: hangs") + @parameterized.named_parameters({ + "testcase_name": "_{}_dims[{}]".format(dtype.__name__, dims), + "dtype": dtype, + "dims": tuple(dims) + } for dtype in float_dtypes for dims in itertools.permutations(range(3))) + def testReduce3DAllPossibleWaysF32(self, dtype, dims): + input_array = self._MakeSample3DArray(dtype) + c = self._NewComputation() + ops.Reduce( + c, + operands=[ops.Constant(c, input_array)], + init_values=[ops.Constant(c, dtype(0))], + computation=self._CreateBinaryAddComputation(dtype), + dimensions_to_reduce=dims) + self._ExecuteAndCompareClose(c, expected=[np.sum(input_array, axis=dims)]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testReduceWindowValidUnitStrides(self, dtype): + if dtype == np.float64 and self.backend.platform == "tpu": + self.skipTest("TPU doesn't support float64") + input_array = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype) + c = self._NewComputation() + window_dimensions = (2, 1) + window_strides = (1, 1) + padding = xla_client.window_padding_type_to_pad_values( + xla_client.PaddingType.VALID, input_array.shape, window_dimensions, + window_strides) + ops.ReduceWindowWithGeneralPadding( + operand=ops.Constant(c, input_array), + init_value=ops.Constant(c, dtype(0)), + computation=self._CreateBinaryAddComputation(dtype), + window_dimensions=window_dimensions, + window_strides=window_strides, + base_dilations=[], + window_dilations=[], + padding=padding) + self._ExecuteAndCompareClose(c, expected=[[[5., 7., 9.]]]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testReduceWindowSameUnitStrides(self, dtype): + if dtype == np.float64 and self.backend.platform == "tpu": + self.skipTest("TPU doesn't support float64") + input_array = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype) + c = self._NewComputation() + window_dimensions = (2, 1) + window_strides = (1, 1) + padding = xla_client.window_padding_type_to_pad_values( + xla_client.PaddingType.SAME, input_array.shape, window_dimensions, + window_strides) + ops.ReduceWindowWithGeneralPadding( + operand=ops.Constant(c, input_array), + init_value=ops.Constant(c, dtype(0)), + computation=self._CreateBinaryAddComputation(dtype), + window_dimensions=window_dimensions, + window_strides=window_strides, + base_dilations=[], + window_dilations=[], + padding=padding) + self._ExecuteAndCompareClose(c, expected=[[[5., 7., 9.], [4., 5., 6.]]]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testReduceWindowValidGeneralStrides(self, dtype): + if dtype == np.float64 and self.backend.platform == "tpu": + self.skipTest("TPU doesn't support float64") + input_array = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype) + c = self._NewComputation() + window_dimensions = (2, 1) + window_strides = (1, 2) + padding = xla_client.window_padding_type_to_pad_values( + xla_client.PaddingType.VALID, input_array.shape, window_dimensions, + window_strides) + ops.ReduceWindowWithGeneralPadding( + operand=ops.Constant(c, input_array), + init_value=ops.Constant(c, dtype(0)), + computation=self._CreateBinaryAddComputation(dtype), + window_dimensions=window_dimensions, + window_strides=window_strides, + base_dilations=[], + window_dilations=[], + padding=padding) + self._ExecuteAndCompareClose(c, expected=[[[5., 9.]]]) + + @unittest.skipIf(pjrt_c_api, "b/264473047: hangs") + def testReduceWindowVariadic(self): + c = self._NewComputation("reducer") + shape = xla_client.shape_from_pyval(np.array(0, dtype=np.int32)) + shape = shape.with_major_to_minor_layout_if_absent() + ps = [ops.Parameter(c, i, shape) for i in range(4)] + which = ops.Ge(ps[0], ps[2]) + ops.Tuple( + c, [ops.Select(which, ps[0], ps[2]), + ops.Select(which, ps[1], ps[3])]) + reducer = c.build() + + key_array = np.array([[1, 5, 6], [4, 2, 3]], dtype=np.int32) + val_array = np.array([[7, 8, 9], [10, 11, 12]], dtype=np.int32) + c = self._NewComputation() + window_dimensions = (2, 1) + window_strides = (1, 1) + padding = xla_client.window_padding_type_to_pad_values( + xla_client.PaddingType.VALID, key_array.shape, window_dimensions, + window_strides) + ops.ReduceWindowWithGeneralPadding( + operands=[ops.Constant(c, key_array), + ops.Constant(c, val_array)], + init_values=[ + ops.Constant(c, np.int32(0)), + ops.Constant(c, np.int32(0)) + ], + computation=reducer, + window_dimensions=window_dimensions, + window_strides=window_strides, + base_dilations=[], + window_dilations=[], + padding=padding) + self._ExecuteAndCompareClose(c, expected=[[[4, 5, 6]], [[10, 8, 9]]]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testWhile(self, dtype): + + def LessThan10Cond(): + c = self._NewComputation("test_lt_10") + shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype)) + ops.Lt(ops.Parameter(c, 0, shape), ops.Constant(c, dtype(10.))) + return c.build() + + cond = LessThan10Cond() + body = self._CreateMulBy2Computation(dtype) + c = self._NewComputation() + init = ops.Constant(c, dtype(1.)) + ops.While(cond, body, init) + self._ExecuteAndCompareClose(c, expected=[16.]) + + def testConditionalTrue(self): + c = self._NewComputation() + pred = ops.Constant(c, np.bool_(True)) + true_operand = ops.Constant(c, np.float32(3.)) + true_computation = self._CreateMulBy2Computation(np.float32) + false_operand = ops.Constant(c, np.float32(2.)) + false_computation = self._CreateConstantComputation( + np.float32, np.float32) + ops.Conditional(pred, true_operand, true_computation, false_operand, + false_computation) + self._ExecuteAndCompareClose(c, expected=[6.]) + + def testConditionalFalse(self): + c = self._NewComputation() + pred = ops.Constant(c, np.bool_(False)) + true_operand = ops.Constant(c, np.float32(3.)) + true_computation = self._CreateMulBy2Computation(np.float32) + false_operand = ops.Constant(c, np.float32(2.)) + false_computation = self._CreateConstantComputation( + np.float32, np.float32) + ops.Conditional(pred, true_operand, true_computation, false_operand, + false_computation) + self._ExecuteAndCompareClose(c, expected=[1.]) + + @unittest.skipIf(cloud_tpu or pathways or pathways_ifrt or pjrt_c_api, + "not implemented") + def testInfeedS32Values(self): + to_infeed = NumpyArrayS32([1, 2, 3, 4]) + c = self._NewComputation() + ops.GetTupleElement( + ops.InfeedWithToken( + ops.CreateToken(c), + xla_client.shape_from_pyval( + to_infeed[0]).with_major_to_minor_layout_if_absent()), 0) + compiled_c = self.backend.compile( + xla_computation_to_mlir_module(c.build())) + device = self.backend.local_devices()[0] + for item in to_infeed: + device.transfer_to_infeed(item) + + for item in to_infeed: + result, = execute_with_python_values( + compiled_c, (), backend=self.backend) + self.assertEqual(result, item) + + @unittest.skipIf(cloud_tpu or pathways or pathways_ifrt or pjrt_c_api, + "not implemented") + def testInfeedTuple(self): + to_infeed = (NumpyArrayS32([1, 2, 3, 4]), NumpyArrayS32([[7], [8]])) + c = self._NewComputation() + ops.GetTupleElement( + ops.InfeedWithToken( + ops.CreateToken(c), + xla_client.shape_from_pyval( + to_infeed).with_major_to_minor_layout_if_absent()), 0) + compiled_c = self.backend.compile( + xla_computation_to_mlir_module(c.build())) + device = self.backend.local_devices()[0] + device.transfer_to_infeed(to_infeed) + + result = execute_with_python_values( + compiled_c, (), backend=self.backend) + self.assertLen(result, 2) + np.testing.assert_equal(result[0], to_infeed[0]) + np.testing.assert_equal(result[1], to_infeed[1]) + + @unittest.skipIf(cloud_tpu or pathways or pathways_ifrt or pjrt_c_api, + "not implemented") + def testInfeedThenOutfeedS32(self): + to_round_trip = NumpyArrayS32([1, 2, 3, 4]) + c = self._NewComputation() + x_and_token = ops.InfeedWithToken( + ops.CreateToken(c), + xla_client.shape_from_pyval( + to_round_trip[0]).with_major_to_minor_layout_if_absent()) + x = ops.GetTupleElement(x_and_token, 0) + token = ops.GetTupleElement(x_and_token, 1) + outfeed_shape = xla_client.shape_from_pyval( + to_round_trip[0]).with_major_to_minor_layout_if_absent() + ops.OutfeedWithToken(x, token, outfeed_shape) + ops.Tuple(c, ()) + + compiled_c = self.backend.compile( + xla_computation_to_mlir_module(c.build())) + device = self.backend.local_devices()[0] + + for want in to_round_trip: + execution = threading.Thread(target=lambda: compiled_c.execute([])) + execution.start() + device.transfer_to_infeed(want) + got = device.transfer_from_outfeed(outfeed_shape) + execution.join() + self.assertEqual(want, got) + + def testScatter(self): + a = np.arange(9).astype(np.int32).reshape((3, 3)) + scatter_indices = np.array([0, 2], dtype=np.int32) + updates = np.array([[10, 20, 30], [70, 80, 90]], dtype=np.int32) + + dnums = xla_client.ScatterDimensionNumbers() + dnums.update_window_dims.append(1) + dnums.inserted_window_dims.append(0) + dnums.scatter_dims_to_operand_dims.append(0) + dnums.index_vector_dim = 1 + + c = self._NewComputation() + ops.Scatter( + ops.Constant(c, a), ops.Constant(c, scatter_indices), + ops.Constant(c, updates), self._CreateBinaryAddComputation(np.int32), + dnums) + expected = np.array([[10, 21, 32], [3, 4, 5], [76, 87, 98]], + dtype=np.int32) + self._ExecuteAndCompareClose(c, expected=[expected]) + + class DeviceTest(ComputationTest): + + def testDevices(self): + self.assertNotEmpty(self.backend.devices()) + + def testLocalDevices(self): + self.assertNotEmpty(self.backend.local_devices()) + if self.backend.platform == "cpu": + self.assertLen(self.backend.local_devices(), 2) + + def testGetAllDevices(self): + # TODO(hyeontaek): Remove this method once we have a unified API for + # enumerating devices with different criteria. + self.assertNotEmpty(self.backend._get_all_devices()) # pylint: disable=protected-access + + def testPlatform(self): + for device in self.backend.local_devices(): + self.assertEqual(device.platform, self.backend.platform) + + def testCoreCount(self): + if self.backend.platform != "gpu": + self.skipTest("core_count is only supported on GPU") + for device in self.backend.local_devices(): + self.assertGreater(device.core_count, 0) + + def testLocalHardwareId(self): + for device in self.backend.devices(): + local_hardware_id = device.local_hardware_id + if local_hardware_id is not None: + self.assertGreaterEqual(local_hardware_id, 0) + + @unittest.skipIf(pathways_ifrt, "not implemented") + def testLocalDeviceFromLocalHardwareId(self): + for device in self.backend.local_devices(): + if device.local_hardware_id is not None: + lookup_device = self.backend.device_from_local_hardware_id( + device.local_hardware_id) + self.assertEqual(lookup_device, device) + + @unittest.skipIf(pathways, "not implemented") + @unittest.skipIf(pathways_ifrt, "not implemented") + def testMemoryStats(self): + for device in self.backend.local_devices(): + stats = device.memory_stats() + if ( + self.backend.platform != "tpu" or not tfrt_tpu + ) and self.backend.platform not in ("gpu", "cuda", "rocm"): + self.assertIsNone(stats) + else: + self.assertIsNotNone(stats) + # Spot check a few fields + self.assertEqual(type(stats["num_allocs"]), int) + self.assertGreaterEqual(stats["num_allocs"], 0) + self.assertEqual(type(stats["bytes_in_use"]), int) + self.assertGreaterEqual(stats["bytes_in_use"], 0) + self.assertEqual(type(stats["peak_bytes_in_use"]), int) + self.assertGreaterEqual(stats["peak_bytes_in_use"], 0) + self.assertEqual(type(stats["largest_alloc_size"]), int) + self.assertGreaterEqual(stats["largest_alloc_size"], 0) + + @unittest.skipIf(pathways, "not implemented") + def testMemory(self): + for device in self.backend.local_devices(): + for memory in device.addressable_memories(): + self.assertEqual(memory.process_index, device.process_index) + self.assertEqual(memory.platform, device.platform) + self.assertIn(device, memory.addressable_by_devices()) + self.assertEqual(memory, device.memory(memory.kind)) + + tests.append(DeviceTest) + + class ErrorTest(ComputationTest): + + def setUp(self): + super(ErrorTest, self).setUp() + self.f32_scalar_2 = NumpyArrayF32(2.0) + self.s32_scalar_2 = NumpyArrayS32(2) + + def testCompileWithWrongElementTypeInLayout(self): + c = self._NewComputation() + c.set_op_metadata(xla_client.CurrentSourceInfoMetadata()) + ops.Parameter(c, 0, xla_client.shape_from_pyval(self.s32_scalar_2)) + c.clear_op_metadata() + + options = xla_client.CompileOptions() + options.argument_layouts = [ + xla_client.Shape.array_shape(np.dtype(np.float32), []) + ] + + def TestFun(): + return self.backend.compile(c.build(), compile_options=options) + + self.assertRaisesRegex( + RuntimeError, r".*Invalid argument shape.*" + r"expected s32\[\], got f32\[\].*", TestFun) + + def testInvokeWithWrongElementType(self): + c = self._NewComputation() + c.set_op_metadata(xla_client.CurrentSourceInfoMetadata()) + ops.Parameter(c, 0, xla_client.shape_from_pyval(self.s32_scalar_2)) + c.clear_op_metadata() + + def TestFun(): + return execute_with_python_values( + self.backend.compile(xla_computation_to_mlir_module(c.build())), + [self.f32_scalar_2], self.backend) + + self.assertRaisesRegex( + RuntimeError, r"Invalid argument: Argument does not match.*" + r"want s32\[\], got f32\[\].*", TestFun) + + tests.append(EmbeddedComputationsTest) + + class ComputationRootTest(ComputationTest): + """Tests related to setting the root of the computation.""" + + def testComputationRootDifferentFromLastOp(self): + c = self._NewComputation() + x = ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(2.0))) + result = ops.Add(x, ops.Constant(c, np.float32(3.14))) + ops.Add(result, ops.Constant(c, np.float32(1.618))) + + arg = NumpyArrayF32(1.0) + compiled_c = self.backend.compile( + xla_computation_to_mlir_module(c.build(result))) + ans, = execute_with_python_values( + compiled_c, [arg], backend=self.backend) + np.testing.assert_allclose(ans, 4.14) + + tests.append(ComputationRootTest) + + class SetShardingTest(ComputationTest): + """Tests related to set OpSharding.""" + + def testSetSharding(self): + c = self._NewComputation() + sharding = xla_client.OpSharding() + sharding.type = xla_client.OpSharding.Type.REPLICATED + sharding.tile_assignment_dimensions = [1] + sharding.tile_assignment_devices = [0] + c.set_sharding(sharding) + x = ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(2.0))) + c.clear_sharding() + + result = ops.Add(x, ops.Constant(c, np.float32(3.14))) + ops.Add(result, ops.Constant(c, np.float32(1.618))) + arg = NumpyArrayF32(1.0) + compiled_c = self.backend.compile( + xla_computation_to_mlir_module(c.build(result))) + ans, = execute_with_python_values( + compiled_c, [arg], backend=self.backend) + np.testing.assert_allclose(ans, 4.14) + + tests.append(SetShardingTest) + + testcase_shapes = [ + (), + (1,), + (2, 3), + (2, 0), + (0, 7), + (4, 1, 2), + (2, 1, 3), + (2, 4, 1), + (3, 1), + (1, 3), + ] + + def FormatShapeAndDtype(shape, dtype): + return "_{}[{}]".format(np.dtype(dtype).name, ",".join(map(str, shape))) + + class DLPackTest(parameterized.TestCase): + + def setUp(self): + super(DLPackTest, self).setUp() + self.backend = xla_backend() + if self.backend.platform not in ("cpu", "gpu", "cuda", "rocm"): + self.skipTest("DLPack requires CPU or GPU") + self.cpu_backend = ( + self.backend + if self.backend.platform == "cpu" else xla_client.make_cpu_client()) + self.gpu_backend = ( + self.backend + if self.backend.platform in ("gpu", "cuda", "rocm") + else None + ) + + def tearDown(self): + super().tearDown() + del self.backend + del self.cpu_backend + del self.gpu_backend + + @classmethod + def _GetStreamFromDevice(cls, device): + try: + return device.get_stream_for_external_ready_events() + except xla_client.XlaRuntimeError as err: # type: ignore + if "UNIMPLEMENTED" in str(err): + return None + else: + raise + + def _DLPackManagedTensorToBuffer( + self, tensor, use_legacy_api, backend=None + ): + if use_legacy_api: + return xla_client._xla.dlpack_managed_tensor_to_buffer( + tensor, self.cpu_backend, self.gpu_backend + ) + else: + if not backend: + backend = self.backend + device = backend.local_devices()[0] + stream = DLPackTest._GetStreamFromDevice(device) + return xla_client._xla.dlpack_managed_tensor_to_buffer( + tensor, device, stream + ) + + # pylint: disable=g-complex-comprehension + # pyformat: disable + @parameterized.named_parameters( + { + "testcase_name": "{}_gpu={}{}".format( + FormatShapeAndDtype(shape, dtype), + gpu, + "_legacy" if use_legacy_api else "", + ), + "dtype": dtype, + "shape": shape, + "gpu": gpu, + "use_legacy_api": use_legacy_api, + } + for dtype in dlpack_dtypes + for shape in testcase_shapes + for gpu in [False, True] + for use_legacy_api in [False, True] + ) + # pyformat: enable + def testRoundTrip(self, dtype, shape, gpu, use_legacy_api): + if gpu and self.gpu_backend is None: + raise unittest.SkipTest("Test not running with GPU support") + backend = self.gpu_backend if gpu else self.cpu_backend + if dtype == np.bool_: + x = np.random.randint(0, 2, size=shape).astype(np.bool_) + else: + x = np.array(np.random.rand(*shape) * 100, dtype=dtype) + buffer = backend.buffer_from_pyval(x) + dlt = xla_client._xla.buffer_to_dlpack_managed_tensor(buffer) + del buffer # Free "buffer" to make sure dlt retains ownership. + self.assertEqual(type(dlt).__name__, "PyCapsule") + y = self._DLPackManagedTensorToBuffer(dlt, use_legacy_api, backend) + np.testing.assert_array_equal( + x.astype(np.uint8) if dtype == np.bool_ else x, np.asarray(y)) + + @parameterized.named_parameters( + { + "testcase_name": "{}".format("_legacy" if use_legacy_api else ""), + "use_legacy_api": use_legacy_api, + } + for use_legacy_api in [False, True] + ) + def testTensorsCanBeConsumedOnceOnly(self, use_legacy_api): + x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32) + buffer = self.backend.buffer_from_pyval(x) + dlt = xla_client._xla.buffer_to_dlpack_managed_tensor(buffer) + + def ConsumeDLPackTensor(): + _ = self._DLPackManagedTensorToBuffer(dlt, use_legacy_api) + + ConsumeDLPackTensor() + self.assertRaisesRegex( + RuntimeError, ".*a DLPack tensor may be consumed at most once.*", + ConsumeDLPackTensor) + + @parameterized.named_parameters( + { + "testcase_name": "{}".format("_legacy" if use_legacy_api else ""), + "use_legacy_api": use_legacy_api, + } + for use_legacy_api in [False, True] + ) + def testNonOwnedDlpackCanBeViewedTwice(self, use_legacy_api): + x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32) + buffer = self.backend.buffer_from_pyval(x) + d1 = xla_client._xla.buffer_to_dlpack_managed_tensor(buffer) + d2 = xla_client._xla.buffer_to_dlpack_managed_tensor(buffer) + + y = self._DLPackManagedTensorToBuffer(d1, use_legacy_api) + z = self._DLPackManagedTensorToBuffer(d2, use_legacy_api) + del d1, d2 + np.testing.assert_array_equal(x, np.asarray(buffer)) + np.testing.assert_array_equal(x, np.asarray(y)) + np.testing.assert_array_equal(x, np.asarray(z)) + + @parameterized.parameters(False, True) + def testZeroCopyOnAlignedDlpackTensor(self, use_legacy_api): + # Using CPU only, since this test is about CPU memory alignment. + if self.backend.platform != "cpu": + self.skipTest("Test requires CPU") + + # Create a numpy array that is aligned to XLA requirements. + x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32) + x = _Aligned(x) + + # Convert it to a DLPack tensor, and then to an XLA buffer. + dlpack_tensor = x.__dlpack__() + buffer = self._DLPackManagedTensorToBuffer(dlpack_tensor, use_legacy_api) + y = np.array(buffer, copy=False) + + # The input was sufficiently aligned, so input and output should alias. + x_ptr = x.__array_interface__["data"][0] + y_ptr = y.__array_interface__["data"][0] + self.assertEqual( + x_ptr, + y_ptr, + msg=f"Buffers are not aliased ({hex(x_ptr)} != {hex(y_ptr)}).", + ) + + @parameterized.named_parameters( + { + "testcase_name": "{}{}".format( + "_legacy" if use_legacy_api else "", + "_transpose" if transpose else "", + ), + "use_legacy_api": use_legacy_api, + "transpose": transpose, + } + for use_legacy_api in [False, True] + for transpose in [False, True] + ) + def testReturnCopyOnUnalignedDlpackTensor(self, use_legacy_api, transpose): + # Using CPU only, since this test is about CPU memory alignment. + if self.backend.platform != "cpu": + self.skipTest("Test requires CPU") + + if transpose and use_legacy_api: + self.skipTest("Non-default layout is not supported in legacy API") + + # Create a numpy array that is not aligned to XLA requirements. XLA's + # alignment requirements differ for different hardware, so we use the + # smallest possible value. If we make sure the buffer is not aligned to + # this value (16 bytes), then it is also not aligned to its multiples (32, + # 64 etc.) + x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32) + x = _Unaligned(x, alignment=_XLA_CPU_MIN_ALIGNMENT) + + # Transpose the array to test non-default layout with trivial striding. + if transpose: + x = x.transpose((0, 2, 1, 3)) + + # Convert it to a DLPack tensor, and then to an XLA buffer. + dlpack_tensor = x.__dlpack__() + buffer = self._DLPackManagedTensorToBuffer(dlpack_tensor, use_legacy_api) + y = np.array(buffer, copy=False) + + # The input was not sufficiently aligned, so input and output should not + # alias (output should be a copy of input, and it should be aligned). + x_ptr = x.__array_interface__["data"][0] + y_ptr = y.__array_interface__["data"][0] + self.assertNotEqual( + x_ptr, + y_ptr, + msg=( + f"Buffers aliased, but should not be ({hex(x_ptr)} ==" + f" {hex(y_ptr)})" + ), + ) + self.assertEqual( + y_ptr % _XLA_CPU_MIN_ALIGNMENT, + 0, + msg="Output buffer not aligned: {hex(y_ptr)}", + ) + np.testing.assert_array_equal(y, x) + + tests.append(DLPackTest) + + class BufferProtocolTest(parameterized.TestCase): + + def setUp(self): + super(BufferProtocolTest, self).setUp() + self.backend = xla_backend() + if self.backend.platform != "cpu": + self.skipTest("Test requires CPU") + + # pylint: disable=g-complex-comprehension + @parameterized.named_parameters({ + "testcase_name": FormatShapeAndDtype(shape, dtype), + "dtype": dtype, + "shape": shape + } for dtype in standard_dtypes if dtype != bfloat16 + for shape in testcase_shapes) + def testRoundTrip(self, dtype, shape): + x = np.array(np.random.rand(*shape) * 100, dtype=dtype) + + x = _Aligned(x) + x_ptr = x.__array_interface__["data"][0] + buffer = self.backend.buffer_from_pyval( + x, host_buffer_semantics=xla_client.HostBufferSemantics.ZERO_COPY) + y = np.array(buffer, copy=False) + y_ptr = y.__array_interface__["data"][0] + np.testing.assert_array_equal(x, y) + + # The input was sufficiently aligned, so input and output should alias. + self.assertEqual(x_ptr, y_ptr) + self.assertEqual(y_ptr, buffer.unsafe_buffer_pointer()) + + during_call = xla_client.HostBufferSemantics.IMMUTABLE_ONLY_DURING_CALL + buffer2 = self.backend.buffer_from_pyval( + x, host_buffer_semantics=during_call) + z = np.array(buffer2, copy=False) + self.assertNotEqual(x.__array_interface__["data"][0], + z.__array_interface__["data"][0]) + + def testDeleteWithActiveView(self): + x = np.random.randn(20, 10) + buffer = self.backend.buffer_from_pyval(x) + buffer_ptr = buffer.unsafe_buffer_pointer() + y = np.array(buffer, copy=False) + buffer.delete() + # It is still legal to access `y`; the array view must keep it alive. + np.testing.assert_array_equal(x, y) + self.assertEqual(y.__array_interface__["data"][0], buffer_ptr) + + tests.append(BufferProtocolTest) + + class TracebackTest(absltest.TestCase): + + def setUp(self): + super(TracebackTest, self).setUp() + self.backend = xla_backend() + + def testNoTracebacksIfDisabled(self): + with xla_client.tracebacks(enabled=False): + self.assertEqual(None, xla_client.Traceback.get_traceback()) + buffer = self.backend.buffer_from_pyval(np.array(7, np.int32)) + self.assertEqual(None, buffer.traceback) + + b = xla_client.XlaBuilder("computation") + ops.Add(ops.Constant(b, np.int32(1)), ops.Constant(b, np.int32(2))) + e = self.backend.compile(xla_computation_to_mlir_module(b.build())) + self.assertEqual(None, e.traceback) + + def assertIsTracebackContaining(self, tb, function): + self.assertIsInstance(tb, xla_client.Traceback) + self.assertIn(function, str(tb)) + self.assertTrue(any(f.function_name == function for f in tb.frames)) + + def testTracebacks(self): + with xla_client.tracebacks(enabled=True): + tb = xla_client.Traceback.get_traceback() + self.assertIsTracebackContaining(tb, "testTracebacks") + + # Tracebacks are not implemented on the TPU driver extension's variant + # of buffers and executables. + if not isinstance(self.backend, xla_client.Client): + return + + buffer = self.backend.buffer_from_pyval(np.array(7, np.int32)) + self.assertIsTracebackContaining(buffer.traceback, "testTracebacks") + + b = xla_client.XlaBuilder("computation") + ops.Add(ops.Constant(b, np.int32(1)), ops.Constant(b, np.int32(2))) + e = self.backend.compile(xla_computation_to_mlir_module(b.build())) + self.assertIsTracebackContaining(e.traceback, "testTracebacks") + + def testNestedFunction(self): + + def AFunction(): + + def AnotherFunction(): + return xla_client.Traceback.get_traceback() + + return AnotherFunction() + + with xla_client.tracebacks(enabled=True): + tb = AFunction() + self.assertIsInstance(tb, xla_client.Traceback) + frames = tb.frames + i = next( + i for (i, f) in enumerate(frames) if f.function_name == "AFunction") + self.assertEqual(frames[i - 1].function_name, "AnotherFunction") + self.assertEqual(frames[i + 1].function_name, "testNestedFunction") + + def testPythonTracebackHasCorrectLineNumbers(self): + def B(): + return xla_client.Traceback.get_traceback() + + def A(): + return B() + + tb = A().as_python_traceback() + for frame, lineno in traceback.walk_tb(tb): + if frame.f_code.co_name == "A": + line = A.__code__.co_firstlineno + self.assertBetween(lineno, line, line + 2) + elif frame.f_code.co_name == "B": + line = B.__code__.co_firstlineno + self.assertBetween(lineno, line, line + 2) + + def testAccessingLocalsDoesNotCrash(self): + # https://github.com/google/jax/issues/16027 + tb = xla_client.Traceback.get_traceback() + python_tb = tb.as_python_traceback() + for frame, _ in traceback.walk_tb(python_tb): + _ = frame.f_locals # should not crash + + def testTracebackFromFrames(self): + def FooFn(x): + return x + 1 + + def BarFn(y): + y = y + 1 + y = y + 2 + return y * 2 + + frame_foo = xla_client.Frame( + __file__, + FooFn.__code__.co_name, + FooFn.__code__.co_firstlineno, + FooFn.__code__.co_firstlineno + 1, + ) + frame_bar = xla_client.Frame( + __file__, + BarFn.__code__.co_name, + BarFn.__code__.co_firstlineno, + BarFn.__code__.co_firstlineno + 2, + ) + frames = [frame_foo, frame_bar] + tb = xla_client.Traceback.traceback_from_frames(frames) + + with self.subTest("WalkDoesNotError"): + for frame, _ in traceback.walk_tb(tb): + _ = frame.f_locals # should not crash + + with self.subTest("TracebackCorrectness"): + tb_string = traceback.format_tb(tb) + # The traceback should have the format: + # File , line N in BarFn + # y = y + 2 + # File , line N in FooFn + # return x + 1 + self.assertLen(tb_string, len(frames)) + bar_frame = tb_string[0].split("\n") + self.assertEndsWith(bar_frame[0], "BarFn") + self.assertEqual(bar_frame[1].strip(), "y = y + 2") + foo_frame = tb_string[1].split("\n") + self.assertEndsWith(foo_frame[0], "FooFn") + self.assertEqual(foo_frame[1].strip(), "return x + 1") + + tests.append(TracebackTest) + + class ClientTest(ComputationTest): + + def setUp(self): + super(ClientTest, self).setUp() + self.backend = xla_backend() + + def testPlatformVersion(self): + version = self.backend.platform_version + logging.info("platform_version:\n%s", version) + if self.backend.platform == "cpu": + self.assertEqual(version, "cpu") + elif self.backend.platform in ("gpu", "cuda", "rocm"): + # Following is false if not built with --config=cuda + if version != "": + self.assertTrue( + re.match(r"^cuda \d{4,}$", version), + msg=f"Expected CUDA version string; got {repr(version)}") + elif self.backend.platform == "tpu" and not (pathways or pathways_ifrt): + self.assertIn("tpu", version.lower()) + self.assertIn("cl/", version) + self.assertIn("Built on ", version) + + @unittest.skipIf( + not cloud_tpu and not pjrt_c_api, "PJRT version only exist for plugins" + ) + def testPjRtCApiVersion(self): + self.assertGreaterEqual(self.backend.pjrt_c_api_major_version, 0) + self.assertGreaterEqual(self.backend.pjrt_c_api_minor_version, 0) + + @unittest.skipUnless( + not pjrt_c_api and tfrt_tpu, + "Test that attributes are zero for non-plugin tfrt_tpu", + ) + def testStaticTfrtTpuAttributes(self): + self.assertEqual(self.backend.pjrt_c_api_major_version, 0) + self.assertEqual(self.backend.pjrt_c_api_minor_version, 0) + # CL number is defined as -1 when running as test. + self.assertEqual(self.backend.__getattr__("cl_number"), -1) + + @unittest.skipIf( + cloud_tpu or pjrt_c_api or (not pjrt_c_api and tfrt_tpu), + "PJRT version only exist for plugins", + ) + def testNotExistPjRtCApiVersion(self): + with self.assertRaises(AttributeError): + self.backend.pjrt_c_api_major_version # pylint: disable=pointless-statement + with self.assertRaises(AttributeError): + self.backend.pjrt_c_api_minor_version # pylint: disable=pointless-statement + + @unittest.skipIf(pathways or pathways_ifrt, "has different behavior") + def testPluginProgramDoesNotCompile(self): + program = xla_client.ifrt_programs.make_plugin_program("foobar") + options = xla_client.ifrt_programs.make_plugin_compile_options() + with self.assertRaisesRegex( + xla_client.XlaRuntimeError, "PjRtCompiler requires an HloProgram" + ): + self.backend.compile_ifrt_program(program, options) + + @unittest.skipIf(pathways, "does not work with non-ifrt legacy pathways") + def testHloProgramViaIfrtProgram(self): + c = self._NewComputation() + ops.Iota(c, xla_client.PrimitiveType.F32, 10) + program = xla_client.ifrt_programs.make_hlo_program( + xla_computation_to_mlir_module(c.build()) + ) + options = xla_client.ifrt_programs.make_xla_compile_options( + xla_client.CompileOptions(), [] + ) + + compiled_c = self.backend.compile_ifrt_program(program, options) + results = execute_with_python_values( + compiled_c, arguments=(), backend=self.backend + ) + + self.assertLen(results, 1) + np.testing.assert_equal(results[0], np.arange(10, dtype=np.float32)) + + @unittest.skipIf(cloud_tpu or pathways or pathways_ifrt or tfrt_tpu, + "not implemented") + def testExecutableSerialization(self): + if self.backend.platform != "tpu": + self.skipTest("Test requires tpu platform") + + c = self._NewComputation() + ops.Add( + ops.Constant(c, NumpyArrayS32([1, 2])), + ops.Constant(c, NumpyArrayS32([3, 4]))) + + options = xla_client.CompileOptions() + executable = self.backend.compile( + xla_computation_to_mlir_module(c.build()), options) + self.assertLen(executable.hlo_modules(), 1) + + serialized = self.backend.serialize_executable(executable) + deserialized = self.backend.deserialize_executable(serialized, options) + + expected, = execute_with_python_values(executable, (), self.backend) + actual, = execute_with_python_values(deserialized, (), self.backend) + self.assertTrue(np.all(actual == expected)) + + def testCompileOptionsSerialization(self): + options = xla_client.CompileOptions() + executable_build_options = options.executable_build_options + options.num_replicas = 3 + options.num_partitions = 2 + options.profile_version = 1337 + options.compile_portable_executable = True + executable_build_options.num_replicas = 3 + executable_build_options.num_partitions = 2 + deb_opt = executable_build_options.debug_options + deb_opt.xla_cpu_enable_fast_math = True + deb_opt.xla_test_all_input_layouts = True + deb_opt.xla_gpu_kernel_cache_file = "/foo/bar" + deb_opt.xla_gpu_enable_llvm_module_compilation_parallelism = True + deb_opt.xla_gpu_per_fusion_autotune_cache_dir = "/bar/foo/" + deb_opt.xla_gpu_experimental_autotune_cache_mode = ( + xla_client.AutotuneCacheMode.READ + ) + + b = options.SerializeAsString() + restored = xla_client.CompileOptions.ParseFromString(b) + + for name in ("num_replicas", "num_partitions", "profile_version", + "compile_portable_executable"): + self.assertEqual(getattr(options, name), getattr(restored, name), + msg=name) + + for name in ("num_replicas", "num_partitions"): + self.assertEqual(getattr(options.executable_build_options, name), + getattr(restored.executable_build_options, name), + msg=name) + + for name in ( + "xla_cpu_enable_fast_math", + "xla_test_all_input_layouts", + "xla_gpu_kernel_cache_file", + "xla_gpu_enable_llvm_module_compilation_parallelism", + "xla_gpu_per_fusion_autotune_cache_dir", + "xla_gpu_experimental_autotune_cache_mode", + ): + self.assertEqual( + getattr(options.executable_build_options.debug_options, name), + getattr(restored.executable_build_options.debug_options, name), + msg=name) + + tests.append(ClientTest) + + # TODO(b/182461453): Add TFRT and cloud TPU implementation of + # ReadDynamicShapes + @unittest.skip("Test fails HLO -> MHLO conversion") + class DynamicReshapeTest(ComputationTest): + """Tests related to DynamicReshape.""" + + def _CompareToPyAndBufferProtocol(self, builder, args, expected_results, + test_fn): + compiled = self.backend.compile( + xla_computation_to_mlir_module(builder.build())) + output_buffers = compiled.execute([ + self.backend.buffer_from_pyval( + arg, device=compiled.local_devices()[0]) for arg in args + ]) + self.assertLen(output_buffers, len(expected_results)) + for buf, expected in zip(output_buffers, expected_results): + to_py_result = np.asarray(buf) + self.assertEqual(expected.shape, to_py_result.shape) + test_fn(expected, to_py_result) + if self.backend.platform == "cpu" and buf.dtype != bfloat16: + mview = memoryview(buf) + self.assertEqual(expected.shape, mview.shape) + test_fn(expected, np.asarray(mview)) + else: + # Buffer protocol expected to fail on non-cpu platforms and bfloat16 + # Note that np.asarray(buf) doesn't throw an exception. To test if the + # error was thrown properly we must use memoryview(buf). + with self.assertRaises(BufferError): + memoryview(buf) + + # 1D reshape of full size, half size, and size of 0. + @unittest.skip("not implemented") + @parameterized.parameters((5), (3), (0)) + def testReshape1D(self, reshape_size): + full_size = 5 + c = self._NewComputation() + arg = np.array(reshape_size, dtype=np.int32) + expected = np.array(range(reshape_size), dtype=np.int32) + p = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg)) + ops.DynamicReshape( + ops.Constant(c, NumpyArrayS32(range(full_size))), [p], [full_size], + [True]) + self._CompareToPyAndBufferProtocol(c, [arg], [expected], + np.testing.assert_equal) + + # 2D reshape with an slice on the minor dimension. We test different types + # where the strides may differ between the host and devices. The reshaped + # physical memory layout is not consecutive, and we test if the program can + # return the correct logical view of the data. + @unittest.skipIf( + cloud_tpu or pathways or tfrt_tpu or pjrt_c_api, + "not implemented") + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in int_dtypes + float_dtypes) + def testReshape2D(self, dtype): + arg0 = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype) + arg1 = np.array(2, dtype=np.int32) + expected = np.array([[1, 2], [4, 5]], dtype=np.int32) + c = self._NewComputation() + p0 = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg0)) + p1 = ops.Parameter(c, 1, xla_client.shape_from_pyval(arg1)) + ops.DynamicReshape(p0, [p1, p1], [2, 3], [False, True]) + self._CompareToPyAndBufferProtocol(c, [arg0, arg1], [expected], + np.testing.assert_equal) + + @unittest.skipIf(cloud_tpu or pathways or tfrt_tpu, "not implemented") + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in int_dtypes + float_dtypes) + def testDynamicShapeArgs(self, dtype): + full_size = 10 + dynamic_shape_size = 4 + # subcomputation 1 + binary_add_builder = self._NewComputation() + scalar_shape = xla_client.Shape.scalar_shape(np.dtype(dtype)) + ops.Add( + ops.Parameter(binary_add_builder, 0, scalar_shape), + ops.Parameter(binary_add_builder, 1, scalar_shape)) + # subcomputation 2 + reshape_reduce_builder = self._NewComputation() + dshape = xla_client.Shape.array_shape( + np.dtype(dtype), dims=[full_size], dynamic_dimensions=[True]) + reshape_reduce_p = ops.Parameter(reshape_reduce_builder, 0, dshape) + ops.Reduce( + reshape_reduce_builder, + operands=[reshape_reduce_p], + init_values=[ops.Constant(reshape_reduce_builder, dtype(0))], + computation=binary_add_builder.build(), + dimensions_to_reduce=[0]) + # main computation: sum(range(full_size)[:dynamic_shape_size]) + c = self._NewComputation() + arg = np.array(dynamic_shape_size, dtype=np.int32) + p = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg)) + reshaped = ops.DynamicReshape( + ops.Constant(c, np.array(range(full_size), dtype=dtype)), [p], + [full_size], [True]) + ops.Call(c, reshape_reduce_builder.build(), operands=(reshaped,)) + self._ExecuteAndCompareClose(c, [arg], [dtype(6)]) + + tests.append(DynamicReshapeTest) + + class DeviceAssignmentTest(ComputationTest): + + def testSerialize(self): + shape = (3, 4) + device_assignment = xla_client.DeviceAssignment.create( + np.arange(np.prod(shape)).reshape(*shape)) + self.assertEqual(device_assignment.replica_count(), shape[0]) + self.assertEqual(device_assignment.computation_count(), shape[1]) + serialized = device_assignment.serialize() + self.assertIsInstance(serialized, bytes) + self.assertNotEmpty(serialized) + + tests.append(DeviceAssignmentTest) + + class TokenTest(ComputationTest): + """Tests related to PyToken.""" + + def testExecuteWithToken(self): + c = self._NewComputation() + ops.Mul( + ops.Constant(c, np.array([2.5, 3.3, -1.2, 0.7], np.float32)), + ops.Constant(c, np.array([-1.2, 2, -2, -3], np.float32))) + compiled_c = self.backend.compile( + xla_computation_to_mlir_module(c.build())) + results, token = compiled_c.execute_with_token([]) + token.block_until_ready() + self.assertLen(results, 1) + np.testing.assert_allclose( + np.asarray(results[0]), np.float32([-3, 6.6, 2.4, -2.1]), rtol=3e-3) + + def testExecuteShardedOnLocalDevicesWithTokens(self): + c = self._NewComputation() + ops.Mul( + ops.Constant(c, np.array([2.5, 3.3, -1.2, 0.7], np.float32)), + ops.Constant(c, np.array([-1.2, 2, -2, -3], np.float32))) + num_replicas = 1 + options = xla_client.CompileOptions() + options.num_replicas = num_replicas + compiled_c = self.backend.compile( + xla_computation_to_mlir_module(c.build()), compile_options=options) + results, sharded_token = ( + compiled_c.execute_sharded_on_local_devices_with_tokens([]) + ) + sharded_token.block_until_ready() + self.assertLen(results, 1) + self.assertLen(results[0], 1) + np.testing.assert_allclose( + np.asarray(results[0][0]), + np.float32([-3, 6.6, 2.4, -2.1]), + rtol=3e-3) + + tests.append(TokenTest) + + class ExecutePortableTest(ComputationTest): + + @unittest.skip("Test does not work under IFRT") + def testExecutePortable(self): + devices_by_kind = collections.defaultdict(list) + for device in self.backend.devices(): + devices_by_kind[device.device_kind].append(device) + multi_devices = [d for d in devices_by_kind.values() if len(d) > 1] + if not multi_devices: + raise unittest.SkipTest("Test needs multiple identical devices") + devices = multi_devices[0] + + c = self._NewComputation() + args = [ + np.array(3, dtype=np.int32), + np.array([10, 15, -2, 7], dtype=np.int32) + ] + p0 = ops.Parameter(c, 0, xla_client.shape_from_pyval(args[0])) + p1 = ops.Parameter(c, 1, xla_client.shape_from_pyval(args[1])) + ops.Mul(p0, p1) + options = xla_client.CompileOptions() + options.compile_portable_executable = True + compiled_c = self.backend.compile(c.build(), compile_options=options) + for device in devices: + out, = compiled_c.execute( + [self.backend.buffer_from_pyval(a, device=device) for a in args], + device=device) + np.testing.assert_array_equal(np.asarray(out), args[0] * args[1]) + + tests.append(ExecutePortableTest) + + class ExecuteShardedOverloadTest(ComputationTest): + + def testExecuteShardedOverloadEmptyInput(self): + c = self._NewComputation() + ops.Constant(c, np.array([2.5, 3.3, -1.2, 0.7], np.float32)) + options = xla_client.CompileOptions() + options.num_replicas = 1 + compiled_c = self.backend.compile( + xla_computation_to_mlir_module(c.build()), compile_options=options) + + results = compiled_c.execute_sharded_on_local_devices([]) + self.assertLen(results, 1) + self.assertIsInstance(results[0], list) + self.assertLen(results[0], 1) + results[0][0].block_until_ready() + self.assertIsInstance(results[0][0], xla_client.ArrayImpl) + + results, _ = compiled_c.execute_sharded_on_local_devices_with_tokens([]) + self.assertLen(results, 1) + self.assertIsInstance(results[0], list) + self.assertLen(results[0], 1) + results[0][0].block_until_ready() + self.assertIsInstance(results[0][0], xla_client.ArrayImpl) + + def testExecuteShardedOverloadBufferInput(self): + arg = np.arange(12, dtype=np.int16).reshape(3, 4) + c = self._NewComputation() + ops.Parameter(c, 0, xla_client.shape_from_pyval(arg)) + + options = xla_client.CompileOptions() + options.num_replicas = 1 + compiled_c = self.backend.compile( + xla_computation_to_mlir_module(c.build()), compile_options=options) + + buffer = self.backend.buffer_from_pyval(arg) + + results = compiled_c.execute_sharded_on_local_devices([[buffer]]) + self.assertLen(results, 1) + self.assertIsInstance(results[0], list) + self.assertLen(results[0], 1) + results[0][0].block_until_ready() + self.assertIsInstance(results[0][0], xla_client.ArrayImpl) + + results, _ = compiled_c.execute_sharded_on_local_devices_with_tokens( + [[buffer]]) + self.assertLen(results, 1) + self.assertIsInstance(results[0], list) + self.assertLen(results[0], 1) + results[0][0].block_until_ready() + self.assertIsInstance(results[0][0], xla_client.ArrayImpl) + + tests.append(ExecuteShardedOverloadTest) + + return tests + + +def InstantiateTests(globals_dict, backend_fn, test_prefix="", **kw): + # Avoid creating a new backend per test (this causes GPU OOM, and is probably + # inefficient). + backend_fn = functools.lru_cache(maxsize=None)(backend_fn) + for klass in TestFactory(backend_fn, **kw): + test = type(test_prefix + klass.__name__, (klass,), {}) + # Clean up the qualified names of the tests to not include the test factory. + test.__qualname__ = test.__name__ + globals_dict[test.__name__] = test + + +backends = { + "cpu": functools.partial(xla_client.make_cpu_client, num_devices=2), +} + +if __name__ == "__main__": + flags.DEFINE_string("backend", "cpu", "Target platform.") + jax.config.parse_flags_with_absl() + # pylint: disable=unnecessary-lambda + InstantiateTests(globals(), lambda: backends[FLAGS.backend]()) + # pylint: enable=unnecessary-lambda + absltest.main() diff --git a/pyproject.toml b/pyproject.toml index a1b9e7dd446a..be29e16beb9c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,8 +23,11 @@ module = [ "jax.experimental.jax2tf.tests.back_compat_testdata", "jax.experimental.jax2tf.tests.flax_models", "jax_cuda12_plugin.*", - "jaxlib.*", + "jaxlib.cpu_feature_guard", + "jaxlib.cuda.*", "jaxlib.mlir.*", + "jaxlib.utils", + "jaxlib.xla_extension.utils", "jraph.*", "libtpu.*", "matplotlib.*", From c2e7c3e72d7ef9f6f30ae155f2142089fe1d6e48 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Fri, 21 Mar 2025 07:10:29 -0700 Subject: [PATCH 076/483] [Mosaic GPU] Add a transform inference rule for `memref.subview`. This will be used when lowering from Pallas, in order to handle calls to `memref_slice` in the `_handle_indexing` util. Currently, we only allow propagating a restricted set of transforms (tile and swizzle transforms), and only when they can be passed through the op bidirectionally without modification. PiperOrigin-RevId: 739168839 --- .../mosaic/gpu/transform_inference.py | 91 ++++++++++++++-- tests/mosaic/gpu_transform_inference_test.py | 101 ++++++++++++++++++ 2 files changed, 184 insertions(+), 8 deletions(-) diff --git a/jax/experimental/mosaic/gpu/transform_inference.py b/jax/experimental/mosaic/gpu/transform_inference.py index d285e5df188f..80ab6077755a 100644 --- a/jax/experimental/mosaic/gpu/transform_inference.py +++ b/jax/experimental/mosaic/gpu/transform_inference.py @@ -30,6 +30,7 @@ from jax._src.lib.mlir.dialects import gpu from jax._src.lib.mlir.dialects import memref from jax._src.lib.mlir.dialects import vector +from jax._src.util import safe_zip from . import fragmented_array as fa from . import inference_utils @@ -184,21 +185,20 @@ def _infer_slice_smem_transforms(op: mgpu.SliceSMEMOp) -> OptionalTransforms: out_transforms = inference_utils.in_transforms_for_operand( consumer, op_user ) - if transforms is not None and out_transforms is not None: - if transforms != out_transforms: - raise NotImplementedError( - f"Conflicting transforms for {op_user} in {op}: " - f"{transforms} != {out_transforms}." - ) - elif out_transforms is not None: + if transforms is None: transforms = out_transforms + elif out_transforms is not None and transforms != out_transforms: + raise NotImplementedError( + f"Conflicting transforms for {op_user} in {op}: " + f"{transforms} != {out_transforms}." + ) return None if transforms is None else ([], [transforms]) # TODO(bchetioui,apaszke): this empty rule is necessary while Mosaic doesn't use # the dialect in all cases. -# The rule is necessary in order to handle the lowering of `utils.memref_ptr` +# The rule is necessary in order to handle the lowering of `utils.memref_ptr` # which is used in `_construct_smem_reftree`. @partial(_add_transform_inference_rule, builtin.UnrealizedConversionCastOp) def _infer_unrealized_conversion_cast_transforms( @@ -250,6 +250,81 @@ def _infer_dynamic_smem_transforms( return None +# This is used by Pallas' "_handle_indexing" memory transform. +@partial(_add_transform_inference_rule, memref.SubViewOp) +def _infer_memref_subview_transforms( + op: memref.SubViewOp, +) -> OptionalTransforms: + transforms = None + + for result_use in cast(ir.OpResult, op.result).uses: + consumer = result_use.owner + op_user = consumer.operands[result_use.operand_number] + user_transforms = inference_utils.in_transforms_for_operand( + consumer, op_user + ) + if transforms is None: + transforms = user_transforms + elif user_transforms is not None and transforms != user_transforms: + raise NotImplementedError( + f"Conflicting transforms for {op_user} in {op}: " + f"{transforms} != {user_transforms}." + ) + + in_transforms = inference_utils.value_transforms(op.source) + if transforms is None: + transforms = in_transforms + elif in_transforms is not None and transforms != in_transforms: + raise ValueError( + f"Conflicting transforms for {op.source} in {op}: " + f"{transforms} != {in_transforms}." + ) + + if transforms is None: + return None + + # Here, we have some transforms to propagate one way or the other. For now, + # we implement only the following basic propagation rules: + # - A tile transform can be propagated bidirectionally if the axes being + # tiled are not sliced, and are the logical minor axes of the source. + # - A swizzle transform can be propagated towards the input of a subview if + # the physical minormost dimension is unchanged. + # - We only propagate transforms if they consist of a single tile transform + # and a single swizzle transform. + # TODO(bchetioui): implement more complex propagation rules. + if len(transforms) == 2: + tile_transform, swizzle_transform = transforms + if not ( + mgpu.TileTransformAttr.isinstance(tile_transform) + and mgpu.SwizzleTransformAttr.isinstance(swizzle_transform) + ): + raise NotImplementedError(f"Can't propagate transforms {transforms}.") + else: + raise NotImplementedError(f"Can't propagate transforms {transforms}.") + + # Check swizzle transform propagation. + strides, _ = ir.MemRefType.get_strides_and_offset(op.source.type) + minor_dim = strides.index(min(strides)) + if op.source.type.shape[minor_dim] != op.static_sizes[minor_dim]: + raise NotImplementedError( + "Swizzle transforms can only propagated if the minor dimension is " + "unchanged." + ) + + # Check tile transform propagation. + num_tiled_axes = len(mgpu.TileTransformAttr(tile_transform).tiling) + last_n_dims = op.source.type.shape[-num_tiled_axes:] + last_n_sizes = list(op.static_sizes)[-num_tiled_axes:] + for slice_size, dim_size in safe_zip(last_n_sizes, last_n_dims): + if slice_size != dim_size: + raise NotImplementedError( + "Tile transforms are only propagated if the tiled axes are not " + "sliced." + ) + + return [transforms], [transforms] + + def _should_have_transforms(op: ir.OpView) -> bool: """Returns 'True' if the operation should be assigned in/out transforms.""" return any( diff --git a/tests/mosaic/gpu_transform_inference_test.py b/tests/mosaic/gpu_transform_inference_test.py index b7cd146dfdb6..983efebc4f86 100644 --- a/tests/mosaic/gpu_transform_inference_test.py +++ b/tests/mosaic/gpu_transform_inference_test.py @@ -24,7 +24,9 @@ from jax._src.interpreters import mlir as mlir_interpreter from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith +from jax._src.lib.mlir.dialects import builtin from jax._src.lib.mlir.dialects import func +from jax._src.lib.mlir.dialects import memref from jax._src.lib.mlir.dialects import vector import jax.experimental.mosaic.gpu as mgpu from jax.experimental.mosaic.gpu import fragmented_array as fa @@ -418,6 +420,105 @@ def body(offset): with self.assertRaisesRegex(NotImplementedError, "Conflicting transforms"): mgpu.infer_transforms(self.module) + @parameterized.parameters([False, True]) + def test_infer_transforms_for_subview_op_propagates_undisturbed_tile_and_swizzle_transforms( + self, annotate_input + ): + subview_op = user_op = None + shape = (2, 64, 64) + elt_ty = ir.BF16Type.get() + smem = ir.Attribute.parse("#gpu.address_space") + + in_ref_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem) + out_ref_ty = ir.MemRefType.get(shape[2:], elt_ty, memory_space=smem) + + def body(in_ref): + nonlocal subview_op, user_op + subview_op = memref.SubViewOp( + out_ref_ty, + in_ref, + [], + [], + [], + static_offsets=[1, 0, 0], + static_sizes=[1, 64, 64], + static_strides=[1, 1, 1], + ) + user_op = builtin.UnrealizedConversionCastOp( + [out_ref_ty], [subview_op.result] + ) + + with ir.InsertionPoint(self.module.body): + f = func.FuncOp.from_py_func(in_ref_ty)(body).func_op + + transforms = ir.ArrayAttr.get([ + mgpu.dialect.TileTransformAttr.get((32, 16)), + mgpu.dialect.SwizzleTransformAttr.get(32), + ]) + + if annotate_input: + f.attributes["in_transforms"] = ir.ArrayAttr.get([transforms]) + else: + user_op.attributes["in_transforms"] = ir.ArrayAttr.get([transforms]) + + mgpu.infer_transforms(self.module) + + self.assertSequenceEqual( + inference_utils.in_transforms(subview_op), [transforms] + ) + self.assertSequenceEqual( + inference_utils.out_transforms(subview_op), [transforms] + ) + + @parameterized.parameters([False, True]) + def test_infer_transforms_for_subview_op_raises_on_disturbed_transforms( + self, annotate_input + ): + subview_op = user_op = None + shape = (2, 64, 64) + elt_ty = ir.BF16Type.get() + smem = ir.Attribute.parse("#gpu.address_space") + + in_ref_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem) + out_ref_ty = ir.MemRefType.get((2, 64, 32), elt_ty, memory_space=smem) + + def body(in_ref): + nonlocal subview_op, user_op + subview_op = memref.SubViewOp( + out_ref_ty, + in_ref, + [], + [], + [], + static_offsets = [1, 0, 0], + static_sizes = [2, 64, 32], + static_strides = [1, 1, 1] + ) + user_op = builtin.UnrealizedConversionCastOp( + [out_ref_ty], [subview_op.result] + ) + + with ir.InsertionPoint(self.module.body): + f = func.FuncOp.from_py_func(in_ref_ty)(body).func_op + + transforms = ir.ArrayAttr.get([ + mgpu.dialect.TileTransformAttr.get((32, 16)), + mgpu.dialect.SwizzleTransformAttr.get(32), + ]) + + if annotate_input: + f.attributes["in_transforms"] = ir.ArrayAttr.get([transforms]) + else: + user_op.attributes["in_transforms"] = ir.ArrayAttr.get([transforms]) + + if annotate_input: + f.attributes["in_transforms"] = ir.ArrayAttr.get([transforms]) + else: + user_op.attributes["in_transforms"] = ir.ArrayAttr.get([transforms]) + + with self.assertRaises(NotImplementedError): + mgpu.infer_transforms(self.module) + if __name__ == "__main__": parameterized.absltest.main(testLoader=jtu.JaxTestLoader()) From dac5247cca85df9a8bcac65b7a033038739ebb90 Mon Sep 17 00:00:00 2001 From: Ayaka Date: Fri, 21 Mar 2025 07:26:10 -0700 Subject: [PATCH 077/483] Ensure traceback correctness in error checking PiperOrigin-RevId: 739172653 --- tests/error_check_test.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/error_check_test.py b/tests/error_check_test.py index ad67cadfb074..5bf71a9eb592 100644 --- a/tests/error_check_test.py +++ b/tests/error_check_test.py @@ -13,6 +13,8 @@ # limitations under the License. +import traceback + from absl.testing import absltest from absl.testing import parameterized import jax @@ -108,6 +110,32 @@ def g(x): with self.assertRaisesRegex(JaxValueError, "x must be greater than 0 in g"): error_check.raise_if_error() + @parameterized.product(jit=[True, False]) + def test_error_includes_traceback(self, jit): + def function_that_triggers_error_for_traceback_test(x): + error_check.set_error_if( # This line must be included in the traceback. + x <= 0, "x must be greater than 0" + ) + return x + 1 + + if jit: + function_that_triggers_error_for_traceback_test = jax.jit( + function_that_triggers_error_for_traceback_test + ) + + x = jnp.zeros((4,), dtype=jnp.int32) + function_that_triggers_error_for_traceback_test(x) + + tb_string = "" + try: + error_check.raise_if_error() + except JaxValueError as e: + tb_string = traceback.format_tb(e.__traceback__) + tb_string = "".join(tb_string) + + self.assertIn("function_that_triggers_error_for_traceback_test", tb_string) + self.assertIn("This line must be included in the traceback", tb_string) + @parameterized.product(jit=[True, False]) def test_error_check_works_with_cond(self, jit): def f(x): From be6585d0005340d1f6ef3830bdac64d7e7e52b8c Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Fri, 21 Mar 2025 08:07:51 -0700 Subject: [PATCH 078/483] [pallas] Add support for `DotAlgorithmPreset.BF16_BF16_F32_X3` in Triton lowering. PiperOrigin-RevId: 739183661 --- jax/_src/pallas/triton/lowering.py | 28 ++++++++++++++++++++++++++++ tests/pallas/pallas_test.py | 14 +++++++++++++- 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index f3a8dd175ec1..64bf635a34ed 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -1260,6 +1260,10 @@ def _cmp( ) +def _is_nan(x: ir.Value) -> ir.Value: + return arith_dialect.cmpf(arith_dialect.CmpFPredicate.UNO, x, x) + + _JAX_TO_TRITON_BINARY = { lax.add_p: _add, lax.sub_p: _sub, @@ -2237,6 +2241,7 @@ def _dot_general_lowering( | lax.DotAlgorithmPreset.F16_F16_F32 | lax.DotAlgorithmPreset.BF16_BF16_BF16 | lax.DotAlgorithmPreset.BF16_BF16_F32 + | lax.DotAlgorithmPreset.BF16_BF16_F32_X3 ): input_precision = None case _: @@ -2276,6 +2281,29 @@ def _dot_general_lowering( m, _ = a_type.shape _, n = b_type.shape acc = _full(ir.RankedTensorType.get([m, n], _dtype_to_ir_type(acc_dtype)), 0) + + if precision == lax.DotAlgorithmPreset.BF16_BF16_F32_X3: + bf16 = _dtype_to_ir_type(jnp.bfloat16) + f32 = _dtype_to_ir_type(jnp.float32) + as_bf16 = lambda x: _ir_cast(x, bf16, signed=False) + as_f32 = lambda x: _ir_cast(x, f32, signed=False) + + a_bf16 = as_bf16(a) + b_bf16 = as_bf16(b) + a_err0 = as_bf16(_sub(a, as_f32(a_bf16))) + b_err0 = as_bf16(_sub(b, as_f32(b_bf16))) + # Accumulate the smallest values first to reduce the numeric error. + acc = tt_dialect.dot(a_err0, b_bf16, acc) + acc = tt_dialect.dot(a_bf16, b_err0, acc) + # If `a_err0` will be zero and `b` is infinite, then `acc` may contain + # `NaN`s (as `0 * inf = NaN`), and vice versa. + acc = arith_dialect.select( + _is_nan(acc), + _full(ir.RankedTensorType.get([m, n], _dtype_to_ir_type(acc_dtype)), 0), + acc, + ) + a, b = a_bf16, b_bf16 + acc = tt_dialect.dot(a, b, acc, input_precision=input_precision) return _cast(acc, acc_dtype, out_aval.dtype) diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 745c30ba98cb..0ce68a5c023c 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -702,6 +702,7 @@ def f(x): ("float32", jax.lax.DotAlgorithmPreset.DEFAULT), ("float32", jax.lax.DotAlgorithmPreset.F16_F16_F32), ("float32", jax.lax.DotAlgorithmPreset.BF16_BF16_F32), + ("float32", jax.lax.DotAlgorithmPreset.BF16_BF16_F32_X3), ("float32", jax.lax.DotAlgorithmPreset.TF32_TF32_F32), ("float32", jax.lax.DotAlgorithmPreset.TF32_TF32_F32_X3), ("float32", jax.lax.DotAlgorithmPreset.F32_F32_F32), @@ -731,7 +732,18 @@ def dot_kernel(x_ref, y_ref, o_ref): precision=jax.lax.Precision.HIGHEST, preferred_element_type=jnp.float32, ) - self.assertAllClose(dot_kernel(x, y), expected, atol=5e-2, rtol=5e-3) + if dtype == "bfloat16" or precision in ( + jax.lax.Precision.HIGHEST, jax.lax.DotAlgorithmPreset.F32_F32_F32 + ): + atol = 0 + elif precision in ( + jax.lax.DotAlgorithmPreset.BF16_BF16_F32_X3, + jax.lax.DotAlgorithmPreset.TF32_TF32_F32_X3, + ): + atol = 5e-4 + else: + atol = 5e-2 + self.assertAllClose(dot_kernel(x, y), expected, atol=atol, rtol=atol / 10) @parameterized.parameters(jnp.int8, jnp.uint8) def test_integer_dot(self, dtype): From f1ff64f404c522210b9edd23a6e5f76cf77ef896 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Fri, 21 Mar 2025 08:08:31 -0700 Subject: [PATCH 079/483] [Mosaic GPU][NFC] Factor our transform resolution into a `_resolve_transforms` util. PiperOrigin-RevId: 739183876 --- .../mosaic/gpu/transform_inference.py | 81 +++++++++---------- 1 file changed, 37 insertions(+), 44 deletions(-) diff --git a/jax/experimental/mosaic/gpu/transform_inference.py b/jax/experimental/mosaic/gpu/transform_inference.py index 80ab6077755a..c76af4fb07e2 100644 --- a/jax/experimental/mosaic/gpu/transform_inference.py +++ b/jax/experimental/mosaic/gpu/transform_inference.py @@ -61,6 +61,37 @@ def _set_transform_attributes( op.attributes["out_transforms"] = ir.ArrayAttr.get(out_transforms) +def _resolve_transforms( + transforms: ir.ArrayAttr | None, + other_transforms: ir.ArrayAttr | None, +) -> ir.ArrayAttr | None: + """Resolves two sets of competing transforms to a single compatible set. + + Args: + transforms: one optional set of transforms. + other_transforms: another optional set of transforms. + + Returns: + A single set of transforms that is compatible with both `transforms` and + `other_transforms`, or `None` if both transforms are `None`. + Raises: + NotImplementedError: if the two sets of transforms can't be resolved to a + single set. + """ + if transforms is None: + return other_transforms + + if other_transforms is None: + return transforms + + if transforms != other_transforms: + raise NotImplementedError( + f"Conflicting transforms {transforms} != {other_transforms}." + ) + + return transforms + + def infer_transforms_for_wgmma_ref(ref_ty: ir.MemRefType) -> ir.ArrayAttr: if len(ref_ty.shape) != 2: raise ValueError(f"Expected a 2D memref, got {ref_ty}") @@ -157,21 +188,8 @@ def _infer_vector_load_store_transforms( f"Got layout {layout} which is not yet supported" ) - if transforms is not None and layout_transforms is not None: - if transforms != layout_transforms: - raise NotImplementedError( - f"Conflicting transforms for {op.base} in {op}: " - f"{transforms} != {layout_transforms}." - ) - return [transforms], [] - - if transforms is not None: - return [transforms], [] - - if layout_transforms is not None: - return [layout_transforms], [] - - return None + transforms = _resolve_transforms(transforms, layout_transforms) + return None if transforms is None else ([transforms], []) @partial(_add_transform_inference_rule, mgpu.SliceSMEMOp) @@ -185,13 +203,7 @@ def _infer_slice_smem_transforms(op: mgpu.SliceSMEMOp) -> OptionalTransforms: out_transforms = inference_utils.in_transforms_for_operand( consumer, op_user ) - if transforms is None: - transforms = out_transforms - elif out_transforms is not None and transforms != out_transforms: - raise NotImplementedError( - f"Conflicting transforms for {op_user} in {op}: " - f"{transforms} != {out_transforms}." - ) + transforms = _resolve_transforms(transforms, out_transforms) return None if transforms is None else ([], [transforms]) @@ -227,14 +239,7 @@ def _infer_memref_view_transforms(op: memref.ViewOp) -> OptionalTransforms: out_transforms = inference_utils.in_transforms_for_operand( consumer, op_user ) - if transforms is not None and out_transforms is not None: - if transforms != out_transforms: - raise ValueError( - f"Conflicting transforms for {op_user} in {op}: " - f"{transforms} != {out_transforms}." - ) - elif out_transforms is not None: - transforms = out_transforms + transforms = _resolve_transforms(transforms, out_transforms) # TODO(bchetioui): do we actually need to assign a transform to the input of # the view op? Presumably, it'll only be used to access scratch memory. @@ -263,22 +268,10 @@ def _infer_memref_subview_transforms( user_transforms = inference_utils.in_transforms_for_operand( consumer, op_user ) - if transforms is None: - transforms = user_transforms - elif user_transforms is not None and transforms != user_transforms: - raise NotImplementedError( - f"Conflicting transforms for {op_user} in {op}: " - f"{transforms} != {user_transforms}." - ) + transforms = _resolve_transforms(transforms, user_transforms) in_transforms = inference_utils.value_transforms(op.source) - if transforms is None: - transforms = in_transforms - elif in_transforms is not None and transforms != in_transforms: - raise ValueError( - f"Conflicting transforms for {op.source} in {op}: " - f"{transforms} != {in_transforms}." - ) + transforms = _resolve_transforms(transforms, in_transforms) if transforms is None: return None From 59d25f4642e383f4236fd85dc753c642ef2307aa Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Fri, 21 Mar 2025 08:21:42 -0700 Subject: [PATCH 080/483] [Mosaic GPU] Add transform inference rule for `memref.load`. PiperOrigin-RevId: 739187660 --- jax/experimental/mosaic/gpu/transform_inference.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/jax/experimental/mosaic/gpu/transform_inference.py b/jax/experimental/mosaic/gpu/transform_inference.py index c76af4fb07e2..3438a654f90a 100644 --- a/jax/experimental/mosaic/gpu/transform_inference.py +++ b/jax/experimental/mosaic/gpu/transform_inference.py @@ -318,6 +318,16 @@ def _infer_memref_subview_transforms( return [transforms], [transforms] +# `memref.load` is used to load barrier phases---the rule needn't do anything +# interesting, but we need to have it in order to avoid crashing on it. +@partial(_add_transform_inference_rule, memref.LoadOp) +def _infer_memref_load_transforms(op: memref.LoadOp) -> OptionalTransforms: + if not ir.MemRefType(op.memref.type).shape: + # memref.load returns a scalar, so there is nothing interesting to do here. + return None + raise NotImplementedError("Non-scalar memref.load transforms") + + def _should_have_transforms(op: ir.OpView) -> bool: """Returns 'True' if the operation should be assigned in/out transforms.""" return any( From 27b30190be173d759fcab223242211a33ab6e3f3 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Fri, 21 Mar 2025 08:36:52 -0700 Subject: [PATCH 081/483] [Pallas/Mosaic GPU] Add lowering for WGMMA using warpgroup semantics. When using warpgroup semantics, the transforms are inferred by the transform inference pass---except for transposition which will still get propagated down from Pallas. Also turn on `transform_inference` in the Pallas->Mosaic GPU lowering pipeline. PiperOrigin-RevId: 739191716 --- jax/_src/pallas/mosaic_gpu/lowering.py | 17 +++--- jax/_src/pallas/mosaic_gpu/primitives.py | 65 ++++++++++++++++++---- tests/pallas/mosaic_gpu_test.py | 68 ++++++++++++++---------- 3 files changed, 106 insertions(+), 44 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index ef4c80cb4649..004a6e7f2760 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -766,6 +766,7 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): # Run Python lowering passes. The remaining passes will be run in C++ in # jax/jaxlib/mosaic/gpu/custom_call.cc mgpu.infer_layout(module) # pytype: disable=attribute-error + mgpu.infer_transforms(module) # pytype: disable=attribute-error mgpu.lower_mgpu_dialect(module, launch_ctx) # pytype: disable=attribute-error mgpu_core._initialize_scratch(launch_ctx, scratch_arr) @@ -1837,13 +1838,15 @@ def _run_scoped_lowering_rule( for v in jaxpr.invars: aval = v.aval if isinstance(aval, gpu_core.WGMMAAbstractAccumulatorRef): - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Warpgroup: - # TODO(bchetioui): Fix this and remove the NotImplementedError. - raise NotImplementedError( - "WGMMA accumulators are not supported with Warpgroup semantics." - ) - mlir_dtype = mlir.dtype_to_ir_type(aval.dtype) - input_refs.append(mgpu.WGMMAAccumulator.zero(*aval.shape, mlir_dtype)) + dtype = mlir.dtype_to_ir_type(aval.dtype) + if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + input_refs.append(mgpu.WGMMAAccumulator.zero(*aval.shape, dtype)) + else: + zero = arith_dialect.constant(dtype, ir.FloatAttr.get(dtype, 0.0)) + acc = vector_dialect.splat(ir.VectorType.get(aval.shape, dtype), zero) + acc = mgpu.dialect.optimization_barrier([acc]) + nvvm_dialect.wgmma_fence_aligned() + input_refs.append(acc) should_discharge.append(True) elif isinstance(aval.dtype, gpu_core.BarrierType): input_refs.append( diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 7f26f5d2b6a3..edfae55fb288 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -517,11 +517,7 @@ def commit_smem_to_gmem_group() -> None: wgmma_ref_p.multiple_results = True -def wgmma( - acc: gpu_core.WGMMAAbstractAccumulatorRef, - a, - b: pallas_core.TransformedRef, -) -> None: +def wgmma(acc: gpu_core.WGMMAAbstractAccumulatorRef, a, b) -> None: """Performs an asynchronous warp group matmul-accumulate on the given references. Conceptually, this is equivalent to doing ``acc[...] += a[...] @ b[...]``, @@ -555,12 +551,17 @@ def wgmma( a = a.ref else: a_transforms_leaves, a_transforms_tree = [], None - b_transforms_leaves, b_transforms_tree = jax.tree.flatten(b.transforms) + + if isinstance(b, pallas_core.TransformedRef): + b_transforms_leaves, b_transforms_tree = jax.tree.flatten(b.transforms) + b = b.ref + else: + b_transforms_leaves, b_transforms_tree = [], None wgmma_ref_p.bind( acc, a, - b.ref, + b, *a_transforms_leaves, *b_transforms_leaves, a_transforms_tree=a_transforms_tree, @@ -674,6 +675,40 @@ def _wgmma_lowering( return new_acc +@lowering.register_lowering_rule(wgmma_p, mgpu.ThreadSemantics.Warpgroup) +def _wgmma_warpgroup_lowering( + ctx: lowering.LoweringRuleContext, + acc, + a, + b, + *transforms_leaves, + a_transforms_tree, + b_transforms_tree, +): + del ctx, transforms_leaves # Unused. + if a_transforms_tree is not None: + match a_transforms_tree: + case gpu_core.TransposeRef((1, 0)): + raise NotImplementedError("WGMMA lhs transpose not supported.") + case _: + raise ValueError( + f"WGMMA lhs has unsupported transforms: {a_transforms_tree}." + ) + + if b_transforms_tree is not None: + match b_transforms_tree: + case gpu_core.TransposeRef((1, 0)): + raise NotImplementedError("WGMMA rhs transpose not supported.") + case _: + raise ValueError( + f"WGMMA rhs has unsupported transforms: {b_transforms_tree}." + ) + + new_acc = mgpu.dialect.wgmma(acc, a, b) + nvvm_dialect.wgmma_commit_group_sync_aligned() + return new_acc + + @wgmma_p.def_effectful_abstract_eval def _wgmma_effectful_abstract_eval(acc, lhs_ref, *args, **kwargs): del args, kwargs @@ -698,6 +733,7 @@ def wgmma_wait_effectful_abstract_eval(_): @lowering.register_lowering_rule(wgmma_wait_p, mgpu.ThreadSemantics.Lane) +@lowering.register_lowering_rule(wgmma_wait_p, mgpu.ThreadSemantics.Warpgroup) def _wgmma_wait_lowering(ctx: lowering.LoweringRuleContext, allow_groups): del ctx nvvm_dialect.wgmma_wait_group_sync_aligned(allow_groups) @@ -728,11 +764,19 @@ def _wgmma_accumulator_deref_discharge(in_avals, out_avals, acc): return (None,), wgmma_accumulator_deref_p.bind(acc) -@lowering.register_lowering_rule(wgmma_accumulator_deref_p, mgpu.ThreadSemantics.Lane) +@lowering.register_lowering_rule( + wgmma_accumulator_deref_p, mgpu.ThreadSemantics.Lane +) +@lowering.register_lowering_rule( + wgmma_accumulator_deref_p, mgpu.ThreadSemantics.Warpgroup +) def _wgmma_accumulator_deref_lowering(ctx: lowering.LoweringRuleContext, acc): - del ctx nvvm_dialect.wgmma_wait_group_sync_aligned(0) - return acc.value + return ( + acc.value + if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane + else acc + ) class Layout(enum.Enum): @@ -835,6 +879,7 @@ def _commit_smem_abstract_eval(): @lowering.register_lowering_rule(commit_smem_p, mgpu.ThreadSemantics.Lane) @lowering.register_lowering_rule(commit_smem_p, mgpu.ThreadSemantics.Warpgroup) def _commit_smem_lowering(ctx: lowering.LoweringRuleContext): + # TODO(bchetioui): add primitive for commit smem to mosaic_gpu dialect. mgpu.commit_shared() return () diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 94c2620f7ae6..408b8bdf5713 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1302,7 +1302,8 @@ def scope(acc_ref): acc_ini = jnp.ones((64, 64), dtype=jnp.float16) np.testing.assert_array_equal(kernel(acc_ini), jnp.full((64, 64), 5, dtype=jnp.float16)) - def test_realistic_matmul(self): + @parameterized.parameters([*plgpu.ThreadSemantics]) + def test_realistic_matmul(self, thread_semantics): dtype = jnp.float16 swizzle = 128 elems_128b = swizzle // jnp.dtype(dtype).itemsize @@ -1326,34 +1327,46 @@ def _epilogue(): a = jax.random.uniform(key1, shape=(m, k), dtype=dtype) b = jax.random.uniform(key2, shape=(k, n), dtype=dtype) + lhs_spec = pl.BlockSpec( + (tile_m, tile_k), + lambda m, n, k: (m, k), + ) + rhs_spec = pl.BlockSpec( + (tile_k, tile_n), + lambda m, n, k: (k, n), + ) + out_spec = pl.BlockSpec( + (tile_m, tile_n), + lambda m, n, k: (m, n), + ) + + if thread_semantics == plgpu.ThreadSemantics.Lane: + lhs_spec = plgpu.GPUBlockSpec( + lhs_spec.block_shape, lhs_spec.index_map, + transforms=( + plgpu.TilingTransform((64, elems_128b)), + plgpu.SwizzleTransform(128), + ) + ) + rhs_spec = plgpu.GPUBlockSpec( + rhs_spec.block_shape, rhs_spec.index_map, + transforms=( + plgpu.TilingTransform((elems_128b, elems_128b)), + plgpu.SwizzleTransform(128), + ) + ) + out_spec = plgpu.GPUBlockSpec( + out_spec.block_shape, out_spec.index_map, + transforms=( + plgpu.TilingTransform((64, elems_128b)), + plgpu.SwizzleTransform(128), + ) + ) + res = pl.pallas_call( kernel, - in_specs=[ - plgpu.GPUBlockSpec( - (tile_m, tile_k), - lambda m, n, k: (m, k), - transforms=( - plgpu.TilingTransform((64, elems_128b)), - plgpu.SwizzleTransform(128), - ), - ), - plgpu.GPUBlockSpec( - (tile_k, tile_n), - lambda m, n, k: (k, n), - transforms=( - plgpu.TilingTransform((elems_128b, elems_128b)), - plgpu.SwizzleTransform(128), - ), - ), - ], - out_specs=plgpu.GPUBlockSpec( - (tile_m, tile_n), - lambda m, n, k: (m, n), - transforms=( - plgpu.TilingTransform((64, elems_128b)), - plgpu.SwizzleTransform(128), - ), - ), + in_specs=[lhs_spec, rhs_spec], + out_specs=out_spec, out_shape=jax.ShapeDtypeStruct((m, n), jnp.float16), scratch_shapes=[plgpu.ACC((tile_m, tile_n), jnp.float32)], grid=(grid_m, grid_n, grid_k), @@ -1361,6 +1374,7 @@ def _epilogue(): dimension_semantics=["parallel", "parallel", "sequential"], max_concurrent_steps=2, delay_release=1, + thread_semantics=thread_semantics, ), )(a, b) np.testing.assert_allclose(res, a @ b, rtol=1e-3) From 92f5d9caa33f59b1c8511f4fc0676e1a155ab4a2 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 21 Mar 2025 08:53:50 -0700 Subject: [PATCH 082/483] Deprecated `jax.tree_util.build_tree` We have no usages of it neither in JAX nor internally, but we still have to go through the deprecation cycle, becuase `jax.tree_util` is public API. PiperOrigin-RevId: 739196514 --- CHANGELOG.md | 5 +++++ jax/_src/tree_util.py | 9 ++------- jax/tree_util.py | 24 ++++++++++++++++++++++-- 3 files changed, 29 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9a817ce80937..17fb421fcc06 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. ## Unreleased +* Deprecations + + * {func}`jax.tree_util.build_tree` is deprecated. Use {func}`jax.tree.unflatten` + instead. + ## jax 0.5.3 (Mar 19, 2025) * New Features diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index 6c7e15a042e5..883937fcce6e 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -362,6 +362,8 @@ def tree_map(f: Callable[..., Any], def build_tree(treedef: PyTreeDef, xs: Any) -> Any: """Build a treedef from a nested iterable structure + DEPRECATED: Use :func:`jax.tree.unflatten` instead. + Args: treedef: the PyTreeDef structure to build. xs: nested iterables matching the arity as the treedef @@ -376,13 +378,6 @@ def build_tree(treedef: PyTreeDef, xs: Any) -> Any: >>> import jax >>> tree = [(1, 2), {'a': 3, 'b': 4}] >>> treedef = jax.tree.structure(tree) - - Both ``build_tree`` and :func:`jax.tree_util.tree_unflatten` can reconstruct - the tree from new values, but ``build_tree`` takes these values in terms of - a nested rather than flat structure: - - >>> jax.tree_util.build_tree(treedef, [[10, 11], [12, 13]]) - [(10, 11), {'a': 12, 'b': 13}] >>> jax.tree_util.tree_unflatten(treedef, [10, 11, 12, 13]) [(10, 11), {'a': 12, 'b': 13}] """ diff --git a/jax/tree_util.py b/jax/tree_util.py index 956d79b9b4ef..3d24c457b3f8 100644 --- a/jax/tree_util.py +++ b/jax/tree_util.py @@ -48,13 +48,13 @@ PyTreeDef as PyTreeDef, SequenceKey as SequenceKey, all_leaves as all_leaves, - build_tree as build_tree, + build_tree as _deprecated_build_tree, default_registry as default_registry, keystr as keystr, + register_dataclass as register_dataclass, register_pytree_node_class as register_pytree_node_class, register_pytree_node as register_pytree_node, register_pytree_with_keys_class as register_pytree_with_keys_class, - register_dataclass as register_dataclass, register_pytree_with_keys as register_pytree_with_keys, register_static as register_static, tree_all as tree_all, @@ -72,3 +72,23 @@ treedef_is_leaf as treedef_is_leaf, treedef_tuple as treedef_tuple, ) + +_deprecations = { + # Added March 21, 2025: + "build_tree": ( + ( + "jax.tree_util.build_tree is deprecated. Use jax.tree.unflatten" + " instead." + ), + _deprecated_build_tree, + ), +} + +import typing as _typing +if _typing.TYPE_CHECKING: + from jax._src.tree_util import build_tree as build_tree +else: + from jax._src.deprecations import deprecation_getattr + __getattr__ = deprecation_getattr(__name__, _deprecations) + del deprecation_getattr, _deprecations +del _typing From 027195489a5b9a1d253550c2954f8aa11fa03370 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 21 Mar 2025 08:55:59 -0700 Subject: [PATCH 083/483] Reorder C++ imports (nanobind). PiperOrigin-RevId: 739197184 --- examples/ffi/src/jax_ffi_example/gpu_examples.cc | 2 +- jaxlib/cpu/lapack.cc | 2 +- jaxlib/cuda/cuda_plugin_extension.cc | 2 +- jaxlib/cuda/versions.cc | 3 +-- jaxlib/gpu/blas.cc | 4 ++-- jaxlib/gpu/gpu_plugin_extension.cc | 6 +++--- jaxlib/gpu/hybrid.cc | 2 +- jaxlib/gpu/py_client_gpu.cc | 2 +- jaxlib/gpu/solver.cc | 4 ++-- jaxlib/gpu/sparse.cc | 4 ++-- jaxlib/gpu/triton.cc | 2 +- jaxlib/kernel_nanobind_helpers.h | 2 +- jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc | 2 +- jaxlib/mlir/_mlir_libs/tpu_ext.cc | 10 +++++----- jaxlib/mlir/_mlir_libs/triton_ext.cc | 2 +- jaxlib/mosaic/gpu/mosaic_gpu_ext.cc | 4 ++-- jaxlib/rocm/rocm_plugin_extension.cc | 2 +- jaxlib/utils.cc | 2 +- jaxlib/xla/custom_calls_testlib.cc | 2 +- 19 files changed, 29 insertions(+), 30 deletions(-) diff --git a/examples/ffi/src/jax_ffi_example/gpu_examples.cc b/examples/ffi/src/jax_ffi_example/gpu_examples.cc index 921039debe5d..79a4ee91e8c6 100644 --- a/examples/ffi/src/jax_ffi_example/gpu_examples.cc +++ b/examples/ffi/src/jax_ffi_example/gpu_examples.cc @@ -16,8 +16,8 @@ limitations under the License. #include #include -#include "nanobind/nanobind.h" #include "cuda_runtime_api.h" +#include "nanobind/nanobind.h" #include "xla/ffi/api/ffi.h" namespace nb = nanobind; diff --git a/jaxlib/cpu/lapack.cc b/jaxlib/cpu/lapack.cc index c104019777e5..7cc4fa9e2dbd 100644 --- a/jaxlib/cpu/lapack.cc +++ b/jaxlib/cpu/lapack.cc @@ -15,8 +15,8 @@ limitations under the License. #include -#include "nanobind/nanobind.h" #include "absl/base/call_once.h" +#include "nanobind/nanobind.h" #include "jaxlib/cpu/lapack_kernels.h" #include "jaxlib/kernel_nanobind_helpers.h" diff --git a/jaxlib/cuda/cuda_plugin_extension.cc b/jaxlib/cuda/cuda_plugin_extension.cc index 789227e273b6..6655128b9842 100644 --- a/jaxlib/cuda/cuda_plugin_extension.cc +++ b/jaxlib/cuda/cuda_plugin_extension.cc @@ -16,10 +16,10 @@ limitations under the License. #include #include -#include "nanobind/nanobind.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "third_party/gpus/cuda/include/cuda.h" +#include "nanobind/nanobind.h" #include "jaxlib/gpu/gpu_plugin_extension.h" #include "jaxlib/gpu/py_client_gpu.h" #include "jaxlib/kernel_nanobind_helpers.h" diff --git a/jaxlib/cuda/versions.cc b/jaxlib/cuda/versions.cc index 8d6577f46709..d9f9f4c86865 100644 --- a/jaxlib/cuda/versions.cc +++ b/jaxlib/cuda/versions.cc @@ -13,9 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "jaxlib/cuda/versions_helpers.h" - #include "nanobind/nanobind.h" +#include "jaxlib/cuda/versions_helpers.h" #include "jaxlib/gpu/vendor.h" namespace jax::cuda { diff --git a/jaxlib/gpu/blas.cc b/jaxlib/gpu/blas.cc index e8761bd32ac9..4a58859016f1 100644 --- a/jaxlib/gpu/blas.cc +++ b/jaxlib/gpu/blas.cc @@ -17,10 +17,10 @@ limitations under the License. #include #include -#include "nanobind/nanobind.h" -#include "nanobind/stl/pair.h" #include "absl/container/flat_hash_map.h" #include "absl/strings/str_format.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/pair.h" #include "jaxlib/gpu/blas_kernels.h" #include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_nanobind_helpers.h" diff --git a/jaxlib/gpu/gpu_plugin_extension.cc b/jaxlib/gpu/gpu_plugin_extension.cc index 5726e0929ee5..d026806e9479 100644 --- a/jaxlib/gpu/gpu_plugin_extension.cc +++ b/jaxlib/gpu/gpu_plugin_extension.cc @@ -20,13 +20,13 @@ limitations under the License. #include #include -#include "nanobind/nanobind.h" -#include "nanobind/stl/string.h" // IWYU pragma: keep -#include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "jaxlib/kernel_nanobind_helpers.h" #include "xla/ffi/api/c_api.h" #include "xla/pjrt/c/pjrt_c_api.h" diff --git a/jaxlib/gpu/hybrid.cc b/jaxlib/gpu/hybrid.cc index 94975a5b969f..71c320a60f02 100644 --- a/jaxlib/gpu/hybrid.cc +++ b/jaxlib/gpu/hybrid.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "nanobind/nanobind.h" #include "absl/base/call_once.h" +#include "nanobind/nanobind.h" #include "jaxlib/cpu/lapack_kernels.h" #include "jaxlib/gpu/hybrid_kernels.h" #include "jaxlib/gpu/vendor.h" diff --git a/jaxlib/gpu/py_client_gpu.cc b/jaxlib/gpu/py_client_gpu.cc index cf701574959b..3e140411770d 100644 --- a/jaxlib/gpu/py_client_gpu.cc +++ b/jaxlib/gpu/py_client_gpu.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include -#include "nanobind/nanobind.h" #include "absl/base/casts.h" #include "absl/log/check.h" #include "absl/status/status.h" @@ -29,6 +28,7 @@ limitations under the License. #include "absl/strings/numbers.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "nanobind/nanobind.h" #include "jaxlib/gpu/vendor.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/host_callback.h" diff --git a/jaxlib/gpu/solver.cc b/jaxlib/gpu/solver.cc index 357a38eecfd5..1cf799bbb491 100644 --- a/jaxlib/gpu/solver.cc +++ b/jaxlib/gpu/solver.cc @@ -17,11 +17,11 @@ limitations under the License. #include #include -#include "nanobind/nanobind.h" -#include "nanobind/stl/pair.h" #include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/pair.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/solver_handle_pool.h" #include "jaxlib/gpu/solver_kernels.h" diff --git a/jaxlib/gpu/sparse.cc b/jaxlib/gpu/sparse.cc index 429c8018dc7a..a7f8dbebc2b3 100644 --- a/jaxlib/gpu/sparse.cc +++ b/jaxlib/gpu/sparse.cc @@ -19,13 +19,13 @@ limitations under the License. #include #include -#include "nanobind/nanobind.h" -#include "nanobind/stl/pair.h" #include "absl/base/casts.h" #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/pair.h" #include "jaxlib/absl_status_casters.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/sparse_kernels.h" diff --git a/jaxlib/gpu/triton.cc b/jaxlib/gpu/triton.cc index 500034af3ebb..135410568f6b 100644 --- a/jaxlib/gpu/triton.cc +++ b/jaxlib/gpu/triton.cc @@ -5,13 +5,13 @@ #include #include +#include "absl/status/statusor.h" #include "nanobind/nanobind.h" #include "nanobind/stl/pair.h" #include "nanobind/stl/string.h" #include "nanobind/stl/string_view.h" #include "nanobind/stl/tuple.h" #include "nanobind/stl/vector.h" -#include "absl/status/statusor.h" #include "jaxlib/absl_status_casters.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/triton.pb.h" diff --git a/jaxlib/kernel_nanobind_helpers.h b/jaxlib/kernel_nanobind_helpers.h index fde37e695349..127d89f702c8 100644 --- a/jaxlib/kernel_nanobind_helpers.h +++ b/jaxlib/kernel_nanobind_helpers.h @@ -19,8 +19,8 @@ limitations under the License. #include #include -#include "nanobind/nanobind.h" #include "absl/base/casts.h" +#include "nanobind/nanobind.h" #include "jaxlib/kernel_helpers.h" #include "xla/ffi/api/c_api.h" #include "xla/tsl/python/lib/core/numpy.h" // NOLINT diff --git a/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc b/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc index 7483d7ed1eea..c73084abc99d 100644 --- a/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc +++ b/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc @@ -16,9 +16,9 @@ limitations under the License. #include #include -#include "nanobind/nanobind.h" #include "mlir-c/IR.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" // IWYU pragma: keep +#include "nanobind/nanobind.h" #include "jaxlib/mosaic/dialect/gpu/integrations/c/attributes.h" #include "jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.h" diff --git a/jaxlib/mlir/_mlir_libs/tpu_ext.cc b/jaxlib/mlir/_mlir_libs/tpu_ext.cc index 2b5ec898ad3e..7d616968b9aa 100644 --- a/jaxlib/mlir/_mlir_libs/tpu_ext.cc +++ b/jaxlib/mlir/_mlir_libs/tpu_ext.cc @@ -43,13 +43,13 @@ limitations under the License. // clang-format off #include "mlir-c/Bindings/Python/Interop.h" // clang-format on +#include "absl/log/check.h" #include "nanobind/nanobind.h" #include "nanobind/stl/optional.h" // IWYU pragma: keep -#include "nanobind/stl/pair.h" // IWYU pragma: keep -#include "nanobind/stl/string.h" // IWYU pragma: keep -#include "nanobind/stl/variant.h" // IWYU pragma: keep -#include "nanobind/stl/vector.h" // IWYU pragma: keep -#include "absl/log/check.h" +#include "nanobind/stl/pair.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep #include "jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.h" #include "xla/python/nb_numpy.h" #include "xla/tsl/python/lib/core/numpy.h" diff --git a/jaxlib/mlir/_mlir_libs/triton_ext.cc b/jaxlib/mlir/_mlir_libs/triton_ext.cc index e824d4058d7e..2a13c40d963f 100644 --- a/jaxlib/mlir/_mlir_libs/triton_ext.cc +++ b/jaxlib/mlir/_mlir_libs/triton_ext.cc @@ -15,9 +15,9 @@ limitations under the License. #include -#include "nanobind/nanobind.h" #include "mlir-c/IR.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "nanobind/nanobind.h" #include "jaxlib/triton/triton_dialect_capi.h" namespace nb = nanobind; diff --git a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc index 2c7242b6e6c0..ee11b22020dc 100644 --- a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc +++ b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc @@ -22,11 +22,11 @@ limitations under the License. #include #include +#include "absl/cleanup/cleanup.h" +#include "absl/strings/str_cat.h" #include "nanobind/nanobind.h" #include "nanobind/stl/tuple.h" #include "nanobind/stl/vector.h" -#include "absl/cleanup/cleanup.h" -#include "absl/strings/str_cat.h" #include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_nanobind_helpers.h" #include "xla/ffi/api/c_api.h" diff --git a/jaxlib/rocm/rocm_plugin_extension.cc b/jaxlib/rocm/rocm_plugin_extension.cc index 1e8013f2bc1b..454f4741d667 100644 --- a/jaxlib/rocm/rocm_plugin_extension.cc +++ b/jaxlib/rocm/rocm_plugin_extension.cc @@ -16,10 +16,10 @@ limitations under the License. #include #include -#include "nanobind/nanobind.h" #include "absl/log/log.h" #include "absl/strings/str_cat.h" #include "rocm/include/hip/hip_runtime.h" +#include "nanobind/nanobind.h" #include "jaxlib/gpu/gpu_plugin_extension.h" #include "jaxlib/gpu/py_client_gpu.h" #include "jaxlib/kernel_nanobind_helpers.h" diff --git a/jaxlib/utils.cc b/jaxlib/utils.cc index bf50b3a5254d..e5bb45e999da 100644 --- a/jaxlib/utils.cc +++ b/jaxlib/utils.cc @@ -19,12 +19,12 @@ limitations under the License. #include #include -#include "nanobind/nanobind.h" #include "absl/cleanup/cleanup.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/synchronization/mutex.h" +#include "nanobind/nanobind.h" namespace nb = nanobind; diff --git a/jaxlib/xla/custom_calls_testlib.cc b/jaxlib/xla/custom_calls_testlib.cc index d06105fce76f..58f4818a431e 100644 --- a/jaxlib/xla/custom_calls_testlib.cc +++ b/jaxlib/xla/custom_calls_testlib.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include -#include "third_party/nanobind/include/nanobind/nanobind.h" +#include "nanobind/nanobind.h" #include "xla/ffi/api/c_api.h" #include "xla/ffi/api/ffi.h" From 40ce44d143e160d7c44f5453fe3f49d413598301 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 21 Mar 2025 09:25:11 -0700 Subject: [PATCH 084/483] Add `ShardingTypeError` to all sharding rules in JAX PiperOrigin-RevId: 739205830 --- jax/_src/lax/lax.py | 25 ++++++++++--------- jax/_src/lax/linalg.py | 4 +-- jax/_src/lax/slicing.py | 7 ++---- jax/_src/lax/utils.py | 2 +- jax/_src/lax/windowed_reductions.py | 4 +-- tests/pjit_test.py | 38 +++++++++++++++-------------- 6 files changed, 40 insertions(+), 40 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 388ad49ec83d..f6ab848ccd5b 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -3805,7 +3805,7 @@ def broadcasting_sharding_rule(name, *avals): for a in avals: if a.sharding is not None and not a.sharding.mesh.empty: if mesh is not None and mesh != a.sharding.mesh: - raise ValueError( + raise core.ShardingTypeError( f'Mesh for all inputs should be equal. Got one mesh: {mesh} and' f' another mesh: {a.sharding.mesh}') mesh = a.sharding.mesh @@ -3839,7 +3839,7 @@ def broadcasting_sharding_rule(name, *avals): result_specs[i] = s elif (result_specs[i] is not None and s is not None and result_specs[i] != s): - raise TypeError( + raise core.ShardingTypeError( f'{name} got incompatible shardings for broadcasting: ' f'{", ".join(map(str, map(tuple, specs)))}.') return NamedSharding(mesh, P(*result_specs)) @@ -4990,13 +4990,13 @@ def _dot_general_shape_computation(lhs_shape, rhs_shape, dimension_numbers): def _check_specs_match(lhs_spec, rhs_spec, msg): for l, r in zip(lhs_spec, rhs_spec): if l is not None and r is not None and l != r: - raise TypeError(msg) + raise core.ShardingTypeError(msg) def _dot_general_sharding_rule(lhs, rhs, *, dimension_numbers, precision, preferred_element_type: DTypeLike | None, out_sharding): if lhs.sharding.mesh != rhs.sharding.mesh: - raise ValueError( + raise core.ShardingTypeError( 'Mesh of both lhs and rhs should match. Got lhs:' f' {lhs.sharding.mesh} and rhs: {rhs.sharding.mesh}') @@ -5020,7 +5020,7 @@ def _dot_general_sharding_rule(lhs, rhs, *, dimension_numbers, precision, for l, r in zip(lhs_contracting_spec, rhs_contracting_spec): if l is not None and r is not None: - raise ValueError( + raise core.ShardingTypeError( 'Contracting dimensions are sharded and it is ambiguous how the' ' output should be sharded. Please specify the output sharding via' ' the `out_sharding` parameter of einsum. Or reshard your input via' @@ -6378,7 +6378,7 @@ def _concatenate_sharding_rule(*operands, **kwargs): return core.get_cur_mesh_sharding() if not all(s == non_empty_s[0] for s in non_empty_s): ss = ", ".join(str(o.sharding) for o in operands) - raise TypeError( + raise core.ShardingTypeError( f"All operands should have the same sharding. Got shardings {ss}") return non_empty_s[0] @@ -6697,7 +6697,7 @@ def _split_on_one_axis(op_shape, new_sizes, name): else: count += 1 if count > 1: - raise ValueError( + raise core.ShardingTypeError( f'{name} on more than 1 axis is not supported. Please specify' ' the sharding of the output via the `sharding` argument of' f' jax.lax.reshape. Got operand.shape={op_shape} and {new_sizes=}') @@ -6738,7 +6738,7 @@ def _reshape_sharding_rule(operand, *, new_sizes, dimensions, sharding): return _merge_an_axis_sharding_rule(operand, operand_merge, new_sizes, dimensions) - raise ValueError( + raise core.ShardingTypeError( 'This reshape is not supported. Please specify the sharding of' ' the output via the `out_sharding` argument of jax.lax.reshape. Got' f' operand shape: {operand.shape}, new sizes: {new_sizes} and' @@ -6771,7 +6771,7 @@ def _split_an_axis_sharding_rule(operand, out_split, new_sizes, dimensions): elif dimensions is None and out[0] % _get_spec_size(sp, mesh) == 0: new_spec.extend([sp] + [None] * (len(out) - 1)) else: - raise ValueError( + raise core.ShardingTypeError( 'This reshape is not supported. Please specify the sharding of the' ' output via the `sharding` argument of jax.lax.reshape. Got' f' operand shape: {operand.shape}, new sizes: {new_sizes} and' @@ -6796,7 +6796,7 @@ def _merge_an_axis_sharding_rule(operand, operand_merge, new_sizes, dimensions): assert new_size % _get_spec_size(sp[0], mesh) == 0 new_spec.append(sp[0]) else: - raise ValueError( + raise core.ShardingTypeError( 'This reshape is not supported. Please specify the sharding of the' ' output via the `sharding` argument of jax.lax.reshape. Got' f' operand shape: {operand.shape}, new sizes: {new_sizes} and' @@ -6979,10 +6979,11 @@ def _select_sharding_rule(which, *cases): return core.get_cur_mesh_sharding() if any(s != non_empty_s[0] for s in non_empty_s[1:]): msg = "select cases must have the same shardings, got [{}]." - raise TypeError(msg.format(", ".join([str(c.sharding) for c in cases]))) + raise core.ShardingTypeError( + msg.format(", ".join([str(c.sharding) for c in cases]))) if (which.shape and not which.sharding.mesh.empty and which.sharding != non_empty_s[0]): - raise TypeError( + raise core.ShardingTypeError( 'select `which` must be scalar or have the same sharding as cases, got' f' `which` sharding {which.sharding} but case sharding' f' {cases[0].sharding}.') diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 3e9077d0a51c..027ec8b801b9 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -717,14 +717,14 @@ def linalg_sharding_rule( spec = aval.sharding.spec batch_spec, rest_spec = spec[:len(spec) - rank], spec[len(spec) - rank:] if not all(s is None for s in rest_spec): - raise ValueError( + raise core.ShardingTypeError( f"Input {i} to {name} must be unsharded on non-batch dimensions, " f"but got {spec}." ) batch_specs.append(batch_spec) batch_spec = batch_specs[0] if any(b != batch_spec for b in batch_specs[1:]): - raise ValueError( + raise core.ShardingTypeError( f"All inputs to {name} must have the same batch sharding, but got " f"{batch_specs}." ) diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index b3a0a8e2d0c1..d3bcb6da2807 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -1333,7 +1333,7 @@ def _get_sharding_for_varying_out_shape(out_shape, operand, name): operand.shape, out_shape, operand.sharding.spec): if (op_sh != out_sh and op_spec is not None and out_sh % _get_sub_spec_size(mesh, op_spec) != 0): - raise NotImplementedError( + raise core.ShardingTypeError( f"{name} on sharded dims where out dim ({out_sh}) is not divisble by" f" mesh axes ({_get_sub_spec_size(mesh, op_spec)}) with spec" f" ({op_spec}) is not implemented.") @@ -1922,9 +1922,6 @@ def _gather_shape_computation(indices, dimension_numbers, slice_sizes): else next(indices_shape_gen) for i in range(output_shape_rank)) return ans -class GatherShardingError(Exception): - pass - def _gather_sharding_rule(operand, indices, *, dimension_numbers, slice_sizes, unique_indices, indices_are_sorted, mode, fill_value): @@ -1936,7 +1933,7 @@ def _gather_sharding_rule(operand, indices, *, dimension_numbers, all(s is None for s in operand.sharding.spec) and all(s is None for s in indices.sharding.spec)): return core.get_cur_mesh_sharding() - raise GatherShardingError( + raise core.ShardingTypeError( "Use `.at[...].get(out_sharding=)` to provide output PartitionSpec for" " the gather indexing.") diff --git a/jax/_src/lax/utils.py b/jax/_src/lax/utils.py index 9fc9ba16a604..206a8312ba8c 100644 --- a/jax/_src/lax/utils.py +++ b/jax/_src/lax/utils.py @@ -96,7 +96,7 @@ def call_shape_dtype_sharding_rule(prim, shape_rule, dtype_rule, sharding_rule, mesh = mesh_lib.empty_abstract_mesh if e.mesh is None else e.mesh out_aval_str = core.str_short_aval(out_shapes, out_dtypes, mesh, e.pspec, short_dtypes=True) - raise TypeError( + raise core.ShardingTypeError( f'{prim} operation with inputs: {avals_str} produces an illegally' f' sharded result: {out_aval_str}') from e return out_shapes, out_dtypes, out_shardings diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index 400646f6238f..42b2e9278889 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -525,7 +525,7 @@ def reduce_window_sharding_rule(operand, window_dimensions, window_strides, if spec is None: continue if not (wdim == 1 and ws == 1 and pd == 1 and bd == 1 and wdil == 1): - raise NotImplementedError( + raise core.ShardingTypeError( "Only trivial windowing is supported along non-replicated" f" dimensions. Got {operand.sharding.spec=}") return operand.sharding @@ -826,7 +826,7 @@ def _select_and_gather_add_sharding_rule( tangents, operand, *, select_prim, window_dimensions, window_strides, padding, base_dilation, window_dilation): if tangents.sharding != operand.sharding: - raise TypeError( + raise core.ShardingTypeError( "select_and_gather_add tangents and operand shardings must match, " f"got {tangents.sharding} and {operand.sharding}.") return reduce_window_sharding_rule( diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 6a1a73fe4301..6fdfa62887b9 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4995,11 +4995,13 @@ def g(x, y): return x * y with self.assertRaisesRegex( - TypeError, "mul got incompatible shardings for broadcasting"): + core.ShardingTypeError, + "mul got incompatible shardings for broadcasting"): g(arr1, jax.device_put(np_inp1, NamedSharding(mesh, P('y', 'x')))) with self.assertRaisesRegex( - TypeError, "mul got incompatible shardings for broadcasting"): + core.ShardingTypeError, + "mul got incompatible shardings for broadcasting"): g(arr1, jax.device_put(np_inp1, NamedSharding(mesh, P(('x', 'y'))))) @parameterized.named_parameters( @@ -5098,14 +5100,14 @@ def f(x, y): @parameterized.named_parameters( ('fail1', P('x', None), P(None, 'x'), "dot_general operation.*produces an illegally sharded result", - TypeError), + core.ShardingTypeError), ('fail2', P('x', 'y'), P('x', 'y'), "dot_general requires contracting dimensions to have consistent sharding", - TypeError), + core.ShardingTypeError), ('contracting1', P('x', 'y'), P('y', None), - 'Contracting dimensions are sharded', ValueError), + 'Contracting dimensions are sharded', core.ShardingTypeError), ('other_half_tp', P(None, 'y'), P('y', None), - 'Contracting dimensions are sharded', ValueError), + 'Contracting dimensions are sharded', core.ShardingTypeError), ) @jtu.with_user_mesh((2, 2), ('x', 'y')) def test_dot_general_error(self, spec1, spec2, error_msg, error_type, mesh): @@ -5127,14 +5129,14 @@ def test_dot_general_batch_error(self, mesh): arr2 = jax.device_put(np.ones((8, 2, 4)), NamedSharding(mesh, P('y', 'z', 'x'))) with self.assertRaisesRegex( - TypeError, + core.ShardingTypeError, 'dot_general requires lhs batch dimensions and rhs batch dimensions to' ' have the consistent sharding'): jax.lax.dot_general( arr1, arr2, dimension_numbers=(([2], [1]), ([0], [0]))) with self.assertRaisesRegex( - TypeError, + core.ShardingTypeError, 'dot_general requires lhs batch dimensions and rhs batch dimensions to' ' have the consistent sharding'): jnp.einsum('abc,acz->abz', arr1, arr2) @@ -5569,7 +5571,7 @@ def f(x): return y if error_msg: - with self.assertRaisesRegex(ValueError, error_msg): + with self.assertRaisesRegex(core.ShardingTypeError, error_msg): f(arr) else: out = f(arr) @@ -5608,7 +5610,7 @@ def f(pred, on_true, on_false): arr3 = jax.device_put(np_inp, NamedSharding(mesh, P('y', 'x'))) with self.assertRaisesRegex( - TypeError, "select cases must have the same shardings"): + core.ShardingTypeError, "select cases must have the same shardings"): f(arr1 == arr2, arr1, arr3) def test_explicit_mode_no_context_mesh(self): @@ -5778,10 +5780,10 @@ def g(x): out = jax.jit(jax.grad(g))(arr) self.assertEqual(out.sharding, arr.sharding) - with self.assertRaisesRegex(NotImplementedError, "slicing on sharded dims"): + with self.assertRaisesRegex(core.ShardingTypeError, "slicing on sharded dims"): f(jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))) - with self.assertRaisesRegex(NotImplementedError, "slicing on sharded dims"): + with self.assertRaisesRegex(core.ShardingTypeError, "slicing on sharded dims"): f(jax.device_put(np_inp, NamedSharding(mesh, P(None, ('x', 'y'))))) @jtu.with_user_mesh((2, 2), ('x', 'y')) @@ -5842,13 +5844,13 @@ def g(x): out = jax.jit(jax.grad(g))(arr) self.assertEqual(out.sharding, arr.sharding) - with self.assertRaisesRegex(NotImplementedError, "padding on sharded dims"): + with self.assertRaisesRegex(core.ShardingTypeError, "padding on sharded dims"): f(arr, ((2, 3, 0), ), None) - with self.assertRaisesRegex(NotImplementedError, "padding on sharded dims"): + with self.assertRaisesRegex(core.ShardingTypeError, "padding on sharded dims"): f(arr, ((0, 3, 0), ), None) - with self.assertRaisesRegex(NotImplementedError, "padding on sharded dims"): + with self.assertRaisesRegex(core.ShardingTypeError, "padding on sharded dims"): arr = jax.device_put(np_inp, NamedSharding(mesh, P(('x', 'y')))) f(arr, ((4, 4, 1),), None) @@ -5879,7 +5881,7 @@ def f(x, y, method='jnp'): self.assertArraysEqual(out, np.concatenate([arr1, arr2], axis=1)) with self.assertRaisesRegex( - TypeError, "All operands should have the same sharding"): + core.ShardingTypeError, "All operands should have the same sharding"): arr3 = jax.device_put(np.arange(4.).reshape(4, 1), NamedSharding(mesh, P('x'))) f(arr1, arr3) @@ -6147,7 +6149,7 @@ def f(x, sizes=(4, 4), axis=0): f(arr) self.check_wsc_in_lowered(f.lower(arr).as_text()) - with self.assertRaisesRegex(NotImplementedError, "split on sharded dims"): + with self.assertRaisesRegex(core.ShardingTypeError, "split on sharded dims"): f(arr, sizes=(1, 1), axis=1) def g(x): @@ -6452,7 +6454,7 @@ def f(x, y, z): # Errors out on the intermediate einsum: `bthj,bthD->bthjD` # because of a conflict with self.assertRaisesRegex( - TypeError, + core.ShardingTypeError, 'dot_general operation.*produces an illegally sharded result'): f(arr1, arr2, arr3) From 3bf2eea259107cfaadedd8ee59d0b586401eeb08 Mon Sep 17 00:00:00 2001 From: Ayaka Date: Fri, 21 Mar 2025 10:16:42 -0700 Subject: [PATCH 085/483] Add AOT support for error checking PiperOrigin-RevId: 739222389 --- jax/_src/error_check.py | 198 ++++++++++++++++++++++++++++++++------ tests/error_check_test.py | 59 +++++++++++- 2 files changed, 224 insertions(+), 33 deletions(-) diff --git a/jax/_src/error_check.py b/jax/_src/error_check.py index 88dcec7063d9..e78b9bc82115 100644 --- a/jax/_src/error_check.py +++ b/jax/_src/error_check.py @@ -14,8 +14,12 @@ from __future__ import annotations +import dataclasses from functools import partial +import json import threading +import traceback as tb_lib +from types import TracebackType import warnings import jax @@ -23,19 +27,17 @@ from jax._src import source_info_util from jax._src import traceback_util import jax._src.mesh as mesh_lib -from jax.experimental.shard_map import shard_map +from jax.experimental import shard_map +import jax.export import jax.numpy as jnp from jax.sharding import NamedSharding, PartitionSpec as P -Traceback = source_info_util.Traceback - - traceback_util.register_exclusion(__file__) class JaxValueError(ValueError): - """Exception raised for failed runtime error checks in JAX.""" + """Exception raised for runtime errors detected within JAX computations.""" #: The default error code for no error. @@ -45,8 +47,9 @@ class JaxValueError(ValueError): _NO_ERROR = jnp.iinfo(jnp.uint32).max -_error_list_lock = threading.Lock() -_error_list: list[tuple[str, Traceback]] = [] # (error_message, traceback) pair +_error_list_lock = threading.RLock() +# (error_message, traceback) pairs. Traceback is `str` when imported from AOT. +_error_list: list[tuple[str, TracebackType | str]] = [] class _ErrorStorage(threading.local): @@ -65,22 +68,21 @@ def _initialize_error_code_ref() -> None: In single-device environments, the array is a scalar. In multi-device environments, its shape and size match those of the mesh. """ - with core.eval_context(): - # Get mesh from the context. - mesh = mesh_lib.get_concrete_mesh() - - if mesh is None: # single-device case. - error_code = jnp.uint32(_NO_ERROR) - - else: # multi-device case. - sharding = NamedSharding(mesh, P(*mesh.axis_names)) - error_code = jnp.full( - mesh.axis_sizes, - jnp.uint32(_NO_ERROR), - device=sharding, - ) + # Get mesh from the context. + mesh = mesh_lib.get_concrete_mesh() + + if mesh is None: # single-device case. + error_code = jnp.uint32(_NO_ERROR) + + else: # multi-device case. + sharding = NamedSharding(mesh, P(*mesh.axis_names)) + error_code = jnp.full( + mesh.axis_sizes, + jnp.uint32(_NO_ERROR), + device=sharding, + ) - _error_storage.ref = core.mutable_array(error_code) + _error_storage.ref = core.mutable_array(error_code) class error_checking_context: @@ -105,7 +107,8 @@ def __init__(self): def __enter__(self): self.old_ref = _error_storage.ref - _initialize_error_code_ref() + with core.eval_context(): + _initialize_error_code_ref() return self def __exit__(self, exc_type, exc_value, traceback): @@ -126,22 +129,33 @@ def set_error_if(pred: jax.Array, /, msg: str) -> None: will not override the existing error. For multi-device environments, in explicit mode, users must call - :func:`error_checking_context()` to initialize a new error tracking state that + :func:`error_checking_context` to initialize a new error tracking state that matches the device mesh. In auto mode, implicit cross-device communication may occur inside this function, which could impact performance. A warning is issued in such cases. + When exporting a function with `jax.export`, error checking must be explicitly + wrapped using :func:`wrap_for_export` before export and + :func:`unwrap_from_import` after import. + Args: pred: A JAX boolean array. If any element of `pred` is `True`, the internal error state will be set. msg: The corresponding error message to be raised later. """ if _error_storage.ref is None: - _initialize_error_code_ref() + with core.eval_context(): + _initialize_error_code_ref() assert _error_storage.ref is not None + # Get the traceback. traceback = source_info_util.current().traceback assert traceback is not None + traceback = traceback.as_python_traceback() + assert isinstance(traceback, TracebackType) + traceback = traceback_util.filter_traceback(traceback) + assert isinstance(traceback, TracebackType) + with _error_list_lock: new_error_code = jnp.uint32(len(_error_list)) _error_list.append((msg, traceback)) @@ -171,7 +185,7 @@ def set_error_if(pred: jax.Array, /, msg: str) -> None: "Please use `with error_checking_context()` to redefine the error " "code state based on the mesh." ) - pred = shard_map( + pred = shard_map.shard_map( partial(jnp.any, keepdims=True), mesh=out_sharding.mesh, in_specs=in_sharding.spec, @@ -179,7 +193,7 @@ def set_error_if(pred: jax.Array, /, msg: str) -> None: )(pred) # perform per-device reduction error_code = _error_storage.ref[...] - should_update = jnp.logical_and(pred, error_code == jnp.uint32(_NO_ERROR)) + should_update = jnp.logical_and(error_code == jnp.uint32(_NO_ERROR), pred) error_code = jnp.where(should_update, new_error_code, error_code) # TODO(ayx): support vmap and shard_map. _error_storage.ref[...] = error_code @@ -216,8 +230,128 @@ def raise_if_error() -> None: device=_error_storage.ref.sharding, ) # clear the error code - msg, traceback = _error_list[error_code] - exc = JaxValueError(msg) - traceback = traceback.as_python_traceback() - filtered_traceback = traceback_util.filter_traceback(traceback) - raise exc.with_traceback(filtered_traceback) + with _error_list_lock: + msg, traceback = _error_list[error_code] + if isinstance(traceback, str): # from imported AOT functions + exc = JaxValueError( + f"{msg}\nThe original traceback is shown below:\n{traceback}" + ) + raise exc + else: + exc = JaxValueError(msg) + raise exc.with_traceback(traceback) + + +@dataclasses.dataclass(frozen=True) +class _ErrorClass: + """A class to store error information for AOT compilation. + + This class is used internally by the wrapper functions `wrap_for_export` and + `unwrap_from_import` to encapsulate error-related data within an exported + function. + + Attributes: + error_code (jax.Array): A JAX array representing the final error state of + the function to be exported. This value is local to the wrapper function. + error_list (list[tuple[str, str]]): A list of `(error_message, traceback)` + pairs containing error messages and corresponding stack traces. This error + list is local to the wrapper function, and does not contain pairs of error + information from other functions. + """ + + error_code: jax.Array + error_list: list[tuple[str, str]] + + +jax.tree_util.register_dataclass( + _ErrorClass, data_fields=("error_code",), meta_fields=("error_list",) +) +jax.export.register_pytree_node_serialization( + _ErrorClass, + serialized_name=f"{_ErrorClass.__module__}.{_ErrorClass.__name__}", + serialize_auxdata=lambda x: json.dumps(x, ensure_ascii=False).encode( + "utf-8" + ), + deserialize_auxdata=lambda x: json.loads(x.decode("utf-8")), +) + + +def _traceback_to_str(traceback: TracebackType) -> str: + """Convert a traceback to a string for export.""" + return "".join(tb_lib.format_list(tb_lib.extract_tb(traceback))).rstrip("\n") + + +def wrap_for_export(f): + """Wrap a function with error checking to make it compatible with AOT mode. + + Error checking relies on global state, which cannot be serialized across + processes. This wrapper ensures that the error state remains within the + function scope, making it possible to export the function and later import in + other processes. + + This function should only be applied once to a function; wrapping the same + function multiple times is unnecessary. + """ + + def inner(*args, **kwargs): + global _error_list + + # 1. Save the old state and initialize a new state. + with core.eval_context(): + old_ref = _error_storage.ref + _initialize_error_code_ref() + with _error_list_lock: + old_error_list, _error_list = _error_list, [] + + # 2. Trace the function. + out = f(*args, **kwargs) + error_code = _error_storage.ref[...].min() + + # 3. Restore the old state. + _error_list, new_error_list = old_error_list, _error_list + with core.eval_context(): + _error_storage.ref = old_ref + + new_error_list = [ + (msg, _traceback_to_str(traceback)) for msg, traceback in new_error_list + ] + return out, _ErrorClass(error_code, new_error_list) + + return inner + + +def unwrap_from_import(f): + """Unwrap a function after AOT import to restore error checking. + + When an AOT-exported function is imported in a new process, its error state is + separate from the global error state of the current process. This wrapper + ensures that errors detected during execution are correctly integrated into + the global error checking mechanism of the current process. + """ + if _error_storage.ref is None: + with core.eval_context(): + _initialize_error_code_ref() + assert _error_storage.ref is not None + + def inner(*args, **kwargs): + out, error_class = f(*args, **kwargs) + new_error_code, error_list = error_class.error_code, error_class.error_list + + # Update the global error list. + with _error_list_lock: + offset = len(_error_list) + _error_list.extend(error_list) + + # Update the global error code array. + error_code = _error_storage.ref[...] + should_update = jnp.logical_and( + error_code == jnp.uint32(_NO_ERROR), + new_error_code != jnp.uint32(_NO_ERROR), + ) + error_code = jnp.where(should_update, new_error_code + offset, error_code) + # TODO(ayx): support vmap and shard_map. + _error_storage.ref[...] = error_code + + return out + + return inner diff --git a/tests/error_check_test.py b/tests/error_check_test.py index 5bf71a9eb592..69e292fb6704 100644 --- a/tests/error_check_test.py +++ b/tests/error_check_test.py @@ -22,6 +22,7 @@ from jax._src import error_check from jax._src import mesh as mesh_lib from jax._src import test_util as jtu +import jax.export import jax.numpy as jnp from jax.sharding import NamedSharding, PartitionSpec as P @@ -33,7 +34,9 @@ jtu.request_cpu_devices(4) -@jtu.with_config(jax_check_tracer_leaks=True) +# TODO: AOT tests fails with the tracer leak checker. +# Reenable once https://github.com/jax-ml/jax/issues/27315 is fixed. +# @jtu.with_config(jax_check_tracer_leaks=True) class ErrorCheckTests(jtu.JaxTestCase): @parameterized.product(jit=[True, False]) @@ -280,6 +283,60 @@ def f(x): with self.assertRaisesRegex(JaxValueError, "x must be greater than 0"): error_check.raise_if_error() + def test_error_check_aot(self): + def run_export(): + def f(x): + error_check.set_error_if(x <= 0, "x must be greater than 0") + return x + 1 + + f = jax.jit(error_check.wrap_for_export(jax.jit(f))) + x = jax.ShapeDtypeStruct((), jnp.float32) + serialized = jax.export.export(f)(x).serialize() + return serialized + + def run_import(serialized): + f = jax.export.deserialize(serialized).call + f = jax.jit(error_check.unwrap_from_import(jax.jit(f))) + x = jnp.float32(-3.) + _ = f(x) + with self.assertRaisesRegex(JaxValueError, "x must be greater than 0"): + error_check.raise_if_error() + + serialized = run_export() + run_import(serialized) + + def test_error_check_aot_should_not_override_existing_error(self): + def f1(x): + error_check.set_error_if(x <= 0, "x must be greater than 0 in f1") + return x + 1 + + def run_export(): + def f2(x): + error_check.set_error_if(x <= 0, "x must be greater than 0 in f2") + return x + 1 + + f2 = jax.jit(error_check.wrap_for_export(jax.jit(f2))) + x = jax.ShapeDtypeStruct((), jnp.float32) + serialized = jax.export.export(f2)(x).serialize() + return serialized + + def run_import(serialized): + f2 = jax.export.deserialize(serialized).call + f2 = jax.jit(error_check.unwrap_from_import(jax.jit(f2))) + return f2 + + x = jnp.float32(-3.) + _ = f1(x) # check fails. so it should set error + + serialized = run_export() + f2 = run_import(serialized) + _ = f2(x) # check fails, but should not override the error + + with self.assertRaisesRegex( + JaxValueError, "x must be greater than 0 in f1" + ): + error_check.raise_if_error() + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 3163fbaac43c8d8187efbd58ee42b333560cf42f Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 21 Mar 2025 10:25:38 -0700 Subject: [PATCH 086/483] Add varying manual axes rules to `mul_p` and `convert_element_type_p`. There are 2 things that need to be added: 1. At the lax level, before we bind the primitive, we need to insert pbroadcasts if some inputs are varying. This is equivalent to the rewrite rules that shard_map has. 2. In abstract_eval rules of primitives, we need to check if all inputs are varying across the same mesh axes and then add the `varying_manual_axes` to the output ShapedArray. This in turn requires us to support `pbroadcast2` and `psum2` primitives in shard_map.py. These primitives don't need to insert any pbroadcasts (equivalent to `no_rewrite` in shard_map) but need to do checks and update the output aval in their abstract_eval rules. * pbroadcast_p: Union the existing aval.varying_manual_axes + axes (passed to pbroadcast) to calculate the output vma. For checks we need to make sure that the intersection of `aval.varying_manual_axes` and `axes` is empty. * psum2_p: Remove the named axes from aval.varying_manual_axes to calculate the output vma. For checks we need to make sure that the intersection of `aval.varying_manual_axes` and `axes` is NOT empty. Majority of the primitives should use the standard_insert_pbroadcast and standard_vma_rule and I'll add those in the follow up CLs to other primitives PiperOrigin-RevId: 739225392 --- jax/_src/core.py | 44 +++++++++++++++++------- jax/_src/lax/lax.py | 19 +++++++++-- jax/_src/lax/linalg.py | 2 +- jax/_src/lax/utils.py | 14 +++++--- jax/experimental/shard_map.py | 64 +++++++++++++++++++++++++++++++---- tests/shard_map_test.py | 22 ++++++++++++ 6 files changed, 138 insertions(+), 27 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index 243ffc871042..ef90341f5cf7 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1900,26 +1900,35 @@ def get_sharding(sharding, shape): _check_divisibility(out_s, shape) return out_s -def str_short_aval(shape, dtype, mesh, spec, short_dtypes=False, - mesh_axis_types=False) -> str: +def str_short_aval(shape, dtype, mesh, spec, vma, + short_dtypes=False, mesh_axis_types=False) -> str: dt_str = dtypes.short_dtype_name(dtype) if short_dtypes else dtype.name dt_str = dt_str.replace('void', 'float0') shapestr = _get_shape_sharding_str(shape, spec) mesh_axes = f'({mesh._axis_types_dict})' if mesh_axis_types else '' - return f'{dt_str}[{shapestr}]{mesh_axes}' + vma = f"{{{','.join(i for i in vma)}}}" if vma else '' + return f'{dt_str}[{shapestr}]{vma}{mesh_axes}' + +def get_vma(vma, mesh): + for i in vma: + if mesh._name_to_type[i] != AxisType.Manual: + raise ValueError( + "Axes mentioned in `vma` field of ShapedArray should" + f" be of type `Manual`. Got axis: {i} of type {mesh._name_to_type[i]}") + return vma class ShapedArray(UnshapedArray): - __slots__ = ['shape', 'sharding', 'varying_manual_axes'] # inherits slots from parent + __slots__ = ['shape', 'sharding', 'vma'] # inherits slots from parent array_abstraction_level = 2 def __init__(self, shape, dtype, weak_type=False, *, sharding=None, - varying_manual_axes: frozenset[AxisName] = frozenset()): + vma: frozenset[AxisName] = frozenset()): self.shape = canonicalize_shape(shape) self.dtype = _dtype_object(dtype) self.weak_type = weak_type self.sharding = get_sharding(sharding, self.shape) if config.varying_axes_in_types.value: - self.varying_manual_axes = varying_manual_axes + self.vma = get_vma(vma, self.sharding.mesh) def update(self, shape=None, dtype=None, weak_type=None, **kwargs): if shape is None: @@ -1930,8 +1939,8 @@ def update(self, shape=None, dtype=None, weak_type=None, **kwargs): weak_type = self.weak_type if 'sharding' not in kwargs: kwargs['sharding'] = self.sharding - if 'varying_manual_axes' not in kwargs: - kwargs['varying_manual_axes'] = getattr(self, 'varying_manual_axes', + if 'vma' not in kwargs: + kwargs['vma'] = getattr(self, 'vma', frozenset()) return ShapedArray(shape, dtype, weak_type, **kwargs) @@ -1950,25 +1959,26 @@ def __eq__(self, other): and self.dtype == other.dtype and self.shape == other.shape and self.weak_type == other.weak_type and self.sharding == other.sharding - and (getattr(self, 'varying_manual_axes', frozenset()) == - getattr(other, 'varying_manual_axes', frozenset()))) + and (getattr(self, 'vma', frozenset()) == + getattr(other, 'vma', frozenset()))) def __hash__(self): # can use hash(self.dtype) and rely on the fact that numpy reuses base dtype # objects, e.g. `np.zeros(3).dtype is np.zeros(4).dtype`, or we can use # the unique character code via hash(self.dtype.char) return hash((self.shape, self.dtype, self.weak_type, self.sharding, - getattr(self, 'varying_manual_axes', frozenset()))) + getattr(self, 'vma', frozenset()))) def to_tangent_aval(self): return ShapedArray( self.shape, primal_dtype_to_tangent_dtype(self.dtype), self.weak_type, sharding=self.sharding, - varying_manual_axes=getattr(self, 'varying_manual_axes', frozenset())) + vma=getattr(self, 'vma', frozenset())) def str_short(self, short_dtypes=False, mesh_axis_types=False): return str_short_aval( self.shape, self.dtype, self.sharding.mesh, self.sharding.spec, + getattr(self, 'varying_manual_axes', frozenset()), short_dtypes, mesh_axis_types) def _len(self, ignored_tracer): @@ -2000,6 +2010,16 @@ def primal_dtype_to_tangent_dtype(primal_dtype): return primal_dtype +def standard_insert_pbroadcast(*args): + if not config.varying_axes_in_types.value: + return args + # TODO(yashkatariya): Move pbroadcast out of shard_map + from jax.experimental.shard_map import pbroadcast # type: ignore + in_vma = [get_aval(a).vma for a in args] + out_vma = frozenset.union(*in_vma) + return [pbroadcast(arg, tuple(n for n in out_vma if n not in src)) + if out_vma - src else arg for arg, src in zip(args, in_vma)] + # Dynamic shape stuff below here! We keep the abstract values distinct just so # as not to interfere with any static shape machinery. diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index f6ab848ccd5b..a6a0924c9c0f 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1140,6 +1140,7 @@ def mul(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.multiply: https://openxla.org/stablehlo/spec#multiply """ + x, y = core.standard_insert_pbroadcast(x, y) return mul_p.bind(x, y) @export @@ -1610,6 +1611,7 @@ def _convert_element_type( (sharding._is_concrete and getattr(operand, 'sharding', None) == sharding))): return operand else: + operand, = core.standard_insert_pbroadcast(operand) return convert_element_type_p.bind( operand, new_dtype=new_dtype, weak_type=bool(weak_type), sharding=sharding) @@ -3844,6 +3846,15 @@ def broadcasting_sharding_rule(name, *avals): f'{", ".join(map(str, map(tuple, specs)))}.') return NamedSharding(mesh, P(*result_specs)) +def standard_vma_rule(prim_name, *avals, **kwargs): + vma, *vmas = [a.vma for a in avals] + if not all(vma == vma_ for vma_ in vmas): + raise ValueError( + f'Primitive {prim_name} requires varying manual axes ' + f'to match, but got {[vma, *vmas]}. Please open an issue at ' + 'https://github.com/jax-ml/jax/issues and as a temporary ' + 'workaround pass the check_rep=False argument to shard_map') + return vma def naryop(result_dtype, accepted_dtypes, name, allow_extended_dtype=False, require_same_dtypes=True): @@ -3852,8 +3863,9 @@ def naryop(result_dtype, accepted_dtypes, name, allow_extended_dtype=False, require_same=require_same_dtypes) shape_rule = partial(broadcasting_shape_rule, name) sharding_rule = partial(broadcasting_sharding_rule, name) - prim = standard_primitive(shape_rule, dtype_rule, name, - sharding_rule=sharding_rule) + prim = standard_primitive( + shape_rule, dtype_rule, name, sharding_rule=sharding_rule, + vma_rule=partial(standard_vma_rule, name)) batching.defbroadcasting(prim) pe.def_trivial_padding(prim) return prim @@ -4704,7 +4716,8 @@ def _convert_element_type_bind_with_trace(trace, args, params): partial(standard_abstract_eval, convert_element_type_p, _convert_element_type_shape_rule, _convert_element_type_dtype_rule, _convert_element_type_weak_type_rule, - _convert_element_type_sharding_rule)) + _convert_element_type_sharding_rule, + partial(standard_vma_rule, convert_element_type_p.name))) ad.defjvp2(convert_element_type_p, _convert_element_type_jvp_rule) ad.primitive_transposes[convert_element_type_p] = _convert_element_type_transpose_rule batching.defvectorized(convert_element_type_p) diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 027ec8b801b9..b22a4cf56062 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -765,7 +765,7 @@ def linalg_primitive(result_dtype, accepted_dtypes, ranks, result_shape, name, else: prim.def_abstract_eval( partial(lax_utils.standard_abstract_eval, prim, shape_rule, dtype_rule, - lax_utils._standard_weak_type_rule, sharding_rule)) + lax_utils._standard_weak_type_rule, sharding_rule, None)) if supports_batching: batching.primitive_batchers[prim] = partial( batching.expand_dims_batcher, prim) diff --git a/jax/_src/lax/utils.py b/jax/_src/lax/utils.py index 206a8312ba8c..63088d665afd 100644 --- a/jax/_src/lax/utils.py +++ b/jax/_src/lax/utils.py @@ -19,6 +19,7 @@ from functools import partial from jax._src import core +from jax._src import config from jax._src import dispatch from jax._src import dtypes from jax._src import mesh as mesh_lib @@ -37,13 +38,13 @@ def _argnum_weak_type(*argnums): return lambda *args, **_: all(args[i].weak_type for i in argnums) def standard_primitive(shape_rule, dtype_rule, name, - weak_type_rule=None, sharding_rule=None): + weak_type_rule=None, sharding_rule=None, vma_rule=None): weak_type_rule = weak_type_rule or _standard_weak_type_rule prim = core.Primitive(name) prim.def_impl(partial(dispatch.apply_primitive, prim)) prim.def_abstract_eval( partial(standard_abstract_eval, prim, shape_rule, dtype_rule, - weak_type_rule, sharding_rule)) + weak_type_rule, sharding_rule, vma_rule)) return prim def _get_array_abstraction_level(a): return a.array_abstraction_level @@ -95,14 +96,14 @@ def call_shape_dtype_sharding_rule(prim, shape_rule, dtype_rule, sharding_rule, avals_str = ', '.join(i.str_short(short_dtypes=True) for i in avals) mesh = mesh_lib.empty_abstract_mesh if e.mesh is None else e.mesh out_aval_str = core.str_short_aval(out_shapes, out_dtypes, mesh, e.pspec, - short_dtypes=True) + frozenset(), short_dtypes=True) raise core.ShardingTypeError( f'{prim} operation with inputs: {avals_str} produces an illegally' f' sharded result: {out_aval_str}') from e return out_shapes, out_dtypes, out_shardings def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, - sharding_rule, *avals, **kwargs): + sharding_rule, vma_rule, *avals, **kwargs): assert all(isinstance(aval, core.UnshapedArray) for aval in avals), avals assert not prim.multiple_results weak_type = weak_type_rule(*avals, **kwargs) @@ -112,8 +113,11 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, out_shape, out_dtype, out_sharding = call_shape_dtype_sharding_rule( prim, shape_rule, dtype_rule, sharding_rule, False, *avals, **kwargs) + out_vma = (vma_rule(*avals, **kwargs) if config.varying_axes_in_types.value + else frozenset()) out_aval = core.ShapedArray( - out_shape, out_dtype, weak_type=weak_type, sharding=out_sharding) + out_shape, out_dtype, weak_type=weak_type, sharding=out_sharding, + vma=out_vma) core.check_avals_context_mesh([out_aval], prim.name) return out_aval elif least_specialized is core.DShapedArray: diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 66b70c6c2d34..c0306f0c5e91 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -189,7 +189,8 @@ def out_names_thunk(): raise e('shard_map out_specs') from None return tuple(map(_canonicalize_spec, out_specs_flat)) - if rewrite := check_rep: + rewrite = check_rep + if not config.varying_axes_in_types.value and rewrite: fun = _efficient_transpose_rewrite(fun, mesh, in_names_flat, out_names_thunk) try: @@ -577,7 +578,8 @@ def _shard_shaped_array(mesh: Mesh, auto: frozenset, names: AxisNames, for i, sz in enumerate(aval.shape)) manual_mesh = _as_manual_mesh(mesh, auto) new_sharding = NamedSharding(manual_mesh, aval.sharding.spec) - return aval.update(shape=new_shape, sharding=new_sharding) + vma = frozenset({n for ns in names.values() for n in ns}) + return aval.update(shape=new_shape, sharding=new_sharding, vma=vma) core.shard_aval_handlers[core.ShapedArray] = _shard_shaped_array def _unshard_shaped_array(mesh: Mesh, names: AxisNames, @@ -606,7 +608,7 @@ def _unshard_shaped_array(mesh: Mesh, names: AxisNames, new_mesh = (mesh.abstract_mesh if get_abstract_mesh().empty else get_abstract_mesh()) new_sharding = NamedSharding(new_mesh, out_spec) - return aval.update(shape=new_shape, sharding=new_sharding) + return aval.update(shape=new_shape, sharding=new_sharding, vma=frozenset()) core.unshard_aval_handlers[core.ShapedArray] = _unshard_shaped_array # Type-checking @@ -1069,7 +1071,41 @@ def _device_put_eager_rule(mesh, *xs, srcs, devices, copy_semantics): psum2_p = core.Primitive('psum2') psum2_p.multiple_results = True psum2_p.def_impl(lax_parallel.psum_p.impl) -psum2_p.def_effectful_abstract_eval(lax_parallel.psum_p.abstract_eval) + +def _psum2_abstract_eval(*args, axes, axis_index_groups): + if not config.varying_axes_in_types.value: + return lax_parallel.psum_p.abstract_eval( + *args, axes=axes, axis_index_groups=axis_index_groups) + + assert isinstance(axes, tuple) + lax_parallel._check_axis_names(axes) + arg_vma = [a.vma for a in args] + if any(not set(axes) & a for a in arg_vma): + raise ValueError( + "Collective psum must be applied to a device-varying " + f"type, but got {arg_vma} for collective acting " + f"over axis name {axes}. Please open an issue at " + "https://github.com/jax-ml/jax/issues, and as a temporary " + "workaround pass the check_rep=False argument to shard_map") + + named_axes = tuple(axis for axis in axes if not isinstance(axis, int)) + pos_axes = tuple(axis for axis in axes if isinstance(axis, int)) + if axis_index_groups is not None: + if len(pos_axes) != 0: + raise ValueError( + "axis_index_groups can only be used with reductions over " + f"named axes, but got: {axes}") + core.check_avals_context_mesh(args, 'all_reduce') + out_avals = [ + core.ShapedArray( + lax._reduce_op_shape_rule(arg, axes=pos_axes), arg.dtype, + sharding=lax._reduce_op_sharding_rule(arg, axes=pos_axes), + vma=frozenset(a for a in arg.vma if a not in named_axes)) + for arg in args + ] + return out_avals, {core.NamedAxisEffect(axis) for axis in named_axes} +psum2_p.def_effectful_abstract_eval(_psum2_abstract_eval) + mlir.register_lowering(psum2_p, mlir._lowerings[lax_parallel.psum_p]) batching.fancy_primitive_batchers[psum2_p] = \ partial(lax_parallel._batched_reduction_collective, psum2_p, @@ -1088,10 +1124,26 @@ def pbroadcast(x, axis_name): xs, treedef = tree_flatten(x) ys = pbroadcast_p.bind(*xs, axes=axes, axis_index_groups=None) return tree_unflatten(treedef, ys) + pbroadcast_p = core.Primitive('pbroadcast') pbroadcast_p.multiple_results = True pbroadcast_p.def_impl(lambda *args, axes, axis_index_groups: args) -pbroadcast_p.def_abstract_eval(lambda *args, axes, axis_index_groups: args) + +def _pbroadcast_abstract_eval(*args, axes, axis_index_groups): + if not config.varying_axes_in_types.value: + return args + assert isinstance(axes, tuple) + arg_vma = [a.vma for a in args] + if any(set(axes) & a for a in arg_vma): + raise ValueError( + "Collective pbroadcast must be applied to a " + f"non-device-varying type, but got {arg_vma} for collective acting " + f"over axis name {axes}. Please open an issue at " + "https://github.com/jax-ml/jax/issues, and as a temporary " + "workaround pass the check_rep=False argument to shard_map") + return [a.update(vma=a.vma.union(frozenset(axes))) for a in args] +pbroadcast_p.def_abstract_eval(_pbroadcast_abstract_eval) + mlir.register_lowering(pbroadcast_p, lambda ctx, *x, axes, axis_index_groups: x) def _pbroadcast_batcher(vals_in, dims_in, *, axes, axis_index_groups): if any(type(axis) is int for axis in axes): raise NotImplementedError @@ -1140,7 +1192,7 @@ def _standard_check(prim, mesh, *in_rep, **__): # The standard check require args' and outputs' replications to be the same, # except for Nones which correspond to constants. in_rep_ = [r for r in in_rep if r is not None] - if in_rep_ and not in_rep_[:-1] == in_rep_[1:]: + if in_rep_ and in_rep_[:-1] != in_rep_[1:]: raise Exception(f"Primitive {prim} requires argument replication types " f"to match, but got {in_rep}. Please open an issue at " "https://github.com/jax-ml/jax/issues and as a temporary " diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index f8d5a11e842f..ce01b6e6e944 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -2685,6 +2685,28 @@ def test_pmax(self): )(x) # don't crash self.assertArraysEqual(y, np.array([6, 7], dtype=np.float32)) + @config.varying_axes_in_types(True) + def test_mul_with_vma_in_types(self): + mesh = jtu.create_mesh((2,), ('x',)) + x = np.arange(8.) + + def f(x): + self.assertEqual(x.aval.vma, frozenset({'x'})) + out = x * 2 + self.assertEqual(out.aval.vma, frozenset({'x'})) + return out + + f = jax.jit(shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))) + jaxpr = f.trace(x).jaxpr + self.assertIn("pbroadcast[axes=('x',)", str(jaxpr)) + out = f(x) + self.assertArraysEqual(out, x * 2) + + # TODO(yashkatariya): Enable grad test which requires adding psum_p support. + # def g(x, y): + # return jnp.sum(f(x, y)) + # print(jax.jit(jax.grad(g)).trace(x, y).jaxpr) + class FunSpec(NamedTuple): name: str From 37b5066d5bc2f0d3915050b7522f189bf61e125d Mon Sep 17 00:00:00 2001 From: Jacob Burnim Date: Fri, 21 Mar 2025 10:45:57 -0700 Subject: [PATCH 087/483] [Pallas] Fixes scalar prefetch in TPU interpret mode. --- jax/_src/pallas/mosaic/interpret.py | 42 ++++++++++++++--------- tests/pallas/tpu_pallas_interpret_test.py | 33 ++++++++++++++++++ 2 files changed, 59 insertions(+), 16 deletions(-) diff --git a/jax/_src/pallas/mosaic/interpret.py b/jax/_src/pallas/mosaic/interpret.py index 3384026c1f5b..439ac98b2ac6 100644 --- a/jax/_src/pallas/mosaic/interpret.py +++ b/jax/_src/pallas/mosaic/interpret.py @@ -1282,17 +1282,20 @@ def f(*args, jaxpr): return jax.util.safe_map(read, jaxpr.outvars) -def _compute_start_indices(block_mapping, loop_idx, *args): - block_indices = ( - jax_core.jaxpr_as_fun(block_mapping.index_map_jaxpr)(*loop_idx, *args)) - if isinstance(block_mapping.indexing_mode, pallas_core.Blocked): - ret = tuple(i if b is pallas_core.mapped else b * i - for b, i in zip(block_mapping.block_shape, block_indices)) - elif isinstance(block_mapping.indexing_mode, pallas_core.Unblocked): - ret = block_indices - else: - raise RuntimeError(f"Unknown indexing mode: {block_mapping.indexing_mode}") - return ret +def _compute_start_indices( + block_mapping, loop_idx, *args, compiler_params, interpret_params): + jaxpr = block_mapping.index_map_jaxpr + block_indices = _interpret_jaxpr( + jaxpr.jaxpr, *jaxpr.consts, *loop_idx, *args, + compiler_params=compiler_params, interpret_params=interpret_params) + if isinstance(block_mapping.indexing_mode, pallas_core.Blocked): + ret = tuple(i if b is pallas_core.mapped else b * i + for b, i in zip(block_mapping.block_shape, block_indices)) + elif isinstance(block_mapping.indexing_mode, pallas_core.Unblocked): + ret = block_indices + else: + raise RuntimeError(f"Unknown indexing mode: {block_mapping.indexing_mode}") + return ret def _get_next_indices(grid, indices): next_indices = [] @@ -1412,6 +1415,7 @@ def interpret_pallas_call( input_buffer_ids = [] for i, var in enumerate( jaxpr.invars[grid_mapping.num_index_operands:][:grid_mapping.num_inputs]): + assert var.aval.dtype == input_args[i].dtype input_buffer_ids.append(callback.io_callback( _allocate_buffer, jax.ShapeDtypeStruct((), jnp.int16), @@ -1451,15 +1455,18 @@ def interpret_pallas_call( # Allocate buffers for non-HBM kernel arguments (e.g., scalars, inputs, # outputs, scratch). - kernel_buffer_ids = [] - for _, val in zip(jaxpr.invars[grid_mapping.slice_index_ops], scalars): - kernel_buffer_ids.append(callback.io_callback( + scalar_buffer_ids = [] + for var, val in zip(jaxpr.invars[grid_mapping.slice_index_ops], scalars): + assert var.aval.shape == val.shape + assert var.aval.dtype == val.dtype + scalar_buffer_ids.append(callback.io_callback( _allocate_buffer, jax.ShapeDtypeStruct((), jnp.int16), device_id, TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.SMEM], val, ordered=True)) + kernel_buffer_ids = scalar_buffer_ids.copy() for i, var in enumerate(jaxpr.invars[grid_mapping.num_index_operands:]): output_idx = i - grid_mapping.num_inputs is_input = i < grid_mapping.num_inputs @@ -1520,11 +1527,14 @@ def body(carry): ) with pallas_core.grid_env(local_grid_env): + start_indices = [ + _compute_start_indices( + bm, loop_idx, *scalar_buffer_ids, compiler_params=compiler_params, + interpret_params=interpret_params) + for bm in grid_mapping.block_mappings] # Copy slices of the input to the kernel buffers. # # TODO(jburnim): Only copy slices when the index mapping has changed? - start_indices = [_compute_start_indices(bm, loop_idx, *scalars) - for bm in grid_mapping.block_mappings] for j, var in enumerate(input_vars): if _is_any(var.aval.memory_space): continue diff --git a/tests/pallas/tpu_pallas_interpret_test.py b/tests/pallas/tpu_pallas_interpret_test.py index 9b8a5b46865d..5b729f0fe07e 100644 --- a/tests/pallas/tpu_pallas_interpret_test.py +++ b/tests/pallas/tpu_pallas_interpret_test.py @@ -20,6 +20,7 @@ from absl.testing import absltest from absl.testing import parameterized +import functools import jax from jax._src import test_util as jtu @@ -68,6 +69,38 @@ def matmul(x: jax.Array, y: jax.Array): z = matmul(x, y) np.testing.assert_allclose(z, x @ y, atol=1e-4) + def test_scalar_prefetch_example(self): + def dynamic_slice_kernel(indices, x_ref, o_ref): + del indices + o_ref[...] = x_ref[...] + + @functools.partial(jax.jit, static_argnums=(2,)) + def block_dynamic_slice(x, starts, sizes): + grid_spec = pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=1, + grid=(1, 1), + in_specs=[pl.BlockSpec( + sizes, + lambda i, j, block_idx: (block_idx[0], block_idx[1]))], + out_specs=pl.BlockSpec(sizes, lambda *_: (0, 0)), + ) + + kernel = pl.pallas_call( + dynamic_slice_kernel, + grid_spec=grid_spec, + out_shape=jax.ShapeDtypeStruct(shape=sizes, dtype=x.dtype), + interpret=mosaic_interpret.TPUInterpretParams(), + ) + block_idx = jnp.array([starts[0] // sizes[0], starts[1] // sizes[1]]) + return kernel(block_idx, x) + + shape = (512, 512) + x = jnp.reshape(jnp.arange(np.prod(shape), dtype=jnp.int32), shape) + result = block_dynamic_slice(x, starts=jnp.array([128, 256]), sizes=(128, 128)) + ref = jax.lax.dynamic_slice(x, start_indices=(128, 256), slice_sizes=(128, 128)) + diff = jnp.max(jnp.abs(result - ref)) + np.testing.assert_allclose(result, ref) + def test_dynamic_grid_and_aliasing(self): self.skipTest('Broken pending fix to extra reads/writes of inputs/outputs') def kernel(s_ref, x_ref, o_ref): From 7dd78d97fad15f47295a25896833abafb92601e0 Mon Sep 17 00:00:00 2001 From: Ayaka Date: Fri, 21 Mar 2025 10:52:34 -0700 Subject: [PATCH 088/483] Add support for configurable error checking categories PiperOrigin-RevId: 739234594 --- jax/_src/config.py | 40 +++++++++- jax/_src/error_check.py | 149 ++++++++++++++++++++++++++++++++------ tests/error_check_test.py | 86 ++++++++++++++++++++++ 3 files changed, 250 insertions(+), 25 deletions(-) diff --git a/jax/_src/config.py b/jax/_src/config.py index cf6a07834a10..5b8b87be2095 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -245,7 +245,10 @@ def trace_context(): pgle_profiling_runs.value, enable_pgle.value, use_shardy_partitioner.value, - use_high_dynamic_range_gumbel.value) + use_high_dynamic_range_gumbel.value, + error_checking_behavior_nan.value, + error_checking_behavior_divide.value, + error_checking_behavior_oob.value) config = Config() @@ -1317,6 +1320,41 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: ), ) +# TODO(ayx): Move these 3 flags out of config once we have a user-level +# extension mechanism for adding contexts to which the jit cache is sensitive. +error_checking_behavior_nan = enum_state( + name='jax_error_checking_behavior_nan', + enum_values=['ignore', 'raise'], + default='ignore', + help=( + 'Specify the behavior when a NaN is encountered. Options are "ignore"' + ' or "raise".' + ), + include_in_jit_key=True, +) + +error_checking_behavior_divide = enum_state( + name='jax_error_checking_behavior_divide', + enum_values=['ignore', 'raise'], + default='ignore', + help=( + 'Specify the behavior when a divide by zero is encountered. Options are' + ' "ignore" or "raise".' + ), + include_in_jit_key=True, +) + +error_checking_behavior_oob = enum_state( + name='jax_error_checking_behavior_oob', + enum_values=['ignore', 'raise'], + default='ignore', + help=( + 'Specify the behavior when an out of bounds access is encountered.' + ' Options are "ignore" or "raise".' + ), + include_in_jit_key=True, +) + def _update_x64_global(val): jax_jit.global_state().enable_x64 = val diff --git a/jax/_src/error_check.py b/jax/_src/error_check.py index e78b9bc82115..9d493c1f351b 100644 --- a/jax/_src/error_check.py +++ b/jax/_src/error_check.py @@ -14,15 +14,18 @@ from __future__ import annotations +import contextlib import dataclasses from functools import partial import json import threading import traceback as tb_lib from types import TracebackType +from typing import Literal import warnings import jax +from jax._src import config from jax._src import core from jax._src import source_info_util from jax._src import traceback_util @@ -115,39 +118,56 @@ def __exit__(self, exc_type, exc_value, traceback): _error_storage.ref = self.old_ref -def set_error_if(pred: jax.Array, /, msg: str) -> None: +# TODO(ayx): Move all category-related logic into the jax.numpy integration once +# it is ready. This logic is specific to how jax.numpy decides when to call +# `set_error_if`, and doesn't belong in the core error-checking library itself. +# The responsibility for deciding whether to predicate an error should lie with +# the user or the higher-level library (like jax.numpy), not with +# `set_error_if`. +Category = Literal["nan", "divide", "oob"] + + +def _is_category_disabled( + category: Category | None, +) -> bool: + """Check if the error checking behavior for the given category is disabled.""" + if category is None: + return False + if category == "nan": + return config.error_checking_behavior_nan.value == "ignore" + if category == "divide": + return config.error_checking_behavior_divide.value == "ignore" + if category == "oob": + return config.error_checking_behavior_oob.value == "ignore" + raise ValueError(f"Invalid category: {category}") + + +def _set_error_if_with_category( + pred: jax.Array, + /, + msg: str, + category: Category | None = None, +) -> None: """Set the internal error state if any element of `pred` is `True`. - This function is used inside JAX computations to detect runtime errors without - immediately halting execution. When this function is traced (e.g., inside - :func:`jax.jit`), the corresponding error message and its traceback are - recorded. At execution time, if `pred` contains any `True` values, the error - state is set, but execution continues without interruption. The recorded error - can later be raised using :func:`raise_if_error`. - - If the error state has already been set, subsequent errors are ignored and - will not override the existing error. - - For multi-device environments, in explicit mode, users must call - :func:`error_checking_context` to initialize a new error tracking state that - matches the device mesh. In auto mode, implicit cross-device communication may - occur inside this function, which could impact performance. A warning is - issued in such cases. + This function is similar to :func:`set_error_if`, but it also takes a category + argument. The category can be "nan", "divide", or "oob". The error checking + behavior for each category can be configured using + :func:`set_error_checking_behavior`. If not provided, there will be no + category. - When exporting a function with `jax.export`, error checking must be explicitly - wrapped using :func:`wrap_for_export` before export and - :func:`unwrap_from_import` after import. - - Args: - pred: A JAX boolean array. If any element of `pred` is `True`, the internal - error state will be set. - msg: The corresponding error message to be raised later. + This function is intended for use in JAX internal APIs (e.g., `jax.numpy`) + to perform category-specific runtime checks tied to the operation being + performed. """ if _error_storage.ref is None: with core.eval_context(): _initialize_error_code_ref() assert _error_storage.ref is not None + if _is_category_disabled(category): + return + # Get the traceback. traceback = source_info_util.current().traceback assert traceback is not None @@ -199,6 +219,37 @@ def set_error_if(pred: jax.Array, /, msg: str) -> None: _error_storage.ref[...] = error_code +def set_error_if(pred: jax.Array, /, msg: str) -> None: + """Set the internal error state if any element of `pred` is `True`. + + This function is used inside JAX computations to detect runtime errors without + immediately halting execution. When this function is traced (e.g., inside + :func:`jax.jit`), the corresponding error message and its traceback are + recorded. At execution time, if `pred` contains any `True` values, the error + state is set, but execution continues without interruption. The recorded error + can later be raised using :func:`raise_if_error`. + + If the error state has already been set, subsequent errors are ignored and + will not override the existing error. + + For multi-device environments, in explicit mode, users must call + :func:`error_checking_context` to initialize a new error tracking state that + matches the device mesh. In auto mode, implicit cross-device communication may + occur inside this function, which could impact performance. A warning is + issued in such cases. + + When exporting a function with `jax.export`, error checking must be explicitly + wrapped using :func:`wrap_for_export` before export and + :func:`unwrap_from_import` after import. + + Args: + pred: A JAX boolean array. If any element of `pred` is `True`, the internal + error state will be set. + msg: The corresponding error message to be raised later. + """ + _set_error_if_with_category(pred, msg) + + def raise_if_error() -> None: """Raise an exception if the internal error state is set. @@ -355,3 +406,53 @@ def inner(*args, **kwargs): return out return inner + + +Behavior = Literal["ignore", "raise"] + + +class error_checking_behavior: + """A context manager to set the error checking behavior. + + If both `all` and a category are provided, the category will override the + `all` setting. + + When the error checking behavior is set to "ignore", all errors will be + ignored. When set to "raise", errors will be detected and recorded, but an + exception will not be raised immediately. Users must call + :func:`raise_if_error` to at the end of the computation to raise the + exception. + """ + + def __init__( + self, + *, + all: Behavior | None = None, + nan: Behavior | None = None, + divide: Behavior | None = None, + oob: Behavior | None = None, + ) -> None: + new_settings = {} + if all is not None: + new_settings["nan"] = new_settings["divide"] = new_settings["oob"] = all + if nan is not None: + new_settings["nan"] = nan + if divide is not None: + new_settings["divide"] = divide + if oob is not None: + new_settings["oob"] = oob + self.new_settings = new_settings + self.stack = contextlib.ExitStack() + + def __enter__(self): + config_flags = { + "nan": config.error_checking_behavior_nan, + "divide": config.error_checking_behavior_divide, + "oob": config.error_checking_behavior_oob, + } + for key, value in self.new_settings.items(): + self.stack.enter_context(config_flags[key](value)) + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.stack.close() diff --git a/tests/error_check_test.py b/tests/error_check_test.py index 69e292fb6704..7f75eeb629a0 100644 --- a/tests/error_check_test.py +++ b/tests/error_check_test.py @@ -337,6 +337,92 @@ def run_import(serialized): ): error_check.raise_if_error() + @parameterized.product(jit=[True, False]) + def test_error_category_nan_check(self, jit): + def f(x): + error_check._set_error_if_with_category( + jnp.isnan(x), "x is NaN", category="nan" + ) + return x + + if jit: + f = jax.jit(f) + + x = jnp.full((4,), jnp.nan, dtype=jnp.float32) + + with error_check.error_checking_behavior(nan="ignore"): + _ = f(x) + error_check.raise_if_error() # should not raise error + + with error_check.error_checking_behavior(nan="raise"): + _ = f(x) + with self.assertRaisesRegex(JaxValueError, "x is NaN"): + error_check.raise_if_error() + + @parameterized.product(jit=[True, False]) + def test_error_category_divide_check(self, jit): + def f(x, y): + error_check._set_error_if_with_category( + y == 0.0, "division by zero", category="divide" + ) + return x / y + + if jit: + f = jax.jit(f) + + x = jnp.arange(4, dtype=jnp.float32) + 1 + y = jnp.arange(4, dtype=jnp.float32) + + with error_check.error_checking_behavior(divide="ignore"): + _ = f(x, y) + error_check.raise_if_error() # should not raise error + + with error_check.error_checking_behavior(divide="raise"): + _ = f(x, y) + with self.assertRaisesRegex(JaxValueError, "division by zero"): + error_check.raise_if_error() + + @parameterized.product(jit=[True, False]) + def test_error_category_oob_check(self, jit): + def f(x, start_indices, slice_sizes): + error_check._set_error_if_with_category( + jnp.logical_or( + start_indices < 0, + start_indices + jnp.array(slice_sizes, dtype=jnp.int32) + >= jnp.array(x.shape, dtype=jnp.int32), + ), + "Out of bounds in dynamic_slice", + category="oob", + ) + y = jax.lax.dynamic_slice( + x, start_indices, slice_sizes, allow_negative_indices=False + ) + return y + + if jit: + f = jax.jit(f, static_argnums=(2,)) + + x = jnp.arange(12).reshape(3, 4) + start_indices = jnp.array([0, -1], dtype=jnp.int32) + slice_sizes = (3, 4) + + with error_check.error_checking_behavior(oob="ignore"): + _ = f(x, start_indices, slice_sizes) + error_check.raise_if_error() # should not raise error + + with error_check.error_checking_behavior(oob="raise"): + _ = f(x, start_indices, slice_sizes) + with self.assertRaisesRegex( + JaxValueError, "Out of bounds in dynamic_slice", + ): + error_check.raise_if_error() + + def test_error_category_invalid_category(self): + with self.assertRaisesRegex(ValueError, "Invalid category"): + error_check._set_error_if_with_category( + jnp.isnan(jnp.float32(1.0)), "x is NaN", category="invalid" + ) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 4fdce200300df181b3088dd0d114e5c759dbf63d Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 21 Mar 2025 11:16:46 -0700 Subject: [PATCH 089/483] Add logit soft-capping support to the ragged paged attention Pallas kernel. PiperOrigin-RevId: 739242412 --- .../pallas/ops/tpu/ragged_paged_attention.py | 12 +++ .../pallas/tpu_ragged_paged_attention_test.py | 94 ++++++++++++++++++- 2 files changed, 105 insertions(+), 1 deletion(-) diff --git a/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py b/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py index 90b808282c22..60ac2e34f610 100644 --- a/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py +++ b/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py @@ -81,6 +81,7 @@ def ref_ragged_paged_attention( *, sm_scale: float = 1.0, sliding_window: int | None = None, + soft_cap: float | None = None, mask_value: float = DEFAULT_MASK_VALUE, ): _, _, num_kv_heads, head_dim = k_pages.shape @@ -108,6 +109,8 @@ def ref_ragged_paged_attention( mask = q_span < kv_span if sliding_window is not None: mask = jnp.logical_or(mask, q_span - sliding_window >= kv_span) + if soft_cap is not None: + attn = soft_cap * jnp.tanh(attn / soft_cap) attn += jnp.where(mask, mask_value, 0.0) attn = jax.nn.softmax(attn, axis=-1).astype(v.dtype) out = jnp.einsum("hqk,khd->qhd", attn, v).astype(queries.dtype) @@ -126,6 +129,7 @@ def validate_inputs_on_runtime( cu_q_lens: jax.Array, # i32[max_num_seqs + 1] num_seqs, # i32[1] sliding_window: int | None = None, + soft_cap: float | None = None, ): check_inputs_shapes( q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens, num_seqs @@ -156,6 +160,8 @@ def validate_inputs_on_runtime( ) if sliding_window is not None and sliding_window <= 0: raise ValueError(f"{sliding_window=} must be positive.") + if soft_cap is not None and soft_cap == 0.0: + raise ValueError(f"{soft_cap=} must not be 0.0.") # Expect to run these checks during compile time. @@ -228,6 +234,7 @@ def ragged_paged_attention_kernel( *, sm_scale: float, sliding_window: int | None = None, + soft_cap: float | None = None, mask_value: float = DEFAULT_MASK_VALUE, ): num_q_per_blk, num_q_heads_per_blk, head_dim = q_ref.shape @@ -432,6 +439,8 @@ def init_scratch_ref(): if sliding_window is not None: causal_mask = jnp.logical_or(causal_mask, row_ids - sliding_window >= col_ids) + if soft_cap is not None: + qk = soft_cap * jnp.tanh(qk / soft_cap) qk += jnp.where(causal_mask, mask_value, 0.0) m_curr = jnp.max(qk, axis=1, keepdims=True) s_curr = jnp.exp(qk - m_curr) @@ -612,6 +621,7 @@ def can_be_xla_fully_tiled(x, packing): "num_queries_per_block", "vmem_limit_bytes", "sliding_window", + "soft_cap", ], ) def ragged_paged_attention( @@ -626,6 +636,7 @@ def ragged_paged_attention( *, sm_scale: float = 1.0, sliding_window: int | None = None, + soft_cap: float | None = None, mask_value: float = DEFAULT_MASK_VALUE, num_kv_pages_per_block: int = 16, num_queries_per_block: int = 128, @@ -719,6 +730,7 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_): ragged_paged_attention_kernel, sm_scale=sm_scale, sliding_window=sliding_window, + soft_cap=soft_cap, mask_value=mask_value, ), grid_spec=pltpu.PrefetchScalarGridSpec( diff --git a/tests/pallas/tpu_ragged_paged_attention_test.py b/tests/pallas/tpu_ragged_paged_attention_test.py index ba574a4ce98c..815c9dc6327f 100644 --- a/tests/pallas/tpu_ragged_paged_attention_test.py +++ b/tests/pallas/tpu_ragged_paged_attention_test.py @@ -52,6 +52,7 @@ def _test_ragged_paged_attention( max_num_batched_tokens=512, max_num_seq=8, sliding_window: int | None = None, + soft_cap: float | None = None, ): if not jtu.is_device_tpu_at_least(version=4): self.skipTest("Expect TPUv4+") @@ -104,6 +105,7 @@ def _test_ragged_paged_attention( cu_q_lens, num_seqs, sliding_window=sliding_window, + soft_cap=soft_cap, ) actual_num_q_tokens = cu_q_lens[num_seqs[0]] @@ -119,6 +121,7 @@ def _test_ragged_paged_attention( num_queries_per_block=num_queries_per_block, vmem_limit_bytes=vmem_limit_bytes, sliding_window=sliding_window, + soft_cap=soft_cap, )[: actual_num_q_tokens] expected = ref_ragged_paged_attention( @@ -130,6 +133,7 @@ def _test_ragged_paged_attention( cu_q_lens, num_seqs=num_seqs, sliding_window=sliding_window, + soft_cap=soft_cap, ) tols = { "float32": 0.15, @@ -272,7 +276,6 @@ def test_ragged_paged_attention_mixed(self, dtype): dtype=[jnp.float32, jnp.bfloat16], num_kv_pages_per_block=[4, 8], num_queries_per_block=[32, 64], - sliding_window=[None, 5, 128], ) def test_ragged_paged_attention_complex( self, @@ -281,8 +284,42 @@ def test_ragged_paged_attention_complex( dtype, num_kv_pages_per_block, num_queries_per_block, + ): + seq_lens = [] + for _ in range(num_seqs): + q_len = random.randint(1, 100) + kv_len = q_len + random.randint(0, 50) + seq_lens.append((q_len, kv_len)) + # TODO(jevinjiang): Support non-128 head_dim! + head_dim = 128 + page_size = 16 + num_pages = 1000 + + self._test_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + num_kv_pages_per_block=num_kv_pages_per_block, + num_queries_per_block=num_queries_per_block, + ) + + @parameterized.product( + num_kv_pages_per_block=[4, 8], + num_queries_per_block=[32, 64], + sliding_window=[None, 5, 128], + ) + def test_ragged_paged_attention_sliding_window( + self, + num_kv_pages_per_block, + num_queries_per_block, sliding_window: int | None, ): + num_seqs = 5 + num_heads = (4, 4) + dtype = jnp.float32 seq_lens = [] for _ in range(num_seqs): q_len = random.randint(1, 100) @@ -305,6 +342,41 @@ def test_ragged_paged_attention_complex( sliding_window=sliding_window, ) + @parameterized.product( + num_kv_pages_per_block=[4, 8], + num_queries_per_block=[32, 64], + soft_cap=[None, 50.0], + ) + def test_ragged_paged_attention_logit_soft_capping( + self, + num_kv_pages_per_block, + num_queries_per_block, + soft_cap: float | None, + ): + num_heads = (12, 2) + num_seqs = 2 + dtype = jnp.float32 + seq_lens = [] + for _ in range(num_seqs): + q_len = random.randint(1, 100) + kv_len = q_len + random.randint(0, 50) + seq_lens.append((q_len, kv_len)) + head_dim = 128 + page_size = 16 + num_pages = 1000 + + self._test_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + num_kv_pages_per_block=num_kv_pages_per_block, + num_queries_per_block=num_queries_per_block, + soft_cap=soft_cap, + ) + def test_ragged_paged_attention_sliding_window_should_be_positive(self): dtype = jnp.float32 seq_lens = [(192, 328), (128, 180), (64, 255)] @@ -335,5 +407,25 @@ def test_ragged_paged_attention_sliding_window_should_be_positive(self): sliding_window=-1, ) + def test_ragged_paged_attention_soft_cap_cannot_be_zero(self): + dtype = jnp.float32 + seq_lens = [(192, 328), (128, 180), (64, 255)] + num_heads = (32, 8) + head_dim = 128 + page_size = 16 + num_pages = 1000 + + with self.assertRaisesRegex(ValueError, "must not be 0.0"): + self._test_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + soft_cap=0.0, + ) + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From e23069b39cd747563abd53132eaee15290ce0ce2 Mon Sep 17 00:00:00 2001 From: Krishna Haridasan Date: Fri, 21 Mar 2025 11:42:14 -0700 Subject: [PATCH 090/483] Allow forcing pallas forward compatibility for some backends PiperOrigin-RevId: 739249745 --- jax/_src/interpreters/mlir.py | 13 ++++++++++--- jax/_src/xla_bridge.py | 3 +++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index f96f07be4149..a707981f5403 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -847,11 +847,18 @@ def is_forward_compat(self) -> bool: """Returns true if the lowering parameters are in forward compatibility mode. """ lowering_parameters = self.module_context.lowering_parameters - return ( - lowering_parameters.for_export - and not lowering_parameters.export_ignore_forward_compatibility + + check_platforms: Sequence[str] = ( + self.platforms or self.module_context.platforms + ) + force_forward_compat = any( + p in xb.FORCE_FORWARD_COMPAT_LOWERING_PLATFORMS for p in check_platforms ) + return ( + lowering_parameters.for_export or force_forward_compat + ) and not lowering_parameters.export_ignore_forward_compatibility + if not MYPY: class LoweringRule(Protocol): diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index be96deab81d8..72d88b9735b7 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -60,6 +60,9 @@ XlaBackend = xla_client.Client +# The platforms in this set will force forward compatibility for lowering. +FORCE_FORWARD_COMPAT_LOWERING_PLATFORMS: set[str] = set() + MIN_COMPUTE_CAPABILITY = 52 _DEFAULT_CPU_COLLECTIVES_IMPL = 'gloo' From 53e8eac7134a13c1d28de673e7e3a23b4a837aed Mon Sep 17 00:00:00 2001 From: Brian Zhao Date: Fri, 21 Mar 2025 12:12:05 -0700 Subject: [PATCH 091/483] Reverts be5713309521d5cf0d2252b9c8f1d38ab50952d1 PiperOrigin-RevId: 739258607 --- jax/_src/numpy/lax_numpy.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 16355695792d..96efc48062e1 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -49,16 +49,15 @@ from jax._src.lax.lax import (PrecisionLike,_array_copy, _sort_le_comparator, _sort_lt_comparator) from jax._src.lib import xla_client as xc +from jax._src.numpy.array_creation import (empty, empty_like, full, + ones, ones_like, zeros, zeros_like) from jax._src.numpy import indexing from jax._src.numpy import reductions from jax._src.numpy import tensor_contractions from jax._src.numpy import ufuncs from jax._src.numpy import util -from jax._src.numpy.array_creation import (empty, empty_like, full, - ones, ones_like, zeros, zeros_like) from jax._src.numpy.sorting import argsort, sort from jax._src.numpy.vectorize import vectorize -from jax._src.sharding_impls import SingleDeviceSharding from jax._src.typing import ( Array, ArrayLike, DType, DTypeLike, DeprecatedArg, DimSize, Shape ) @@ -66,7 +65,8 @@ NumpyComplexWarning, canonicalize_axis as _canonicalize_axis, ceil_of_ratio, safe_zip, set_module, unzip2) from jax.sharding import Sharding -from jax.tree_util import tree_flatten, tree_map +from jax._src.sharding_impls import SingleDeviceSharding +from jax.tree_util import tree_leaves, tree_map import numpy as np export = set_module('jax.numpy') @@ -5504,7 +5504,9 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, object = xc._xla.cuda_array_interface_to_buffer( cai=cai, gpu_backend=backend, device_id=device_id) - leaves, treedef = tree_flatten(object, is_leaf=lambda x: x is None) + object = tree_map(lambda leaf: leaf.__jax_array__() + if hasattr(leaf, "__jax_array__") else leaf, object) + leaves = tree_leaves(object, is_leaf=lambda x: x is None) if any(leaf is None for leaf in leaves): # Added Nov 16 2023 if deprecations.is_accelerated("jax-numpy-array-none"): @@ -5513,13 +5515,7 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, "None encountered in jnp.array(); this is currently treated as NaN. " "In the future this will result in an error.", FutureWarning, stacklevel=2) - leaves, treedef = tree_flatten(object) - leaves = [ - leaf - if (leaf_jax_array := getattr(leaf, "__jax_array__", None)) is None - else leaf_jax_array() - for leaf in leaves - ] + leaves = tree_leaves(object) if dtype is None: # Use lattice_result_type rather than result_type to avoid canonicalization. # Otherwise, weakly-typed inputs would have their dtypes canonicalized. @@ -5534,8 +5530,8 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, if not weak_type: dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True) # type: ignore[assignment] - object = treedef.unflatten(leaves) out: ArrayLike + if all(not isinstance(leaf, Array) for leaf in leaves): # TODO(jakevdp): falling back to numpy here fails to overflow for lists # containing large integers; see discussion in From 520b44fc5ca70d6bb5d70e539ac6e53b2e53072b Mon Sep 17 00:00:00 2001 From: Ayaka Date: Fri, 21 Mar 2025 12:50:37 -0700 Subject: [PATCH 092/483] Ensure traceback correctness in error checking in AOT mode This PR is similar to https://github.com/jax-ml/jax/pull/27329. The difference is that in AOT mode, the original traceback is exported as a string and appended to the error message when an exception is raised. PiperOrigin-RevId: 739270141 --- tests/error_check_test.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/tests/error_check_test.py b/tests/error_check_test.py index 7f75eeb629a0..af3f35c7ab62 100644 --- a/tests/error_check_test.py +++ b/tests/error_check_test.py @@ -305,6 +305,41 @@ def run_import(serialized): serialized = run_export() run_import(serialized) + def test_error_check_aot_includes_traceback(self): + def run_export(): + def function_that_triggers_error_for_traceback_test(x): + error_check.set_error_if( # This line must be included in the traceback + x <= 0, "x must be greater than 0" + ) + return x + 1 + + f = jax.jit( + error_check.wrap_for_export( + jax.jit(function_that_triggers_error_for_traceback_test) + ) + ) + x = jax.ShapeDtypeStruct((), jnp.float32) + serialized = jax.export.export(f)(x).serialize() + return serialized + + def run_import(serialized): + f = jax.export.deserialize(serialized).call + f = jax.jit(error_check.unwrap_from_import(jax.jit(f))) + x = jnp.float32(-3.0) + _ = f(x) + + msg = "" + try: + error_check.raise_if_error() + except JaxValueError as e: + msg = str(e) + + self.assertIn("function_that_triggers_error_for_traceback_test", msg) + self.assertIn("This line must be included in the traceback", msg) + + serialized = run_export() + run_import(serialized) + def test_error_check_aot_should_not_override_existing_error(self): def f1(x): error_check.set_error_if(x <= 0, "x must be greater than 0 in f1") From e71bcde543ba4db23f382f0aca7d9d6fe4227f06 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 21 Mar 2025 13:22:36 -0700 Subject: [PATCH 093/483] Remove some long-stale version guards. PiperOrigin-RevId: 739279729 --- jax/_src/tpu_custom_call.py | 10 +++------- jaxlib/xla/xla_client_test.py | 2 -- tests/array_test.py | 2 -- 3 files changed, 3 insertions(+), 11 deletions(-) diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index 4089e047f8b0..e37d5e064a26 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -484,13 +484,9 @@ def _lower_mosaic_module_to_asm( module_op = module.operation.clone() prev_allow_unregistered_dialects = ctx.allow_unregistered_dialects ctx.allow_unregistered_dialects = True - # TODO(apaszke): Remove once the minimum jaxlib version is at least 0.4.37. - if jax.version._version_as_tuple(jax.lib.__version__) < (0, 4, 37): - target_version = "" - else: - target_version = ( - f"target-version={ir_version}" if ir_version is not None else "" - ) + target_version = ( + f"target-version={ir_version}" if ir_version is not None else "" + ) try: pipeline = PassManager.parse( "builtin.module(mosaic-serde{serialize=true " + target_version + "})" diff --git a/jaxlib/xla/xla_client_test.py b/jaxlib/xla/xla_client_test.py index e228905637cb..5a2f3881f510 100644 --- a/jaxlib/xla/xla_client_test.py +++ b/jaxlib/xla/xla_client_test.py @@ -630,8 +630,6 @@ def testStatefulCustomCall(self): def testCustomCallLookup(self): if self.backend.platform != "cpu": self.skipTest("Test requires cpu platform") - if xla_client._version < 241: - self.skipTest("Test requires jaxlib version 241") self.assertTrue(_CUSTOM_CALLS_REGISTERED) xla_client.make_cpu_client() diff --git a/tests/array_test.py b/tests/array_test.py index 6100283cc032..5891db5a3e36 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -368,8 +368,6 @@ def test_different_devices_in_arrays_than_sharding(self): array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs, committed=True) def test_duplicated_devices_in_arrays(self): - if xc._version <= 274: - self.skipTest('Test requires jaxlib version 275') shape = (8, 2) mesh = jtu.create_mesh((1, 2), ('x', 'y')) # Sharding device ids = {0, 1} From ba5be78cdd218136506d5a11b10a793a8692aae2 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 21 Mar 2025 13:26:09 -0700 Subject: [PATCH 094/483] Remove symlinking of xla_client.py. Use a stub instead. Symlinking led to confusing behaviors because Python may believe there are two copies of the module. PiperOrigin-RevId: 739280690 --- jaxlib/BUILD | 9 +-------- jaxlib/tools/build_wheel.py | 2 +- jaxlib/xla_client.py | 18 ++++++++++++++++++ 3 files changed, 20 insertions(+), 9 deletions(-) create mode 100644 jaxlib/xla_client.py diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 2397639fddf2..5f693c5384df 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -50,8 +50,8 @@ py_library_providing_imports_info( "init.py", "lapack.py", "plugin_support.py", + "xla_client.py", ":version", - ":xla_client", ":xla_extension_py", ], data = [":ffi_headers"], @@ -92,13 +92,6 @@ symlink_files( flatten = True, ) -symlink_files( - name = "xla_client", - srcs = ["//jaxlib/xla:xla_client"], - dst = ".", - flatten = True, -) - symlink_files( name = "ffi_headers", srcs = ["@xla//xla/ffi/api:all_headers"], diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index 8632468acb97..0e0ce077cb23 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -197,7 +197,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu): "__main__/jaxlib/gpu_sparse.py", "__main__/jaxlib/plugin_support.py", "__main__/jaxlib/version.py", - "__main__/jaxlib/xla_client.py", + "__main__/jaxlib/xla/xla_client.py", f"xla/xla/python/xla_extension.{pyext}", ], ) diff --git a/jaxlib/xla_client.py b/jaxlib/xla_client.py new file mode 100644 index 000000000000..01b01ecf704e --- /dev/null +++ b/jaxlib/xla_client.py @@ -0,0 +1,18 @@ +# Copyright 2025 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from jaxlib.xla.xla_client import * # noqa: F403 +from jaxlib.xla.xla_client import _version # noqa: F401 +from jaxlib.xla.xla_client import _xla # noqa: F401 From 93f3e4aa19cd2b892ad2c788f2f1a2ebdb853ce6 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 21 Mar 2025 08:41:56 -0400 Subject: [PATCH 095/483] Increase the test timeout for tsan builds. Update the list of TSAN suppressions. Issue #27244 --- .github/workflows/tsan-suppressions.txt | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/.github/workflows/tsan-suppressions.txt b/.github/workflows/tsan-suppressions.txt index 7b713b2da194..296f4432e687 100644 --- a/.github/workflows/tsan-suppressions.txt +++ b/.github/workflows/tsan-suppressions.txt @@ -2,14 +2,11 @@ # are racing on a call to __register_frame_info(), but that function appears to be correctly locked internally. race:llvm::RuntimeDyldELF::registerEHFrames -# https://github.com/python/cpython/issues/128050 -race:partial_vectorcall_fallback - # https://github.com/openxla/xla/issues/20686 race:dnnl_sgemm -# https://github.com/python/cpython/issues/128130 -race_top:run_eval_code_obj +# https://github.com/python/cpython/issues/128050 +race:partial_vectorcall_fallback # Likely only happens when the process is crashing. race:dump_traceback @@ -18,19 +15,15 @@ race:dump_traceback # Fixed in Python 3.14, but not backported to 3.13. race:immortalize_interned race:_PyUnicode_InternMortal +race:_PyUnicode_InternImmortal # https://github.com/python/cpython/issues/128144 # Fixed in Python 3.14, but not backported to 3.13. race_top:PyMember_GetOne -# https://github.com/python/cpython/issues/129547 -race:type_get_annotations - - # https://github.com/python/cpython/issues/129748 race:mi_block_set_nextx - # Races because the LAPACK and BLAS in our scipy isn't TSAN instrumented. race:heevd_ffi race:gesdd_ffi @@ -65,3 +58,10 @@ race:gemm_oncopy # https://github.com/python/cpython/issues/130547 # race:split_keys_entry_added + +# https://github.com/python/cpython/issues/128130 +# race_top:run_eval_code_obj + +# https://github.com/python/cpython/issues/129547 +# Maybe fixed? +# race:type_get_annotations From 6b7744581b6810d1fab176994e591b8ccb4f6f5b Mon Sep 17 00:00:00 2001 From: Nicolas Castet Date: Fri, 21 Feb 2025 16:51:36 -0600 Subject: [PATCH 096/483] [Pallas] [1/3] Move communication primitives from mosaic to core --- jax/_src/pallas/core.py | 4 + jax/_src/pallas/mosaic/core.py | 11 +- jax/_src/pallas/mosaic/helpers.py | 4 +- jax/_src/pallas/mosaic/interpret.py | 8 +- jax/_src/pallas/mosaic/lowering.py | 30 +-- jax/_src/pallas/mosaic/primitives.py | 277 +-------------------------- jax/_src/pallas/primitives.py | 264 +++++++++++++++++++++++++ jax/experimental/pallas/__init__.py | 5 + jax/experimental/pallas/tpu.py | 14 +- 9 files changed, 314 insertions(+), 303 deletions(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 466f6037a8ef..389bbd3b0733 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -65,6 +65,10 @@ def __repr__(self): SEMAPHORE_INTERPRET_DTYPE = jnp.int16 SEMAPHORE_MAX_VALUE = jnp.iinfo(SEMAPHORE_INTERPRET_DTYPE).max +class semaphore_dtype(dtypes.extended): pass +class semaphore(semaphore_dtype): pass +class dma_semaphore(semaphore_dtype): pass +class barrier_semaphore(semaphore_dtype): pass @runtime_checkable class CompilerParams(Protocol): diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index 3e60e471dfa2..5d503779f092 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -112,11 +112,6 @@ def __call__(self, shape: tuple[int, ...], dtype: jnp.dtype): # A convenience function for constructing MemoryRef types. return pallas_core.MemoryRef(shape, dtype, self) -class semaphore_dtype(dtypes.extended): pass -class semaphore(semaphore_dtype): pass -class dma_semaphore(semaphore_dtype): pass -class barrier_semaphore(semaphore_dtype): pass - class AbstractSemaphoreTyRules: @staticmethod def pallas_interpret_element_aval(_) -> jax_core.ShapedArray: @@ -142,15 +137,15 @@ def __hash__(self) -> int: # TODO(sharadmv): implement dtype rules for AbstractSemaphoreTy class SemaphoreTy(AbstractSemaphoreTy): - type = semaphore + type = pallas_core.semaphore name = "sem" class DmaSemaphoreTy(AbstractSemaphoreTy): - type = dma_semaphore + type = pallas_core.dma_semaphore name = "dma_sem" class BarrierSemaphoreTy(AbstractSemaphoreTy): - type = barrier_semaphore + type = pallas_core.barrier_semaphore name = "barrier_sem" class SemaphoreType(enum.Enum): diff --git a/jax/_src/pallas/mosaic/helpers.py b/jax/_src/pallas/mosaic/helpers.py index 76421cec3340..24cd7cad6086 100644 --- a/jax/_src/pallas/mosaic/helpers.py +++ b/jax/_src/pallas/mosaic/helpers.py @@ -88,8 +88,8 @@ def signal_core(i): # Don't signal ourself @pl_helpers.when(core_id != i) def _(): - plm_primitives.semaphore_signal(sem, 1, core_index=i) + pl_primitives.semaphore_signal(sem, 1, core_index=i) for i in range(num_cores): signal_core(i) - plm_primitives.semaphore_wait(sem, num_cores - 1) + pl_primitives.semaphore_wait(sem, num_cores - 1) diff --git a/jax/_src/pallas/mosaic/interpret.py b/jax/_src/pallas/mosaic/interpret.py index a731bfdfdae1..ba1a7b0017c4 100644 --- a/jax/_src/pallas/mosaic/interpret.py +++ b/jax/_src/pallas/mosaic/interpret.py @@ -943,9 +943,9 @@ def _device_coords_to_logical_id(device_coords, axis_sizes): def _device_id_to_logical(device_id, device_id_type, axis_sizes): if device_id is None: return None - if device_id_type == mosaic_primitives.DeviceIdType.MESH: + if device_id_type == primitives.DeviceIdType.MESH: return _device_coords_to_logical_id(device_id, axis_sizes) - elif device_id_type == mosaic_primitives.DeviceIdType.LOGICAL: + elif device_id_type == primitives.DeviceIdType.LOGICAL: return device_id else: raise ValueError(f'Unsupported device ID type: {device_id_type}') @@ -1223,7 +1223,7 @@ def f(*args, jaxpr): compiler_params['mosaic']['collective_id'], ordered=True) - elif prim is mosaic_primitives.semaphore_signal_p: + elif prim is primitives.semaphore_signal_p: sem, sem_transforms, inc, target_device_id, core_index = ( jax.tree.unflatten(eqn.params['args_tree'], deferred_invals())) target_device_id = _device_id_to_logical( @@ -1239,7 +1239,7 @@ def f(*args, jaxpr): ordered=True) out = [] - elif prim is mosaic_primitives.semaphore_wait_p: + elif prim is primitives.semaphore_wait_p: sem, sem_transforms, value = ( jax.tree.unflatten(eqn.params['args_tree'], deferred_invals())) callback.io_callback( diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 4efb2b276f56..3469ef4de952 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -192,12 +192,12 @@ def _memory_space_to_mosaic_attribute(memory_space: MemorySpace | None def _dtype_to_ir_type(dtype: jnp.dtype, is_kernel_boundary: bool = False) -> ir.Type: - if jnp.issubdtype(dtype, tpu_core.semaphore_dtype): - if jnp.issubdtype(dtype, tpu_core.dma_semaphore): + if jnp.issubdtype(dtype, pallas_core.semaphore_dtype): + if jnp.issubdtype(dtype, pallas_core.dma_semaphore): return ir.Type.parse("!tpu.dma_semaphore") - elif jnp.issubdtype(dtype, tpu_core.semaphore): + elif jnp.issubdtype(dtype, pallas_core.semaphore): return ir.Type.parse("!tpu.semaphore") - elif jnp.issubdtype(dtype, tpu_core.barrier_semaphore): + elif jnp.issubdtype(dtype, pallas_core.barrier_semaphore): return ir.Type.parse("!tpu.semaphore") else: raise NotImplementedError @@ -3291,7 +3291,7 @@ def _alloc_value( ) -> ir.Value: if isinstance(aval, pallas_core.AbstractMemoryRef): memspace = _memory_space_to_mosaic_attribute(aval.memory_space) - if jnp.issubdtype(aval.dtype, tpu_core.semaphore_dtype): + if jnp.issubdtype(aval.dtype, pallas_core.semaphore_dtype): assert aval.memory_space == TPUMemorySpace.SEMAPHORE memref_type = aval_to_ir_type( ctx.lowering_context.dynamic_shape_replacement_fn, @@ -3341,8 +3341,8 @@ def _run_scoped_lowering_rule(ctx: LoweringRuleContext, *consts, jaxpr): def _device_id_to_logical( ctx: LoweringRuleContext, device_id, - device_id_type: tpu_primitives.DeviceIdType): - if device_id_type is tpu_primitives.DeviceIdType.MESH: + device_id_type: primitives.DeviceIdType): + if device_id_type is primitives.DeviceIdType.MESH: # Mesh means we are passed the mesh coordinates for the device device_ids = tree_util.tree_leaves(device_id) mesh_strides = ctx.lowering_context.mesh_context.mesh_strides @@ -3357,7 +3357,7 @@ def _device_id_to_logical( for a, b in zip(device_ids, mesh_strides) ), ) - elif device_id_type is tpu_primitives.DeviceIdType.LOGICAL: + elif device_id_type is primitives.DeviceIdType.LOGICAL: return device_id raise NotImplementedError(f"Unsupported device id type: {device_id_type}") @@ -3373,13 +3373,13 @@ def _semaphore_read_lowering_rule( return tpu.sem_read(sem) -lowering_rules[tpu_primitives.semaphore_read_p] = _semaphore_read_lowering_rule +lowering_rules[primitives.semaphore_read_p] = _semaphore_read_lowering_rule def _semaphore_signal_lowering_rule( ctx: LoweringRuleContext, *args, args_tree, - device_id_type: tpu_primitives.DeviceIdType, + device_id_type: primitives.DeviceIdType, ): sem_aval, _, _, _, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in) sem, transforms, value, device_id, core_index = tree_util.tree_unflatten( @@ -3392,7 +3392,7 @@ def _semaphore_signal_lowering_rule( return [] -lowering_rules[tpu_primitives.semaphore_signal_p] = ( +lowering_rules[primitives.semaphore_signal_p] = ( _semaphore_signal_lowering_rule) @@ -3402,10 +3402,10 @@ def _semaphore_wait_lowering_rule(ctx: LoweringRuleContext, *args, args_tree): sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, transforms) tpu.sem_wait(sem, value) return [] -lowering_rules[tpu_primitives.semaphore_wait_p] = _semaphore_wait_lowering_rule +lowering_rules[primitives.semaphore_wait_p] = _semaphore_wait_lowering_rule def _dma_start_lowering_rule(ctx: LoweringRuleContext, *args, tree, - device_id_type: tpu_primitives.DeviceIdType): + device_id_type: primitives.DeviceIdType): ( src_ref, src_transforms, @@ -3445,7 +3445,7 @@ def _dma_start_lowering_rule(ctx: LoweringRuleContext, *args, tree, def _dma_wait_lowering_rule(ctx: LoweringRuleContext, *args, tree, - device_id_type: tpu_primitives.DeviceIdType): + device_id_type: primitives.DeviceIdType): del device_id_type (src, src_transforms, dst, transforms, sem, sem_transforms, _, _, _) = ( tree_util.tree_unflatten(tree, args) @@ -3477,7 +3477,7 @@ def _dma_wait_lowering_rule(ctx: LoweringRuleContext, *args, tree, def _device_id_lowering_rule(ctx: LoweringRuleContext): return tpu.device_id() -lowering_rules[tpu_primitives.device_id_p] = _device_id_lowering_rule +lowering_rules[primitives.device_id_p] = _device_id_lowering_rule def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable): grid_names = ctx.lowering_context.grid_names diff --git a/jax/_src/pallas/mosaic/primitives.py b/jax/_src/pallas/mosaic/primitives.py index fb0e0c2c55e3..106f342bace8 100644 --- a/jax/_src/pallas/mosaic/primitives.py +++ b/jax/_src/pallas/mosaic/primitives.py @@ -16,7 +16,6 @@ from __future__ import annotations import dataclasses -import enum from typing import Any import jax @@ -28,6 +27,7 @@ from jax._src import util from jax._src.interpreters import mlir from jax._src.pallas import core as pl_core +from jax._src.pallas import primitives from jax._src.pallas import utils as pallas_utils from jax._src.pallas.mosaic import core as tpu_core from jax._src.state import discharge as state_discharge @@ -160,255 +160,6 @@ def _roll(x, shift): mlir.register_lowering(roll_p, _roll_lowering_rule) -class DeviceIdType(enum.Enum): - MESH = "mesh" - LOGICAL = "logical" - - -def check_sem_avals( - sem_aval, sem_transforms_avals, name, allowed_semaphore_types=None -): - if allowed_semaphore_types is None: - allowed_semaphore_types = { - tpu_core.semaphore, - tpu_core.barrier_semaphore, - # For interpret mode. - pl_core.SEMAPHORE_INTERPRET_DTYPE, - } - if not isinstance(sem_aval, state.AbstractRef): - raise ValueError(f"Cannot {name} on a non-semaphore Ref: {sem_aval}") - sem_shape = sem_aval.shape - if sem_transforms_avals: - sem_shape = sem_transforms_avals[-1].get_indexer_shape() - if sem_shape: - raise ValueError(f"Cannot {name} on a non-()-shaped semaphore: {sem_shape}") - sem_dtype = sem_aval.dtype - if not any( - jnp.issubdtype(sem_dtype, sem_type) - for sem_type in allowed_semaphore_types - ): - raise ValueError( - f"Must {name} semaphores of the following types:" - f" {allowed_semaphore_types}. Got {sem_dtype}." - ) - - -def _transform_semaphore(ref_value, transforms, ref_aval): - """Helper function for indexing into a semaphore during state_discharge.""" - if ref_value.shape == ref_aval.shape: - return state_discharge.transform_array(ref_value, transforms) - elif len(ref_value.shape) == 0: - return ref_value - else: - raise ValueError( - f"Semaphore value shape {ref_value.shape} does not match aval shape" - f" {ref_aval.shape}" - ) - - -semaphore_read_p = jax_core.Primitive("semaphore_read") -semaphore_read_p.multiple_results = False - - -def semaphore_read(sem_or_view): - ref, transforms = _get_ref_and_transforms(sem_or_view) - args = [ref, transforms] - flat_args, args_tree = tree_util.tree_flatten(args) - return semaphore_read_p.bind(*flat_args, args_tree=args_tree) - -@semaphore_read_p.def_abstract_eval -def _semaphore_read_abstract_eval( - *avals, - args_tree, -): - sem_aval, sem_transforms_avals = tree_util.tree_unflatten(args_tree, avals) - check_sem_avals( - sem_aval, - sem_transforms_avals, - "read", - allowed_semaphore_types={ - tpu_core.dma_semaphore, - tpu_core.semaphore, - tpu_core.barrier_semaphore, - pl_core.SEMAPHORE_INTERPRET_DTYPE, - }, - ) - return jax_core.ShapedArray((), jnp.dtype("int32")) - -def _semaphore_read_discharge_rule(in_avals, - out_avals, - *flat_args, - args_tree): - del out_avals - [ref, transforms] = args_tree.unflatten(flat_args) - sem_value = _transform_semaphore(ref, transforms, in_avals[0]) - sem_value = sem_value.astype(jnp.int32) - return (None,) * len(in_avals), sem_value -state_discharge.register_discharge_rule(semaphore_read_p)( - _semaphore_read_discharge_rule -) - - -semaphore_signal_p = jax_core.Primitive('semaphore_signal') -semaphore_signal_p.multiple_results = True - - -def semaphore_signal( - sem_or_view, - inc: int | jax.Array = 1, - *, - device_id: int | jax.Array | None | tuple[int | jax.Array, ...] = None, - device_id_type: DeviceIdType = DeviceIdType.MESH, - core_index: int | jax.Array | None = None, -): - ref, transforms = _get_ref_and_transforms(sem_or_view) - inc = jnp.asarray(inc, dtype=jnp.int32) - args = [ref, transforms, inc, device_id, core_index] - flat_args, args_tree = tree_util.tree_flatten(args) - semaphore_signal_p.bind( - *flat_args, - args_tree=args_tree, - device_id_type=device_id_type, - ) - - -@semaphore_signal_p.def_abstract_eval -def _semaphore_signal_abstract_eval( - *avals, - args_tree, - device_id_type: DeviceIdType, -): - del device_id_type - ( - sem_aval, - sem_transforms_avals, - value_aval, - device_id_avals, - core_index_aval, - ) = tree_util.tree_unflatten(args_tree, avals) - check_sem_avals(sem_aval, sem_transforms_avals, "signal") - if value_aval.dtype != jnp.dtype("int32"): - raise ValueError("Must signal an int32 value.") - if device_id_avals is not None: - device_id_flat_avals = tree_util.tree_leaves(device_id_avals) - for aval in device_id_flat_avals: - if aval.dtype != jnp.dtype("int32"): - raise ValueError("`device_id`s must be an int32 value.") - return [] - - -def _semaphore_signal_pp_eqn(eqn: jax_core.JaxprEqn, - context: jax_core.JaxprPpContext, - settings: jax_core.JaxprPpSettings): - del settings - invars = eqn.invars - tree = eqn.params["args_tree"] - ( - sem, - sem_transforms, - value, - device_ids, - _, - ) = tree_util.tree_unflatten(tree, invars) - out = pp.concat([ - pp.text("semaphore_signal"), - pp.text(" "), - sp.pp_ref_transforms(context, sem, sem_transforms), - pp.text(" "), - pp.text(jax_core.pp_var(value, context)), - ]) - if device_ids is not None: - flat_device_ids = tree_util.tree_leaves(device_ids) - if not flat_device_ids: - return out - device_ids_pp = [pp.text(jax_core.pp_var(flat_device_ids[0], context))] - for device_id in flat_device_ids[1:]: - device_ids_pp.append(pp.text(" ")) - device_ids_pp.append(pp.text(jax_core.pp_var(device_id, context))) - out = pp.concat([out, pp.concat(device_ids_pp)]) - return out -jax_core.pp_eqn_rules[semaphore_signal_p] = _semaphore_signal_pp_eqn - - -def _semaphore_signal_discharge_rule(in_avals, - out_avals, - *flat_args, - args_tree, - device_id_type): - del out_avals, device_id_type - [ref, transforms, inc, device_id, core_index] = args_tree.unflatten(flat_args) - if device_id is not None: - raise NotImplementedError("Remote signal not implemented.") - if core_index is not None: - raise NotImplementedError("Multiple core support not implemented.") - sem_value = _transform_semaphore(ref, transforms, in_avals[0]) - inc = inc.astype(pl_core.SEMAPHORE_INTERPRET_DTYPE) - _, new_sem_value = state_discharge.transform_swap_array( - ref, transforms, sem_value + inc - ) - return (new_sem_value,) + (None,) * (len(in_avals) - 1), () -state_discharge.register_discharge_rule(semaphore_signal_p)( - _semaphore_signal_discharge_rule -) - - -semaphore_wait_p = jax_core.Primitive('semaphore_wait') -semaphore_wait_p.multiple_results = True - -def semaphore_wait(sem_or_view, dec: int | jax.Array = 1): - ref, transforms = _get_ref_and_transforms(sem_or_view) - dec = jnp.asarray(dec, dtype=jnp.int32) - args = [ref, transforms, dec] - flat_args, args_tree = tree_util.tree_flatten(args) - semaphore_wait_p.bind(*flat_args, args_tree=args_tree) - -@semaphore_wait_p.def_abstract_eval -def _semaphore_wait_abstract_eval(*avals, args_tree): - sem_aval, sem_transforms_avals, value_aval = tree_util.tree_unflatten( - args_tree, avals - ) - check_sem_avals(sem_aval, sem_transforms_avals, "wait") - if value_aval.dtype != jnp.dtype("int32"): - raise ValueError("Must wait an int32 value.") - return [] - -def _semaphore_wait_pp_eqn(eqn: jax_core.JaxprEqn, - context: jax_core.JaxprPpContext, - settings: jax_core.JaxprPpSettings): - del settings - invars = eqn.invars - tree = eqn.params["args_tree"] - ( - sem, - sem_transforms, - value, - ) = tree_util.tree_unflatten(tree, invars) - return pp.concat([ - pp.text("semaphore_wait"), - pp.text(" "), - sp.pp_ref_transforms(context, sem, sem_transforms), - pp.text(" "), - pp.text(jax_core.pp_var(value, context)), - ]) -jax_core.pp_eqn_rules[semaphore_wait_p] = _semaphore_wait_pp_eqn - -def _semaphore_wait_discharge_rule(in_avals, - out_avals, - *flat_args, - args_tree): - del out_avals - [ref, transforms, dec] = args_tree.unflatten(flat_args) - sem_value = _transform_semaphore(ref, transforms, in_avals[0]) - dec = dec.astype(pl_core.SEMAPHORE_INTERPRET_DTYPE) - _, new_sem_value = state_discharge.transform_swap_array( - ref, transforms, sem_value - dec - ) - return (new_sem_value,) + (None,) * (len(in_avals) - 1), () -state_discharge.register_discharge_rule(semaphore_wait_p)( - _semaphore_wait_discharge_rule -) - - @dataclasses.dataclass class AsyncCopyDescriptor: src_ref: Any @@ -420,7 +171,7 @@ class AsyncCopyDescriptor: src_sem: int | jax.Array | None src_sem_transforms: tuple[Transform, ...] | None device_id: int | jax.Array | None - device_id_type: DeviceIdType = DeviceIdType.MESH + device_id_type: primitives.DeviceIdType = primitives.DeviceIdType.MESH def __post_init__(self): if (self.src_sem is None) ^ (self.device_id is None): @@ -610,14 +361,14 @@ def dma_start_partial_discharge_rule(should_discharge, in_avals, out_avals, # TODO(justinfu): Verify that code only works in SPMD mode. axis_env = jax_core.get_axis_env() nonempty_axes = [name for name in axis_env.axis_sizes if name is not None] - if device_id_type == DeviceIdType.LOGICAL: + if device_id_type == primitives.DeviceIdType.LOGICAL: if len(nonempty_axes) > 1: raise NotImplementedError("Sharding with more than one named axis not " "implemented in dma_start_p for LOGICAL " "device_id_type.") shard_axis = nonempty_axes[0] my_axis = jax.lax.axis_index(shard_axis) - elif device_id_type == DeviceIdType.MESH: + elif device_id_type == primitives.DeviceIdType.MESH: device_id_len = 1 if isinstance(device_id, jax.Array): device_id_len = device_id.size @@ -667,7 +418,7 @@ def do_discharge_dst(dst_ref=dst_ref): def do_discharge_dst_sem(dst_sem=dst_sem): recv_size = jnp.minimum(updates.size, pl_core.SEMAPHORE_MAX_VALUE) recv_size = jnp.array(recv_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE) - dst_sem_value = _transform_semaphore( + dst_sem_value = primitives._transform_semaphore( dst_sem, dst_sem_transforms, dst_sem_aval ) _, ret = state_discharge.transform_swap_array( @@ -678,7 +429,7 @@ def do_discharge_dst_sem(dst_sem=dst_sem): def do_discharge_src_sem(src_sem=src_sem): send_size = jnp.minimum(local_src.size, pl_core.SEMAPHORE_MAX_VALUE) send_size = jnp.array(send_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE) - src_sem_value = _transform_semaphore( + src_sem_value = primitives._transform_semaphore( src_sem, src_sem_transforms, src_sem_aval ) _, ret = state_discharge.transform_swap_array( @@ -778,7 +529,7 @@ def dma_wait_partial_discharge_rule(should_discharge, updates = state_discharge.transform_array(dst_ref, dst_ref_transforms) copy_size = jnp.minimum(updates.size, pl_core.SEMAPHORE_MAX_VALUE) copy_size = jnp.array(copy_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE) - sem_value = _transform_semaphore(dst_sem, dst_sem_transforms, dst_sem_aval) + sem_value = primitives._transform_semaphore(dst_sem, dst_sem_transforms, dst_sem_aval) _, new_sem = state_discharge.transform_swap_array( dst_sem, dst_sem_transforms, sem_value - copy_size ) @@ -814,7 +565,7 @@ def make_async_copy(src_ref, dst_ref, sem): None, None, None, - DeviceIdType.MESH, + primitives.DeviceIdType.MESH, ) def async_copy(src_ref, dst_ref, sem): @@ -824,7 +575,7 @@ def async_copy(src_ref, dst_ref, sem): return copy_descriptor def make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id, - device_id_type: DeviceIdType = DeviceIdType.MESH): + device_id_type: primitives.DeviceIdType = primitives.DeviceIdType.MESH): """Creates a description of a remote copy operation. Copies data from src_ref on the current device to dst_ref on the device @@ -861,20 +612,12 @@ def make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id, ) def async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id, - device_id_type: DeviceIdType = DeviceIdType.MESH): + device_id_type: primitives.DeviceIdType = primitives.DeviceIdType.MESH): copy_descriptor = make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id, device_id_type) copy_descriptor.start() return copy_descriptor -device_id_p = jax_core.Primitive('device_id') - -@device_id_p.def_abstract_eval -def _device_id_abstract_eval(): - return jax_core.ShapedArray((), jnp.dtype("int32")) - -device_id = device_id_p.bind - get_barrier_semaphore_p = jax_core.Primitive('get_barrier_semaphore') @get_barrier_semaphore_p.def_abstract_eval diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index 3306649f24f3..5d3444ef719f 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -993,3 +993,267 @@ def _lower_fun(*lower_fun_args): return out[:num_return_values] return mlir.lower_fun(_lower_fun, multiple_results=True)(ctx, *args) + + +def _get_ref_and_transforms(ref): + if isinstance(ref, state.TransformedRef): + return ref.ref, ref.transforms + return ref, () + + +class DeviceIdType(enum.Enum): + MESH = "mesh" + LOGICAL = "logical" + + +def check_sem_avals( + sem_aval, sem_transforms_avals, name, allowed_semaphore_types=None +): + if allowed_semaphore_types is None: + allowed_semaphore_types = { + pallas_core.semaphore, + pallas_core.barrier_semaphore, + # For interpret mode. + pallas_core.SEMAPHORE_INTERPRET_DTYPE, + } + if not isinstance(sem_aval, state.AbstractRef): + raise ValueError(f"Cannot {name} on a non-semaphore Ref: {sem_aval}") + sem_shape = sem_aval.shape + if sem_transforms_avals: + sem_shape = sem_transforms_avals[-1].get_indexer_shape() + if sem_shape: + raise ValueError(f"Cannot {name} on a non-()-shaped semaphore: {sem_shape}") + # Uncomment when semaphore type works for Mosaic-GPU lowering + # sem_dtype = sem_aval.dtype + # if not any( + # jnp.issubdtype(sem_dtype, sem_type) + # for sem_type in allowed_semaphore_types + # ): + # raise ValueError( + # f"Must {name} semaphores of the following types:" + # f" {allowed_semaphore_types}." + # ) + + +def _transform_semaphore(ref_value, transforms, ref_aval): + """Helper function for indexing into a semaphore during state_discharge.""" + if ref_value.shape == ref_aval.shape: + return state_discharge.transform_array(ref_value, transforms) + elif len(ref_value.shape) == 0: + return ref_value + else: + raise ValueError( + f"Semaphore value shape {ref_value.shape} does not match aval shape" + f" {ref_aval.shape}" + ) + + +semaphore_read_p = jax_core.Primitive("semaphore_read") +semaphore_read_p.multiple_results = False + + +def semaphore_read(sem_or_view): + ref, transforms = _get_ref_and_transforms(sem_or_view) + args = [ref, transforms] + flat_args, args_tree = tree_util.tree_flatten(args) + return semaphore_read_p.bind(*flat_args, args_tree=args_tree) + +@semaphore_read_p.def_abstract_eval +def _semaphore_read_abstract_eval( + *avals, + args_tree, +): + sem_aval, sem_transforms_avals = tree_util.tree_unflatten(args_tree, avals) + check_sem_avals( + sem_aval, + sem_transforms_avals, + "read", + allowed_semaphore_types={ + pallas_core.dma_semaphore, + pallas_core.semaphore, + pallas_core.barrier_semaphore, + pallas_core.SEMAPHORE_INTERPRET_DTYPE, + }, + ) + return jax_core.ShapedArray((), jnp.dtype("int32")) + +def _semaphore_read_discharge_rule(in_avals, + out_avals, + *flat_args, + args_tree): + del out_avals + [ref, transforms] = args_tree.unflatten(flat_args) + sem_value = _transform_semaphore(ref, transforms, in_avals[0]) + sem_value = sem_value.astype(jnp.int32) + return (None,) * len(in_avals), sem_value +state_discharge.register_discharge_rule(semaphore_read_p)( + _semaphore_read_discharge_rule +) + + +semaphore_signal_p = jax_core.Primitive('semaphore_signal') +semaphore_signal_p.multiple_results = True + + +def semaphore_signal( + sem_or_view, + inc: int | jax.Array = 1, + *, + device_id: int | jax.Array | None | tuple[int | jax.Array, ...] = None, + device_id_type: DeviceIdType = DeviceIdType.MESH, + core_index: int | jax.Array | None = None, +): + ref, transforms = _get_ref_and_transforms(sem_or_view) + inc = jnp.asarray(inc, dtype=jnp.int32) + args = [ref, transforms, inc, device_id, core_index] + flat_args, args_tree = tree_util.tree_flatten(args) + semaphore_signal_p.bind( + *flat_args, + args_tree=args_tree, + device_id_type=device_id_type, + ) + + +@semaphore_signal_p.def_abstract_eval +def _semaphore_signal_abstract_eval( + *avals, + args_tree, + device_id_type: DeviceIdType, +): + del device_id_type + ( + sem_aval, + sem_transforms_avals, + value_aval, + device_id_avals, + core_index_aval, + ) = tree_util.tree_unflatten(args_tree, avals) + check_sem_avals(sem_aval, sem_transforms_avals, "signal") + if value_aval.dtype != jnp.dtype("int32"): + raise ValueError("Must signal an int32 value.") + if device_id_avals is not None: + device_id_flat_avals = tree_util.tree_leaves(device_id_avals) + for aval in device_id_flat_avals: + if aval.dtype != jnp.dtype("int32"): + raise ValueError("`device_id`s must be an int32 value.") + return [] + + +def _semaphore_signal_pp_eqn(eqn: jax_core.JaxprEqn, + context: jax_core.JaxprPpContext, + settings: jax_core.JaxprPpSettings): + del settings + invars = eqn.invars + tree = eqn.params["args_tree"] + ( + sem, + sem_transforms, + value, + device_ids, + _, + ) = tree_util.tree_unflatten(tree, invars) + out = pp.concat([ + pp.text("semaphore_signal"), + pp.text(" "), + sp.pp_ref_transforms(context, sem, sem_transforms), + pp.text(" "), + pp.text(jax_core.pp_var(value, context)), + ]) + if device_ids is not None: + flat_device_ids = tree_util.tree_leaves(device_ids) + if not flat_device_ids: + return out + device_ids_pp = [pp.text(jax_core.pp_var(flat_device_ids[0], context))] + for device_id in flat_device_ids[1:]: + device_ids_pp.append(pp.text(" ")) + device_ids_pp.append(pp.text(jax_core.pp_var(device_id, context))) + out = pp.concat([out, pp.concat(device_ids_pp)]) + return out +jax_core.pp_eqn_rules[semaphore_signal_p] = _semaphore_signal_pp_eqn + + +def _semaphore_signal_discharge_rule(in_avals, + out_avals, + *flat_args, + args_tree, + device_id_type): + del out_avals, device_id_type + [ref, transforms, inc, device_id, core_index] = args_tree.unflatten(flat_args) + if device_id is not None: + raise NotImplementedError("Remote signal not implemented.") + if core_index is not None: + raise NotImplementedError("Multiple core support not implemented.") + sem_value = _transform_semaphore(ref, transforms, in_avals[0]) + inc = inc.astype(pallas_core.SEMAPHORE_INTERPRET_DTYPE) + _, new_sem_value = state_discharge.transform_swap_array( + ref, transforms, sem_value + inc + ) + return (new_sem_value,) + (None,) * (len(in_avals) - 1), () +state_discharge.register_discharge_rule(semaphore_signal_p)( + _semaphore_signal_discharge_rule +) + + +semaphore_wait_p = jax_core.Primitive('semaphore_wait') +semaphore_wait_p.multiple_results = True + +def semaphore_wait(sem_or_view, dec: int | jax.Array = 1): + ref, transforms = _get_ref_and_transforms(sem_or_view) + dec = jnp.asarray(dec, dtype=jnp.int32) + args = [ref, transforms, dec] + flat_args, args_tree = tree_util.tree_flatten(args) + semaphore_wait_p.bind(*flat_args, args_tree=args_tree) + +@semaphore_wait_p.def_abstract_eval +def _semaphore_wait_abstract_eval(*avals, args_tree): + sem_aval, sem_transforms_avals, value_aval = tree_util.tree_unflatten( + args_tree, avals + ) + check_sem_avals(sem_aval, sem_transforms_avals, "wait") + if value_aval.dtype != jnp.dtype("int32"): + raise ValueError("Must wait an int32 value.") + return [] + +def _semaphore_wait_pp_eqn(eqn: jax_core.JaxprEqn, + context: jax_core.JaxprPpContext, + settings: jax_core.JaxprPpSettings): + del settings + invars = eqn.invars + tree = eqn.params["args_tree"] + ( + sem, + sem_transforms, + value, + ) = tree_util.tree_unflatten(tree, invars) + return pp.concat([ + pp.text("semaphore_wait"), + pp.text(" "), + sp.pp_ref_transforms(context, sem, sem_transforms), + pp.text(" "), + pp.text(jax_core.pp_var(value, context)), + ]) +jax_core.pp_eqn_rules[semaphore_wait_p] = _semaphore_wait_pp_eqn + +def _semaphore_wait_discharge_rule(in_avals, + out_avals, + *flat_args, + args_tree): + del out_avals + [ref, transforms, dec] = args_tree.unflatten(flat_args) + sem_value = _transform_semaphore(ref, transforms, in_avals[0]) + dec = dec.astype(pallas_core.SEMAPHORE_INTERPRET_DTYPE) + _, new_sem_value = state_discharge.transform_swap_array( + ref, transforms, sem_value - dec + ) + return (new_sem_value,) + (None,) * (len(in_avals) - 1), () +state_discharge.register_discharge_rule(semaphore_wait_p)( + _semaphore_wait_discharge_rule +) + +device_id_p = jax_core.Primitive('device_id') + +@device_id_p.def_abstract_eval +def _device_id_abstract_eval(): + return jax_core.ShapedArray((), jnp.dtype("int32")) + +device_id = device_id_p.bind diff --git a/jax/experimental/pallas/__init__.py b/jax/experimental/pallas/__init__.py index 1e0abacfc25f..ea58fae3d283 100644 --- a/jax/experimental/pallas/__init__.py +++ b/jax/experimental/pallas/__init__.py @@ -47,6 +47,7 @@ from jax._src.pallas.primitives import atomic_xchg as atomic_xchg from jax._src.pallas.primitives import atomic_xor as atomic_xor from jax._src.pallas.primitives import debug_print as debug_print +from jax._src.pallas.primitives import device_id as device_id from jax._src.pallas.primitives import dot as dot from jax._src.pallas.primitives import load as load from jax._src.pallas.primitives import max_contiguous as max_contiguous @@ -55,8 +56,12 @@ from jax._src.pallas.primitives import program_id as program_id from jax._src.pallas.primitives import reciprocal as reciprocal from jax._src.pallas.primitives import run_scoped as run_scoped +from jax._src.pallas.primitives import semaphore_read as semaphore_read +from jax._src.pallas.primitives import semaphore_signal as semaphore_signal +from jax._src.pallas.primitives import semaphore_wait as semaphore_wait from jax._src.pallas.primitives import store as store from jax._src.pallas.primitives import swap as swap +from jax._src.pallas.primitives import DeviceIdType as DeviceIdType from jax._src.pallas.utils import cdiv as cdiv from jax._src.pallas.utils import next_power_of_2 as next_power_of_2 from jax._src.pallas.utils import strides_from_shape as strides_from_shape diff --git a/jax/experimental/pallas/tpu.py b/jax/experimental/pallas/tpu.py index ecc9d0d15120..c81edaf76fa3 100644 --- a/jax/experimental/pallas/tpu.py +++ b/jax/experimental/pallas/tpu.py @@ -17,11 +17,11 @@ from jax._src.pallas.mosaic import core as core from jax._src.pallas.mosaic.core import ARBITRARY as ARBITRARY from jax._src.pallas.mosaic.core import create_tensorcore_mesh as create_tensorcore_mesh -from jax._src.pallas.mosaic.core import dma_semaphore as dma_semaphore +from jax._src.pallas.core import dma_semaphore as dma_semaphore from jax._src.pallas.mosaic.core import GridDimensionSemantics as GridDimensionSemantics from jax._src.pallas.mosaic.core import PARALLEL as PARALLEL from jax._src.pallas.mosaic.core import PrefetchScalarGridSpec as PrefetchScalarGridSpec -from jax._src.pallas.mosaic.core import semaphore as semaphore +from jax._src.pallas.core import semaphore as semaphore from jax._src.pallas.mosaic.core import SemaphoreType as SemaphoreType from jax._src.pallas.mosaic.core import TPUMemorySpace as TPUMemorySpace from jax._src.pallas.mosaic.core import TPUCompilerParams as TPUCompilerParams @@ -40,8 +40,8 @@ from jax._src.pallas.mosaic.primitives import async_remote_copy as async_remote_copy from jax._src.pallas.mosaic.primitives import bitcast as bitcast from jax._src.pallas.mosaic.primitives import delay as delay -from jax._src.pallas.mosaic.primitives import device_id as device_id -from jax._src.pallas.mosaic.primitives import DeviceIdType as DeviceIdType +from jax._src.pallas.primitives import device_id as device_id +from jax._src.pallas.primitives import DeviceIdType as DeviceIdType from jax._src.pallas.mosaic.primitives import get_barrier_semaphore as get_barrier_semaphore from jax._src.pallas.mosaic.primitives import make_async_copy as make_async_copy from jax._src.pallas.mosaic.primitives import make_async_remote_copy as make_async_remote_copy @@ -49,9 +49,9 @@ from jax._src.pallas.mosaic.primitives import prng_seed as prng_seed from jax._src.pallas.mosaic.primitives import repeat as repeat from jax._src.pallas.mosaic.primitives import roll as roll -from jax._src.pallas.mosaic.primitives import semaphore_read as semaphore_read -from jax._src.pallas.mosaic.primitives import semaphore_signal as semaphore_signal -from jax._src.pallas.mosaic.primitives import semaphore_wait as semaphore_wait +from jax._src.pallas.primitives import semaphore_read as semaphore_read +from jax._src.pallas.primitives import semaphore_signal as semaphore_signal +from jax._src.pallas.primitives import semaphore_wait as semaphore_wait from jax._src.pallas.mosaic.random import sample_block as sample_block from jax._src.pallas.mosaic.random import to_pallas_key as to_pallas_key From 2692c5ff98c6dfbcf45b9f8d26db4cc9c2a67a79 Mon Sep 17 00:00:00 2001 From: Praveen Narayanan Date: Fri, 21 Mar 2025 17:35:37 -0700 Subject: [PATCH 097/483] Lower lax.ragged_dot_general to chlo.ragged_dot in some cases on tpu. PiperOrigin-RevId: 739348011 --- jax/_src/lax/lax.py | 131 ++++++++++++++++++++++++++++++++++++------- tests/export_test.py | 4 +- tests/lax_test.py | 92 ++++++++++++++++++++++++++++-- 3 files changed, 200 insertions(+), 27 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index a6a0924c9c0f..80a469ab6a11 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -67,6 +67,7 @@ from jax._src.lib.mlir.dialects import chlo from jax._src.lib.mlir.dialects import hlo from jax._src.sharding_impls import (PmapSharding, NamedSharding, + ShardingContext, SPMDAxisContext, PartitionSpec as P, canonicalize_sharding) from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DTypeLike, Shape from jax._src.util import (NumpyComplexWarning, cache, canonicalize_axis, @@ -5378,15 +5379,26 @@ def _dot_general_batch_unpack_dims(batch_dims): core.pp_eqn_rules[dot_general_p] = _dot_general_pp_rule batching.ragged_prop_rules[dot_general_p] = _dot_general_ragged_prop_rule -def precision_attr(precision: Precision) -> ir.ArrayAttr: + +def _full_precision(precision: Precision) -> tuple[Precision, Precision]: if precision is None or isinstance(precision, (DotAlgorithm, DotAlgorithmPreset)): - full_precision = (Precision.DEFAULT, Precision.DEFAULT) + return (Precision.DEFAULT, Precision.DEFAULT) elif not isinstance(precision, tuple): - full_precision = (precision, precision) + return (precision, precision) else: - full_precision = precision + return precision + + +def precision_attr(precision: Precision) -> ir.ArrayAttr: return ir.ArrayAttr.get( - [hlo.PrecisionAttr.get(str(p)) for p in full_precision]) + [hlo.PrecisionAttr.get(str(p)) for p in _full_precision(precision)] + ) + + +def chlo_precision_attr(precision: Precision) -> ir.ArrayAttr: + return ir.ArrayAttr.get( + [chlo.PrecisionAttr.get(str(p)) for p in _full_precision(precision)] + ) def dot_algorithm_attr(precision: CanonicalPrecision, lhs_dtype: DTypeLike, @@ -5424,9 +5436,7 @@ def maybe_convert_dtype(input_dtype, target_dtypes): return lhs_dtype, rhs_dtype, out_type -def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers, - precision, preferred_element_type: np.dtype | None, - out_sharding, platform: str = "default"): +def _handle_dot_precision(ctx, lhs, rhs, precision, platform): def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes): fp8_dtypes = (dtypes.float8_e4m3fn, dtypes.float8_e5m2, dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz) @@ -5437,19 +5447,12 @@ def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes): if dtypes.float8_e8m0fnu is not None: fp8_dtypes += (dtypes.float8_e8m0fnu,) return _lhs_dtypes in fp8_dtypes and _rhs_dtypes in fp8_dtypes - del preferred_element_type # Implied by the output aval - lhs_aval, rhs_aval = ctx.avals_in + + # The *_ lets us reuse this for ragged_dot_general, which has group_sizes. + lhs_aval, rhs_aval, *_ = ctx.avals_in lhs_dtype, rhs_dtype = lhs_aval.dtype, rhs_aval.dtype aval_out, = ctx.avals_out accumulation_aval = aval_out - (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers - - dot_dnums = hlo.DotDimensionNumbers.get( - lhs_batching_dimensions=list(lhs_batch), - rhs_batching_dimensions=list(rhs_batch), - lhs_contracting_dimensions=list(lhs_contracting), - rhs_contracting_dimensions=list(rhs_contracting)) - algorithm_kwarg = {} if isinstance(precision, (DotAlgorithm, DotAlgorithmPreset)): # The CPU backend silently ignores the algorithm spec, so we check here to @@ -5507,7 +5510,22 @@ def maybe_convert_dtype(operand, operand_aval, target_dtype): core.ShapedArray(lhs_aval.shape, aval_out.dtype)) rhs = mlir.convert_hlo(ctx, rhs, rhs_aval, core.ShapedArray(rhs_aval.shape, aval_out.dtype)) + return lhs, rhs, accumulation_aval, algorithm_kwarg + +def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers, + precision, preferred_element_type: np.dtype | None, + out_sharding, platform: str = "default"): + del preferred_element_type # Implied by the output aval + lhs, rhs, accumulation_aval, algorithm_kwarg = _handle_dot_precision( + ctx, lhs, rhs, precision, platform + ) + (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers + dot_dnums = hlo.DotDimensionNumbers.get( + lhs_batching_dimensions=list(lhs_batch), + rhs_batching_dimensions=list(rhs_batch), + lhs_contracting_dimensions=list(lhs_contracting), + rhs_contracting_dimensions=list(rhs_contracting)) result = hlo.dot_general( mlir.aval_to_ir_type(accumulation_aval), lhs, @@ -5516,7 +5534,7 @@ def maybe_convert_dtype(operand, operand_aval, target_dtype): precision_config=precision_attr(precision), **algorithm_kwarg, ) - + aval_out, = ctx.avals_out result = mlir.lower_with_sharding_in_types(ctx, result, aval_out) if accumulation_aval.dtype != aval_out.dtype: result = mlir.convert_hlo(ctx, result, accumulation_aval, aval_out) @@ -6035,10 +6053,85 @@ def expand(x, dim, gs, *axes): ) +def _ragged_dot_general_lower( + ctx, + lhs, + rhs, + group_sizes, + *, + ragged_dot_dimension_numbers, + precision, + preferred_element_type: np.dtype | None, + group_offset: Array | None = None, + platform: str = 'default', +): + if group_offset is not None: + raise NotImplementedError('Unimplemented group_offset support.') + + # TODO(pravnar): Remove this once we have sharding support. + def use_default_lowering(): + axis_context = ctx.module_context.axis_context + return ( + isinstance(axis_context, SPMDAxisContext) + or isinstance(axis_context, ShardingContext) + and axis_context.num_devices > 1 + ) + if use_default_lowering(): + result = mlir.lower_fun(_ragged_dot_general_impl, multiple_results=False)( + ctx, lhs, rhs, group_sizes, + ragged_dot_dimension_numbers=ragged_dot_dimension_numbers, + precision=precision, + preferred_element_type=preferred_element_type, + group_offset=group_offset + ) + (aval_out,) = ctx.avals_out + return mlir.lower_with_sharding_in_types(ctx, result, aval_out) + + del preferred_element_type # Implied by the output aval + lhs, rhs, accumulation_aval, _ = _handle_dot_precision( + ctx, lhs, rhs, precision, platform + ) + (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = ( + ragged_dot_dimension_numbers.dot_dimension_numbers + ) + ragged_dot_dnums = chlo.RaggedDotDimensionNumbers.get( + lhs_batching_dimensions=list(lhs_batch), + rhs_batching_dimensions=list(rhs_batch), + lhs_contracting_dimensions=list(lhs_contracting), + rhs_contracting_dimensions=list(rhs_contracting), + lhs_ragged_dimensions=list( + ragged_dot_dimension_numbers.lhs_ragged_dimensions + ), + rhs_group_dimensions=list( + ragged_dot_dimension_numbers.rhs_group_dimensions + ), + ) + result = chlo.ragged_dot( + mlir.aval_to_ir_type(accumulation_aval), + lhs, + rhs, + group_sizes, + ragged_dot_dnums, + precision_config=chlo_precision_attr(precision), + ) + (aval_out,) = ctx.avals_out + result = mlir.lower_with_sharding_in_types(ctx, result, aval_out) + if accumulation_aval.dtype != aval_out.dtype: + result = mlir.convert_hlo(ctx, result, accumulation_aval, aval_out) + return [result] + + mlir.register_lowering(ragged_dot_general_p, mlir.lower_fun(_ragged_dot_general_impl, multiple_results=False)) +for platform in ['tpu']: + mlir.register_lowering( + ragged_dot_general_p, + partial(_ragged_dot_general_lower, platform=platform), + platform=platform, + ) + def _broadcast_in_dim_shape_rule(operand, *, shape, broadcast_dimensions, sharding): diff --git a/tests/export_test.py b/tests/export_test.py index 2b083f3121f4..0b78a29a8e6a 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -1903,8 +1903,8 @@ def f_jax(x): @jtu.parameterized_filterable( kwargs=[ - {"m": 5, "k": 4, "n": 3, "group_sizes": [5]}, - {"m": 10, "k": 9, "n": 8, "group_sizes": [3, 7]}, + {"m": 64, "k": 4, "n": 3, "group_sizes": [5]}, + {"m": 64, "k": 9, "n": 8, "group_sizes": [3, 7]}, ]) def test_ragged_dot(self, m, k, n, group_sizes): def f_jax(x, y, gs): diff --git a/tests/lax_test.py b/tests/lax_test.py index f7cca2c9b48f..40f2eb8f3588 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -4796,10 +4796,10 @@ class RaggedTest(jtu.JaxTestCase): @jtu.sample_product( [ - {'m': 5, 'k': 4, 'n': 3, 'num_groups': 1}, - {'m': 10, 'k': 9, 'n': 8, 'num_groups': 2}, + {'m': 64, 'k': 4, 'n': 3, 'num_groups': 1}, + {'m': 64, 'k': 9, 'n': 8, 'num_groups': 2}, ], - dtype=jtu.dtypes.numeric, + dtype=jtu.dtypes.all_floating, ) def test_ragged_dot(self, m, k, n, num_groups, dtype): """Tests ragged_dot. @@ -4810,6 +4810,8 @@ def test_ragged_dot(self, m, k, n, num_groups, dtype): Raises: SkipTest: in the case dtype is not supported. """ + if (dtype == np.float16): + raise SkipTest(f"unsupported dtype for ragged_dot: {dtype}") lhs_shape = (m, k) rhs_shape = (num_groups, k, n) @@ -4831,6 +4833,25 @@ def group_sizes(m, num_groups): self._CheckAgainstNumpy( lax_reference.ragged_dot, lax.ragged_dot, args_maker) + @parameterized.parameters( + { "m": 5, "k": 4, "n": 3, "num_groups": 1}, + { "m": 10, "k": 9, "n": 8, "num_groups": 2}, + ) + def test_ragged_dot_unsupported( + self, m, k, n, num_groups): + lhs_shape = (m, k) + rhs_shape = (num_groups, k, n) + group_sizes_shape = (num_groups,) + + args_maker = lambda: [ + jnp.ones(lhs_shape, dtype=jnp.float32), + jnp.ones(rhs_shape, dtype=jnp.float32), + jnp.ones(group_sizes_shape, dtype=jnp.int32), + ] + if jtu.test_device_matches(["tpu"]): + with self.assertRaises(jax.errors.JaxRuntimeError): + self._CompileAndCheck(lax.ragged_dot, args_maker) + @parameterized.parameters( { "lhs_shape": lhs_shape, @@ -5049,10 +5070,69 @@ def test_ragged_dot_general_shape_inference_success( lhs = jnp.ones(lhs_shape, dtype=jnp.float32) rhs = jnp.ones(rhs_shape, dtype=jnp.float32) group_sizes = jnp.ones(group_sizes_shape, dtype=jnp.int32) - self.assertEqual( - lax.ragged_dot_general(lhs, rhs, group_sizes, ragged_dnums).shape, - out_shape, + if jtu.test_device_matches(["tpu"]): + actual_shape = lax_internal._ragged_dot_general_shape_rule( + lhs, rhs, group_sizes, ragged_dot_dimension_numbers=ragged_dnums, + precision=jax.lax.Precision.DEFAULT, + preferred_element_type=jnp.float32, + ) + else: + actual_shape = lax.ragged_dot_general( + lhs, rhs, group_sizes, ragged_dnums + ).shape + self.assertEqual(actual_shape, out_shape) + + @parameterized.product( + batch_size=[3, 5], + m=[128, 1024], + k=[128, 1024], + n=[128, 1024], + num_groups=[2, 4], + ) + def test_ragged_dot_general_vmap( + self, batch_size: int, m: int, k: int, n: int, num_groups: int + ): + if (jtu.test_device_matches(["tpu"])): + raise SkipTest("batched ragged_dot not yet supported on TPU") + + lhs_shape = (batch_size, m, k) + rhs_shape = (batch_size, num_groups, k, n) + dtype = jnp.float32 + + def make_group_sizes(m, num_groups): + ends_no_final = jnp.sort(self.rng().choice(m, size=num_groups - 1)) + ends = jnp.concatenate( + [ends_no_final, jnp.array([m], dtype=ends_no_final.dtype)]) + starts = jnp.concatenate( + [jnp.zeros(1, dtype=ends_no_final.dtype), ends_no_final]) + return ends - starts + + rng = jtu.rand_small(self.rng()) + args_maker = lambda: [ + rng(lhs_shape, dtype), + rng(rhs_shape, dtype), + jnp.array([make_group_sizes(m, num_groups) for _ in range(batch_size)]), + ] + lhs, rhs, group_sizes = args_maker() + + out_dtype = jnp.float32 + precision = jax.lax.Precision.HIGHEST + ragged_dot = partial( + jax.lax.ragged_dot, + preferred_element_type=out_dtype, + precision=precision, ) + tol = 1e-5 + + batch_res = jax.vmap(ragged_dot)(lhs, rhs, group_sizes) + for i in range(batch_size): + # The ragged_dot does not zero out the output in the case sum(group_sizes) + # < m, hence we need to compare only the valid part of the output. + upper_bound = group_sizes[i].sum(axis=0) + ref_res = ragged_dot(lhs[i], rhs[i], group_sizes[i])[0:upper_bound, :] + self.assertArraysAllClose( + batch_res[i, 0:upper_bound, :], ref_res, rtol=tol, atol=tol + ) if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From 55e408471ceaf5f0ed0e10053331d919fa2540ec Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 21 Mar 2025 18:52:12 -0700 Subject: [PATCH 098/483] [JAX] [XLA:Python] Migrate xla_extension and its type stubs into jaxlib. Future changes will migrate many of its dependent modules. PiperOrigin-RevId: 739361786 --- jax/_src/lib/BUILD | 2 +- jaxlib/BUILD | 10 +- jaxlib/jax.bzl | 9 + jaxlib/tools/BUILD.bazel | 4 +- jaxlib/tools/build_wheel.py | 6 +- jaxlib/xla/BUILD | 111 +- jaxlib/xla/xla.cc | 965 +++++++++++++++ jaxlib/xla/xla_client.py | 2 +- jaxlib/xla/xla_extension/__init__.pyi | 1059 +++++++++++++++++ jaxlib/xla/xla_extension/config.pyi | 32 + jaxlib/xla/xla_extension/guard_lib.pyi | 46 + jaxlib/xla/xla_extension/ifrt_programs.pyi | 43 + jaxlib/xla/xla_extension/ifrt_proxy.pyi | 33 + jaxlib/xla/xla_extension/jax_jit.pyi | 76 ++ jaxlib/xla/xla_extension/mlir.pyi | 34 + jaxlib/xla/xla_extension/ops.pyi | 465 ++++++++ jaxlib/xla/xla_extension/pmap_lib.pyi | 83 ++ jaxlib/xla/xla_extension/profiler.pyi | 58 + jaxlib/xla/xla_extension/pytree.pyi | 158 +++ jaxlib/xla/xla_extension/sdy.pyi | 32 + .../xla/xla_extension/transfer_guard_lib.pyi | 39 + jaxlib/xla_extension.py | 17 + 22 files changed, 3268 insertions(+), 16 deletions(-) create mode 100644 jaxlib/xla/xla.cc create mode 100644 jaxlib/xla/xla_extension/__init__.pyi create mode 100644 jaxlib/xla/xla_extension/config.pyi create mode 100644 jaxlib/xla/xla_extension/guard_lib.pyi create mode 100644 jaxlib/xla/xla_extension/ifrt_programs.pyi create mode 100644 jaxlib/xla/xla_extension/ifrt_proxy.pyi create mode 100644 jaxlib/xla/xla_extension/jax_jit.pyi create mode 100644 jaxlib/xla/xla_extension/mlir.pyi create mode 100644 jaxlib/xla/xla_extension/ops.pyi create mode 100644 jaxlib/xla/xla_extension/pmap_lib.pyi create mode 100644 jaxlib/xla/xla_extension/profiler.pyi create mode 100644 jaxlib/xla/xla_extension/pytree.pyi create mode 100644 jaxlib/xla/xla_extension/sdy.pyi create mode 100644 jaxlib/xla/xla_extension/transfer_guard_lib.pyi create mode 100644 jaxlib/xla_extension.py diff --git a/jax/_src/lib/BUILD b/jax/_src/lib/BUILD index 1f4f41132e9e..aa2d9cba4973 100644 --- a/jax/_src/lib/BUILD +++ b/jax/_src/lib/BUILD @@ -45,6 +45,7 @@ py_library_providing_imports_info( "//jaxlib:cpu_feature_guard", "//jaxlib:utils", "//jaxlib/xla:xla_client", + "//jaxlib/xla:xla_extension", "//jaxlib/triton", "//jaxlib/mlir/_mlir_libs:register_jax_dialects", "//jaxlib/mlir:arithmetic_dialect", @@ -61,6 +62,5 @@ py_library_providing_imports_info( "//jaxlib/mlir:sparse_tensor_dialect", "//jaxlib/mlir:stablehlo_dialect", "//jaxlib/mlir:vector_dialect", - # xla_extension ]), ) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 5f693c5384df..52c945482222 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -29,13 +29,6 @@ package( default_visibility = ["//jax:internal"], ) -# This makes xla_extension module accessible from jax._src.lib. -genrule( - name = "xla_extension_py", - outs = ["xla_extension.py"], - cmd = "echo 'from xla.xla.python.xla_extension import *\n' > $@", -) - py_library_providing_imports_info( name = "jaxlib", srcs = [ @@ -51,8 +44,8 @@ py_library_providing_imports_info( "lapack.py", "plugin_support.py", "xla_client.py", + "xla_extension.py", ":version", - ":xla_extension_py", ], data = [":ffi_headers"], lib_rule = pytype_library, @@ -82,6 +75,7 @@ py_library_providing_imports_info( "//jaxlib/mosaic", "//jaxlib/triton", "//jaxlib/xla:xla_client", + "//jaxlib/xla:xla_extension", ], ) diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 4403915154bc..c6f55a86143f 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -610,3 +610,12 @@ def jax_py_test( if "PYTHONWARNINGS" not in env: env["PYTHONWARNINGS"] = "error" py_test(name = name, env = env, **kwargs) + +def if_oss(oss_value, google_value = []): + """Returns one of the arguments based on the non-configurable build env. + + Specifically, it does not return a `select`, and can be used to e.g. + compute elements of list attributes. + """ + _ = (google_value, oss_value) # buildifier: disable=unused-variable + return oss_value diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index afa5866e286d..2ddc9e90a702 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -62,11 +62,11 @@ py_binary( "//jaxlib", "//jaxlib:README.md", "//jaxlib:setup.py", + "//jaxlib/xla:xla_client.py", + "//jaxlib/xla:xla_extension", "@xla//xla/ffi/api:api.h", "@xla//xla/ffi/api:c_api.h", "@xla//xla/ffi/api:ffi.h", - "@xla//xla/python:xla_client.py", - "@xla//xla/python:xla_extension", ] + if_windows([ "//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi.dll", ]), diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index 0e0ce077cb23..9967fc14b9f9 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -110,7 +110,7 @@ def patch_copy_xla_extension_stubs(dst_dir): xla_extension_dir = os.path.join(dst_dir, "xla_extension") os.makedirs(xla_extension_dir) for stub_name in _XLA_EXTENSION_STUBS: - stub_path = r.Rlocation("xla/xla/python/xla_extension/" + stub_name) + stub_path = r.Rlocation("__main__/jaxlib/xla/xla_extension/" + stub_name) stub_path = str(stub_path) # Make pytype accept os.path.exists(stub_path). if stub_name in _OPTIONAL_XLA_EXTENSION_STUBS and not os.path.exists(stub_path): continue @@ -135,7 +135,7 @@ def verify_mac_libraries_dont_reference_chkstack(): if not _is_mac(): return nm = subprocess.run( - ["nm", "-g", r.Rlocation("xla/xla/python/xla_extension.so")], + ["nm", "-g", r.Rlocation("__main/jaxlib/xla/xla_extension.so")], capture_output=True, text=True, check=False, @@ -198,7 +198,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu): "__main__/jaxlib/plugin_support.py", "__main__/jaxlib/version.py", "__main__/jaxlib/xla/xla_client.py", - f"xla/xla/python/xla_extension.{pyext}", + f"__main__/jaxlib/xla/xla_extension.{pyext}", ], ) # This file is required by PEP-561. It marks jaxlib as package containing diff --git a/jaxlib/xla/BUILD b/jaxlib/xla/BUILD index 41152d642fc8..3239ba703937 100644 --- a/jaxlib/xla/BUILD +++ b/jaxlib/xla/BUILD @@ -14,6 +14,7 @@ load( "//jaxlib:jax.bzl", + "if_oss", "nanobind_extension", "py_deps", "py_strict_library", @@ -35,6 +36,114 @@ package_group( ], ) +nanobind_extension( + name = "xla_extension", + srcs = ["xla.cc"], + pytype_deps = py_deps(["numpy"]), + pytype_srcs = glob(["xla_extension/*.pyi"]), + visibility = ["//visibility:public"], + deps = [ + "@com_google_absl//absl/base", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/log:initialize", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@nanobind", + "@tsl//tsl/platform", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla:literal", + "@xla//xla:shape_util", + "@xla//xla:types", + "@xla//xla:util", + "@xla//xla/backends/cpu/collectives:cpu_collectives", + "@xla//xla/ffi:ffi_api", + "@xla//xla/pjrt:exceptions", + "@xla//xla/pjrt:mlir_to_hlo", + "@xla//xla/pjrt:pjrt_api", + "@xla//xla/pjrt:pjrt_c_api_client", + "@xla//xla/pjrt:pjrt_client", + "@xla//xla/pjrt:pjrt_common", + "@xla//xla/pjrt:pjrt_compiler", + "@xla//xla/pjrt:pjrt_executable", + "@xla//xla/pjrt:pjrt_layout", + "@xla//xla/pjrt:status_casters", + "@xla//xla/pjrt/c:pjrt_c_api_hdrs", + "@xla//xla/pjrt/distributed", + "@xla//xla/pjrt/distributed:client", + "@xla//xla/pjrt/distributed:key_value_store_interface", + "@xla//xla/pjrt/distributed:protocol_proto_cc", + "@xla//xla/pjrt/distributed:service", + "@xla//xla/pjrt/plugin/xla_cpu:cpu_client_options", + "@xla//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client", + "@xla//xla/python:config", + "@xla//xla/python:custom_call_sharding", + "@xla//xla/python:dlpack", + "@xla//xla/python:guard_lib", + "@xla//xla/python:jax_jit", + "@xla//xla/python:logging", + "@xla//xla/python:mlir", + "@xla//xla/python:nb_absl_flat_hash_map", + "@xla//xla/python:nb_absl_span", + "@xla//xla/python:nb_class_ptr", + "@xla//xla/python:ops", + "@xla//xla/python:pjit", + "@xla//xla/python:pmap_lib", + "@xla//xla/python:pprof_profile_builder", + "@xla//xla/python:profiler", + "@xla//xla/python:py_client", + "@xla//xla/python:python_ref_manager", + "@xla//xla/python:pytree", + "@xla//xla/python:refine_polymorphic_shapes", + "@xla//xla/python:sdy", + "@xla//xla/python:traceback", + "@xla//xla/python:types", + "@xla//xla/python:util", + "@xla//xla/python:weakref_lru_cache", + "@xla//xla/python:xla_compiler", + "@xla//xla/python/ifrt", + "@xla//xla/python/ifrt:plugin_program", + "@xla//xla/python/ifrt:plugin_program_serdes", + "@xla//xla/python/ifrt_proxy/client:py_module", + "@xla//xla/python/pjrt_ifrt", + "@xla//xla/python/pjrt_ifrt:pjrt_attribute_map_util", + "@xla//xla/python/pjrt_ifrt:xla_ifrt", + "@xla//xla/tsl/concurrency:ref_count", + "@xla//xla/tsl/distributed_runtime/preemption:preemption_sync_manager", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/platform:status", + "@xla//xla/tsl/platform:statusor", + "@xla//xla/tsl/platform/cloud:gcs_file_system", + "@xla//xla/tsl/python/lib/core:numpy", + ] + select({ + # gloo tcp transport only builds on linux + "@xla//xla/tsl:macos": [ + "@gloo//:transport_uv", + "@xla//xla/backends/cpu/collectives:gloo_collectives", + "@xla//xla/backends/cpu/collectives:gloo_kv_store", + ], + "@xla//xla/tsl:windows": [], + "//conditions:default": [ + "@gloo//:transport_tcp", + "@xla//xla/backends/cpu/collectives:gloo_collectives", + "@xla//xla/backends/cpu/collectives:gloo_kv_store", + "@xla//xla/python/transfer:py_socket_transfer", + ], + }) + select({ + # mpitrampoline does not build on windows + "@xla//xla/tsl:windows": [], + # we support MPI collectives only in OSS builds + "//conditions:default": if_oss(["@xla//xla/backends/cpu/collectives:mpi_collectives"]), + }), +) + pytype_strict_library( name = "xla_client", srcs = ["xla_client.py"], @@ -43,7 +152,7 @@ pytype_strict_library( deps = py_deps([ "numpy", "ml_dtypes", - ]) + ["@xla//xla/python:xla_extension"], + ]) + [":xla_extension"], ) py_strict_test( diff --git a/jaxlib/xla/xla.cc b/jaxlib/xla/xla.cc new file mode 100644 index 000000000000..5f39b9173b89 --- /dev/null +++ b/jaxlib/xla/xla.cc @@ -0,0 +1,965 @@ +/* Copyright 2019 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/casts.h" +#include "absl/container/flat_hash_map.h" +#include "absl/hash/hash.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "llvm/Support/Casting.h" +#include "nanobind/nanobind.h" +#include "nanobind/nb_defs.h" +#include "nanobind/stl/function.h" // IWYU pragma: keep +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/pair.h" // IWYU pragma: keep +#include "nanobind/stl/set.h" // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "xla/backends/cpu/collectives/cpu_collectives.h" +#include "xla/pjrt/c/pjrt_c_api.h" +#include "xla/pjrt/distributed/client.h" +#include "xla/pjrt/distributed/distributed.h" +#include "xla/pjrt/distributed/protocol.pb.h" +#include "xla/pjrt/distributed/service.h" +#include "xla/pjrt/pjrt_compiler.h" +#include "xla/pjrt/plugin/xla_cpu/cpu_client_options.h" +#include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/topology.h" +#include "xla/python/ifrt_proxy/client/py_module.h" +#include "xla/python/pjrt_ifrt/pjrt_attribute_map_util.h" +#include "xla/python/py_client.h" +#include "xla/python/py_program.h" +#include "xla/python/sdy.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/python/lib/core/numpy.h" // NOLINT + +#if defined(__linux__) +#include "gloo/transport/tcp/attr.h" +#include "gloo/transport/tcp/device.h" +#include "xla/backends/cpu/collectives/gloo_collectives.h" +#include "xla/backends/cpu/collectives/gloo_kv_store.h" +#include "xla/python/transfer/py_socket_transfer.h" +#elif defined(__APPLE__) +#include "gloo/transport/uv/device.h" +#include "xla/backends/cpu/collectives/gloo_collectives.h" // NOLINT +#include "xla/backends/cpu/collectives/gloo_kv_store.h" // NOLINT +#endif // defined(__linux__) + +#if !defined(_WIN32) && !defined(PLATFORM_GOOGLE) +#include "xla/backends/cpu/collectives/mpi_collectives.h" +#endif // !_WIN32 && !PLATFORM_GOOGLE + +#include "xla/pjrt/distributed/key_value_store_interface.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/pjrt_api.h" +#include "xla/pjrt/pjrt_c_api_client.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_common.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/python/config.h" +#include "xla/python/custom_call_sharding.h" +#include "xla/python/dlpack.h" +#include "xla/python/guard_lib.h" +#include "xla/python/jax_jit.h" +#include "xla/python/logging.h" // IWYU pragma: keep +#include "xla/python/mlir.h" +#include "xla/python/nb_absl_flat_hash_map.h" // IWYU pragma: keep +#include "xla/python/nb_absl_span.h" // IWYU pragma: keep +#include "xla/python/nb_class_ptr.h" +#include "xla/python/ops.h" +#include "xla/python/pjit.h" +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "xla/python/pjrt_ifrt/pjrt_executable.h" +#include "xla/python/pjrt_ifrt/pjrt_topology.h" +#include "xla/python/pmap_lib.h" +#include "xla/python/pprof_profile_builder.h" +#include "xla/python/profiler.h" +#include "xla/python/py_array.h" +#include "xla/python/py_compile_only_client.h" +#include "xla/python/py_device.h" +#include "xla/python/py_device_list.h" +#include "xla/python/py_executable.h" +#include "xla/python/py_memory_space.h" +#include "xla/python/python_ref_manager.h" +#include "xla/python/pytree.h" +#include "xla/python/sharding.h" +#include "xla/python/traceback.h" +#include "xla/python/weakref_lru_cache.h" +#include "xla/python/xla_compiler.h" +#include "xla/tsl/distributed_runtime/preemption/preemption_sync_manager.h" +#include "xla/tsl/platform/status.h" +#include "tsl/platform/platform.h" + +// TODO(phawkins): remove host_id properties after JAX is update to avoid them. + +namespace xla { +namespace { + +namespace nb = nanobind; + +bool IsOptimizedBuild() { +#if NDEBUG + return true; +#else + return false; +#endif // NDEBUG +} + +// Is*san reports whether the build is under that particular sanitizer. +bool IsAsan() { +#if defined(ADDRESS_SANITIZER) + return true; +#else // defined(ADDRESS_SANITIZER) + return false; +#endif +} + +bool IsMsan() { +#if defined(MEMORY_SANITIZER) + return true; +#else // defined(MEMORY_SANITIZER) + return false; +#endif +} + +bool IsTsan() { +#if defined(THREAD_SANITIZER) + return true; +#else // defined(THREAD_SANITIZER) + return false; +#endif +} + +// IsSanitized reports whether the build is under any sanitizer. +bool IsSanitized() { return IsAsan() || IsMsan() || IsTsan(); } + +} // namespace + +NB_MODULE(xla_extension, m) { + // Initialize ABSL logging because code within XLA uses it. +#ifndef PLATFORM_GOOGLE + InitializeAbslLogging(); +#endif // PLATFORM_GOOGLE + + // We seem to get a fair number of leak warnings from nanobind. It's unclear + // whether these are false positives or not. + nb::set_leak_warnings(false); + + tsl::ImportNumpy(); + + // Exceptions + nb::exception xla_runtime_error(m, "XlaRuntimeError", + PyExc_RuntimeError); + xla_runtime_error.attr("__doc__") = nb::str( + "Runtime errors thrown by the JAX runtime. While the JAX runtime may " + "raise other exceptions as well, most exceptions thrown by the runtime " + "are instances of this class."); + + // Types + nb::enum_(m, "PrimitiveType", nb::is_arithmetic()) + .value("PRIMITIVE_TYPE_INVALID", PRIMITIVE_TYPE_INVALID) + .value("PRED", PRED) + .value("S4", S4) + .value("S8", S8) + .value("S16", S16) + .value("S32", S32) + .value("S64", S64) + .value("U4", U4) + .value("U8", U8) + .value("U16", U16) + .value("U32", U32) + .value("U64", U64) + .value("F16", F16) + .value("F4E2M1FN", F4E2M1FN) + // TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. + // .value("F8E3M4", F8E3M4) + // .value("F8E4M3", F8E4M3) + .value("F8E8M0FNU", F8E8M0FNU) + .value("F8E4M3FN", F8E4M3FN) + .value("F8E4M3B11FNUZ", F8E4M3B11FNUZ) + .value("F8E4M3FNUZ", F8E4M3FNUZ) + .value("F8E5M2", F8E5M2) + .value("F8E5M2FNUZ", F8E5M2FNUZ) + .value("BF16", BF16) + .value("F32", F32) + .value("F64", F64) + .value("C64", C64) + .value("C128", C128) + .value("TUPLE", TUPLE) + .value("OPAQUE_TYPE", OPAQUE_TYPE) + .value("TOKEN", TOKEN); + + // Must be before PyClient.compile. + BuildXlaCompilerSubmodule(m); + + PyDevice::RegisterPythonType(m); + PyMemorySpace::RegisterPythonType(m); + PyClient::RegisterPythonTypes(m); + + nb::enum_(m, "ArrayCopySemantics", + nb::is_arithmetic()) + .value("ALWAYS_COPY", ifrt::ArrayCopySemantics::kAlwaysCopy) + .value("REUSE_INPUT", ifrt::ArrayCopySemantics::kReuseInput) + .value("DONATE_INPUT", ifrt::ArrayCopySemantics::kDonateInput); + + nb::class_(m, "PjRtLayout") + .def("__str__", &PjRtLayout::ToString) + .def("__eq__", [](const PjRtLayout& layout, + const PjRtLayout& other) { return layout == other; }) + .def("__hash__", + [](const PjRtLayout& layout) { return absl::HashOf(layout); }) + .def("_xla_layout", &PjRtLayout::xla_layout) + .def("__getstate__", + [](const PjRtLayout& layout) -> nb::tuple { + absl::StatusOr serialized = layout.Serialize(); + ThrowIfError(serialized.status()); + return nb::make_tuple( + nb::bytes(serialized->data(), serialized->size())); + }) + .def("__setstate__", [](PjRtLayout* self, nb::tuple t) { + nb::bytes serialized = nb::cast(t[0]); + absl::StatusOr> layout = + PjRtLayout::Deserialize( + absl::string_view(serialized.c_str(), serialized.size())); + ThrowIfError(layout.status()); + new (self) PjRtLayout((*layout)->xla_layout()); + }); + + jax::BuildWeakrefLRUCacheAPI(m); + + nb::class_ cpu_collectives(m, "CpuCollectives"); + + m.def( + "make_gloo_tcp_collectives", + [](std::shared_ptr distributed_client, + + std::optional hostname, + std::optional interface) + -> std::shared_ptr { +#if defined(__linux__) + std::shared_ptr kv_store = nullptr; + if (distributed_client != nullptr) { + kv_store = GetDistributedKeyValueStore(distributed_client, + /*key_prefix=*/"cpu:"); + } + auto gloo_kv_store = std::make_unique(kv_store); + auto tcp_attrs = gloo::transport::tcp::attr(); + if (hostname) { + tcp_attrs.hostname = *hostname; + } + if (interface) { + tcp_attrs.iface = *interface; + } + auto tcp_device = gloo::transport::tcp::CreateDevice(tcp_attrs); + return std::make_shared(std::move(gloo_kv_store), + std::move(tcp_device)); +#elif defined(__APPLE__) + std::shared_ptr kv_store = nullptr; + if (distributed_client != nullptr) { + kv_store = GetDistributedKeyValueStore(distributed_client, + /*key_prefix=*/"cpu:"); + } + auto gloo_kv_store = std::make_unique(kv_store); + auto uv_attrs = gloo::transport::uv::attr(); + if (hostname) { + uv_attrs.hostname = *hostname; + } + if (interface) { + uv_attrs.iface = *interface; + } + auto uv_device = gloo::transport::uv::CreateDevice(uv_attrs); + return std::make_shared(std::move(gloo_kv_store), + std::move(uv_device)); +#else // defined(__linux__) + throw xla::XlaRuntimeError( + "make_gloo_tcp_collectives only implemented for linux and macos"); +#endif // defined(__linux__) + }, + nb::arg("distributed_client"), nb::arg("hostname").none() = std::nullopt, + nb::arg("interface").none() = std::nullopt); + +#if !defined(_WIN32) && !defined(PLATFORM_GOOGLE) + nb::class_ mpi_collectives(m, "MpiCollectives", + cpu_collectives); + mpi_collectives.def("Init", &cpu::MpiCollectives::Init); + mpi_collectives.def("Finalize", &cpu::MpiCollectives::Finalize); + m.def("make_mpi_collectives", []() -> std::shared_ptr { + return std::make_shared(); + }); +#else // !_WIN32 && !PLATFORM_GOOGLE + m.def("make_mpi_collectives", + []() -> std::shared_ptr { + throw xla::XlaRuntimeError( + "make_mpi_collectives is not implemented for Windows"); + }); +#endif // !_WIN32 && !PLATFORM_GOOGLE + + m.def( + "get_tfrt_cpu_client", + [](bool asynchronous, + std::shared_ptr distributed_client, + int node_id, int num_nodes, + std::shared_ptr collectives, + std::optional num_devices) -> nb_class_ptr { + std::unique_ptr ifrt_client; + { + nb::gil_scoped_release gil_release; + xla::CpuClientOptions options; + + options.asynchronous = asynchronous; + options.collectives = std::move(collectives); + options.process_id = node_id; + options.cpu_device_count = num_devices; + std::unique_ptr client = + xla::ValueOrThrow(xla::GetXlaPjrtCpuClient(std::move(options))); + ifrt::PjRtClient::CreateOptions ifrt_options; + ifrt_options.pjrt_client = + std::shared_ptr(std::move(client)); + if (distributed_client != nullptr) { + ifrt_options.kv_store = + GetDistributedKeyValueStore(distributed_client, + /*key_prefix=*/"cpu:"); + ifrt_options.process_id = node_id; + ifrt_options.num_processes = num_nodes; + } + ifrt_client = + ValueOrThrow(ifrt::PjRtClient::Create(std::move(ifrt_options))); + } + return PyClient::Make(std::move(ifrt_client)); + }, + nb::arg("asynchronous") = true, nb::arg("distributed_client") = nullptr, + nb::arg("node_id") = 0, nb::arg("num_nodes") = 1, + nb::arg("collectives").none() = + std::shared_ptr(), + nb::arg("num_devices").none() = std::nullopt); + m.def("pjrt_plugin_loaded", [](std::string platform_name) -> bool { + absl::StatusOr pjrt_api = pjrt::PjrtApi(platform_name); + return pjrt_api.ok(); + }); + m.def( + "load_pjrt_plugin", + [](std::string platform_name, std::optional library_path, + std::optional c_api) -> nb::capsule { + if (library_path.has_value()) { + const PJRT_Api* api = xla::ValueOrThrow( + pjrt::LoadPjrtPlugin(platform_name, *library_path)); + return nb::capsule(absl::bit_cast(api), "pjrt_c_api"); + } + if (absl::string_view(c_api->name()) != "pjrt_c_api") { + throw nb::value_error( + "c_api argument to load_pjrt_plugin is not a pjrt_c_api " + "capsule."); + } + xla::ThrowIfError(pjrt::SetPjrtApi( + platform_name, static_cast(c_api->data()))); + return *c_api; + }, + nb::arg("platform_name"), nb::arg("library_path").none() = std::nullopt, + nb::arg("c_api").none() = std::nullopt); + m.def("pjrt_plugin_initialized", [](std::string platform_name) -> bool { + return xla::ValueOrThrow(pjrt::IsPjrtPluginInitialized(platform_name)); + }); + m.def("initialize_pjrt_plugin", [](std::string platform_name) { + return xla::ThrowIfError(pjrt::InitializePjrtPlugin(platform_name)); + }); + + m.def( + "get_c_api_client", + [](std::string platform_name, + const absl::flat_hash_map& options, + std::shared_ptr distributed_client) + -> nb_class_ptr { + std::unique_ptr ifrt_client; + { + nb::gil_scoped_release gil_release; + std::shared_ptr kv_store = nullptr; + if (distributed_client != nullptr) { + kv_store = GetDistributedKeyValueStore( + distributed_client, + /*key_prefix=*/absl::StrCat(platform_name, ":")); + } + std::unique_ptr c_api_client = xla::ValueOrThrow( + GetCApiClient(platform_name, options, kv_store)); + ifrt_client = ifrt::PjRtClient::Create(std::move(c_api_client)); + } + return PyClient::Make(std::move(ifrt_client)); + }, + nb::arg("platform_name"), + nb::arg("options") = absl::flat_hash_map(), + nb::arg("distributed_client").none() = nullptr); + // TODO(b/322357665): Delete this method after TPU plugin changes to use the + // standard registration. + m.def("get_default_c_api_topology", + [](std::string platform_name, std::string topology_name, + const absl::flat_hash_map& options) + -> std::shared_ptr { + return std::make_shared(xla::ValueOrThrow( + GetCApiTopology(platform_name, topology_name, options))); + }); + m.def("get_c_api_topology", + [](nb::capsule c_api, std::string topology_name, + const absl::flat_hash_map& options) + -> std::shared_ptr { + if (absl::string_view(c_api.name()) != "pjrt_c_api") { + throw nb::value_error( + "Argument to get_c_api_topology was not a pjrt_c_api capsule."); + } + return std::make_shared(xla::ValueOrThrow( + GetCApiTopology(static_cast(c_api.data()), + topology_name, options))); + }); + m.def("get_topology_for_devices", + [](const std::vector>& py_devices) { + if (py_devices.empty()) { + throw nb::value_error( + "get_topology_for_devices requires >= 1 devices."); + } + auto client = py_devices[0]->client(); + absl::InlinedVector ifrt_devices; + ifrt_devices.reserve(py_devices.size()); + for (const auto& py_device : py_devices) { + if (py_device->client().get() != client.get()) { + throw nb::value_error( + "devices passed to get_topology_for_devices come from " + "different clients."); + } + ifrt_devices.push_back(py_device->device()); + } + ifrt::DeviceListRef device_list = + client->ifrt_client()->MakeDeviceList(ifrt_devices); + return xla::ValueOrThrow( + client->ifrt_client()->GetTopologyForDevices(device_list)); + }); + + TF_CHECK_OK(PyArray::RegisterTypes(m)); + jax::PyDeviceList::Register(m); + jax::RegisterSharding(m); + + nb::class_(m, "CompiledMemoryStats") + .def_rw("generated_code_size_in_bytes", + &CompiledMemoryStats::generated_code_size_in_bytes) + .def_rw("argument_size_in_bytes", + &CompiledMemoryStats::argument_size_in_bytes) + .def_rw("output_size_in_bytes", + &CompiledMemoryStats::output_size_in_bytes) + .def_rw("alias_size_in_bytes", &CompiledMemoryStats::alias_size_in_bytes) + .def_rw("temp_size_in_bytes", &CompiledMemoryStats::temp_size_in_bytes) + .def_rw("host_generated_code_size_in_bytes", + &CompiledMemoryStats::host_generated_code_size_in_bytes) + .def_rw("host_argument_size_in_bytes", + &CompiledMemoryStats::host_argument_size_in_bytes) + .def_rw("host_output_size_in_bytes", + &CompiledMemoryStats::host_output_size_in_bytes) + .def_rw("host_alias_size_in_bytes", + &CompiledMemoryStats::host_alias_size_in_bytes) + .def_rw("host_temp_size_in_bytes", + &CompiledMemoryStats::host_temp_size_in_bytes) + .def_prop_ro("serialized_hlo_proto", + [](const CompiledMemoryStats& cms) -> nb::bytes { + return nb::bytes(cms.serialized_hlo_proto.data(), + cms.serialized_hlo_proto.size()); + }) + .def("__str__", &CompiledMemoryStats::DebugString); + + nb::class_(m, "ExecuteResults") + .def("__len__", [](PyExecuteResults& results) { return results.Size(); }) + .def("disassemble_into_single_device_arrays", + &PyExecuteResults::DisassembleIntoSingleDeviceArrays) + .def("disassemble_prefix_into_single_device_arrays", + &PyExecuteResults::DisassemblePrefixIntoSingleDeviceArrays) + .def("consume_with_handlers", &PyExecuteResults::ConsumeWithHandlers) + .def("consume_token", &PyExecuteResults::ConsumeToken); + + nb::class_(m, "LoadedExecutable") + .def_prop_ro("client", &PyLoadedExecutable::client) + .def("local_devices", &PyLoadedExecutable::AddressableDevices) + .def("size_of_generated_code_in_bytes", + &PyLoadedExecutable::SizeOfGeneratedCodeInBytes) + .def( + "get_compiled_memory_stats", + xla::ValueOrThrowWrapper(&PyLoadedExecutable::GetCompiledMemoryStats)) + .def("delete", &PyLoadedExecutable::Delete) + .def("execute_sharded_on_local_devices", + xla::ValueOrThrowWrapper( + &PyLoadedExecutable::ExecuteShardedOnLocalDevices), + nb::arg("arguments")) + .def("execute_sharded_on_local_devices_with_tokens", + xla::ValueOrThrowWrapper( + &PyLoadedExecutable::ExecuteShardedOnLocalDevicesWithTokens), + nb::arg("arguments")) + // TODO(parkers): Switch execute_sharded_on_local_devices* to this. + .def("execute_sharded", + xla::ValueOrThrowWrapper(&PyLoadedExecutable::ExecuteSharded), + nb::arg("arguments"), nb::arg("with_tokens") = false) + .def("hlo_modules", ValueOrThrowWrapper(&PyLoadedExecutable::HloModules)) + .def("get_output_memory_kinds", + xla::ValueOrThrowWrapper(&PyLoadedExecutable::GetOutputMemoryKinds)) + .def("get_output_shardings", &PyLoadedExecutable::GetOutputShardings) + .def("get_parameter_layouts", + xla::ValueOrThrowWrapper(&PyLoadedExecutable::GetParameterLayouts)) + .def("get_output_layouts", + xla::ValueOrThrowWrapper(&PyLoadedExecutable::GetOutputLayouts)) + .def("get_parameter_shardings", + &PyLoadedExecutable::GetParameterShardings) + .def("keep_alive", &PyLoadedExecutable::KeepAlive) + .def("cost_analysis", + [](const PyLoadedExecutable& self) { + auto map = ValueOrThrow(self.GetCostAnalysis()); + return ifrt::ToPjRtAttributeMap(std::move(map)); + }) + .def_prop_ro("traceback", &PyLoadedExecutable::traceback) + .def_prop_ro("fingerprint", [](PyLoadedExecutable* exec) -> nb::object { + if (exec->fingerprint().has_value()) { + return nb::bytes(exec->fingerprint()->data(), + exec->fingerprint()->size()); + } else { + return nb::none(); + } + }); + nb::class_ token(m, "Token"); + token.def("block_until_ready", + [](PyToken& self) { xla::ThrowIfError(self.Await()); }); + + nb::class_ sharded_token(m, "ShardedToken"); + sharded_token.def("block_until_ready", [](PyShardedToken& self) { + xla::ThrowIfError(self.Await()); + }); + sharded_token.def("get_token", &PyShardedToken::GetPyToken); + + m.def("buffer_to_dlpack_managed_tensor", + xla::ValueOrThrowWrapper(BufferToDLPackManagedTensor), + nb::arg("buffer"), nb::arg("stream").none() = nb::none()); + m.def( + "dlpack_managed_tensor_to_buffer", + [](const nb::capsule& tensor, nb_class_ptr device, + std::optional stream) { + return xla::ValueOrThrow(DLPackManagedTensorToBuffer( + tensor, device->device(), device->client(), stream)); + }, + nb::arg("dlpack"), nb::arg("device"), nb::arg("stream").none()); + // Legacy overload + m.def( + "dlpack_managed_tensor_to_buffer", + [](const nb::capsule& tensor, + std::optional> cpu_client, + std::optional> gpu_client) { + return xla::ValueOrThrow(DLPackManagedTensorToBuffer( + tensor, std::move(cpu_client), std::move(gpu_client))); + }, + nb::arg("dlpack"), nb::arg("cpu_backend").none() = nb::none(), + nb::arg("gpu_backend").none() = nb::none()); + m.def("cuda_array_interface_to_buffer", + xla::ValueOrThrowWrapper(CudaArrayInterfaceToBuffer), nb::arg("cai"), + nb::arg("gpu_backend").none() = nb::none(), + nb::arg("device_id").none() = nb::none()); + + jax::BuildConfigSubmodule(m); + BuildIfrtProgramsSubmodule(m); + BuildProfilerSubmodule(m); + BuildOpsSubmodule(m); + BuildPytreeSubmodule(m); + jax::BuildGuardSubmodule(m); + jax::BuildJaxjitSubmodule(m); + jax::BuildPmapSubmodule(m); + jax::BuildPjitSubmodule(m); + BuildTracebackSubmodule(m); + BuildMlirSubmodule(m); + BuildSdySubmodule(m); + BuildCustomCallShardingPybindAPI(m); +#if defined(__linux__) + aux::RegisterTransferServerTypes(m); +#endif // defined(__linux__) + + // The following uses python bindings for PyClient defined above using + // pybind11, and hence needs pybind11::module_ (not just nanobind::module_). + xla::ifrt::proxy::BuildIfrtProxySubmodule(m); + + nb::class_ preemption_sync_manager( + m, "PreemptionSyncManager"); + preemption_sync_manager + .def( + "initialize", + [](tsl::PreemptionSyncManager& manager, + DistributedRuntimeClient* client) { + tsl::CoordinationServiceAgent* agent = + xla::ValueOrThrow(client->GetCoordinationServiceAgent()); + xla::ThrowIfError(manager.Initialize(agent)); + }, + nb::arg("distributed_client")) + .def("reached_sync_point", + [](tsl::PreemptionSyncManager& manager, int step_counter) { + return manager.ReachedSyncPoint(step_counter); + }); + m.def("create_preemption_sync_manager", + []() { return tsl::CreatePreemptionSyncManager(); }); + + nb::class_ distributed_runtime_service( + m, "DistributedRuntimeService"); + distributed_runtime_service.def("shutdown", + &DistributedRuntimeService::Shutdown, + nb::call_guard()); + nb::class_ distributed_runtime_client( + m, "DistributedRuntimeClient"); + distributed_runtime_client + .def("connect", + [](DistributedRuntimeClient& self) { + nb::gil_scoped_release gil_release; + xla::ThrowIfError(self.Connect()); + }) + .def("shutdown", + [](DistributedRuntimeClient& self) { + nb::gil_scoped_release gil_release; + xla::ThrowIfError(self.Shutdown()); + }) + // This method assumes that the value is a Python string. Use + // `blocking_key_value_get_bytes()` if key_value_set() was called with a + // Python bytes object as its value. + .def( + "blocking_key_value_get", + [](DistributedRuntimeClient& client, std::string key, + int64_t timeout_in_ms) { + nb::gil_scoped_release gil_release; + return xla::ValueOrThrow(client.BlockingKeyValueGet( + key, absl::Milliseconds(timeout_in_ms))); + }, + nb::arg("key"), nb::arg("timeout_in_ms")) + // Same as `blocking_key_value_get()`, but retrieves the raw Python byte + // values explicitly. + .def( + "blocking_key_value_get_bytes", + [](DistributedRuntimeClient& client, std::string key, + int64_t timeout_in_ms) -> nb::bytes { + std::string result; + { + nb::gil_scoped_release gil_release; + result = xla::ValueOrThrow(client.BlockingKeyValueGet( + key, absl::Milliseconds(timeout_in_ms))); + } + return nb::bytes(result.data(), result.size()); + }, + nb::arg("key"), nb::arg("timeout_in_ms")) + .def( + "key_value_try_get", + [](DistributedRuntimeClient& client, std::string key) { + nb::gil_scoped_release gil_release; + return xla::ValueOrThrow(client.KeyValueTryGet(key)); + }, + nb::arg("key")) + .def( + "key_value_try_get_bytes", + [](DistributedRuntimeClient& client, std::string key) -> nb::bytes { + std::string result; + { + nb::gil_scoped_release gil_release; + result = xla::ValueOrThrow(client.KeyValueTryGet(key)); + } + return nb::bytes(result.data(), result.size()); + }, + nb::arg("key")) + .def( + "wait_at_barrier", + [](DistributedRuntimeClient& client, std::string barrier_id, + int64_t timeout_in_ms, + std::optional> process_ids) { + nb::gil_scoped_release gil_release; + xla::ThrowIfError(client.WaitAtBarrier( + barrier_id, absl::Milliseconds(timeout_in_ms), process_ids)); + }, + nb::arg("barrier_id"), nb::arg("timeout_in_ms"), + nb::arg("process_ids") = std::nullopt) + .def( + "get_live_nodes", + [](DistributedRuntimeClient& client, + std::vector process_ids) { + nb::gil_scoped_release gil_release; + return xla::ValueOrThrow(client.GetLiveNodes(process_ids)); + }, + nb::arg("process_ids")) + // The key must be a string, but the value can either be a Python string + // or bytes object. + // With Python string values, use `key_value_set()` and + // `blocking_key_value_get()`. + // With Python byte object values, use `key_value_set()` and + // `blocking_key_value_get_bytes()`. + .def( + "key_value_set", + [](DistributedRuntimeClient& client, absl::string_view key, + absl::string_view value, bool allow_overwrite) { + nb::gil_scoped_release gil_release; + xla::ThrowIfError(client.KeyValueSet(key, value, allow_overwrite)); + }, + nb::arg("key"), nb::arg("value"), nb::arg("allow_overwrite") = false) + // The key must be a string, but the value must a + // Python bytes object. + // Use `key_value_set_bytes()` and `blocking_key_value_get_bytes()`. + .def( + "key_value_set_bytes", + [](DistributedRuntimeClient& client, absl::string_view key, + nb::bytes value, bool allow_overwrite) { + nb::gil_scoped_release gil_release; + xla::ThrowIfError(client.KeyValueSet( + key, absl::string_view(value.c_str(), value.size()), + allow_overwrite)); + }, + nb::arg("key"), nb::arg("value"), nb::arg("allow_overwrite") = false) + // Assumes that all values in the directory are Python strings. + .def( + "key_value_dir_get", + [](DistributedRuntimeClient& client, absl::string_view key) { + nb::gil_scoped_release gil_release; + return xla::ValueOrThrow(client.KeyValueDirGet(key)); + }, + nb::arg("key")) + // Assumes that all values in the directory are Python byte objects. + // Same as `key_value_dir_get()`, but retrieves Python byte values + // explicitly. + .def( + "key_value_dir_get_bytes", + [](DistributedRuntimeClient& client, absl::string_view key) + -> std::vector> { + std::vector> result; + { + nb::gil_scoped_release gil_release; + result = xla::ValueOrThrow(client.KeyValueDirGet(key)); + } + // Convert std::string values to nb::bytes. + std::vector> kvs; + kvs.reserve(result.size()); + for (auto& kv : result) { + kvs.push_back( + std::pair(std::move(kv.first), + nb::bytes(kv.second.data(), kv.second.size()))); + } + return kvs; + }, + nb::arg("key")) + .def( + "key_value_delete", + [](DistributedRuntimeClient& client, absl::string_view key) { + nb::gil_scoped_release gil_release; + return xla::ThrowIfError(client.KeyValueDelete(key)); + }, + nb::arg("key")); + + m.def( + "get_distributed_runtime_service", + [](std::string address, int num_nodes, + std::optional heartbeat_interval, + std::optional max_missing_heartbeats, + std::optional cluster_register_timeout, + std::optional shutdown_timeout) + -> std::unique_ptr { + CoordinationServiceImpl::Options options; + options.num_nodes = num_nodes; + if (heartbeat_interval.has_value()) { + options.heartbeat_interval = absl::Seconds(*heartbeat_interval); + } + if (max_missing_heartbeats.has_value()) { + options.max_missing_heartbeats = *max_missing_heartbeats; + } + if (cluster_register_timeout.has_value()) { + options.cluster_register_timeout = + absl::Seconds(*cluster_register_timeout); + } + if (shutdown_timeout.has_value()) { + options.shutdown_timeout = absl::Seconds(*shutdown_timeout); + } + std::unique_ptr service = + xla::ValueOrThrow(GetDistributedRuntimeService(address, options)); + return service; + }, + nb::arg("address"), nb::arg("num_nodes"), + nb::arg("heartbeat_interval").none() = std::nullopt, + nb::arg("max_missing_heartbeats").none() = std::nullopt, + nb::arg("cluster_register_timeout").none() = std::nullopt, + nb::arg("shutdown_timeout").none() = std::nullopt); + + m.def( + "get_distributed_runtime_client", + [](std::string address, int node_id, std::optional rpc_timeout, + std::optional init_timeout, std::optional shutdown_timeout, + std::optional heartbeat_interval, + std::optional max_missing_heartbeats, + std::optional> + missed_heartbeat_callback, + std::optional shutdown_on_destruction, + std::optional use_compression) + -> std::shared_ptr { + bool compression = use_compression.value_or(false); + DistributedRuntimeClient::Options options; + options.node_id = node_id; + if (rpc_timeout.has_value()) { + options.rpc_timeout = absl::Seconds(*rpc_timeout); + } + if (init_timeout.has_value()) { + options.init_timeout = absl::Seconds(*init_timeout); + } + if (shutdown_timeout.has_value()) { + options.shutdown_timeout = absl::Seconds(*shutdown_timeout); + } + if (heartbeat_interval.has_value()) { + options.heartbeat_interval = absl::Seconds(*heartbeat_interval); + } + if (max_missing_heartbeats.has_value()) { + options.max_missing_heartbeats = *max_missing_heartbeats; + } + if (missed_heartbeat_callback.has_value()) { + options.missed_heartbeat_callback = + std::move(*missed_heartbeat_callback); + } + if (shutdown_on_destruction.has_value()) { + options.shutdown_on_destruction = *shutdown_on_destruction; + } + return GetDistributedRuntimeClient(address, options, compression); + }, + nb::arg("address"), nb::arg("node_id"), + nb::arg("rpc_timeout").none() = std::nullopt, + nb::arg("init_timeout").none() = std::nullopt, + nb::arg("shutdown_timeout").none() = std::nullopt, + nb::arg("heartbeat_interval").none() = std::nullopt, + nb::arg("max_missing_heartbeats").none() = std::nullopt, + nb::arg("missed_heartbeat_callback").none() = std::nullopt, + nb::arg("shutdown_on_destruction").none() = std::nullopt, + nb::arg("use_compression").none() = std::nullopt); + + m.def("collect_garbage", []() { GlobalPyRefManager()->CollectGarbage(); }); + + m.def("is_optimized_build", &IsOptimizedBuild); + + m.def("json_to_pprof_profile", xla::ValueOrThrowWrapper(JsonToPprofProfile), + "Encodes the JSON representation of a pprof Profile into its binary " + "protocol buffer encoding."); + m.def("pprof_profile_to_json", xla::ValueOrThrowWrapper(PprofProfileToJson), + "Decodes an uncompressed pprof Profile protocol buffer into a JSON " + "representation"); + + RegisterCompileOnlyClient(m); + nb::class_(m, "DeviceTopology") + .def("_make_compile_only_devices", + [](std::shared_ptr topology) { + if (!llvm::isa(*topology)) { + throw xla::XlaRuntimeError("Only PjRtTopologies are supported."); + } + return MakeCompileOnlyClient( + std::dynamic_pointer_cast(topology)) + ->Devices(); + }) + .def_prop_ro( + "platform", + [](ifrt::Topology& topology) { return topology.platform_name(); }) + .def_prop_ro( + "platform_version", + [](ifrt::Topology& topology) { return topology.platform_version(); }) + .def("serialize", + [](ifrt::Topology& topology) -> nb::bytes { + std::string serialized = ValueOrThrow(topology.Serialize()); + return nb::bytes(serialized.data(), serialized.size()); + }) + .def("__getattr__", + [](ifrt::Topology& topology, absl::string_view name) -> nb::object { + const auto& attrs = topology.Attributes().map(); + auto it = attrs.find(name); + if (it != attrs.end()) { + return std::visit([](auto&& v) { return nb::cast(v.value); }, + it->second); + } + throw nb::attribute_error( + absl::StrCat("Unknown attribute ", name).c_str()); + }); + + nb::class_(m, "Executable") + .def("hlo_modules", ValueOrThrowWrapper(&ifrt::Executable::GetHloModules)) + .def("get_output_memory_kinds", + xla::ValueOrThrowWrapper(&ifrt::Executable::GetOutputMemoryKinds)) + .def("get_output_shardings", &ifrt::Executable::GetOutputShardings) + .def("get_parameter_layouts", + ValueOrThrowWrapper(&ifrt::Executable::GetParameterLayouts)) + .def("get_output_layouts", + xla::ValueOrThrowWrapper(&ifrt::Executable::GetOutputLayouts)) + .def("get_parameter_shardings", &ifrt::Executable::GetParameterShardings) + .def("get_compiled_memory_stats", + xla::ValueOrThrowWrapper(&ifrt::Executable::GetCompiledMemoryStats)) + .def("serialize", + [](const ifrt::Executable& exec) -> nb::bytes { + std::string serialized = ValueOrThrow(exec.Serialize()); + return nb::bytes(serialized.data(), serialized.size()); + }) + .def("cost_analysis", [](const ifrt::Executable& exec) { + auto attrs = ValueOrThrow(exec.GetCostAnalysis()); + return ifrt::ToPjRtAttributeMap(std::move(attrs)); + }); + + m.def("is_asan", IsAsan); + m.def("is_msan", IsMsan); + m.def("is_tsan", IsTsan); + m.def("is_sanitized", IsSanitized); + + m.def( + "batched_device_put", + [](nb::object aval, nb::object sharding, std::vector xs, + std::vector dst_devices, bool committed, + bool force_copy, + PjRtClient::HostBufferSemantics host_buffer_semantics) -> nb::object { + return ValueOrThrow(PyArray::BatchedDevicePut( + aval, sharding, std::move(xs), std::move(dst_devices), committed, + force_copy, host_buffer_semantics, jax::GetEnableX64())); + }, + nb::arg("aval"), nb::arg("sharding"), nb::arg("xs"), nb::arg("devices"), + nb::arg("committed") = true, nb::arg("force_copy") = false, + nb::arg("host_buffer_semantics") = + PjRtClient::HostBufferSemantics::kImmutableZeroCopy); + m.def( + "reorder_shards", + [](PyArray x, nb::object dst_sharding, + ifrt::ArrayCopySemantics array_copy_semantics) { + return ValueOrThrow(PyArray::ReorderShards( + std::move(x), std::move(dst_sharding), array_copy_semantics)); + }, + nb::arg("x"), nb::arg("dst_sharding"), nb::arg("array_copy_semantics")); + + m.def("batched_block_until_ready", [](std::vector xs) { + ThrowIfError(PyArray::BatchedBlockUntilReady(std::move(xs))); + }); + + m.def("check_and_canonicalize_memory_kind", + &jax::CheckAndCanonicalizeMemoryKind, nb::arg("memory_kind").none(), + nb::arg("device_list")); +} // NOLINT(readability/fn_size) + +} // namespace xla diff --git a/jaxlib/xla/xla_client.py b/jaxlib/xla/xla_client.py index b6c5707d05dd..a111c14232de 100644 --- a/jaxlib/xla/xla_client.py +++ b/jaxlib/xla/xla_client.py @@ -19,7 +19,7 @@ import atexit from collections.abc import Mapping, Sequence import contextlib -import enum # pylint: disable=g-bad-import-order +import enum import gzip import inspect import logging diff --git a/jaxlib/xla/xla_extension/__init__.pyi b/jaxlib/xla/xla_extension/__init__.pyi new file mode 100644 index 000000000000..3a6435824b67 --- /dev/null +++ b/jaxlib/xla/xla_extension/__init__.pyi @@ -0,0 +1,1059 @@ +# Copyright 2021 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import annotations + +import enum +import inspect +import types +import typing +from typing import ( + Any, + Callable, + ClassVar, + Dict, + Iterator, + List, + Optional, + Sequence, + Tuple, + Type, + TypeVar, + Union, + overload, +) + +import numpy as np + +from . import config as config +from . import guard_lib as guard_lib +from . import ifrt_programs as ifrt_programs +from . import ifrt_proxy as ifrt_proxy +from . import jax_jit as jax_jit +from . import mlir as mlir +from . import ops as ops +from . import pmap_lib as pmap_lib +from . import profiler as profiler +from . import pytree as pytree +from . import transfer_guard_lib as transfer_guard_lib + +custom_call_targets = Any +hlo_sharding_util = Any + +_LiteralSlice = Any +_Status = Any +_Dtype = Any +_XlaOpMetadata = Any + +_T = TypeVar("_T") + +class XlaRuntimeError(RuntimeError): + pass + +class PrimitiveType(enum.IntEnum): + PRIMITIVE_TYPE_INVALID: PrimitiveType + PRED: PrimitiveType + S2: PrimitiveType + S4: PrimitiveType + S8: PrimitiveType + S16: PrimitiveType + S32: PrimitiveType + S64: PrimitiveType + U2: PrimitiveType + U4: PrimitiveType + U8: PrimitiveType + U16: PrimitiveType + U32: PrimitiveType + U64: PrimitiveType + F4E2M1FN: PrimitiveType + F8E3M4: PrimitiveType + F8E4M3: PrimitiveType + F8E4M3FN: PrimitiveType + F8E4M3B11FNUZ: PrimitiveType + F8E4M3FNUZ: PrimitiveType + F8E5M2: PrimitiveType + F8E5M2FNUZ: PrimitiveType + F8E8M0FNU: PrimitiveType + BF16: PrimitiveType + F16: PrimitiveType + F32: PrimitiveType + F64: PrimitiveType + C64: PrimitiveType + C128: PrimitiveType + TUPLE: PrimitiveType + OPAQUE_TYPE: PrimitiveType + TOKEN: PrimitiveType + +# === BEGIN xla_compiler.cc + +class ArrayCopySemantics(enum.IntEnum): + ALWAYS_COPY: ArrayCopySemantics + REUSE_INPUT: ArrayCopySemantics + DONATE_INPUT: ArrayCopySemantics + +class Layout: + @overload + def __init__(self, minor_to_major: Tuple[int, ...]): ... + @overload + def __init__(self, minor_to_major: Tuple[int, ...], + tiling: Tuple[Tuple[int, ...], ...], + element_size_in_bits: int): ... + def minor_to_major(self) -> Tuple[int, ...]: ... + def tiling(self) -> Sequence[Tuple[int, ...]]: ... + def element_size_in_bits(self) -> int: ... + def to_string(self) -> str: ... + def __eq__(self, other: Layout) -> bool: ... + def __ne__(self, other: Layout) -> bool: ... + def __hash__(self) -> int: ... + +class Shape: + def __init__(self, s: str): ... + @staticmethod + def tuple_shape(shapes: Sequence[Shape]) -> Shape: ... + @staticmethod + def array_shape( + type: Union[np.dtype, PrimitiveType], + dims_seq: Any = ..., + layout_seq: Any = ..., + dynamic_dimensions: Optional[List[bool]] = ..., + ) -> Shape: ... + @staticmethod + def token_shape() -> Shape: ... + @staticmethod + def scalar_shape(type: Union[np.dtype, PrimitiveType]) -> Shape: ... + def dimensions(self) -> Tuple[int, ...]: ... + def layout(self) -> Layout: ... + def xla_element_type(self) -> PrimitiveType: ... + def element_type(self) -> np.dtype: ... + def numpy_dtype(self) -> np.dtype: ... + def is_tuple(self) -> bool: ... + def is_array(self) -> bool: ... + def is_token(self) -> bool: ... + def is_static(self) -> bool: ... + def is_dynamic(self) -> bool: ... + def is_dynamic_dimension(self, dimension: int) -> bool: ... + def set_dynamic_dimension(self, dimension: int, is_dynamic: bool) -> None: ... + def rank(self) -> int: ... + def to_serialized_proto(self) -> bytes: ... + def tuple_shapes(self) -> List[Shape]: ... + def leaf_count(self) -> int: ... + def with_major_to_minor_layout_if_absent(self) -> Shape: ... + def __eq__(self, other: Shape) -> bool: ... + def __ne__(self, other: Shape) -> bool: ... + def __hash__(self) -> int: ... + def __repr__(self) -> str: ... + +class ProgramShape: + def __init__(self, params: Sequence[Shape], result: Shape) -> None: ... + def parameter_shapes(self) -> List[Shape]: ... + def result_shape(self) -> Shape: ... + def __repr__(self) -> str: ... + +class ShapeIndex: + def __init__(self, indices: List[int]) -> ShapeIndex: ... + def __eq__(self, other: Shape) -> bool: ... + def __ne__(self, other: Shape) -> bool: ... + def __hash__(self) -> int: ... + def __repr__(self) -> str: ... + +class Literal: + def __init__(self, shape: Shape) -> Literal: ... + def __repr__(self) -> str: ... + def __array__( + self, dtype: Optional[np.dtype] = None, copy: Optional[bool] = None + ) -> np.ndarray: ... + def shape(self) -> Shape: ... + +class XlaComputation: + def __init__(self, serialized_hlo_module_proto: bytes) -> None: ... + def get_hlo_module(self) -> HloModule: ... + def program_shape(self) -> ProgramShape: ... + def as_serialized_hlo_module_proto(self) -> bytes: ... + def as_hlo_text(self, print_large_constants: bool = False) -> str: ... + def as_hlo_dot_graph(self) -> str: ... + def hash(self) -> int: ... + def as_hlo_module(self) -> HloModule: ... + +class HloPrintOptions: + def __init__(self) -> None: ... + @staticmethod + def short_parsable() -> HloPrintOptions: ... + @staticmethod + def canonical() -> HloPrintOptions: ... + @staticmethod + def fingerprint() -> HloPrintOptions: ... + print_large_constants: bool + print_metadata: bool + print_backend_config: bool + print_result_shape: bool + print_operand_shape: bool + print_operand_names: bool + print_ids: bool + print_extra_attributes: bool + print_program_shape: bool + print_percent: bool + print_control_dependencies: bool + compact_operands: bool + include_layout_in_shapes: bool + canonicalize_instruction_names: bool + canonicalize_computations: bool + indent_amount: int + is_in_nested_computation: bool + +class HloComputation: + def render_html(self) -> None: ... + +class HloModule: + spmd_output_sharding: Optional[OpSharding] + spmd_parameters_shardings: Optional[List[OpSharding]] + @property + def name(self) -> str: ... + def to_string(self, options: HloPrintOptions = ...) -> str: ... + def as_serialized_hlo_module_proto(self) -> bytes: ... + @staticmethod + def from_serialized_hlo_module_proto( + serialized_hlo_module_proto: bytes, + ) -> HloModule: ... + def computations(self) -> List[HloComputation]: ... + +class HloModuleGroup: + def __init__(self, name: str, modules: List[HloModule]) -> None: ... + @property + def name(self) -> str: ... + def to_string(self) -> str: ... + def to_modules(self) -> List[HloModule]: ... + +def hlo_module_to_dot_graph(hlo_module: HloModule) -> str: ... +def hlo_module_from_text(hlo_module_text: str) -> HloModule: ... +def hlo_module_cost_analysis( + client: Client, module: HloModule +) -> Dict[str, float]: ... + +class XlaOp: ... + +class XlaBuilder: + def __init__(self, name: str) -> None: ... + def Build(self, root: Optional[XlaOp] = ...) -> XlaComputation: ... + def GetShape(self, __op: XlaOp) -> Shape: ... + build = Build + def clear_op_metadata(self) -> None: ... + get_shape = GetShape + def get_program_shape(self, root: Optional[XlaOp] = ...) -> ProgramShape: ... + def is_constant(self, __op: XlaOp) -> bool: ... + def set_op_metadata(self, metadata: _XlaOpMetadata) -> None: ... + def set_sharding(self, sharding: OpSharding_Type) -> None: ... + def clear_sharding(self) -> None: ... + def setup_alias( + self, + __output_index: Sequence[int], + __param_number: int, + __param_index: Sequence[int], + ) -> None: ... + +class DeviceAssignment: + @staticmethod + def create(array: np.ndarray) -> DeviceAssignment: ... + def replica_count(self) -> int: ... + def computation_count(self) -> int: ... + def __repr__(self) -> str: ... + def serialize(self) -> bytes: ... + +class CompileOptions: + @staticmethod + def ParseFromString(s: bytes) -> CompileOptions: ... + def __init__(self) -> None: ... + def SerializeAsString(self) -> bytes: ... + argument_layouts: Optional[List[Shape]] + parameter_is_tupled_arguments: bool + executable_build_options: ExecutableBuildOptions + tuple_arguments: bool + num_replicas: int + num_partitions: int + profile_version: int + device_assignment: Optional[DeviceAssignment] + compile_portable_executable: bool + env_option_overrides: List[Tuple[str, str]] + +def register_custom_call_target( + fn_name: str, capsule: Any, platform: str, api_version: int = ..., +) -> _Status: ... +def register_custom_call_partitioner( + name: str, + prop_user_sharding: Callable, + partition: Callable, + infer_sharding_from_operands: Callable, + can_side_effecting_have_replicated_sharding: bool = ..., + c_api: Optional[Any] = ..., +) -> None: ... +def encode_inspect_sharding_callback(handler: Any) -> bytes: ... +def register_custom_call_as_batch_partitionable( + target_name: str, + c_api: Optional[Any] = ..., +) -> None: ... + +def register_custom_type_id(type_name: str, type_id: Any) -> None: ... + +class AutotuneCacheMode(enum.IntEnum): + UNSPECIFIED: AutotuneCacheMode + UPDATE: AutotuneCacheMode + READ: AutotuneCacheMode + +class DebugOptions: + def __repr__(self) -> str: ... + xla_cpu_enable_fast_math: bool + xla_cpu_fast_math_honor_infs: bool + xla_cpu_fast_math_honor_nans: bool + xla_cpu_fast_math_honor_division: bool + xla_cpu_fast_math_honor_functions: bool + xla_gpu_enable_fast_min_max: bool + xla_backend_optimization_level: int + xla_cpu_enable_xprof_traceme: bool + xla_llvm_disable_expensive_passes: bool + xla_test_all_input_layouts: bool + xla_disable_hlo_passes: str + xla_enable_hlo_passes_only: str + xla_force_host_platform_device_count: int + xla_dump_to: str + xla_dump_hlo_module_re: str + xla_dump_hlo_pass_re: str + xla_dump_hlo_as_text: bool + xla_dump_hlo_as_proto: bool + xla_dump_hlo_as_dot: bool + xla_dump_hlo_as_url: bool + xla_dump_hlo_as_html: bool + xla_dump_fusion_visualization: bool + xla_dump_hlo_snapshots: bool + xla_dump_max_hlo_modules: bool + xla_dump_module_metadata: bool + xla_dump_compress_protos: bool + xla_dump_hlo_as_long_text: bool + xla_dump_disable_metadata: bool + xla_dump_hlo_pipeline_re: str + xla_gpu_cuda_data_dir: str + xla_detailed_logging: bool + xla_enable_dumping: bool + xla_gpu_dump_autotune_results_to: str + xla_gpu_load_autotune_results_from: str + xla_gpu_dump_autotune_logs_to: str + xla_gpu_kernel_cache_file: str + xla_gpu_enable_llvm_module_compilation_parallelism: bool + xla_gpu_per_fusion_autotune_cache_dir: str + xla_gpu_experimental_autotune_cache_mode: AutotuneCacheMode + +class CompiledMemoryStats: + generated_code_size_in_bytes: int + argument_size_in_bytes: int + output_size_in_bytes: int + alias_size_in_bytes: int + temp_size_in_bytes: int + host_generated_code_size_in_bytes: int + host_argument_size_in_bytes: int + host_output_size_in_bytes: int + host_alias_size_in_bytes: int + host_temp_size_in_bytes: int + serialized_hlo_proto: bytes + def __str__(self) -> str: ... + +class ExecutableBuildOptions: + def __init__(self) -> None: ... + def __repr__(self) -> str: ... + result_layout: Optional[Shape] + fdo_profile: Optional[bytes] + num_replicas: int + num_partitions: int + debug_options: DebugOptions + device_assignment: Optional[DeviceAssignment] + use_spmd_partitioning: bool + use_auto_spmd_partitioning: bool + auto_spmd_partitioning_mesh_shape: List[int] + auto_spmd_partitioning_mesh_ids: List[int] + use_shardy_partitioner: bool + def compilation_environments_from_serialized_proto(self, serialized_proto: bytes) -> None: ... + +class PrecisionConfig_Precision(enum.IntEnum): + DEFAULT: int + HIGH: int + HIGHEST: int + + +class ResultAccuracy_Mode(enum.IntEnum): + DEFAULT: int + HIGHEST: int + TOLERANCE: int + +class ResultAccuracy: + mode: ResultAccuracy_Mode + atol: float + rtol: float + ulps: int + +class OpSharding_Type(enum.IntEnum): + REPLICATED: int + MAXIMAL: int + TUPLE: int + OTHER: int + MANUAL: int + UNKNOWN: int + +class OpSharding_ShardGroupType(enum.IntEnum): + AS: int + LIKE: int + +class OpSharding: + Type: typing.Type[OpSharding_Type] + type: OpSharding_Type + replicate_on_last_tile_dim: bool + last_tile_dims: Sequence[Type] + tile_assignment_dimensions: Sequence[int] + tile_assignment_devices: Sequence[int] + iota_reshape_dims: Sequence[int] + iota_transpose_perm: Sequence[int] + tuple_shardings: Sequence[OpSharding] + is_shard_group: bool + shard_group_id: int + ShardGroupType: typing.Type[OpSharding_ShardGroupType] + shard_group_type: OpSharding_ShardGroupType + def ParseFromString(self, s: bytes) -> None: ... + def SerializeToString(self) -> bytes: ... + def clone(self) -> OpSharding: ... + +class HloSharding: + @staticmethod + def from_proto(proto: OpSharding) -> HloSharding: ... + @staticmethod + def from_string(sharding: str) -> HloSharding: ... + @staticmethod + def tuple_sharding( + shape: Shape, shardings: Sequence[HloSharding] + ) -> HloSharding: ... + @staticmethod + def iota_tile( + dims: Sequence[int], + reshape_dims: Sequence[int], + transpose_perm: Sequence[int], + subgroup_types: Sequence[OpSharding.Type], + ) -> HloSharding: ... + @staticmethod + def replicate() -> HloSharding: ... + @staticmethod + def manual() -> HloSharding: ... + @staticmethod + def unknown() -> HloSharding: ... + @staticmethod + def subgroup_with_device_ordering( + tile_assignment: np.ndarray, + subgroup_types: Sequence[OpSharding.Type]) -> HloSharding: ... + def __eq__(self, other: HloSharding) -> bool: ... + def __hash__(self) -> int: ... + def __repr__(self) -> str: ... + def tile(self, shape: Shape) -> Shape: ... + def is_replicated(self) -> bool: ... + def is_manual(self) -> bool: ... + def is_unknown(self) -> bool: ... + def is_tiled(self) -> bool: ... + def is_maximal(self) -> bool: ... + def tuple_elements(self) -> List[HloSharding]: ... + def num_devices(self) -> int: ... + def num_dimensions(self) -> int: ... + def tile_assignment_dimensions(self) -> Sequence[int]: ... + def tile_assignment_devices(self) -> Sequence[int]: ... + def subgroup_types(self) -> Sequence[OpSharding.Type]: ... + def replicate_on_last_tile_dim(self) -> bool: ... + def to_proto(self) -> OpSharding: ... + +class FftType(enum.IntEnum): + FFT: FftType + IFFT: FftType + RFFT: FftType + IRFFT: FftType + +# === END xla_compiler.cc + +class Device: + id: int + host_id: int + process_index: int + platform: str + device_kind: str + client: Client + local_hardware_id: int | None + def __repr__(self) -> str: ... + def __str__(self) -> str: ... + def transfer_to_infeed(self, literal: _LiteralSlice): ... + def transfer_from_outfeed(self, shape: Shape): ... + def memory(self, kind: str) -> Memory: ... + def default_memory(self) -> Memory: ... + def addressable_memories(self) -> List[Memory]: ... + def live_buffers(self) -> List[Any]: ... + def memory_stats(self) -> Optional[Dict[str, int]]: ... + def get_stream_for_external_ready_events(self) -> int: ... + def __getattr__(self, name: str) -> Any: ... + +class Memory: + process_index: int + platform: str + kind: str + def __repr__(self) -> str: ... + def __str__(self) -> str: ... + def addressable_by_devices(self) -> List[Device]: ... + +class PjRtLayout: + def __str__(self) -> str: ... + def __eq__(self, other: PjRtLayout) -> bool: ... + def __hash__(self) -> int: ... + def __getstate__(self) -> Any: ... + def __setstate__(self, _: Any): ... + def _xla_layout(self) -> Layout: ... + +class GpuAllocatorConfig: + class Kind(enum.IntEnum): + DEFAULT: int + PLATFORM: int + BFC: int + CUDA_ASYNC: int + + def __init__( + self, + kind: Kind = ..., + memory_fraction: float = ..., + preallocate: bool = ..., + collective_memory_size: int = ..., + ) -> None: ... + +class HostBufferSemantics(enum.IntEnum): + IMMUTABLE_ONLY_DURING_CALL: HostBufferSemantics + IMMUTABLE_UNTIL_TRANSFER_COMPLETES: HostBufferSemantics + ZERO_COPY: HostBufferSemantics + +class Client: + platform: str + _raw_platform: str + platform_version: str + runtime_type: str + def device_count(self) -> int: ... + def local_device_count(self) -> int: ... + def devices(self) -> List[Device]: ... + def local_devices(self) -> List[Device]: ... + def _get_all_devices(self) -> List[Device]: ... + def device_from_local_hardware_id(self, int) -> Device: ... + def live_buffers(self) -> List[Any]: ... + def live_arrays(self) -> List[ArrayImpl]: ... + def live_executables(self) -> List[LoadedExecutable]: ... + def host_id(self) -> int: ... + def process_index(self) -> int: ... + def buffer_from_pyval( + self, + argument: Any, + device: Optional[Device] = ..., + force_copy: bool = ..., + host_buffer_semantics: HostBufferSemantics = ..., + ) -> ArrayImpl: ... + def compile( + self, + computation: Union[str, bytes], + compile_options: CompileOptions = ..., + host_callbacks: Sequence[Any] = ..., + ) -> LoadedExecutable: ... + def compile_ifrt_program( + self, + program: ifrt_programs.Program, + program_options: ifrt_programs.CompileOptions, + ) -> LoadedExecutable: ... + def serialize_executable(self, executable: LoadedExecutable) -> bytes: ... + def deserialize_executable( + self, + serialized: bytes, + options: Optional[CompileOptions], + host_callbacks: Sequence[Any] = ..., + ) -> LoadedExecutable: ... + def heap_profile(self) -> bytes: ... + def defragment(self) -> _Status: ... + def get_emit_python_callback_descriptor( + self, + callable: Callable, + operand_shapes: Sequence[Shape], + results_shapes: Sequence[Shape], + ) -> Tuple[Any, Any]: ... + def make_python_callback_from_host_send_and_recv( + self, + callable: Callable, + operand_shapes: Sequence[Shape], + result_shapes: Sequence[Shape], + send_channel_ids: Sequence[int], + recv_channel_ids: Sequence[int], + serializer: Optional[Callable] = ..., + ) -> Any: ... + def get_default_layout( + self, dtype: np.dtype, shard_shape: Sequence[int], device: Device + ) -> PjRtLayout: ... + def __getattr__(self, name: str) -> Any: ... + +class CpuCollectives: ... + +def make_gloo_tcp_collectives( + distributed_client: Optional[DistributedRuntimeClient] = ..., + hostname: Optional[str] = ..., + interface: Optional[str] = ..., +) -> CpuCollectives: ... + +class MpiCollectives(CpuCollectives): + def Init(self): ... + def Finalize(self): ... + +def make_mpi_collectives() -> MpiCollectives: ... + +def get_tfrt_cpu_client( + asynchronous: bool = ..., + distributed_client: Optional[DistributedRuntimeClient] = ..., + node_id: int = ..., + num_nodes: int = ..., + collectives: Optional[CpuCollectives] = ..., + num_devices: int | None = ..., +) -> Client: ... +def get_gpu_client( + asynchronous: bool = ..., + allocator_config: GpuAllocatorConfig = ..., + distributed_client: Optional[DistributedRuntimeClient] = ..., + node_id: int = ..., + num_nodes: int = ..., + allowed_devices: Optional[Any] = ..., + platform_name: Optional[str] = ..., + mock: Optional[bool] = ..., + mock_gpu_topology: Optional[str] = ..., +) -> Client: ... +def get_mock_gpu_client( + asynchronous: bool = ..., + allocator_config: GpuAllocatorConfig = ..., + distributed_client: Optional[DistributedRuntimeClient] = ..., + node_id: int = ..., + allowed_devices: Optional[Any] = ..., + platform_name: Optional[str] = ..., +) -> Client: ... +def get_c_api_client( + platform_name: str, + options: Dict[str, Union[str, int, List[int], float, bool]], + distributed_client: Optional[DistributedRuntimeClient] = ..., +) -> Client: ... +def get_default_c_api_topology( + platform_name: str, + topology_name: str, + options: Dict[str, Union[str, int, List[int], float]], +) -> DeviceTopology: ... +def get_c_api_topology( + c_api: Any, + topology_name: str, + options: Dict[str, Union[str, int, List[int], float]], +) -> DeviceTopology: ... +def get_topology_for_devices(devices: List[Device]) -> DeviceTopology: ... +def load_pjrt_plugin(platform_name: str, library_path: Optional[str], c_api: Optional[Any]) -> _Status: ... +def pjrt_plugin_loaded(plugin_name: str) -> bool: ... +def pjrt_plugin_initialized(plugin_name: str) -> bool: ... +def initialize_pjrt_plugin(platform_name: str) -> _Status: ... + +ArrayImpl = Any + +# TODO(phawkins): this type is problematic because it is not a subtype of +# jax.Array, and pytype notices. +# class ArrayImpl: +# def __init__(self, +# aval: Any, +# sharding: Any, +# arrays: Sequence[ArrayImpl], +# committed: bool, +# _skip_checks: bool = ...): ... +# def block_until_ready(self) -> ArrayImpl: ... +# def is_deleted(self) -> bool: ... +# def is_ready(self) -> bool: ... +# def delete(self): ... +# def unsafe_buffer_pointer(self) -> Any: ... +# def clone(self) -> ArrayImpl: ... +# def _copy_single_device_array_to_host_async(self): ... +# def _single_device_array_to_np_array_did_copy(self) -> tuple[np.ndarray, bool]: ... +# def on_device_size_in_bytes(self) -> int: ... +# def _fully_replicated_shard(self) -> ArrayImpl: ... +# __cuda_array_interface__: Dict[str, Any] +# dtype: np.dtype +# shape: Tuple[int, ...] +# _arrays: Any +# _npy_value: Any +# traceback: Traceback +# _HAS_DYNAMIC_ATTRIBUTES: bool = ... + +def batched_copy_array_to_devices_with_sharding( + arrays: Sequence[ArrayImpl], + devices: Sequence[List[Device]], + sharding: Sequence[Any], + array_copy_semantics: Sequence[ArrayCopySemantics], +) -> Sequence[ArrayImpl]: ... + +def batched_block_until_ready(x: Sequence[ArrayImpl]) -> None: ... + +def batched_device_put( + aval: Any, + sharding: Any, + shards: Sequence[Any], + devices: List[Device], + committed: bool = True, +) -> ArrayImpl: ... + +def reorder_shards( + x: ArrayImpl, + dst_sharding: Any, + array_copy_semantics: ArrayCopySemantics, +) -> ArrayImpl: ... + +def check_and_canonicalize_memory_kind( + memory_kind: Optional[str], device_list: DeviceList +) -> Optional[str]: ... +def array_result_handler( + aval: Any, sharding: Any, committed: bool, _skip_checks: bool = ... +) -> Callable: ... + +class Token: + def block_until_ready(self): ... + +class ShardedToken: + def block_until_ready(self): ... + def get_token(self, device_id: int): ... + +class ExecuteResults: + def __len__(self) -> int: ... + def disassemble_into_single_device_arrays(self) -> List[List[ArrayImpl]]: ... + def disassemble_prefix_into_single_device_arrays( + self, n: int + ) -> List[List[ArrayImpl]]: ... + def consume_with_handlers(self, handlers: List[Callable]) -> List[Any]: ... + def consume_token(self) -> ShardedToken: ... + +class LoadedExecutable: + client: Client + def local_devices(self) -> List[Device]: ... + def size_of_generated_code_in_bytes(self) -> int: ... + def delete(self) -> None: ... + def execute(self, arguments: Sequence[ArrayImpl]) -> List[ArrayImpl]: ... + def execute_with_token( + self, arguments: Sequence[ArrayImpl] + ) -> Tuple[List[ArrayImpl], Token]: ... + def execute_sharded_on_local_devices( + self, arguments: Sequence[List[ArrayImpl]] + ) -> List[List[ArrayImpl]]: ... + def execute_sharded_on_local_devices_with_tokens( + self, arguments: Sequence[List[ArrayImpl]] + ) -> Tuple[List[List[ArrayImpl]], ShardedToken]: ... + def execute_sharded( + self, arguments: Sequence[List[ArrayImpl]], with_tokens: bool = ... + ) -> ExecuteResults: ... + def hlo_modules(self) -> List[HloModule]: ... + def get_output_memory_kinds(self) -> List[List[str]]: ... + def get_compiled_memory_stats(self) -> CompiledMemoryStats: ... + def get_output_shardings(self) -> Optional[List[OpSharding]]: ... + def get_parameter_shardings(self) -> Optional[List[OpSharding]]: ... + def get_parameter_layouts(self) -> List[Layout]: ... + def get_output_layouts(self) -> List[Layout]: ... + def keep_alive(self) -> None: ... + def cost_analysis(self) -> Dict[str, Any]: ... + traceback: Traceback + fingerprint: Optional[bytes] + +class Executable: + def hlo_modules(self) -> List[HloModule]: ... + def get_output_memory_kinds(self) -> List[List[str]]: ... + def get_output_shardings(self) -> Optional[List[OpSharding]]: ... + def get_parameter_shardings(self) -> Optional[List[OpSharding]]: ... + def get_parameter_layouts(self) -> List[Layout]: ... + def get_output_layouts(self) -> List[Layout]: ... + def get_compiled_memory_stats(self) -> CompiledMemoryStats: ... + def serialize(self) -> str: ... + def cost_analysis(self) -> Dict[str, Any]: ... + +class DeviceTopology: + platform: str + platform_version: str + def _make_compile_only_devices(self) -> List[Device]: ... + def serialize(self) -> bytes: ... + def __getattr__(self, name: str) -> Any: ... + +def buffer_to_dlpack_managed_tensor( + buffer: ArrayImpl, stream: int | None = None +) -> Any: ... +@overload +def dlpack_managed_tensor_to_buffer( + tensor: Any, device: Device, stream: int | None +) -> ArrayImpl: ... +@overload +def dlpack_managed_tensor_to_buffer( # Legacy overload + tensor: Any, + cpu_backend: Optional[Client] = ..., + gpu_backend: Optional[Client] = ..., +) -> ArrayImpl: ... + +def cuda_array_interface_to_buffer( + cai: Dict[str, Union[ + str, int, None, + Tuple[int, ...], Tuple[int, bool], + List[Tuple[str, str]], + List[Tuple[str, str, Tuple[int, ...]]]] + ], + gpu_backend: Optional[Client] = ..., + device_id: int | None = None, +) -> ArrayImpl: ... + +# === BEGIN py_traceback.cc + +class Frame: + file_name: str + function_name: str + function_line_start: int + line_num: int + def __init__(self, + file_name: str, + function_name: str, + function_line_start: int, + line_num: int): ... + def __repr__(self) -> str: ... + +class Traceback: + enabled: ClassVar[bool] + @staticmethod + def get_traceback() -> Traceback: ... + @staticmethod + def traceback_from_frames(frames: Sequence[Frame]) -> Any: ... + frames: Sequence[Frame] + def __str__(self) -> str: ... + def as_python_traceback(self) -> Any: ... + def raw_frames(self) -> Tuple[List[types.CodeType], List[int]]: ... + @staticmethod + def code_addr2line(code: types.CodeType, lasti: int) -> int: ... + @staticmethod + def code_addr2location( + code: types.CodeType, lasti: int + ) -> Tuple[int, int, int, int]: ... + +def replace_thread_exc_traceback(traceback: Any): ... + +# === END py_traceback.cc + +class DistributedRuntimeService: + def shutdown(self) -> None: ... + +class DistributedRuntimeClient: + def connect(self) -> _Status: ... + def shutdown(self) -> _Status: ... + def blocking_key_value_get(self, key: str, timeout_in_ms: int) -> _Status: ... + def blocking_key_value_get_bytes( + self, key: str, timeout_in_ms: int + ) -> _Status: ... + def key_value_try_get(self, key: str) -> _Status: ... + def key_value_try_get_bytes(self, key: str) -> _Status: ... + def key_value_dir_get(self, key: str) -> _Status: ... + def key_value_dir_get_bytes(self, key: str) -> _Status: ... + def key_value_set(self, key: str, value: str, + allow_overwrite: bool = False) -> _Status: ... + def key_value_set_bytes(self, key: str, value: bytes, + allow_overwrite: bool = False) -> _Status: ... + def key_value_delete(self, key: str) -> _Status: ... + def wait_at_barrier( + self, barrier_id: str, timeout_in_ms: int, process_ids: Optional[List[int]] + ) -> _Status: ... + def get_live_nodes(self, process_ids: List[int]) -> _Status: ... + +def get_distributed_runtime_service( + address: str, + num_nodes: int, + heartbeat_interval: Optional[int] = ..., + max_missing_heartbeats: Optional[int] = ..., + cluster_register_timeout: Optional[int] = ..., + shutdown_timeout: Optional[int] = ..., +) -> DistributedRuntimeService: ... +def get_distributed_runtime_client( + address: str, + node_id: int, + rpc_timeout: Optional[int] = ..., + init_timeout: Optional[int] = ..., + shutdown_timeout: Optional[int] = ..., + heartbeat_interval: Optional[int] = ..., + max_missing_heartbeats: Optional[int] = ..., + missed_heartbeat_callback: Optional[Any] = ..., + shutdown_on_destruction: Optional[bool] = ..., + use_compression: Optional[bool] = ..., +) -> DistributedRuntimeClient: ... + +class PreemptionSyncManager: + def initialize(self, client: DistributedRuntimeClient) -> _Status: ... + def reached_sync_point(self, step_counter: int) -> bool: ... + +def create_preemption_sync_manager() -> PreemptionSyncManager: ... +def collect_garbage() -> None: ... +def is_optimized_build() -> bool: ... +def json_to_pprof_profile(json: str) -> bytes: ... +def pprof_profile_to_json(proto: bytes) -> str: ... + +class PmapFunction: + def __call__(self, *args, **kwargs) -> Any: ... + def __getstate__(self) -> Any: ... + def __setstate__(self, Any): ... + __signature__: inspect.Signature + def _cache_size(self) -> int: ... + def _cache_clear(self) -> None: ... + +def weakref_lru_cache( + cache_context_fn: Callable, call: Callable, maxsize=... +) -> WeakrefLRUCache: ... + +class DeviceList: + def __init__(self, device_assignment: Tuple[Device, ...]): ... + def __hash__(self) -> int: ... + def __eq__(self, other: Any) -> bool: ... + def __ne__(self, other: Any) -> bool: ... + def __len__(self) -> int: ... + def __getitem__(self, index: Any) -> Any: ... + def __iter__(self) -> Iterator[Device]: ... + def __str__(self) -> str: ... + def __repr__(self) -> str: ... + def __getstate__(self) -> Any: ... + def __setstate__(self, state: Any): ... + @property + def is_fully_addressable(self) -> bool: ... + @property + def addressable_device_list(self) -> DeviceList: ... + @property + def default_memory_kind(self) -> Optional[str]: ... + @property + def memory_kinds(self) -> Tuple[str, ...]: ... + +class Sharding: ... + +class NamedSharding(Sharding): + def __init__( + self, + mesh: Any, + spec: Any, + *, + memory_kind: Optional[str] = None, + _manual_axes: frozenset[Any] = frozenset(), + _logical_device_ids: tuple[int, ...] | None = None, + ): ... + mesh: Any + spec: Any + _memory_kind: Optional[str] + _internal_device_list: DeviceList + _manual_axes: frozenset[Any] + _logical_device_ids: tuple[int, ...] | None + +class SingleDeviceSharding(Sharding): + def __init__(self, device: Device, *, memory_kind: Optional[str] = None): ... + _device: Device + _memory_kind: Optional[str] + _internal_device_list: DeviceList + +class PmapSharding(Sharding): + def __init__( + self, devices: Sequence[Any], sharding_spec: pmap_lib.ShardingSpec + ): ... + devices: List[Any] + sharding_spec: pmap_lib.ShardingSpec + _internal_device_list: DeviceList + +class GSPMDSharding(Sharding): + def __init__( + self, + devices: Sequence[Device], + op_sharding: Union[OpSharding, HloSharding], + *, + memory_kind: Optional[str] = None, + _device_list: Optional[DeviceList] = None, + ): ... + _devices: Tuple[Device, ...] + _hlo_sharding: HloSharding + _memory_kind: Optional[str] + _internal_device_list: DeviceList + +class PjitFunction: + def __call__(self, *args, **kwargs) -> Any: ... + +class PjitFunctionCache: + def __init__(self, capacity: int = ...): ... + def __getstate__(self) -> Any: ... + def __setstate__(self, Any): ... + def size(self) -> int: ... + def capacity(self) -> int: ... + def clear(self): ... + @staticmethod + def clear_all(): ... + +def pjit( + function_name: str, + fun: Optional[Callable], + cache_miss: Callable, + static_argnums: Sequence[int], + static_argnames: Sequence[str], + global_cache_key: Any, + pytree_registry: pytree.PyTreeRegistry, + shard_arg_fallback: Callable, + cache: Optional[PjitFunctionCache] = ..., +) -> PjitFunction: ... + +class HloPassInterface: + @property + def name(self) -> str: ... + def is_pass_pipeline(self) -> bool: ... + def run(self, module: HloModule) -> bool: ... + def run_on_module_group(self, module_group: HloModuleGroup) -> bool: ... + +class HloDCE(HloPassInterface): + def __init__(self) -> None: ... + +class CallInliner(HloPassInterface): + def __init__(self) -> None: ... + +class FlattenCallGraph(HloPassInterface): + def __init__(self) -> None: ... + +class TupleSimplifer(HloPassInterface): + def __init__(self) -> None: ... + +class WeakrefLRUCacheInfo: + @property + def hits(self) -> int: ... + @property + def misses(self) -> int: ... + @property + def maxsize(self) -> int: ... + @property + def currsize(self) -> int: ... + +class WeakrefLRUCache: + def __call__(self, weakref_key: Any, *args, **kwargs) -> Any: ... + def cache_keys(self) -> list[Any]: ... + def cache_info(self) -> WeakrefLRUCacheInfo: ... + def cache_clear(self): ... + +def is_asan() -> bool: ... +def is_msan() -> bool: ... +def is_tsan() -> bool: ... +def is_sanitized() -> bool: ... + +class TransferConnection: + + def address(self) -> str: ... + + def _pull_flat(self, uuid, backend, avals_flat) -> list[Any]: ... + +class TransferServer: + def _await_pull_flat(self, uuid, args: list[ArrayImpl]): ... + + def connect(self, address: str) -> TransferConnection: ... + +def start_transfer_server(client: Client, address: str = "", transport_addresses: list[str] = [], max_num_parallel_copies: int = 0, transfer_size: int = 0) -> TransferServer: ... diff --git a/jaxlib/xla/xla_extension/config.pyi b/jaxlib/xla/xla_extension/config.pyi new file mode 100644 index 000000000000..535554559180 --- /dev/null +++ b/jaxlib/xla/xla_extension/config.pyi @@ -0,0 +1,32 @@ +# Copyright 2024 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Any, Generic, TypeVar + +unset: object + +_T = TypeVar('_T') + +class Config(Generic[_T]): + def __init__(self, value: _T, include_in_jit_key: bool = False): ... + + @property + def value(self) -> _T: ... + + def get_local(self) -> Any: ... + def get_global(self) -> _T: ... + def set_local(self, value: Any) -> None: ... + def swap_local(self, value: Any) -> Any: ... + def set_global(self, value: _T) -> None: ... diff --git a/jaxlib/xla/xla_extension/guard_lib.pyi b/jaxlib/xla/xla_extension/guard_lib.pyi new file mode 100644 index 000000000000..cfa8b0c5fa5e --- /dev/null +++ b/jaxlib/xla/xla_extension/guard_lib.pyi @@ -0,0 +1,46 @@ +# Copyright 2024 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Any, List, Optional + +class TransferGuardLevel: + ALLOW: Any + LOG: Any + DISALLOW: Any + LOG_EXPLICIT: Any + DISALLOW_EXPLICIT: Any + +class GarbageCollectionGuardLevel: + ALLOW: Any + LOG: Any + FATAL: Any + +class GuardState: + host_to_device: Optional[TransferGuardLevel] + device_to_device: Optional[TransferGuardLevel] + device_to_host: Optional[TransferGuardLevel] + + explicit_device_put: bool + explicit_device_get: bool + + garbage_collect_array: Optional[GarbageCollectionGuardLevel] + +def global_state() -> GuardState: ... +def thread_local_state() -> GuardState: ... + +class _TestingScopedLogSink: + def __enter__(self) -> _TestingScopedLogSink: ... + def __exit__(self, *args, **kwargs) -> None: ... + def logs(self) -> List[str]: ... diff --git a/jaxlib/xla/xla_extension/ifrt_programs.pyi b/jaxlib/xla/xla_extension/ifrt_programs.pyi new file mode 100644 index 000000000000..bcee365e5732 --- /dev/null +++ b/jaxlib/xla/xla_extension/ifrt_programs.pyi @@ -0,0 +1,43 @@ +# Copyright 2024 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Any, Sequence, Union + +from jax.jaxlib.xla import xla_extension + +class Program: ... + +class CompileOptions: ... + +def make_hlo_program(mlir_module: Union[str, bytes]) -> Program: ... + +def make_colocated_python_program( + name : str, + picked_function: bytes, + devices: Sequence[xla_extension.Device] | xla_extension.DeviceList, + input_avals: Sequence[Any], + output_avals: Sequence[Any], +) -> Program: ... + +def make_plugin_program(data: Union[str, bytes]) -> Program: ... + +def make_colocated_python_compile_options() -> CompileOptions: ... + +def make_xla_compile_options( + compile_options: xla_extension.CompileOptions, + host_callbacks: Sequence[Any] +) -> CompileOptions: ... + +def make_plugin_compile_options() -> CompileOptions: ... diff --git a/jaxlib/xla/xla_extension/ifrt_proxy.pyi b/jaxlib/xla/xla_extension/ifrt_proxy.pyi new file mode 100644 index 000000000000..3b5de7aa97c9 --- /dev/null +++ b/jaxlib/xla/xla_extension/ifrt_proxy.pyi @@ -0,0 +1,33 @@ +# Copyright 2024 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Any, Optional, Callable + +from jax.jaxlib.xla import xla_extension + +_Status = Any +Client = xla_extension.Client + + +class ClientConnectionOptions: + on_disconnect: Optional[Callable[[_Status], None]] = None + on_connection_update: Optional[Callable[[str], None]] = None + connection_timeout_in_seconds: Optional[int] = None + + +def get_client( + proxy_server_address: str, + options: ClientConnectionOptions +) -> Client: ... diff --git a/jaxlib/xla/xla_extension/jax_jit.pyi b/jaxlib/xla/xla_extension/jax_jit.pyi new file mode 100644 index 000000000000..1f78d283333c --- /dev/null +++ b/jaxlib/xla/xla_extension/jax_jit.pyi @@ -0,0 +1,76 @@ +# Copyright 2021 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Any, Callable, Optional, Sequence, Tuple + +import numpy as np +from jax.jaxlib.xla import xla_extension + +from . import pytree + +Client = xla_extension.Client +Device = xla_extension.Device + + +class JitState: + disable_jit: Optional[bool] + enable_x64: Optional[bool] + default_device: Optional[Any] + extra_jit_context: Optional[Any] + post_hook: Optional[Callable[..., Any]] + +def global_state() -> JitState: ... +def thread_local_state() -> JitState: ... + +def get_enable_x64() -> bool: ... +def set_thread_local_state_initialization_callback( + function: Callable[[], None]): ... + +def swap_thread_local_state_disable_jit( + value: Optional[bool]) -> Optional[bool]: ... + +class ArgSignature: + dtype: np.dtype + shape: Tuple[int, ...] + weak_type: bool + +def _ArgSignatureOfValue( + __arg: Any, + __jax_enable_x64: bool) -> ArgSignature: ... + +def _is_float0(__arg: Any) -> bool: ... + + +class ArgumentSignature: + static_args: Sequence[Any] + static_arg_names: Sequence[str] + dynamic_arg_names: Sequence[str] + dynamic_arg_treedefs: Sequence[pytree.PyTreeDef] + + def __eq__(self, value, /): ... + def __ne__(self, value, /): ... + def __hash__(self, /): ... + def __str__(self): ... + def __repr__(self): ... + + +def parse_arguments( + positional_args: Sequence[Any], + keyword_args: Sequence[Any], + kwnames: Tuple[str, ...], + static_argnums: Sequence[int], + static_argnames: Sequence[str], + pytree_registry: pytree.PyTreeRegistry, +) -> tuple[ArgumentSignature, Sequence[Any]]: ... diff --git a/jaxlib/xla/xla_extension/mlir.pyi b/jaxlib/xla/xla_extension/mlir.pyi new file mode 100644 index 000000000000..95eeae660c0c --- /dev/null +++ b/jaxlib/xla/xla_extension/mlir.pyi @@ -0,0 +1,34 @@ +# Copyright 2021 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Union +from . import XlaComputation + +def xla_computation_to_mlir_module(computation: XlaComputation) -> str: ... +def mlir_module_to_xla_computation( + mlir_module: Union[bytes, str], + use_tuple_args: bool = ..., + return_tuple: bool = ..., +) -> XlaComputation: ... +def mhlo_to_stablehlo(mlir_module: Union[bytes, str]) -> bytes: ... +def stablehlo_to_mhlo(mlir_module: Union[bytes, str]) -> bytes: ... +def serialize_portable_artifact(mlir_module: str, target: str) -> bytes: ... +def deserialize_portable_artifact(mlir_module: bytes) -> str: ... +def refine_polymorphic_shapes( + mlir_module: Union[bytes, str], + enable_shape_assertions: bool = ..., + validate_static_shapes: bool = ..., + enable_shardy: bool = ..., +) -> bytes: ... diff --git a/jaxlib/xla/xla_extension/ops.pyi b/jaxlib/xla/xla_extension/ops.pyi new file mode 100644 index 000000000000..ff55de3a5cdc --- /dev/null +++ b/jaxlib/xla/xla_extension/ops.pyi @@ -0,0 +1,465 @@ +# Copyright 2021 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import enum +from typing import Any, Optional, Sequence, overload + +from jax.jaxlib.xla import xla_extension + +FftType = xla_extension.FftType +XlaBuilder = xla_extension.XlaBuilder +XlaComputation = xla_extension.XlaComputation +XlaOp = xla_extension.XlaOp +PrecisionConfig_Precision = xla_extension.PrecisionConfig_Precision +PrimitiveType = xla_extension.PrimitiveType +Shape = xla_extension.Shape +ShapeIndex = xla_extension.ShapeIndex +ResultAccuracy = xla_extension.ResultAccuracy + +_ChannelHandle = Any +_ConvDimensionNumbers = Any +_DotDimensionNumbers = Any +_Layout = Any +_LiteralSlice = Any +_GatherDimensionNumbers = Any +_PaddingConfig = Any +_ReplicaGroup = Any +_ScatterDimensionNumbers = Any + +class TriangularSolveOptions_Transpose(enum.IntEnum): + TRANSPOSE_INVALID: int + NO_TRANSPOSE: int + TRANSPOSE: int + ADJOINT: int + +class RandomAlgorithm(enum.IntEnum): + RNG_DEFAULT: int + RNG_THREE_FRY: int + RNG_PHILOX: int + +class CustomCallSchedule(enum.IntEnum): + SCHEDULE_NONE: int + SCHEDULE_LATEST: int + SCHEDULE_EARLIEST: int + +# TODO(b/189822916): Remove this enum when all clients are migrated to the +# status-returning API. +class CustomCallApiVersion(enum.IntEnum): + API_VERSION_ORIGINAL: int + API_VERSION_STATUS_RETURNING: int + API_VERSION_STATUS_RETURNING_UNIFIED: int + API_VERSION_TYPED_FFI: int + +def AfterAll(builder: XlaBuilder, tokens: Sequence[XlaOp]) -> XlaOp: ... +def AllGather( + operand: XlaOp, + all_gather_dimension: int, + shard_count: int, + replica_groups: Sequence[_ReplicaGroup] = ..., + channel_id: Optional[_ChannelHandle] = ..., + shape_with_layout: Optional[_Layout] = ..., + use_global_device_ids: Optional[bool] = ...) -> XlaOp: ... +def AllReduce( + operand: XlaOp, + computation: XlaComputation, + replica_groups: Sequence[_ReplicaGroup] = ..., + channel_id: Optional[_ChannelHandle] = ..., + shape_with_layout: Optional[_Layout] = ...) -> XlaOp: ... +def ApproxTopK( + builder: XlaBuilder, + operands: Sequence[XlaOp], + init_values: Sequence[XlaOp], + top_k: int, + reduction_dim: int, + comparator: XlaComputation, + recall_target: Optional[float], + aggregate_to_topk: Optional[bool], + reduction_input_size_override: Optional[int]) -> XlaOp: ... +def ApproxTopKFallback( + builder: XlaBuilder, + operands: Sequence[XlaOp], + init_values: Sequence[XlaOp], + top_k: int, + reduction_dim: int, + comparator: XlaComputation, + recall_target: Optional[float], + aggregate_to_topk: Optional[bool], + reduction_input_size_override: Optional[int]) -> XlaOp: ... +def ApproxTopKReductionOutputSize( + input_size: int, + rank: int, + top_k: int, + recall_target: float, + aggregate_to_topk: Optional[bool] = ..., + input_size_override: Optional[int] = ...) -> tuple[int, int]: ... +def ReduceScatter( + operand: XlaOp, + computation: XlaComputation, + scatter_dimension: int, + shard_count: int, + replica_groups: Sequence[_ReplicaGroup] = ..., + channel_id: Optional[_ChannelHandle] = ..., + layout: Optional[_Layout] = ..., + use_global_device_ids: Optional[bool] = ...) -> XlaOp: ... +def AllToAll( + operand: XlaOp, + split_dimension: int, + concat_dimension: int, + split_count: int, + replica_groups: Sequence[_ReplicaGroup] = ..., + layout: Optional[_Layout] = ..., + channel_id: Optional[_ChannelHandle] = ...) -> XlaOp: ... +def BitcastConvertType(operand: XlaOp, + new_element_type: PrimitiveType) -> XlaOp: ... +def Broadcast(operand: XlaOp, sizes: Sequence[int]) -> XlaOp: ... +def BroadcastInDim(operand: XlaOp, + shape: Sequence[int], + broadcast_dimensions: Sequence[int]) -> XlaOp: ... +def Call(builder: XlaBuilder, + computation: XlaComputation, + operands: Sequence[XlaOp]) -> XlaOp: ... +def Cholesky(a: XlaOp, lower: bool = ...) -> XlaOp: ... +def Clamp(min: XlaOp, operand: XlaOp, max: XlaOp) -> XlaOp: ... +def Collapse(operand: XlaOp, dimensions: Sequence[int]) -> XlaOp: ... +def CollectivePermute( + operand: XlaOp, + source_target_pairs: Sequence[tuple[int, int]], + channel_id: Optional[_ChannelHandle] = ..., + inplace: bool = ...) -> XlaOp: ... +def ConcatInDim(builder: XlaBuilder, + operands: Sequence[XlaOp], + dimension: int) -> XlaOp: ... +@overload +def Conditional(branch_index: XlaOp, + branch_computations: Sequence[XlaComputation], + branch_operands: Sequence[XlaOp]) -> XlaOp: ... +@overload +def Conditional( + predicate: XlaOp, + true_operand: XlaOp, + true_computation: XlaComputation, + false_operand: XlaOp, + false_computation: XlaComputation) -> XlaOp: ... + +def Constant(builder: XlaBuilder, value: _LiteralSlice) -> XlaOp: ... +def ConstantLiteral(builder: XlaBuilder, value: _LiteralSlice) -> XlaOp: ... +def ConvGeneralDilated( + lhs: XlaOp, + rhs: XlaOp, + window_strides: Sequence[int], + padding: Sequence[tuple[int, int]], + lhs_dilation: Sequence[int], + rhs_dilation: Sequence[int], + dimension_numbers: _ConvDimensionNumbers, + feature_group_count: int = ..., + batch_group_count: int = ..., + precision_config: Optional[PrecisionConfig_Precision] = ..., + preferred_element_type: Optional[PrimitiveType] = ..., + window_reversal: Optional[Sequence[bool]] = ...) -> XlaOp: ... +def ConvertElementType( + operand: XlaOp, + new_element_type: PrimitiveType) -> XlaOp: ... +def CreateToken(builder: XlaBuilder) -> XlaOp: ... +def CrossReplicaSum( + operand: XlaOp, + replica_groups: Sequence[_ReplicaGroup] = ...) -> XlaOp: ... +def CustomCall( + builder: XlaBuilder, + call_target_name: bytes, + operands: Sequence[XlaOp], + shape: Shape, + opaque: bytes = ..., + has_side_effect: bool = ..., + schedule: CustomCallSchedule = ..., + api_version: CustomCallApiVersion = ...) -> XlaOp: ... +def CustomCallWithLayout( + builder: XlaBuilder, + call_target_name: bytes, + operands: Sequence[XlaOp], + shape_with_layout: Shape, + operand_shapes_with_layout: Sequence[Shape], + opaque: bytes = ..., + has_side_effect: bool = ..., + schedule: CustomCallSchedule = ..., + api_version: CustomCallApiVersion = ...) -> XlaOp: ... +def CustomCallWithAliasing( + builder: XlaBuilder, + call_target_name: bytes, + operands: Sequence[XlaOp], + shape_with_layout: Shape, + operand_shapes_with_layout: Sequence[Shape], + opaque: bytes = ..., + has_side_effect: bool = ..., + output_operand_aliasing: Sequence[tuple[ShapeIndex, tuple[int, ShapeIndex]]] = ..., + literal: _LiteralSlice = ..., + schedule: CustomCallSchedule = ..., + api_version: CustomCallApiVersion = ...) -> XlaOp: ... +def Dot( + lhs: XlaOp, + rhs: XlaOp, + precision_config: Optional[PrecisionConfig_Precision] = ..., + preferred_element_type: Optional[PrimitiveType] = ...) -> XlaOp: ... +def DotGeneral( + lhs: XlaOp, + rhs: XlaOp, + dimensions_numbers: _DotDimensionNumbers, + precision_config: Optional[PrecisionConfig_Precision] = ..., + preferred_element_type: Optional[PrimitiveType] = ...) -> XlaOp: ... +def DynamicReshape( + operand: XlaOp, + dim_sizes: Sequence[XlaOp], + new_size_bounds: Sequence[int], + dims_are_dynamic: Sequence[bool]) -> XlaOp: ... +def DynamicSlice( + operand: XlaOp, + start_indices: Sequence[XlaOp], + slice_sizes: Sequence[int]) -> XlaOp: ... +def DynamicUpdateSlice( + operand: XlaOp, + update: XlaOp, + start_indices: Sequence[XlaOp]) -> XlaOp: ... +def Eigh( + a: XlaOp, + lower: bool = ..., + max_iter: int = ..., + epsilon: float = ..., + sort_eigenvalues: bool = ...) -> tuple[XlaOp, XlaOp]: ... +def Fft( + operand: XlaOp, + fft_type: FftType, + fft_length: Sequence[int]) -> XlaOp: ... +def Gather( + a: XlaOp, + start_indices: XlaOp, + dimension_numbers: _GatherDimensionNumbers, + slice_sizes: Sequence[int], + indices_are_sorted: bool = ...) -> XlaOp: ... +def GetDimensionSize(operand: XlaOp, index: int) -> XlaOp: ... +def GetTupleElement(tuple_data: XlaOp, index: int) -> XlaOp: ... +def InfeedWithToken( + token: XlaOp, + shape: Shape, + config: Optional[str] = ...) -> XlaOp: ... +@overload +def Iota(builder: XlaBuilder, shape: Shape, iota_dimension: int) -> XlaOp: ... +@overload +def Iota(builder: XlaBuilder, type: PrimitiveType, size: int) -> XlaOp: ... +def LU(a: XlaOp) -> tuple[XlaOp, XlaOp, XlaOp]: ... +def Map( + builder: XlaBuilder, + operands: Sequence[XlaOp], + computation: XlaComputation, + dimensions: Sequence[int], + static_operands: Sequence[XlaOp] = ...) -> XlaOp: ... +def MultiCollectivePermute( + operands: Sequence[XlaOp], + source_target_pairs: Sequence[tuple[int, int]], + channel_id: Optional[_ChannelHandle] = ..., + inplace: bool = ...) -> XlaOp: ... +def NextAfter(__from: XlaOp, to: XlaOp) -> XlaOp: ... +def OutfeedWithToken( + operand: XlaOp, + token: XlaOp, + shape_with_layout: Shape, + outfeed_config: Optional[str] = ...) -> XlaOp: ... +def Pad( + operand: XlaOp, + padding_value: XlaOp, + padding_config: _PaddingConfig) -> XlaOp: ... +def Parameter( + builder: XlaBuilder, + parameter_number: int, + shape: Shape, + name: str = ..., + replicated_at_leaf_buffers: Sequence[bool] = ...) -> XlaOp: ... +def ProductOfElementaryHouseholderReflectors(a: XlaOp, taus: XlaOp) -> XlaOp: ... +def QR(a: XlaOp, full_matrices: bool) -> tuple[XlaOp, XlaOp]: ... +def QrDecomposition(a: XlaOp) -> tuple[XlaOp, XlaOp]: ... +def Reduce( + builder: XlaBuilder, + operands: Sequence[XlaOp], + init_values: Sequence[XlaOp], + computation: XlaComputation, + dimensions_to_reduce: Sequence[int]) -> XlaOp: ... +def ReducePrecision( + operand: XlaOp, + exponent_bits: int, + mantissa_bits: int) -> XlaOp: ... +@overload +def ReduceWindowWithGeneralPadding( + operand: XlaOp, + init_value: XlaOp, + computation: XlaComputation, + window_dimensions: Sequence[int], + window_strides: Sequence[int], + base_dilations: Sequence[int], + window_dilations: Sequence[int], + padding: Sequence[tuple[int, int]]) -> XlaOp: ... +@overload +def ReduceWindowWithGeneralPadding( + operands: Sequence[XlaOp], + init_values: Sequence[XlaOp], + computation: XlaComputation, + window_dimensions: Sequence[int], + window_strides: Sequence[int], + base_dilations: Sequence[int], + window_dilations: Sequence[int], + padding: Sequence[tuple[int, int]]) -> XlaOp: ... +def ReplicaId(builder: XlaBuilder) -> XlaOp: ... +def Reshape(operand: XlaOp, new_sizes: Sequence[int]) -> XlaOp: ... +def Rev(operand: XlaOp, dimensions: Sequence[int]) -> XlaOp: ... +def RngBitGenerator( + algorithm: RandomAlgorithm, + initial_state: XlaOp, + shape: Shape) -> XlaOp: ... +def RngNormal(mu: XlaOp, sigma: XlaOp, shape: Shape) -> XlaOp: ... +def RngUniform(a: XlaOp, b: XlaOp, shape: Shape) -> XlaOp: ... +@overload +def Scatter( + input: XlaOp, + scatter_indices: XlaOp, + updates: XlaOp, + update_computation: XlaComputation, + dimension_numbers: _ScatterDimensionNumbers, + indices_are_sorted: bool = ..., + unique_indices: bool = ...) -> XlaOp: ... +@overload +def Scatter( + inputs: Sequence[XlaOp], + scatter_indices: XlaOp, + updates: Sequence[XlaOp], + update_computation: XlaComputation, + dimension_numbers: _ScatterDimensionNumbers, + indices_are_sorted: bool = ..., + unique_indices: bool = ...) -> XlaOp: ... +def Select(pred: XlaOp, on_true: XlaOp, on_false: XlaOp) -> XlaOp: ... +def SelectAndScatterWithGeneralPadding( + operand: XlaOp, + select: XlaComputation, + window_dimensions: Sequence[int], + window_strides: Sequence[int], + padding: Sequence[tuple[int, int]], + source: XlaOp, + init_value: XlaOp, + scatter: XlaComputation) -> XlaOp: ... +def Slice( + operand: XlaOp, + start_indices: Sequence[int], + limit_indices: Sequence[int], + strides: Sequence[int]) -> XlaOp: ... +def SliceInDim( + operand: XlaOp, + start_index: int, + limit_index: int, + stride: int, + dimno: int) -> XlaOp: ... +def Sort( + builder: XlaBuilder, + operands: Sequence[XlaOp], + comparator: Optional[XlaComputation] = ..., + dimension: int = ..., + is_stable: bool = ...) -> XlaOp: ... +def SVD( + a: XlaOp, + max_iter: int = ..., + epsilon: float = ...) -> tuple[XlaOp, XlaOp, XlaOp]: ... +def TopK(input: XlaOp, k: int) -> XlaOp: ... +def Transpose(operand: XlaOp, permutation: Sequence[int]) -> XlaOp: ... +def TriangularSolve( + a: XlaOp, + b: XlaOp, + left_side: bool, + lower: bool, + unit_diagonal: bool, + transpose_a: TriangularSolveOptions_Transpose) -> XlaOp: ... +def Tuple(builder: XlaBuilder, elements: Sequence[XlaOp]) -> XlaOp: ... +def While( + condition: XlaComputation, + body: XlaComputation, + init: XlaOp) -> XlaOp: ... + + +def Igamma(a: XlaOp, x: XlaOp) -> XlaOp: ... +def Igammac(a: XlaOp, x: XlaOp) -> XlaOp: ... +def IgammaGradA(a: XlaOp, x: XlaOp) -> XlaOp: ... +def RandomGammaGrad(a: XlaOp, x: XlaOp) -> XlaOp: ... +def RegularizedIncompleteBeta(a: XlaOp, b: XlaOp, x: XlaOp) -> XlaOp: ... +def Zeta(a: XlaOp, q: XlaOp) -> XlaOp: ... + +def Eq(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Ne(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Ge(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Gt(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Lt(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Le(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Add(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Sub(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Mul(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Div(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Rem(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Max(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Min(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def And(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Or(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Xor(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def ShiftLeft(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def ShiftRightArithmetic(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def ShiftRightLogical(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Atan2(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Pow(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... +def Complex(lhs: XlaOp, rhs: XlaOp, broadcast_dimensions: Sequence[int] = ...) -> XlaOp: ... + +def Not(__arg: XlaOp) -> XlaOp: ... +def PopulationCount(__arg: XlaOp) -> XlaOp: ... +def Clz(__arg: XlaOp) -> XlaOp: ... +def Abs(__arg: XlaOp) -> XlaOp: ... +def Exp(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ... +def Expm1(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ... +def Floor(__arg: XlaOp) -> XlaOp: ... +def Ceil(__arg: XlaOp) -> XlaOp: ... +def Round(__arg: XlaOp) -> XlaOp: ... +def Log(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ... +def Log1p(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ... +def Sign(__arg: XlaOp) -> XlaOp: ... +def Cos(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ... +def OptimizationBarrier(__arg: XlaOp) -> XlaOp: ... +def Sin(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ... +def Tan(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ... +def Tanh(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ... +def IsFinite(__arg: XlaOp) -> XlaOp: ... +def Neg(__arg: XlaOp) -> XlaOp: ... +def Sqrt(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ... +def Rsqrt(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ... +def Cbrt(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ... +def Square(__arg: XlaOp) -> XlaOp: ... +def Reciprocal(__arg: XlaOp) -> XlaOp: ... +def Erfc(__arg: XlaOp) -> XlaOp: ... +def Erf(operand: XlaOp, result_accuracy: ResultAccuracy = ...) -> XlaOp: ... +def ErfInv(__arg: XlaOp) -> XlaOp: ... +def Lgamma(__arg: XlaOp) -> XlaOp: ... +def Digamma(__arg: XlaOp) -> XlaOp: ... +def BesselI0e(__arg: XlaOp) -> XlaOp: ... +def BesselI1e(__arg: XlaOp) -> XlaOp: ... +def Acos(__arg: XlaOp) -> XlaOp: ... +def Asin(__arg: XlaOp) -> XlaOp: ... +def Atan(__arg: XlaOp) -> XlaOp: ... +def Acosh(__arg: XlaOp) -> XlaOp: ... +def Asinh(__arg: XlaOp) -> XlaOp: ... +def Atanh(__arg: XlaOp) -> XlaOp: ... +def Cosh(__arg: XlaOp) -> XlaOp: ... +def Sinh(__arg: XlaOp) -> XlaOp: ... +def Real(__arg: XlaOp) -> XlaOp: ... +def Imag(__arg: XlaOp) -> XlaOp: ... +def Conj(__arg: XlaOp) -> XlaOp: ... diff --git a/jaxlib/xla/xla_extension/pmap_lib.pyi b/jaxlib/xla/xla_extension/pmap_lib.pyi new file mode 100644 index 000000000000..8733d6c27b21 --- /dev/null +++ b/jaxlib/xla/xla_extension/pmap_lib.pyi @@ -0,0 +1,83 @@ +# Copyright 2021 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import inspect +from typing import Any, Callable, Sequence, Iterable, Tuple + +from . import pytree + +_AvalDimSharding = Any +_MeshDimAssignment = Any + +class NoSharding: + def __init__(self) -> None: ... + def __repr__(self) -> str: ... + def __eq__(self, __other: Any) -> bool: ... + +class Chunked: + @property + def chunks(self) -> Sequence[int]: ... + def __init__(self, __chunks: Sequence[int]) -> None: ... + def __repr__(self) -> str: ... + def __eq__(self, __other: Any) -> bool: ... + +class Unstacked: + @property + def size(self) -> int: ... + def __init__(self, __sz: int) -> None: ... + def __repr__(self) -> str: ... + def __eq__(self, __other: Any) -> bool: ... + +class ShardedAxis: + @property + def axis(self) -> int: ... + def __init__(self, __axis: int) -> None: ... + def __repr__(self) -> str: ... + def __eq__(self, __other: ShardedAxis) -> bool: ... + +class Replicated: + @property + def replicas(self) -> int: ... + def __init__(self, __replicas: int) -> None: ... + def __repr__(self) -> str: ... + def __eq__(self, __other: Replicated) -> bool: ... + +class ShardingSpec: + def __init__(self, + sharding: Iterable[_AvalDimSharding], + mesh_mapping: Iterable[_MeshDimAssignment]) -> None: ... + @property + def sharding(self) -> Tuple[_AvalDimSharding, ...]: ... + @property + def mesh_mapping(self) -> Tuple[_MeshDimAssignment]: ... + def __eq__(self, __other: ShardingSpec) -> bool: ... + def __hash__(self) -> int: ... + + _HAS_DYNAMIC_ATTRIBUTES = True + +class PmapFunction: + def __call__(self, *args, **kwargs) -> Any: ... + def __getstate__(self) -> Any: ... + def __setstate__(self, Any): ... + __signature__: inspect.Signature + def _cache_size(self) -> int: ... + def _cache_clear(self) -> None: ... + def _debug_cache_keys(self) -> str: ... + +def pmap(fun: Callable[..., Any], + cache_miss: Callable[..., Any], + static_argnums: Sequence[int], + shard_arg_fallback: Callable[..., Any], + pytree_registry: pytree.PyTreeRegistry) -> PmapFunction: ... diff --git a/jaxlib/xla/xla_extension/profiler.pyi b/jaxlib/xla/xla_extension/profiler.pyi new file mode 100644 index 000000000000..7610ce1000bf --- /dev/null +++ b/jaxlib/xla/xla_extension/profiler.pyi @@ -0,0 +1,58 @@ +# Copyright 2021 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from types import TracebackType +from typing import Any, Optional, Type, Union, List, Tuple + +_Status = Any + +class ProfilerServer: ... +def start_server(port: int) -> ProfilerServer: ... + +def register_plugin_profiler(c_api: Any) -> None: ... + +def get_profiled_instructions_proto(tensorboard_dir: str) -> bytes: ... +def get_instructins_profile(tensorboard_dir: str) -> List[Tuple[str, float]]: ... +def get_fdo_profile( + xspace: bytes, as_textproto: bool = ... +) -> Union[bytes, str]: ... + +class ProfilerSession: + def __init__(self, options: Optional[ProfileOptions] = ...) -> None: ... + def stop(self) -> bytes: ... + def export(self, xspace: bytes, tensorboard_dir: str) -> _Status:... + +class ProfileOptions: + include_dataset_ops: bool + host_tracer_level: int + python_tracer_level: int + enable_hlo_proto: bool + start_timestamp_ns: int + duration_ms: int + repository_path: str + +def aggregate_profiled_instructions(profiles: List[bytes], percentile: int) -> str: ... + +class TraceMe: + def __init__(self, name: str, **kwargs: Any) -> None: ... + def __enter__(self) -> TraceMe: ... + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + exc_tb: Optional[TracebackType]) -> Optional[bool]:... + def set_metadata(self, **kwargs): ... + @staticmethod + def is_enabled() -> bool: ... diff --git a/jaxlib/xla/xla_extension/pytree.pyi b/jaxlib/xla/xla_extension/pytree.pyi new file mode 100644 index 000000000000..bfbad5de89d5 --- /dev/null +++ b/jaxlib/xla/xla_extension/pytree.pyi @@ -0,0 +1,158 @@ +# Copyright 2021 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import ( + Any, + Callable, + Hashable, + Iterable, + List, + Optional, + Sequence, + Tuple, + Type, + TypeVar, +) + +_T = TypeVar("_T") + +version: int + +class PyTreeRegistry: + def __init__( + self, + *, + enable_none: bool = ..., + enable_tuple: bool = ..., + enable_namedtuple: bool = ..., + enable_list: bool = ..., + enable_dict: bool = ... + ): ... + def flatten( + self, + tree: Any, + leaf_predicate: Optional[Callable[[Any], bool]] = ..., + ) -> Tuple[List[Any], PyTreeDef]: ... + def flatten_one_level( + self, tree: Any + ) -> Optional[Tuple[Iterable[Any], Any]]: ... + def flatten_one_level_with_keys( + self, tree: Any + ) -> Optional[Tuple[Iterable[_KeyLeafPair], Any]]: ... + def flatten_with_path( + self, + tree: Any, + leaf_predicate: Optional[Callable[[Any], bool]] = ..., + ) -> Tuple[List[Tuple[_KeyPath, Any]], PyTreeDef]: ... + def register_node( + self, + __type: Type[_T], + to_iterable: Callable[[_T], Tuple[_Children, _AuxData]], + from_iterable: Callable[[_AuxData, _Children], _T], + to_iterable_with_keys: ( + Callable[[_T], Tuple[_KeyLeafPairs, _AuxData]] | None + ) = ..., + ) -> Any: ... + def register_dataclass_node( + self, __type: Type[_T], meta_fields: List[str], data_fields: List[str] + ) -> Any: ... + +def default_registry() -> PyTreeRegistry: ... +def tuple(registry: PyTreeRegistry, arg0: Sequence[PyTreeDef]) -> PyTreeDef: ... +def all_leaves(registry: PyTreeRegistry, arg0: Iterable[Any]) -> bool: ... + +class SequenceKey(Hashable): + idx: int + __match_args__: tuple = ... + def __init__(self, idx: int): ... + def __str__(self) -> str: ... + def __repr__(self) -> str: ... + def __hash__(self) -> int: ... + def __getstate__(self) -> Any: ... + def __setstate__(self, state: Any): ... + def __eq__(self, __other: Any) -> bool: ... + +class DictKey(Hashable): + key: Hashable + __match_args__: tuple = ... + def __init__(self, key: Hashable): ... + def __str__(self) -> str: ... + def __repr__(self) -> str: ... + def __hash__(self) -> int: ... + def __getstate__(self) -> Any: ... + def __setstate__(self, state: Any): ... + def __eq__(self, __other: Any) -> bool: ... + +class GetAttrKey(Hashable): + name: str + __match_args__: tuple = ... + def __init__(self, name: str): ... + def __str__(self) -> str: ... + def __repr__(self) -> str: ... + def __hash__(self) -> int: ... + def __getstate__(self) -> Any: ... + def __setstate__(self, state: Any): ... + def __eq__(self, __other: Any) -> bool: ... + +class FlattenedIndexKey(Hashable): + key: int + __match_args__: tuple = ... + def __init__(self, key: int): ... + def __str__(self) -> str: ... + def __repr__(self) -> str: ... + def __hash__(self) -> int: ... + def __getstate__(self) -> Any: ... + def __setstate__(self, state: Any): ... + def __eq__(self, __other: Any) -> bool: ... + +class PyTreeDef: + def unflatten(self, __leaves: Iterable[Any]) -> Any: ... + def flatten_up_to(self, __xs: Any) -> List[Any]: ... + def compose(self, __inner: PyTreeDef) -> PyTreeDef: ... + def walk( + self, + __f_node: Callable[[Any, Any], Any], + __f_leaf: Optional[Callable[[_T], Any]], + leaves: Iterable[Any], + ) -> Any: ... + def from_iterable_tree(self, __xs: Any): ... + def node_data(self) -> Optional[Tuple[Type, Any]]: ... + def children(self) -> List[PyTreeDef]: ... + @staticmethod + def make_from_node_data_and_children( + registry: PyTreeRegistry, + node_data: Optional[Tuple[Type, Any]], + children: Iterable[PyTreeDef], + ) -> PyTreeDef: ... + + num_leaves: int + num_nodes: int + def __repr__(self) -> str: ... + def __eq__(self, __other: PyTreeDef) -> bool: ... + def __ne__(self, __other: PyTreeDef) -> bool: ... + def __hash__(self) -> int: ... + def __getstate__(self) -> Any: ... + def __setstate__(self, state: Any): ... + def serialize_using_proto(self) -> bytes: ... + @staticmethod + def deserialize_using_proto( + registry: PyTreeRegistry, data: bytes + ) -> PyTreeDef: ... + +_Children = TypeVar("_Children", bound=Iterable[Any]) +_KeyLeafPair = TypeVar("_KeyLeafPair", bound=Tuple[Any, Any]) +_KeyLeafPairs = TypeVar("_KeyLeafPairs", bound=Iterable[_KeyLeafPair]) +_KeyPath = TypeVar("_KeyPath", bound=Tuple[Any, ...]) +_AuxData = TypeVar("_AuxData", bound=Hashable) diff --git a/jaxlib/xla/xla_extension/sdy.pyi b/jaxlib/xla/xla_extension/sdy.pyi new file mode 100644 index 000000000000..34714e5c0219 --- /dev/null +++ b/jaxlib/xla/xla_extension/sdy.pyi @@ -0,0 +1,32 @@ +# Copyright 2021 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from mlir import ir + +def sdy_round_trip_export_pipeline( + module: ir.module +) -> str: ... + +def sdy_round_trip_import_shardings( + module: ir.module +) -> str: ... + +def get_mesh( + module: ir.module +) -> tuple[tuple[str, int], ...]: ... + +def lowered_with_shardy( + module: ir.module +) -> bool: ... diff --git a/jaxlib/xla/xla_extension/transfer_guard_lib.pyi b/jaxlib/xla/xla_extension/transfer_guard_lib.pyi new file mode 100644 index 000000000000..091e1e10a742 --- /dev/null +++ b/jaxlib/xla/xla_extension/transfer_guard_lib.pyi @@ -0,0 +1,39 @@ +# Copyright 2022 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Any, List, Optional + +class TransferGuardLevel: + ALLOW: Any + LOG: Any + DISALLOW: Any + LOG_EXPLICIT: Any + DISALLOW_EXPLICIT: Any + +class TransferGuardState: + host_to_device: Optional[TransferGuardLevel] + device_to_device: Optional[TransferGuardLevel] + device_to_host: Optional[TransferGuardLevel] + + explicit_device_put: bool + explicit_device_get: bool + +def global_state() -> TransferGuardState: ... +def thread_local_state() -> TransferGuardState: ... + +class _TestingScopedLogSink: + def __enter__(self) -> _TestingScopedLogSink: ... + def __exit__(self, *args, **kwargs) -> None: ... + def logs(self) -> List[str]: ... diff --git a/jaxlib/xla_extension.py b/jaxlib/xla_extension.py new file mode 100644 index 000000000000..e4fc7e96a1ab --- /dev/null +++ b/jaxlib/xla_extension.py @@ -0,0 +1,17 @@ +# Copyright 2025 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from jaxlib.xla.xla_extension import * # noqa: F403 +from jaxlib.xla.xla_extension import sdy # noqa: F401 From 396e389001ce3d6f6e3f1bc944245868968539f3 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Fri, 21 Mar 2025 22:02:58 -0700 Subject: [PATCH 099/483] [pallas] Add `_zeros[_like]` and `_ones[_like]` utility functions in Triton lowering. PiperOrigin-RevId: 739395754 --- jax/_src/pallas/triton/lowering.py | 40 +++++++++++++++++++----------- 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 64bf635a34ed..bc7144f376b4 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -1120,7 +1120,7 @@ def inner(ctx: LoweringRuleContext, *args: ir.Value) -> ir.Value: def _minus(x: ir.Value) -> ir.Value: if tt_dialect.PointerType.isinstance(_element_type(x.type)): raise NotImplementedError(f"unsupported type: {x.type}") - return _sub(_full(x.type, 0), x) + return _sub(_zeros_like(x), x) def _add(x: ir.Value, y: ir.Value): @@ -1377,7 +1377,7 @@ def _broadcast_to_rule(ctx: LoweringRuleContext, x, shape: Sequence[int]): @register_lowering(lax.integer_pow_p) def _integer_pow_rule(ctx: LoweringRuleContext, x, *, y: int): if y == 0: - return _full(x.type, 1) + return _ones_like(x) is_reciprocal = y < 0 if is_reciprocal: @@ -1397,7 +1397,7 @@ def _integer_pow_rule(ctx: LoweringRuleContext, x, *, y: int): acc = _cast(acc, x_aval.dtype, out_aval.dtype) if is_reciprocal: signed = jnp.issubdtype(out_aval.dtype, jnp.signedinteger) - return _truediv(_full(acc.type, 1), acc, signed=signed) + return _truediv(_ones_like(acc), acc, signed=signed) else: return acc @@ -1518,6 +1518,22 @@ def _full(t: ir.Type, v: object) -> ir.Type: return result +def _zeros(t: ir.Type) -> ir.Value: + return _full(t, 0) + + +def _zeros_like(x: ir.Value) -> ir.Value: + return _full(x.type, 0) + + +def _ones(t: ir.Type) -> ir.Value: + return _full(t, 1) + + +def _ones_like(x: ir.Value) -> ir.Value: + return _full(x.type, 1) + + def _splat(x: ir.value, shape: Sequence[int]) -> ir.Value: if ir.RankedTensorType.isinstance(x.type): raise TypeError("cannot splat a tensor") @@ -1556,7 +1572,7 @@ def _int_int_cast(src: ir.Value, dst_type: ir.Type, signed: bool) -> ir.Value: dst_element_type = ir.IntegerType(_element_type(dst_type)) assert src_element_type != dst_element_type if dst_element_type.width == 1: - return _not_equal(src, _full(src.type, 0), signed=signed) + return _not_equal(src, _zeros_like(src), signed=signed) if src_element_type.width == dst_element_type.width: return arith_dialect.bitcast(dst_type, src) @@ -1576,7 +1592,7 @@ def _float_int_cast( raise NotImplementedError(f"cannot cast {src} tp {dst_type}") dst_element_type = ir.IntegerType(_element_type(dst_type)) if dst_element_type.width == 1: - return _not_equal(src, _full(src.type, 0), signed=signed) + return _not_equal(src, _zeros_like(src), signed=signed) else: # We clamp the float value to the min/max integer destination value # in order to match JAX/XLA casting behavior. Note that this differs @@ -1679,7 +1695,7 @@ def _ir_cast(src: ir.Value, dst_type: ir.Type, *, return tt_dialect.ptr_to_int(dst_type, src) elif dst_element_type.width == 1: x = _ir_cast(src, ir.IntegerType.get_signless(64), signed=signed) - zero = _full(x.type, 0) + zero = _zeros_like(x) return _ir_cast(_not_equal(x, zero, signed=signed), dst_type, signed=signed) if isinstance( src_element_type, ir.IntegerType @@ -1802,7 +1818,7 @@ def _compute_offsets_from_indices( # Use 64-bit indexing when offset might be >= 2**32 bytes. offset_eltype = ir.IntegerType.get_signless(64 if full_size > 2**32 else 32) if indexer_shape: - offsets = _full(ir.RankedTensorType.get(indexer_shape, offset_eltype), 0) + offsets = _zeros(ir.RankedTensorType.get(indexer_shape, offset_eltype)) else: offsets = _ir_constant(0, offset_eltype) @@ -2074,7 +2090,7 @@ def _masked_load_lowering_rule( offsets = _ir_cast(offsets, ir.IntegerType.get_signless(32), signed=False) in_msb = _mod(offsets, _full(offsets.type, 2), signed=False) if jaxlib_version < (0, 5, 2): - in_msb = arith_dialect.xori(in_msb, _full(in_msb.type, 1)) + in_msb = arith_dialect.xori(in_msb, _ones_like(in_msb)) shift = _mul(in_msb, _full(in_msb.type, 4)) shift = _ir_cast(shift, values.type, signed=False) values = arith_dialect.shrui(values, shift) @@ -2280,7 +2296,7 @@ def _dot_general_lowering( m, _ = a_type.shape _, n = b_type.shape - acc = _full(ir.RankedTensorType.get([m, n], _dtype_to_ir_type(acc_dtype)), 0) + acc = _zeros(ir.RankedTensorType.get([m, n], _dtype_to_ir_type(acc_dtype))) if precision == lax.DotAlgorithmPreset.BF16_BF16_F32_X3: bf16 = _dtype_to_ir_type(jnp.bfloat16) @@ -2297,11 +2313,7 @@ def _dot_general_lowering( acc = tt_dialect.dot(a_bf16, b_err0, acc) # If `a_err0` will be zero and `b` is infinite, then `acc` may contain # `NaN`s (as `0 * inf = NaN`), and vice versa. - acc = arith_dialect.select( - _is_nan(acc), - _full(ir.RankedTensorType.get([m, n], _dtype_to_ir_type(acc_dtype)), 0), - acc, - ) + acc = arith_dialect.select(_is_nan(acc), _zeros_like(acc), acc) a, b = a_bf16, b_bf16 acc = tt_dialect.dot(a, b, acc, input_precision=input_precision) From fd0ac0229ff8006e3105615f0837d2f224ff1095 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Fri, 21 Mar 2025 22:13:40 -0700 Subject: [PATCH 100/483] [mosaic_gpu] Add `cupti_no_finalize` profiler mode. PiperOrigin-RevId: 739397564 --- jax/experimental/mosaic/gpu/profiler.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/jax/experimental/mosaic/gpu/profiler.py b/jax/experimental/mosaic/gpu/profiler.py index 32b3edf7caf9..011b921d728e 100644 --- a/jax/experimental/mosaic/gpu/profiler.py +++ b/jax/experimental/mosaic/gpu/profiler.py @@ -98,7 +98,7 @@ def run(*args, **kwargs): return outs, float(elapsed) -def _measure_cupti(f, aggregate): +def _measure_cupti(f, aggregate, *, finalize=True): if not isinstance(f, (stages.Wrapped, stages.Compiled)): f = jax.jit(f) @@ -108,7 +108,7 @@ def wrapper(*args, **kwargs): try: results = jax.block_until_ready(f(*args, **kwargs)) finally: - timings = mosaic_gpu_lib._mosaic_gpu_ext._cupti_get_timings() + timings = mosaic_gpu_lib._mosaic_gpu_ext._cupti_get_timings(finalize) if not timings: return results, None @@ -133,6 +133,7 @@ def measure(f: Callable, *, mode: str = "events", aggregate: bool = True mode: The mode of operation. Possible values are: - "cupti", for CUPTI-based profiling. + - "cupti_no_finalize", as above, but CUPTI left attached to the process. - "events", for CUDA events-based profiling. The two modes use different measurement methodologies and should not be @@ -175,10 +176,12 @@ def measure(f: Callable, *, mode: str = "events", aggregate: bool = True In an attempt to minimize the second effect, internally the events-based implementation may execute ``f`` more than once to "warm up" and exclude compilation time from the measurement. - """ + """ # fmt: skip match mode: case "cupti": return _measure_cupti(f, aggregate) + case "cupti_no_finalize": + return _measure_cupti(f, aggregate, finalize=False) case "events": if not aggregate: raise ValueError(f"{aggregate=} is not supported with {mode=}") From 74977938d8355b41c389b146f4c71f205ceff3ec Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Fri, 21 Mar 2025 22:32:43 -0700 Subject: [PATCH 101/483] [pallas] Add support for `DotAlgorithmPreset.BF16_BF16_F32_X{6,9}` in Triton lowering. PiperOrigin-RevId: 739400359 --- jax/_src/pallas/triton/lowering.py | 50 ++++++++++++++++++++++-------- tests/pallas/pallas_test.py | 9 ++++-- 2 files changed, 44 insertions(+), 15 deletions(-) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index bc7144f376b4..0077ec55ace8 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -2218,6 +2218,14 @@ def _transpose_lowering(ctx: LoweringRuleContext, x, *, permutation): _TF32_PRECISIONS = (lax.Precision.HIGH, lax.Precision.DEFAULT) +def _as_bf16(x): + return _ir_cast(x, _dtype_to_ir_type(jnp.bfloat16), signed=False) + + +def _as_f32(x): + return _ir_cast(x, _dtype_to_ir_type(jnp.float32), signed=False) + + @register_lowering(lax.dot_general_p) def _dot_general_lowering( ctx: LoweringRuleContext, @@ -2258,6 +2266,8 @@ def _dot_general_lowering( | lax.DotAlgorithmPreset.BF16_BF16_BF16 | lax.DotAlgorithmPreset.BF16_BF16_F32 | lax.DotAlgorithmPreset.BF16_BF16_F32_X3 + | lax.DotAlgorithmPreset.BF16_BF16_F32_X6 + | lax.DotAlgorithmPreset.BF16_BF16_F32_X9 ): input_precision = None case _: @@ -2298,20 +2308,34 @@ def _dot_general_lowering( _, n = b_type.shape acc = _zeros(ir.RankedTensorType.get([m, n], _dtype_to_ir_type(acc_dtype))) - if precision == lax.DotAlgorithmPreset.BF16_BF16_F32_X3: - bf16 = _dtype_to_ir_type(jnp.bfloat16) - f32 = _dtype_to_ir_type(jnp.float32) - as_bf16 = lambda x: _ir_cast(x, bf16, signed=False) - as_f32 = lambda x: _ir_cast(x, f32, signed=False) - - a_bf16 = as_bf16(a) - b_bf16 = as_bf16(b) - a_err0 = as_bf16(_sub(a, as_f32(a_bf16))) - b_err0 = as_bf16(_sub(b, as_f32(b_bf16))) + if precision in ( + lax.DotAlgorithmPreset.BF16_BF16_F32_X3, + lax.DotAlgorithmPreset.BF16_BF16_F32_X6, + lax.DotAlgorithmPreset.BF16_BF16_F32_X9, + ): + a_bf16 = _as_bf16(a) + b_bf16 = _as_bf16(b) + a_err0 = _sub(a, _as_f32(a_bf16)) + b_err0 = _sub(b, _as_f32(b_bf16)) + a_err0_bf16 = _as_bf16(a_err0) + b_err0_bf16 = _as_bf16(b_err0) + a_err1_bf16 = _as_bf16(_sub(a_err0, _as_f32(a_err0_bf16))) + b_err1_bf16 = _as_bf16(_sub(b_err0, _as_f32(b_err0_bf16))) # Accumulate the smallest values first to reduce the numeric error. - acc = tt_dialect.dot(a_err0, b_bf16, acc) - acc = tt_dialect.dot(a_bf16, b_err0, acc) - # If `a_err0` will be zero and `b` is infinite, then `acc` may contain + if precision == lax.DotAlgorithmPreset.BF16_BF16_F32_X9: + acc = tt_dialect.dot(a_err1_bf16, b_err0_bf16, acc) + acc = tt_dialect.dot(a_err1_bf16, b_err1_bf16, acc) + acc = tt_dialect.dot(a_err0_bf16, b_err1_bf16, acc) + if precision in ( + lax.DotAlgorithmPreset.BF16_BF16_F32_X6, + lax.DotAlgorithmPreset.BF16_BF16_F32_X9, + ): + acc = tt_dialect.dot(a_err1_bf16, b_bf16, acc) + acc = tt_dialect.dot(a_bf16, b_err1_bf16, acc) + acc = tt_dialect.dot(a_err0_bf16, b_err0_bf16, acc) + acc = tt_dialect.dot(a_err0_bf16, b_bf16, acc) + acc = tt_dialect.dot(a_bf16, b_err0_bf16, acc) + # If `a` rounding error is zero and `b` is `inf` then `acc` may contain # `NaN`s (as `0 * inf = NaN`), and vice versa. acc = arith_dialect.select(_is_nan(acc), _zeros_like(acc), acc) a, b = a_bf16, b_bf16 diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 0ce68a5c023c..6f52a7afb1bf 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -703,6 +703,8 @@ def f(x): ("float32", jax.lax.DotAlgorithmPreset.F16_F16_F32), ("float32", jax.lax.DotAlgorithmPreset.BF16_BF16_F32), ("float32", jax.lax.DotAlgorithmPreset.BF16_BF16_F32_X3), + ("float32", jax.lax.DotAlgorithmPreset.BF16_BF16_F32_X6), + ("float32", jax.lax.DotAlgorithmPreset.BF16_BF16_F32_X9), ("float32", jax.lax.DotAlgorithmPreset.TF32_TF32_F32), ("float32", jax.lax.DotAlgorithmPreset.TF32_TF32_F32_X3), ("float32", jax.lax.DotAlgorithmPreset.F32_F32_F32), @@ -733,9 +735,12 @@ def dot_kernel(x_ref, y_ref, o_ref): preferred_element_type=jnp.float32, ) if dtype == "bfloat16" or precision in ( - jax.lax.Precision.HIGHEST, jax.lax.DotAlgorithmPreset.F32_F32_F32 + jax.lax.Precision.HIGHEST, + jax.lax.DotAlgorithmPreset.F32_F32_F32, + jax.lax.DotAlgorithmPreset.BF16_BF16_F32_X6, + jax.lax.DotAlgorithmPreset.BF16_BF16_F32_X9, ): - atol = 0 + atol = 5e-6 elif precision in ( jax.lax.DotAlgorithmPreset.BF16_BF16_F32_X3, jax.lax.DotAlgorithmPreset.TF32_TF32_F32_X3, From d4745b9bd81b49e2a7a8938ea98516296d54635f Mon Sep 17 00:00:00 2001 From: jax authors Date: Sat, 22 Mar 2025 01:54:12 -0700 Subject: [PATCH 102/483] Reverts ad21b62bfec5560d4c612ed3c8412eb2d240468b PiperOrigin-RevId: 739431800 --- tests/pgle_test.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/tests/pgle_test.py b/tests/pgle_test.py index 7dabd809d95e..7f9ea598d51b 100644 --- a/tests/pgle_test.py +++ b/tests/pgle_test.py @@ -65,11 +65,7 @@ def testPGLEProfilerGetFDOProfile(self): jax.jit, in_shardings=NamedSharding(mesh, PartitionSpec('x')), out_shardings=NamedSharding(mesh, PartitionSpec('x')), - compiler_options={ - 'xla_gpu_enable_latency_hiding_scheduler': 'True', - # Make sure that matmul is not emitted as Triton GEMM. - 'xla_gpu_enable_triton_gemm': 'False', - }, + compiler_options={'xla_gpu_enable_latency_hiding_scheduler': 'True'}, ) def f(x, y): return x @ y @@ -97,8 +93,6 @@ def testPGLEProfilerGetFDOProfileLarge(self): compiler_options = { 'xla_gpu_enable_latency_hiding_scheduler': 'True', - # Make sure that matmul is not emitted as Triton GEMM. - 'xla_gpu_enable_triton_gemm': 'False', } # TODO(b/37664749): Remove this flag once the bug is fixed. compiler_options['xla_gpu_enable_command_buffer'] = '' @@ -327,11 +321,7 @@ def testPassingFDOProfile(self): jax.jit, in_shardings=NamedSharding(mesh, PartitionSpec('x')), out_shardings=NamedSharding(mesh, PartitionSpec('x')), - compiler_options={ - 'xla_gpu_enable_latency_hiding_scheduler': 'True', - # Make sure that matmul is not emitted as Triton GEMM. - 'xla_gpu_enable_triton_gemm': 'False', - }, + compiler_options={'xla_gpu_enable_latency_hiding_scheduler': 'True'}, ) def f(x, y): return x @ y From 34aa5e69477f74d5e1d5e2945c7fd23f72c6dd6e Mon Sep 17 00:00:00 2001 From: jax authors Date: Sat, 22 Mar 2025 04:21:51 -0700 Subject: [PATCH 103/483] Update XLA dependency to use revision http://github.com/openxla/xla/commit/1b26f2c8502a7d180ce959d0e6546c91ef820b02. PiperOrigin-RevId: 739453338 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 00f985cdf352..305cf14c1045 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "469329ec36be093fd71d29e4518402300e04aeec" -XLA_SHA256 = "9de006d7b51c36057898c81111fa9723b59f024eced067572fe5f6b1df63abdd" +XLA_COMMIT = "1b26f2c8502a7d180ce959d0e6546c91ef820b02" +XLA_SHA256 = "9492831de7840a3977eb8fcad34f2673e1bd8871cb060f9b6ee93f622956b896" def repo(): tf_http_archive( From a092df90ba7868f86e71cdaed245bb1abd77f1d4 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Sat, 22 Mar 2025 20:40:14 +0000 Subject: [PATCH 104/483] fix a linearize-of-remat-of-while_loop-fixpoint bug We were using the original unknown-carries-in rather than the fixpoint-updated ones. --- jax/_src/lax/control_flow/loops.py | 41 +++++++++++++++++++++++++----- tests/lax_control_flow_test.py | 12 +++++++++ 2 files changed, 46 insertions(+), 7 deletions(-) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 3084fa722977..33e2d2cbb0c8 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -1438,7 +1438,29 @@ def _join_while_effects(body_jaxpr, cond_jaxpr, body_nconsts, cond_nconsts def _while_loop_abstract_eval(*avals, cond_jaxpr, body_jaxpr, body_nconsts, cond_nconsts): - del avals + cond_consts_avals, body_consts_avals, in_avals = \ + util.split_list(avals, [cond_nconsts, body_nconsts]) + + if len(cond_jaxpr.in_avals) != len(cond_consts_avals) + len(in_avals): + raise core.JaxprTypeError( + f"while_loop {len(cond_jaxpr.in_avals)=} but {len(cond_consts_avals) + len(in_avals)=}") + if len(body_jaxpr.in_avals) != len(body_consts_avals) + len(in_avals): + raise core.JaxprTypeError( + f"while_loop {len(body_jaxpr.in_avals)=} but {len(body_consts_avals) + len(in_avals)=}") + # TODO(mattjj): check body carry type + # TODO(mattjj): make these typecompat checks work with bints + # if not all(_map(core.typecompat, [*cond_consts_avals, *in_avals], cond_jaxpr.in_avals)): # type: ignore + # cond_avals = [*cond_consts_avals, *in_avals] + # a1, a2 = next((a1, a2) for a1, a2 in zip(cond_avals, cond_jaxpr.in_avals) + # if not core.typecompat(a1, a2)) + # raise core.JaxprTypeError(f"while_loop cond function input type error: {a1} != {a2}") + # if not all(_map(core.typecompat, [*body_consts_avals, *in_avals], body_jaxpr.in_avals)): # type: ignore + # body_avals = [*body_consts_avals, *in_avals] + # a1, a2 = next((a1, a2) for a1, a2 in zip(body_avals, body_jaxpr.in_avals) + # if not core.typecompat(a1, a2)) + # raise core.JaxprTypeError(f"while_loop body function input type error: {a1} != {a2}") + + joined_effects = _join_while_effects(body_jaxpr, cond_jaxpr, body_nconsts, cond_nconsts) disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(joined_effects) @@ -1679,7 +1701,7 @@ def _while_partial_eval_custom(saveable, unks_in, inst_in, eqn): assert False, "Fixpoint not reached" assert not num_res body_jaxpr_known = core.ClosedJaxpr(jaxpr_known_, body_jaxpr.consts) - del jaxpr_known_, carry_uk_out, num_res + del jaxpr_known_, carry_uk_out, num_res, unks_in # Instantiate all inputs (b/c jaxpr_staged will take all inputs). new_inst = [x for x, inst in zip(eqn.invars, inst_in) @@ -1701,6 +1723,7 @@ def _while_partial_eval_custom(saveable, unks_in, inst_in, eqn): del cond_uk # Build the known eqn. + unks_in = [*cond_consts_uk, *body_consts_uk, *carry_uk] # fixpoint carry_uk ins_known, _ = partition_list(unks_in, eqn.invars) out_binders_known, _ = partition_list(carry_uk, eqn.outvars) params_known = dict(cond_jaxpr=cond_jaxpr_known, body_jaxpr=body_jaxpr_known, @@ -1711,6 +1734,11 @@ def _while_partial_eval_custom(saveable, unks_in, inst_in, eqn): eqn_known = pe.new_jaxpr_eqn(ins_known, out_binders_known, while_p, params_known, effects_known, eqn.source_info, eqn.ctx) + # Typecheck known eqn. + _while_loop_abstract_eval( + *[v.aval for v in eqn_known.invars], cond_jaxpr=cond_jaxpr_known, + body_jaxpr=body_jaxpr_known, body_nconsts=params_known['body_nconsts'], + cond_nconsts=params_known['cond_nconsts']) # Staged eqn is same as input eqn. eqn_staged = eqn @@ -1798,8 +1826,7 @@ def fun(*args): cond_block.arguments[i] for i in range(len(flat_loop_carry_types)) ] cond_args = mlir.unflatten_ir_values_like_types(flat_cond_args, loop_carry_types) - # Remove tokens from cond args - cond_args = cond_args[num_tokens:] + cond_args = cond_args[num_tokens:] # Remove tokens from cond args x, _, z = util.split_list(cond_args, [cond_nconsts, body_nconsts]) cond_consts = [ mlir.ir_constant(xla.canonicalize_dtype(x)) for x in cond_jaxpr.consts @@ -1861,8 +1888,9 @@ def fun(*args): partial(_pred_bcast_select_hlo, ctx, pred_aval, body_pred), new_z, z, body_jaxpr.out_avals) - hlo.return_([*mlir.flatten_ir_values(out_tokens), *mlir.flatten_ir_values(x), *mlir.flatten_ir_values(y), - *mlir.flatten_ir_values(new_z)]) + hlo.return_([*mlir.flatten_ir_values(out_tokens), + *mlir.flatten_ir_values(x), *mlir.flatten_ir_values(y), + *mlir.flatten_ir_values(new_z)]) outputs = mlir.unflatten_ir_values_like_types(while_op.results, loop_carry_types) tokens, _, _, z = util.split_list(outputs, [num_tokens, cond_nconsts, body_nconsts]) @@ -1976,7 +2004,6 @@ def new_cond(*consts_refs_carry): batching.fancy_primitive_batchers[while_p] = _while_loop_batching_rule pe.partial_eval_jaxpr_custom_rules[while_p] = _while_partial_eval_custom mlir.register_lowering(while_p, _while_lowering) -core.custom_typechecks[while_p] = _while_typecheck state_discharge.register_partial_discharge_rule(while_p)(_while_partial_discharge_rule) diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 3871a87a7a3e..fcc7fd99ee13 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -3066,6 +3066,18 @@ def g(): leak() self.assertEqual(base, nbufs()) + def test_grad_remat_while_fixpoint(self): + @jax.remat + def f(x, y): + def cond(_): + return False + def body(c): + x, y = c + return (y, x) + x, y = jax.lax.while_loop(cond, body, (x, y)) + return x + y + jax.linearize(f, 1., 2.) # don't crash + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From 540541a3d32fe6678aa1c208a4f5a9a697b92e2c Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 23 Mar 2025 03:39:34 -0700 Subject: [PATCH 105/483] Update XLA dependency to use revision http://github.com/openxla/xla/commit/5ed2ff5c07868d1a7486f4040f8b38936640268e. PiperOrigin-RevId: 739649983 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 305cf14c1045..3e3d636f0f43 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "1b26f2c8502a7d180ce959d0e6546c91ef820b02" -XLA_SHA256 = "9492831de7840a3977eb8fcad34f2673e1bd8871cb060f9b6ee93f622956b896" +XLA_COMMIT = "5ed2ff5c07868d1a7486f4040f8b38936640268e" +XLA_SHA256 = "08d175c57d0db599ad57b8fa820ca2f2a6d2808578d53dba421e3af4edb0bccf" def repo(): tf_http_archive( From 5d79df7e67cfc8b253c817f90af81393ea256763 Mon Sep 17 00:00:00 2001 From: Jesse Perla Date: Sun, 23 Mar 2025 15:03:49 -0700 Subject: [PATCH 106/483] Add identity activation Fix typo --- docs/jax.nn.rst | 1 + jax/_src/nn/functions.py | 19 +++++++++++++++++++ jax/nn/__init__.py | 1 + tests/nn_test.py | 8 +++++++- 4 files changed, 28 insertions(+), 1 deletion(-) diff --git a/docs/jax.nn.rst b/docs/jax.nn.rst index adb13f89903d..2e2e9644d50d 100644 --- a/docs/jax.nn.rst +++ b/docs/jax.nn.rst @@ -40,6 +40,7 @@ Activation functions glu squareplus mish + identity Other functions --------------- diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index 7df0a638e566..ee0643e116f9 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -54,6 +54,25 @@ def __repr__(self): # activations +@jax.jit +def identity(x: ArrayLike) -> Array: + r"""Identity activation function. + + Returns the argument unmodified. + + Args: + x : input array + + Returns: + The argument `x` unmodified. + + Examples: + >>> jax.nn.identity(jax.numpy.array([-2., -1., -0.5, 0, 0.5, 1., 2.])) + Array([-2. , -1. , -0.5, 0. , 0.5, 1. , 2. ], dtype=float32) + + """ + numpy_util.check_arraylike("identity", x) + return jnp.asarray(x) @custom_jvp @jax.jit diff --git a/jax/nn/__init__.py b/jax/nn/__init__.py index 3f08e1c0fd12..10f11f829abe 100644 --- a/jax/nn/__init__.py +++ b/jax/nn/__init__.py @@ -35,6 +35,7 @@ standardize as standardize, one_hot as one_hot, relu as relu, + identity as identity, relu6 as relu6, dot_product_attention as dot_product_attention, scaled_dot_general as scaled_dot_general, diff --git a/tests/nn_test.py b/tests/nn_test.py index 1a1670444ef8..e46843186c02 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -543,7 +543,7 @@ def gelu_reference(x): (jnp.float32, jnp.bfloat16, jnp.float16), (partial(nn.gelu, approximate=False), partial(nn.gelu, approximate=True), - nn.relu, nn.softplus, nn.sparse_plus, nn.sigmoid, nn.squareplus, nn.mish))) + nn.relu, nn.identity, nn.softplus, nn.sparse_plus, nn.sigmoid, nn.squareplus, nn.mish))) def testDtypeMatchesInput(self, dtype, fn): x = jnp.zeros((), dtype=dtype) out = fn(x) @@ -831,6 +831,12 @@ def testVarianceScalingError(self): ): initializer(rng, shape) + def testIdentity(self): + x = jnp.array([1., 2., 3.]) + self.assertAllClose(nn.identity(x), x, check_dtypes=False) + grad = jax.grad(nn.identity)(6.0) + self.assertEqual(grad, 1.) + def testAccidentalUpcasting(self): rng = random.PRNGKey(0) shape = (4, 4) From 5b0a767d83cf28b41dd1c2207eb56010bcb594d7 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Sun, 23 Mar 2025 21:33:01 -0700 Subject: [PATCH 107/483] [jax] Add `ndim` and `size` properties to `TransformedRef`. Without these implementations, `ndim` and `size` were retrieved from the underlying, non-transformed reference and were inconsistent with `TransformedRef.shape`. PiperOrigin-RevId: 739802491 --- jax/_src/state/types.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py index 057242f4c1ac..fa9d0cb9fb16 100644 --- a/jax/_src/state/types.py +++ b/jax/_src/state/types.py @@ -266,6 +266,9 @@ def dtype(self): assert dtype is not None return dtype + ndim = property(lambda self: len(self.shape)) + size = property(lambda self: math.prod(self.shape)) + @property def at(self) -> RefIndexer: return RefIndexer(self) From a2475a66c50c148fbe4dcafd61b917c6435d1e4a Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Mon, 24 Mar 2025 02:06:10 -0700 Subject: [PATCH 108/483] [pallas] Add support for `split` (into two equal parts) in Triton lowering. PiperOrigin-RevId: 739855323 --- jax/_src/pallas/triton/lowering.py | 15 +++++++++++++++ tests/pallas/pallas_test.py | 14 ++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 0077ec55ace8..d7f6e4695229 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -1798,6 +1798,21 @@ def _concatenate_lowering_rule(ctx: LoweringRuleContext, *args, dimension): ) +@register_lowering(lax.split_p) +def _split_lowering_rule(ctx: LoweringRuleContext, x, *, sizes, axis): + pass + # TODO(cjfj): Add support for larger powers of 2. + if len(sizes) != 2: + raise NotImplementedError("Only splitting into two parts is supported.") + if sizes[0] != sizes[1]: + raise NotImplementedError("Only equal-sized splits are supported.") + (x_aval,) = ctx.avals_in + shape = x_aval.shape + x = _reshape(x, shape[:axis] + (2, sizes[0]) + shape[axis + 1 :]) + permutation = tuple(d for d in range(len(shape) + 1) if d != axis) + (axis,) + return tt_dialect.split(tt_dialect.trans(x, permutation)) + + def _compute_offsets_from_indices( block_info: BlockInfo, nd_indexer: NDIndexer ) -> ir.Value: diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 6f52a7afb1bf..0b16260ff25a 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -843,6 +843,20 @@ def dot_kernel(x_ref, y_ref, o_ref): self.assertAllClose(dot_kernel(x, y), expected) + @parameterized.parameters( + ((32,), 0), ((32, 64), 0), ((32, 16), 1), ((32, 16, 2), 1) + ) + def test_split(self, shape, axis): + x = jax.random.normal(jax.random.key(0), shape) + expected = jnp.split(x, 2, axis) + + @functools.partial(self.pallas_call, out_shape=expected) + def kernel(x_ref, o0_ref, o1_ref): + o0_ref[()], o1_ref[()] = jnp.split(x_ref[()], 2, axis) + + self.assertAllClose(kernel(x), expected) + + class PallasCallInterpretTest(PallasCallTest): INTERPRET = True From 4da1faf5b6cc8c1e99b3abf6de5f5889f0dc43dd Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 24 Mar 2025 02:49:43 -0700 Subject: [PATCH 109/483] Move PGLE documentation to JAX docs. PiperOrigin-RevId: 739865595 --- docs/gpu_performance_tips.md | 144 ++++++++++++++++++++++++++++++++++- 1 file changed, 142 insertions(+), 2 deletions(-) diff --git a/docs/gpu_performance_tips.md b/docs/gpu_performance_tips.md index bf032dccff88..737486485736 100644 --- a/docs/gpu_performance_tips.md +++ b/docs/gpu_performance_tips.md @@ -1,6 +1,6 @@ # GPU performance tips - + This document focuses on performance tips for neural network workloads @@ -58,7 +58,147 @@ training on Nvidia GPUs](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta * **--xla_gpu_triton_gemm_any** Use the Triton-based GEMM (matmul) emitter for any GEMM that it supports. The default value is False. -### Communication flags +## Communication tips + +### Auto and manual PGLE + +The Profile Guided Latency Estimator (PGLE) workflow measures the actual running time +of compute and collectives, the the profile information is fed back into XLA compiler +for a better scheduling decision. + +The Profile Guided Latency Estimator can be used manually or automatically. In the auto mode +JAX will collect profile information and recompile a module in a single run. While +in manual mode you need to run a task twice, the first time to collect and save profiles +and the second to compile and run with provided data. + +### Auto PGLE +The auto PGLE can be turned on by setting the following environment variables: + +Mandatory: +```bash +export JAX_ENABLE_PGLE=true + +# For JAX version <= 0.5.0 make sure to include: +export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true" +``` + +Optional: +```bash +export JAX_PGLE_PROFILING_RUNS=3 +export JAX_PGLE_AGGREGATION_PERCENTILE=85 + +# Right now the auto PGLE profile collection doesn't work with command buffer. +# If the command buffer is enabled, Auto PGLE will disable it during profile +# colletion and enable it back after the recompilation. If you need to have a +# consistent command buffer logic with and with PGLE profile you can disable it +# manually: +export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_command_buffer=''" +``` + +Or in the JAX this can be set as the following: + +``` +import jax +from jax._src import config + +with config.enable_pgle(True), config.pgle_profiling_runs(1): + # Run with the profiler collecting performance information. + train_step() + # Automatically re-compile with PGLE profile results + train_step() + ... +``` + +You can control amount of reruns used to collect profile data by changing `JAX_PGLE_PROFILING_RUNS`. +Increasing this parameter would lead to better profile information, but it will also increase the +amount of non-optimized training steps. + +Decreasing the `JAX_PGLE_AGGREGATION_PERCENTILE` parameter might help in case when performance between steps is too noisy to filter out a non-relevant measures. + +**Attention:** Auto PGLE doesn't work for pre-compiled modules. Since JAX need to recompile the module during execution the auto PGLE will not work neither for AoT nor for the following case: + +``` +import jax +from jax._src import config + +train_step_compiled = train_step().lower().compile() + +with config.enable_pgle(True), config.pgle_profiling_runs(1): + train_step_compiled() + # No effect since module was pre-compiled. + train_step_compiled() +``` + +### Manual PGLE + +If you still want to use a manual Profile Guided Latency Estimator the workflow in XLA/GPU is: + +- 1. Run your workload once, with async collectives and latency hiding scheduler enabled. + +You could do so by setting: + +```bash +export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true" +``` + +- 2. Collect and post process a profile by using JAX profiler, saving the extracted instruction latencies into a binary protobuf file. + +```python +import os +from etils import epath +import jax +from jax.experimental import profiler as exp_profiler + +# Define your profile directory +profile_dir = 'gs://my_bucket/profile' +jax.profiler.start_trace(profile_dir) + +# run your workflow +# for i in range(10): +# train_step() + +# Stop trace +jax.profiler.stop_trace() +profile_dir = epath.Path(profile_dir) +directories = profile_dir.glob('plugins/profile/*/') +directories = [d for d in directories if d.is_dir()] +rundir = directories[-1] +logging.info('rundir: %s', rundir) + +# Post process the profile +fdo_profile = exp_profiler.get_profiled_instructions_proto(os.fspath(rundir)) + +# Save the profile proto to a file. +dump_dir = rundir / 'profile.pb' +dump_dir.parent.mkdir(parents=True, exist_ok=True) +dump_dir.write_bytes(fdo_profile) + +``` + +After this step, you will get a `profile.pb` file under the `rundir` printed in the code. + +- 3. Run the workload again feeding that file into the compilation. + +You need to pass the `profile.pb` file to the `--xla_gpu_pgle_profile_file_or_directory_path` flag. + +```bash + export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_pgle_profile_file_or_directory_path=/path/to/profile/profile.pb" +``` + +To enable logging in the XLA and check if the profile is good, set the logging level to include `INFO`: + +```bash +export TF_CPP_MIN_LOG_LEVEL=0 +``` + +Run the real workflow, if you found these loggings in the running log, it means the profiler is used in the latency hiding scheduler: + +``` +2023-07-21 16:09:43.551600: I external/xla/xla/service/gpu/gpu_hlo_schedule.cc:478] Using PGLE profile from /tmp/profile/plugins/profile/2023_07_20_18_29_30/profile.pb +2023-07-21 16:09:43.551741: I external/xla/xla/service/gpu/gpu_hlo_schedule.cc:573] Found profile, using profile guided latency estimator +``` + +#### Flags * **--xla_gpu_enable_latency_hiding_scheduler** This flag enables latency hiding schedulers to overlap asynchronous communication with computation efficiently. From 0c38368bce53e5aab7a9ad3e1fc858668035874a Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Mon, 24 Mar 2025 03:27:15 -0700 Subject: [PATCH 110/483] [mosaic_gpu] Add `Cupti` profiler class. PiperOrigin-RevId: 739874654 --- jax/experimental/mosaic/gpu/profiler.py | 59 +++++++++++++++---------- 1 file changed, 35 insertions(+), 24 deletions(-) diff --git a/jax/experimental/mosaic/gpu/profiler.py b/jax/experimental/mosaic/gpu/profiler.py index 011b921d728e..5b278468b98c 100644 --- a/jax/experimental/mosaic/gpu/profiler.py +++ b/jax/experimental/mosaic/gpu/profiler.py @@ -17,7 +17,7 @@ import itertools import json import math -from typing import Callable, ParamSpec, TypeVar +from typing import Callable, ParamSpec, TypeAlias, TypeVar import warnings import jax @@ -98,30 +98,44 @@ def run(*args, **kwargs): return outs, float(elapsed) -def _measure_cupti(f, aggregate, *, finalize=True): - if not isinstance(f, (stages.Wrapped, stages.Compiled)): - f = jax.jit(f) +Timings: TypeAlias = list[tuple[str, float]] | float | None - def wrapper(*args, **kwargs): - jax.block_until_ready(f(*args, **kwargs)) # Warmup. - mosaic_gpu_lib._mosaic_gpu_ext._cupti_init() - try: - results = jax.block_until_ready(f(*args, **kwargs)) - finally: - timings = mosaic_gpu_lib._mosaic_gpu_ext._cupti_get_timings(finalize) - if not timings: - return results, None - elif aggregate: - return results, sum(item[1] for item in timings) - else: - return results, timings +@dataclasses.dataclass(frozen=True, kw_only=True) +class Cupti: + """CUPTI-based profiler.""" - return wrapper + # If `True`, detach CUPTI from the process after measurement. + finalize: bool = True + def measure( + self, f: Callable[P, T], *, aggregate: bool = True + ) -> Callable[P, tuple[T, Timings]]: + if not isinstance(f, (stages.Wrapped, stages.Compiled)): + f = jax.jit(f) -def measure(f: Callable, *, mode: str = "events", aggregate: bool = True -) -> Callable: + def wrapper(*args: P.args, **kwargs: P.kwargs): + jax.block_until_ready(f(*args, **kwargs)) # Warmup. + ext = mosaic_gpu_lib._mosaic_gpu_ext + ext._cupti_init() + try: + results = jax.block_until_ready(f(*args, **kwargs)) + finally: + timings = ext._cupti_get_timings(self.finalize) + + if not timings: + return results, None + elif aggregate: + return results, sum(item[1] for item in timings) + else: + return results, timings + + return wrapper + + +def measure( + f: Callable[P, T], *, mode: str = "events", aggregate: bool = True +) -> Callable[P, tuple[T, Timings]]: """Sets up a function ``f`` for profiling on GPU. ``measure`` is a higher-order function that augments the argument ``f`` to @@ -133,7 +147,6 @@ def measure(f: Callable, *, mode: str = "events", aggregate: bool = True mode: The mode of operation. Possible values are: - "cupti", for CUPTI-based profiling. - - "cupti_no_finalize", as above, but CUPTI left attached to the process. - "events", for CUDA events-based profiling. The two modes use different measurement methodologies and should not be @@ -179,9 +192,7 @@ def measure(f: Callable, *, mode: str = "events", aggregate: bool = True """ # fmt: skip match mode: case "cupti": - return _measure_cupti(f, aggregate) - case "cupti_no_finalize": - return _measure_cupti(f, aggregate, finalize=False) + return Cupti().measure(f, aggregate=aggregate) case "events": if not aggregate: raise ValueError(f"{aggregate=} is not supported with {mode=}") From a3e6c6ef61bd0193ea5977e2dce7b6e861e48f52 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 24 Mar 2025 04:02:58 -0700 Subject: [PATCH 111/483] [Mosaic GPU] Add support for f16 Blackwell MMA accumulation Very importantly, this also includes support for loading the packed accumulator from TMEM. PiperOrigin-RevId: 739883035 --- jax/experimental/mosaic/gpu/tcgen05.py | 59 +++++++++++++++++--------- tests/mosaic/gpu_test.py | 53 ++++++++++++++++------- 2 files changed, 77 insertions(+), 35 deletions(-) diff --git a/jax/experimental/mosaic/gpu/tcgen05.py b/jax/experimental/mosaic/gpu/tcgen05.py index 3330500cd6dc..ac3b80b93689 100644 --- a/jax/experimental/mosaic/gpu/tcgen05.py +++ b/jax/experimental/mosaic/gpu/tcgen05.py @@ -197,7 +197,7 @@ def mma( ), a_mk, b_nk, - d_type=ir.F32Type.get(), + d_type=d.dtype, m=m_group_elems, n=n_group_elems, collective=collective, @@ -327,7 +327,7 @@ def tmem_relinquish_alloc_permit(): has_side_effects=True, ) -def tmem_load(tmem_addr, shape, num): +def tmem_load(tmem_addr, shape, num, packing: int = 1): if num.bit_count() != 1 or num > 128: raise ValueError(f"num must be a power of 2 and <= 128, got: {num}") match shape: @@ -345,12 +345,18 @@ def tmem_load(tmem_addr, shape, num): num_out_regs *= num i32 = ir.IntegerType.get_signless(32) out_regs = ",".join("$" + str(i) for i in range(num_out_regs)) + if packing == 1: + pack_mod = "" + elif packing == 2: + pack_mod = ".pack::16b" + else: + raise ValueError(f"Unsupported packing: {packing}") regs = llvm.inline_asm( ir.Type.parse( "!llvm.struct<(" + ",".join("i32" for _ in range(num_out_regs)) + ")>" ), [tmem_addr], - f"tcgen05.ld.sync.aligned.{shape}.x{num}.b32 {{{out_regs}}}, [${num_out_regs}];", + f"tcgen05.ld.sync.aligned.{shape}.x{num}{pack_mod}.b32 {{{out_regs}}}, [${num_out_regs}];", "=r," * num_out_regs + "r", has_side_effects=True, ) @@ -521,9 +527,9 @@ def __getitem__(self, *idxs): raise NotImplementedError("Slicing of TMEM not impelmented yet") if self.shape[1] % 8: raise NotImplementedError - if self.dtype != ir.F32Type.get(): - raise NotImplementedError(self.dtype) - layout = _m128_256bit_32bit_layout(self.shape) + if utils.bitwidth(self.dtype) not in {16, 32}: + raise NotImplementedError(f"Unsupported dtype: {self.dtype}") + layout = _m128_layout(self.shape) regs_shape = layout.registers_shape(self.shape) if self.layout == TMEMLayout(elements_in_tile=(TMEM_ROWS, 8)): # load_32xcols returns a 4xN array, but the FA tiling we use here tiles @@ -556,20 +562,28 @@ def __getitem__(self, *idxs): ) return fa.FragmentedArray(_registers=registers, _layout=layout, _is_signed=None) + def _load_32xcols(base_addr, cols, dtype): # See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16256b i32 = ir.IntegerType.get_signless(32) - assert cols % 8 == 0 - cols_per_num_tile = 8 - load_shape = "16x256b" - num = cols // 8 + packing = 32 // utils.bitwidth(dtype) + if packing == 1: + load_shape = "16x256b" # 8 columns * 32 bits = 256 bits + cols_per_num_tile = 8 * packing + elif packing == 2: + load_shape = "16x128b" # 8 columns * 16 bits = 128 bits + cols_per_num_tile = 4 * packing + else: + raise NotImplementedError(packing) + assert cols % cols_per_num_tile == 0 + num = cols // cols_per_num_tile if num <= 32: num_tiling = num elif num == 64: num_tiling = 32 else: raise NotImplementedError(num) - vector_regs = np.ndarray((4, num), dtype=object) + vector_regs = np.ndarray((4, cols // 8), dtype=object) # We load 16 lanes at a time, but need 32 in total. for row_group in range(2): addr_row = arith.addi(base_addr, arith.constant(i32, (row_group * 16) << 16)) @@ -579,17 +593,24 @@ def _load_32xcols(base_addr, cols, dtype): addr_row, arith.constant(i32, num_tiling * num_group * cols_per_num_tile), ) - regs += tmem_load(addr_row_col, load_shape, num_tiling) - regs = [llvm.bitcast(dtype, r) for r in regs] - undef = llvm.mlir_undef(ir.VectorType.get((2,), dtype)) - for r_low, r_high, idx in zip(regs[::2], regs[1::2], np.ndindex(num, 2)): - high_undef = llvm.insertelement(undef, r_low, utils.c(0, i32)) - vreg = llvm.insertelement(high_undef, r_high, utils.c(1, i32)) - vector_regs[idx[1] + 2 * row_group, idx[0]] = vreg + regs += tmem_load(addr_row_col, load_shape, num_tiling, packing) + if packing == 1: + regs = [llvm.bitcast(dtype, r) for r in regs] + undef = llvm.mlir_undef(ir.VectorType.get((2,), dtype)) + for r_low, r_high, idx in zip(regs[::2], regs[1::2], np.ndindex(cols // 8, 2), strict=True): + high_undef = llvm.insertelement(undef, r_low, utils.c(0, i32)) + vreg = llvm.insertelement(high_undef, r_high, utils.c(1, i32)) + vector_regs[idx[1] + 2 * row_group, idx[0]] = vreg + else: + assert packing == 2 + regs = [llvm.bitcast(ir.VectorType.get((2,), dtype), r) for r in regs] + for vreg, idx in zip(regs, np.ndindex(cols // 8, 2), strict=True): + vector_regs[idx[1] + 2 * row_group, idx[0]] = vreg return vector_regs -def _m128_256bit_32bit_layout(shape: tuple[int, ...]): +# Like WGMMA_LAYOUT, only each warp holds a 32xN strip instead of 16xN. +def _m128_layout(shape: tuple[int, ...]): if len(shape) != 2: raise ValueError(f"Shape {shape} is not 2D") if shape[0] % 128 != 0 or shape[1] % 8 != 0: diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index e7bd7fad3798..478064188750 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -912,15 +912,41 @@ def setUp(self): lhs_transpose=(False, True), rhs_transpose=(False, True), in_jax_dtype=(jnp.float16, jnp.bfloat16), # TODO(apaszke): f32 - out_jax_dtype=(jnp.float32,), # TODO(apaszke): f16 accumulation + out_jax_dtype=(jnp.float16, jnp.float32,), m=(128,), # TODO(apaszke): 64, 192, 256 n=(64, 128, 256, 512), # TODO(apaszke): 192, other non-power-of-2 - k_steps=(1, 2), swizzle=(32, 64, 128,), - rhs_transpose_tiles=(False, True), + ) + def test_mma_basic(self, *args, **kwargs): + self._basic_mma_test( + *args, + **kwargs, + k_steps=2, # Reducing to 1 can be helpful while debugging. + lhs_transpose_tiles=False, + rhs_transpose_tiles=False, + ) + + @parameterized.product( + lhs_transpose=(False, True), + rhs_transpose=(False, True), + in_jax_dtype=(jnp.float16,), + out_jax_dtype=(jnp.float32,), + m=(128,), + n=(128, 512), + swizzle=(32, 64, 128,), lhs_transpose_tiles=(False, True), + rhs_transpose_tiles=(False, True), ) - def test_mma_basic( + def test_mma_transposed_tiles(self, *args, **kwargs): + if not kwargs["lhs_transpose_tiles"] and not kwargs["rhs_transpose_tiles"]: + self.skipTest("This is already tested in test_mma_basic") + self._basic_mma_test( + *args, + **kwargs, + k_steps=2, # Reducing to 1 can be helpful while debugging. + ) + + def _basic_mma_test( self, m, n, @@ -981,16 +1007,10 @@ def kernel(ctx, lhs, rhs, out, scratch): barriers[2].wait(for_tensor_core=True) acc[:].store_untiled(out) - in_finfo = jnp.finfo(in_jax_dtype) - exponent_bits, mantissa_bits = in_finfo.nexp, in_finfo.nmant - def quantize(x): - # Quantize the input to avoid rounding when feeding the TensorCore - return jax.lax.reduce_precision(x, exponent_bits, mantissa_bits) - x_shape = (k, m) if lhs_transpose else (m, k) - x = quantize(self.prng.uniform(-1, 1, x_shape)).astype(in_jax_dtype) + x = self.prng.uniform(-1, 1, x_shape).astype(in_jax_dtype) y_shape = (n, k) if rhs_transpose else (k, n) - y = quantize(self.prng.uniform(-1, 1, y_shape)).astype(in_jax_dtype) + y = self.prng.uniform(-1, 1, y_shape).astype(in_jax_dtype) out_shape = jax.ShapeDtypeStruct((m, n), out_jax_dtype) if rhs_transpose_tiles: rhs_smem_shape = ( @@ -1015,14 +1035,15 @@ def quantize(x): )(x, y) x32, y32 = x.astype(np.float32), y.astype(np.float32) ref = (x32.T if lhs_transpose else x32) @ (y32.T if rhs_transpose else y32) - atol = 2e-2 if out_jax_dtype == jnp.float16 else 5e-6 - np.testing.assert_allclose(z, ref, atol=atol) + atol = 2e-2 if out_jax_dtype == jnp.float16 else 2e-5 + rtol = 8e-4 if out_jax_dtype == jnp.float16 else 1e-7 + np.testing.assert_allclose(z, ref, atol=atol, rtol=rtol) @parameterized.product( lhs_transpose=(False, True), rhs_transpose=(False, True), - in_jax_dtype=(jnp.float16,), # TODO(apaszke): f32 - out_jax_dtype=(jnp.float32,), # TODO(apaszke): f16 accumulation + in_jax_dtype=(jnp.float16,), + out_jax_dtype=(jnp.float32,), m=(256,), # TODO(apaszke): 64, 192, 256 n=(128, 256, 512), # TODO(apaszke): 192, other non-power-of-2 k_steps=(1, 2), From 01904363e4a9d2f721c0e2193ef4b199a4ffe9ac Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 24 Mar 2025 04:16:40 -0700 Subject: [PATCH 112/483] Update XLA dependency to use revision http://github.com/openxla/xla/commit/9a8dd0796bcfeb00e4e6d09d74726db5c7d4a003. PiperOrigin-RevId: 739886693 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 3e3d636f0f43..996ee511f835 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "5ed2ff5c07868d1a7486f4040f8b38936640268e" -XLA_SHA256 = "08d175c57d0db599ad57b8fa820ca2f2a6d2808578d53dba421e3af4edb0bccf" +XLA_COMMIT = "9a8dd0796bcfeb00e4e6d09d74726db5c7d4a003" +XLA_SHA256 = "4e3248d37a1b0598de3e93e8e46ede060578bc45bfbdfaf24d91ab598543b770" def repo(): tf_http_archive( From 381f11090e702fa9403e178a6699017c62d24453 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 24 Mar 2025 09:26:06 -0400 Subject: [PATCH 113/483] Reenable tsan suppression, mark some tests as thread-unsafe. --- .github/workflows/tsan-suppressions.txt | 6 +++--- .github/workflows/tsan.yaml | 1 + tests/cache_key_test.py | 2 ++ tests/pjit_test.py | 1 + 4 files changed, 7 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tsan-suppressions.txt b/.github/workflows/tsan-suppressions.txt index 296f4432e687..bdffddc58ca0 100644 --- a/.github/workflows/tsan-suppressions.txt +++ b/.github/workflows/tsan-suppressions.txt @@ -24,6 +24,9 @@ race_top:PyMember_GetOne # https://github.com/python/cpython/issues/129748 race:mi_block_set_nextx +# https://github.com/python/cpython/issues/128130 +race_top:run_eval_code_obj + # Races because the LAPACK and BLAS in our scipy isn't TSAN instrumented. race:heevd_ffi race:gesdd_ffi @@ -59,9 +62,6 @@ race:gemm_oncopy # https://github.com/python/cpython/issues/130547 # race:split_keys_entry_added -# https://github.com/python/cpython/issues/128130 -# race_top:run_eval_code_obj - # https://github.com/python/cpython/issues/129547 # Maybe fixed? # race:type_get_annotations diff --git a/.github/workflows/tsan.yaml b/.github/workflows/tsan.yaml index 6c97b7347ceb..cd59c0bf45e0 100644 --- a/.github/workflows/tsan.yaml +++ b/.github/workflows/tsan.yaml @@ -13,6 +13,7 @@ on: - main paths: - '**/workflows/tsan.yaml' + - '**/workflows/tsan-suppressions.txt' jobs: tsan: diff --git a/tests/cache_key_test.py b/tests/cache_key_test.py index 2faa4dbaf9d4..ed80c7060e4c 100644 --- a/tests/cache_key_test.py +++ b/tests/cache_key_test.py @@ -163,6 +163,8 @@ def test_different_computations(self): cache_key.get(computation2, devices, compile_options, backend), ) + # TODO(phawkins): this test flakes if test concurrency is enabled. + @jtu.thread_unsafe_test() def test_custom_partitioning_ptr_removal(self): def _partition(mesh, arg_shapes, result_shape): arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 6fdfa62887b9..2033126759e4 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -3438,6 +3438,7 @@ def f(x): pjit(f)(inp) self.assertEqual(count(), 1) + @jtu.thread_unsafe_test() # count_pjit_cpp_cache_miss is not thread-safe def test_pjit_no_global_cache_hit_axis_resources(self): mesh = jtu.create_mesh((1,), ('x',)) s = NamedSharding(mesh, P('x')) From c6525bc58f0ec1507c83a4c3f149208a1b60368f Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Mon, 24 Mar 2025 08:10:33 -0700 Subject: [PATCH 114/483] [Mosaic GPU][NFC] Fix documentation of `WGMMA_LAYOUT`. `TiledLayout` has no notion of partitioning over warpgroups, and each warp holds `16 x 8` elements. PiperOrigin-RevId: 739942481 --- jax/experimental/mosaic/gpu/fragmented_array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 4bbfd0dd8afe..c2b61c6d5bfe 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -537,7 +537,7 @@ def linear_thread_idxs(self): # The tiled layout is equivalent to one described here in PTX documentation: # https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-d -# In this layout, we partition the 64x8 tiles over 4 warpgroups into 16x8 tiles. +# In this layout, we partition the 64x8 tiles over 4 warps into 16x8 tiles. # Then, we further split the 16x8 tiles into 8x8 submatrices which are the unit # of data that is split across a warp. Since 8*8 = 64, but a warp has only 32 # threads, we vectorize pairs of elements along columns. From b6b5d952392960ceb78401302cb9d620719407b1 Mon Sep 17 00:00:00 2001 From: Jacob Burnim Date: Thu, 20 Mar 2025 16:25:51 -0700 Subject: [PATCH 115/483] [Pallas] In TPU interpret mode, add initial barrier for kernels without one. --- jax/_src/pallas/mosaic/interpret.py | 33 ++++++++++++++++--- .../tpu_pallas_interpret_distributed_test.py | 16 --------- 2 files changed, 28 insertions(+), 21 deletions(-) diff --git a/jax/_src/pallas/mosaic/interpret.py b/jax/_src/pallas/mosaic/interpret.py index 439ac98b2ac6..2e31d0fba7cf 100644 --- a/jax/_src/pallas/mosaic/interpret.py +++ b/jax/_src/pallas/mosaic/interpret.py @@ -451,6 +451,7 @@ class SharedMemory: num_devices: int clocks: list[VectorClock] barrier: threading.Barrier + clean_up_barrier: threading.Barrier # (memory_space, buffer_id, device_id) -> NumPy array # TODO(jburnim): Handle Megacore. @@ -502,18 +503,35 @@ def _initialize_shared_memory(device_id, num_devices, *, interpret_params): interpret_params=interpret_params, num_devices=num_devices, clocks=[make_vector_clock(num_devices) for _ in range(num_devices)], - barrier=threading.Barrier(num_devices)) + barrier=threading.Barrier( + num_devices, action=_update_clocks_for_global_barrier), + clean_up_barrier=threading.Barrier( + num_devices, action=_clear_shared_memory)) assert _shared_memory.num_devices == num_devices global races races = RaceDetectionState(num_devices=num_devices) +def _update_clocks_for_global_barrier(): + shared_memory = _get_shared_memory() + with shared_memory.lock: + # Set the vector clock for device 0 to the max over all device clocks. + for c in shared_memory.clocks[1:]: + update_vector_clock(shared_memory.clocks[0], c) + # Set all other device vector clocks to the max over all the clocks. + for c in shared_memory.clocks[1:]: + update_vector_clock(c, shared_memory.clocks[0]) + +def _barrier(device_id): + device_id = int(device_id) + shared_memory = _get_shared_memory() + if shared_memory.num_devices > 1: + shared_memory.barrier.wait() + def _clean_up_shared_memory(device_id): device_id = int(device_id) shared_memory = _get_shared_memory() - shared_memory.barrier.wait() - if device_id == 0: - _clear_shared_memory() + shared_memory.clean_up_barrier.wait() def _validate(device_id): device_id = int(device_id) @@ -1359,7 +1377,7 @@ def interpret_pallas_call( input_output_aliases: tuple[tuple[int, int], ...], grid_mapping: GridMapping, mesh: pallas_core.Mesh | None, - compiler_params: Any, + compiler_params: mosaic_core.TPUCompilerParams, cost_estimate: CostEstimate, out_avals: tuple[jax_core.AbstractValue, ...], interpret_params: TPUInterpretParams, @@ -1499,6 +1517,11 @@ def interpret_pallas_call( var.aval.shape, var.aval.dtype, interpret_params), ordered=True)) + if compiler_params.get('mosaic', {}).get('collective_id', None) is None: + # The kernel doesn't specify its own barrier semaphore, so we do a global + # barrier before running the first iteration of the kernel. + callback.io_callback(_barrier, (), device_id, ordered=True) + _, input_ids, kernel_output_ids, _ = split_list( kernel_buffer_ids, [grid_mapping.num_index_operands, num_inputs, grid_mapping.num_outputs]) diff --git a/tests/pallas/tpu_pallas_interpret_distributed_test.py b/tests/pallas/tpu_pallas_interpret_distributed_test.py index 518c16ed2109..1ed139e9e867 100644 --- a/tests/pallas/tpu_pallas_interpret_distributed_test.py +++ b/tests/pallas/tpu_pallas_interpret_distributed_test.py @@ -18,8 +18,6 @@ contains only tests that use shard_map. """ -import functools - from absl.testing import absltest from absl.testing import parameterized @@ -1017,19 +1015,6 @@ def test_race_detection(self): input_arr = jax.device_put(input_arr, sharding) def kernel(src_dst_ids_ref, x_ref, o_ref, send_sem, recv_sem): - # Barrier with all devices before doing any DMAs. - barrier_sem = pltpu.get_barrier_semaphore() - @functools.partial(jax.lax.fori_loop, 0, num_devices, init_val=None) - def _(i, _): - pltpu.semaphore_signal( - barrier_sem, - inc=1, - device_id=(jnp.int32(i),), - device_id_type=pltpu.DeviceIdType.MESH, - ) - return None - pltpu.semaphore_wait(barrier_sem, num_devices) - # Send the specified DMAs. my_id = lax.axis_index('x') src_dst_ids = src_dst_ids_ref[:] @@ -1076,7 +1061,6 @@ def run(src_dst_ids): ], out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), scratch_shapes=[pltpu.SemaphoreType.DMA, pltpu.SemaphoreType.DMA], - compiler_params=pltpu.TPUCompilerParams(collective_id=0), interpret=mosaic_interpret.TPUInterpretParams( dma_execution_mode='eager', detect_races=True, From 788ad8c6a2ded930a2cf4379780a749b680c5ba0 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 24 Mar 2025 09:02:09 -0700 Subject: [PATCH 116/483] Change `python-tag` to `python_tag` to conform to the new setuptools version. PiperOrigin-RevId: 739958612 --- jaxlib/tools/build_gpu_plugin_wheel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jaxlib/tools/build_gpu_plugin_wheel.py b/jaxlib/tools/build_gpu_plugin_wheel.py index 667807b51197..d52cc7da36e8 100644 --- a/jaxlib/tools/build_gpu_plugin_wheel.py +++ b/jaxlib/tools/build_gpu_plugin_wheel.py @@ -81,7 +81,7 @@ def write_setup_cfg(sources_path, cpu): [bdist_wheel] plat_name={tag} -python-tag=py3 +python_tag=py3 """ ) From c1f65c3e1f045106e090e41d74b6968fed824b1c Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 24 Mar 2025 09:27:50 -0700 Subject: [PATCH 117/483] Update CUDA version in Bazel configs to 12.8, and CUDNN version to 9.8. PiperOrigin-RevId: 739967341 --- .bazelrc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.bazelrc b/.bazelrc index fb938169b3c0..642fb15ed541 100644 --- a/.bazelrc +++ b/.bazelrc @@ -141,8 +141,8 @@ build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain build:cuda --@local_config_cuda//:enable_cuda # Default hermetic CUDA and CUDNN versions. -build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2" -build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1" +build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.8.0" +build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.8.0" build:cuda --@local_config_cuda//cuda:include_cuda_libs=true # This config is used for building targets with CUDA libraries from stubs. From 198d7bb9c29bbf20dab893739cf546a6e78f4c18 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Mon, 24 Mar 2025 09:29:55 -0700 Subject: [PATCH 118/483] [pallas] Add support for `split` into any power-of-two equal parts in Triton lowering. PiperOrigin-RevId: 739968019 --- jax/_src/pallas/triton/lowering.py | 23 +++++++++++++++-------- tests/pallas/pallas_test.py | 15 ++++++++++----- 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index d7f6e4695229..a0883ea589b0 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -1802,15 +1802,22 @@ def _concatenate_lowering_rule(ctx: LoweringRuleContext, *args, dimension): def _split_lowering_rule(ctx: LoweringRuleContext, x, *, sizes, axis): pass # TODO(cjfj): Add support for larger powers of 2. - if len(sizes) != 2: - raise NotImplementedError("Only splitting into two parts is supported.") - if sizes[0] != sizes[1]: + num_parts = len(sizes) + if num_parts != pallas_utils.next_power_of_2(num_parts): + raise NotImplementedError("Only power-of-2 num parts supported.") + if any(size != sizes[0] for size in sizes): raise NotImplementedError("Only equal-sized splits are supported.") - (x_aval,) = ctx.avals_in - shape = x_aval.shape - x = _reshape(x, shape[:axis] + (2, sizes[0]) + shape[axis + 1 :]) - permutation = tuple(d for d in range(len(shape) + 1) if d != axis) + (axis,) - return tt_dialect.split(tt_dialect.trans(x, permutation)) + + def split_into_2(x): + shape = ir.RankedTensorType(x.type).shape + x = _reshape(x, shape[:axis] + [2, shape[axis] // 2] + shape[axis + 1 :]) + permutation = tuple(d for d in range(len(shape) + 1) if d != axis) + (axis,) + return tuple(tt_dialect.split(tt_dialect.trans(x, permutation))) + + x_parts = (x,) + while len(x_parts) < num_parts: + x_parts = sum(map(split_into_2, x_parts), ()) + return x_parts def _compute_offsets_from_indices( diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 0b16260ff25a..9e5130b8f449 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -844,15 +844,20 @@ def dot_kernel(x_ref, y_ref, o_ref): self.assertAllClose(dot_kernel(x, y), expected) @parameterized.parameters( - ((32,), 0), ((32, 64), 0), ((32, 16), 1), ((32, 16, 2), 1) + ((32,), 2, 0), ((32, 64), 4, 0), ((32, 16), 8, 1), ((32, 16, 2), 16, 1) ) - def test_split(self, shape, axis): + def test_split(self, shape, num_parts, axis): + if jtu.test_device_matches(["tpu"]) and shape[axis] == num_parts: + self.skipTest("TPU doesn't support fully split axis.") + x = jax.random.normal(jax.random.key(0), shape) - expected = jnp.split(x, 2, axis) + expected = jnp.split(x, num_parts, axis) @functools.partial(self.pallas_call, out_shape=expected) - def kernel(x_ref, o0_ref, o1_ref): - o0_ref[()], o1_ref[()] = jnp.split(x_ref[()], 2, axis) + def kernel(x_ref, *o_ref): + x_parts = jnp.split(x_ref[()], num_parts, axis) + for o_ref, x_part in zip(o_ref, x_parts): + o_ref[...] = x_part self.assertAllClose(kernel(x), expected) From a2f22cc1dec0c02d5d1f0213af4c731a008775bd Mon Sep 17 00:00:00 2001 From: Gleb Pobudzey Date: Mon, 24 Mar 2025 10:46:30 -0700 Subject: [PATCH 119/483] [Mosaic GPU] Adding a primitive to load from memrefs *with* a specified layout. PiperOrigin-RevId: 739995908 --- jax/_src/pallas/mosaic_gpu/primitives.py | 96 +++++++++++++++++++++++- jax/experimental/pallas/mosaic_gpu.py | 1 + tests/pallas/mosaic_gpu_test.py | 66 ++++++++++++++++ 3 files changed, 162 insertions(+), 1 deletion(-) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index edfae55fb288..a27137964349 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -21,7 +21,7 @@ import enum import itertools import math -from typing import Any, Literal +from typing import Any, Literal, Optional import jax from jax._src import core as jax_core @@ -31,6 +31,7 @@ from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith as arith_dialect from jax._src.lib.mlir.dialects import llvm as llvm_dialect +from jax._src.lib.mlir.dialects import memref as memref_dialect from jax._src.lib.mlir.dialects import nvvm as nvvm_dialect from jax._src.pallas import core as pallas_core from jax._src.pallas.mosaic_gpu import core as gpu_core @@ -62,6 +63,99 @@ def _check_ref( ) +load_p = jax_core.Primitive("load") + +@load_p.def_effectful_abstract_eval +def _load_abstract_eval(src, *avals_flat, args_tree, layout): + del layout # Unused. + + transforms = args_tree.unflatten(avals_flat) + return ( + jax_core.ShapedArray(transforms[-1].get_indexer_shape(), src.dtype), + {state.ReadEffect(0)}, + ) + +@lowering.register_lowering_rule(load_p, mgpu.ThreadSemantics.Lane) +def _load_p_lowering_rule( + ctx: lowering.LoweringRuleContext, x_ref, *leaves, args_tree, layout +): + if not isinstance(x_ref, ir.Value) or not ir.MemRefType.isinstance(x_ref.type): + raise TypeError(f"Can only load from references (got {x_ref}).") + + x_aval = ctx.avals_in[0] + + transforms = jax.tree.unflatten(args_tree, leaves) + x_ref, transforms = lowering._handle_reshaping(x_ref, transforms) + x_ref, transforms = lowering._handle_indexing(x_ref, transforms) + + if layout is not None: + layout = layout.to_mgpu() + + match transforms: + case (gpu_core.UnswizzleRef(swizzle), gpu_core.UntileRef(tiling)): + if tiling != (64, swizzle // x_aval.dtype.itemsize): + raise NotImplementedError("Tiling does not fit swizzle") + return mgpu.FragmentedArray.load_tiled( + x_ref, is_signed=mgpu_utils.is_signed(x_aval.dtype), swizzle=swizzle, + layout=layout + ) + case (): + # Handle scalar indexing. + if not ctx.avals_out[0].shape: + is_signed = mgpu_utils.is_signed(x_aval.dtype) + val = memref_dialect.load(x_ref, []) + return mgpu.FragmentedArray.splat(val, shape=(), layout=layout, is_signed=is_signed) + match layout: + case mgpu.WGMMARowFragLayout(): + return mgpu.FragmentedArray.load_wgmma_row( + x_ref, is_signed=mgpu_utils.is_signed(x_aval.dtype) + ) + case mgpu.WGStridedFragLayout(shape=shape, vec_size=vec_size): + ref_ty = ir.MemRefType(x_ref.type) + if shape != tuple(ref_ty.shape): + raise ValueError( + f"Unsupported shape {shape}, (expected {tuple(ref_ty.shape)})" + ) + + return mgpu.FragmentedArray.load_strided( + x_ref, is_signed=mgpu_utils.is_signed(x_aval.dtype), vec_size=vec_size, + ) + case None: + return mgpu.FragmentedArray.load_strided( + x_ref, is_signed=mgpu_utils.is_signed(x_aval.dtype) + ) + case _: + raise NotImplementedError(f"Unsupported layout: {layout}") + case _: + raise NotImplementedError(f"Unsupported transforms: {transforms}") + + +def load( + src: _Ref, idx, *, layout: Optional[Layout | ParameterizedLayout] = None +) -> mgpu.FragmentedArray: + """ Loads a ref (SMEM or GMEM) into a FragmentedArray with the specified layout. + + Args: + src: The reference to copy from. + idx: The index to load from. + layout: The optional layout to use for the returned FragmentedArray. + + Returns: + A FragmentedArray containing the loaded data in the specified layout. + """ + src, src_transforms = state_primitives.get_ref_and_transforms( + src, idx, "load", force_trailing_indexer=True, + ) + flat_src_transforms, src_transforms_treedef = tree_util.tree_flatten( + src_transforms + ) + return load_p.bind( + src, + *flat_src_transforms, + args_tree=src_transforms_treedef, + layout=layout + ) + copy_smem_to_gmem_p = jax_core.Primitive("copy_smem_to_gmem") copy_smem_to_gmem_p.multiple_results = True diff --git a/jax/experimental/pallas/mosaic_gpu.py b/jax/experimental/pallas/mosaic_gpu.py index aab58d092190..d5acb9b131ad 100644 --- a/jax/experimental/pallas/mosaic_gpu.py +++ b/jax/experimental/pallas/mosaic_gpu.py @@ -40,6 +40,7 @@ from jax._src.pallas.mosaic_gpu.primitives import copy_smem_to_gmem as copy_smem_to_gmem from jax._src.pallas.mosaic_gpu.primitives import Layout as Layout from jax._src.pallas.mosaic_gpu.primitives import layout_cast as layout_cast +from jax._src.pallas.mosaic_gpu.primitives import load as load from jax._src.pallas.mosaic_gpu.primitives import set_max_registers as set_max_registers from jax._src.pallas.mosaic_gpu.primitives import wait_smem_to_gmem as wait_smem_to_gmem from jax._src.pallas.mosaic_gpu.primitives import wgmma as wgmma diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 408b8bdf5713..d31d1c9d41b2 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -634,6 +634,72 @@ def kernel(x_ref, o_ref, barrier_ref): x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128) np.testing.assert_array_equal(f(x), np.stack([x, x], axis=0)) + @parameterized.product( + src_memory_space=[plgpu.SMEM, plgpu.GMEM], + layout=[ + plgpu.Layout.WGMMA_ROW, + plgpu.Layout.WG_STRIDED((128,), vec_size=1), + None, + ], + ) + def test_load_to_layout_with_indexing(self, src_memory_space, layout): + def kernel(x_ref, o_ref): + for i in range(2): + x = plgpu.load(x_ref, (i,), layout=layout) + o_ref[i, ...] = x + + in_spec = pl.BlockSpec(memory_space=src_memory_space) + out_spec = plgpu.GPUBlockSpec( + (2, 128), lambda: (0, 0), memory_space=plgpu.SMEM, + ) + f = pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct([2, 128], jnp.float32), + in_specs=(in_spec,), + out_specs=out_spec, + ) + x = jnp.arange(2 * 128, dtype=jnp.float32).reshape(2, 128) + np.testing.assert_array_equal(f(x), x) + + @parameterized.product(src_memory_space=[plgpu.SMEM, plgpu.GMEM]) + def test_load_row_input_to_wgmma_with_transforms(self, src_memory_space): + m, k, n = 64, 128, 192 + key1, key2 = jax.random.split(jax.random.key(42), 2) + a = jax.random.uniform(key1, shape=(m,), dtype=jnp.float16) + b = jax.random.uniform(key2, shape=(k, n), dtype=jnp.float16) + def kernel(x_ref, y_ref, o_ref): + x = plgpu.load(x_ref, (), layout=plgpu.Layout.WGMMA_ROW) + x = lax.broadcast_in_dim(x, (m, k), [0]) + + def compute(acc_ref): + plgpu.wgmma(acc_ref, x, y_ref) + return acc_ref[...] + + out = pl.run_scoped(compute, plgpu.ACC((m, n), jnp.float32)) + o_ref[...] = out + + out_spec = plgpu.GPUBlockSpec( + (m, n), lambda: (0, 0), memory_space=plgpu.SMEM, + ) + f = pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct([m, n], jnp.float32), + in_specs=( + pl.BlockSpec(memory_space=src_memory_space), + plgpu.GPUBlockSpec( + (k, n), + lambda: (0, 0), + transforms=( + plgpu.TilingTransform((64, 64)), + plgpu.SwizzleTransform(128), + ) + )), + out_specs=out_spec, + ) + + out_ref = jnp.broadcast_to(a[:, None], (m, k)) @ b + np.testing.assert_allclose(f(a, b), out_ref, rtol=1e-3) + def test_indexing_before_transpose(self): def kernel(x_ref, o_ref, barrier_ref): for i in range(2): From 92f231e875118cb114e25c8517eb5aed53729066 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 24 Mar 2025 11:31:18 -0700 Subject: [PATCH 120/483] Delay the unflattening in `jnp.array` Reverts 53e8eac7134a13c1d28de673e7e3a23b4a837aed PiperOrigin-RevId: 740012608 --- jax/_src/numpy/lax_numpy.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 96efc48062e1..16355695792d 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -49,15 +49,16 @@ from jax._src.lax.lax import (PrecisionLike,_array_copy, _sort_le_comparator, _sort_lt_comparator) from jax._src.lib import xla_client as xc -from jax._src.numpy.array_creation import (empty, empty_like, full, - ones, ones_like, zeros, zeros_like) from jax._src.numpy import indexing from jax._src.numpy import reductions from jax._src.numpy import tensor_contractions from jax._src.numpy import ufuncs from jax._src.numpy import util +from jax._src.numpy.array_creation import (empty, empty_like, full, + ones, ones_like, zeros, zeros_like) from jax._src.numpy.sorting import argsort, sort from jax._src.numpy.vectorize import vectorize +from jax._src.sharding_impls import SingleDeviceSharding from jax._src.typing import ( Array, ArrayLike, DType, DTypeLike, DeprecatedArg, DimSize, Shape ) @@ -65,8 +66,7 @@ NumpyComplexWarning, canonicalize_axis as _canonicalize_axis, ceil_of_ratio, safe_zip, set_module, unzip2) from jax.sharding import Sharding -from jax._src.sharding_impls import SingleDeviceSharding -from jax.tree_util import tree_leaves, tree_map +from jax.tree_util import tree_flatten, tree_map import numpy as np export = set_module('jax.numpy') @@ -5504,9 +5504,7 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, object = xc._xla.cuda_array_interface_to_buffer( cai=cai, gpu_backend=backend, device_id=device_id) - object = tree_map(lambda leaf: leaf.__jax_array__() - if hasattr(leaf, "__jax_array__") else leaf, object) - leaves = tree_leaves(object, is_leaf=lambda x: x is None) + leaves, treedef = tree_flatten(object, is_leaf=lambda x: x is None) if any(leaf is None for leaf in leaves): # Added Nov 16 2023 if deprecations.is_accelerated("jax-numpy-array-none"): @@ -5515,7 +5513,13 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, "None encountered in jnp.array(); this is currently treated as NaN. " "In the future this will result in an error.", FutureWarning, stacklevel=2) - leaves = tree_leaves(object) + leaves, treedef = tree_flatten(object) + leaves = [ + leaf + if (leaf_jax_array := getattr(leaf, "__jax_array__", None)) is None + else leaf_jax_array() + for leaf in leaves + ] if dtype is None: # Use lattice_result_type rather than result_type to avoid canonicalization. # Otherwise, weakly-typed inputs would have their dtypes canonicalized. @@ -5530,8 +5534,8 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, if not weak_type: dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True) # type: ignore[assignment] + object = treedef.unflatten(leaves) out: ArrayLike - if all(not isinstance(leaf, Array) for leaf in leaves): # TODO(jakevdp): falling back to numpy here fails to overflow for lists # containing large integers; see discussion in From b89cf0de91cdaec2388f6b9b2dc07d17a18d5b99 Mon Sep 17 00:00:00 2001 From: Zac Mustin Date: Mon, 24 Mar 2025 11:40:41 -0700 Subject: [PATCH 121/483] Stop using mesh and `*_specs` in roofline tests. These args are optional, so not specifying them in our tests will make them simpler and easier to read. This change is a no-op. PiperOrigin-RevId: 740015584 --- tests/roofline_test.py | 93 +++++++++++------------------------------- 1 file changed, 24 insertions(+), 69 deletions(-) diff --git a/tests/roofline_test.py b/tests/roofline_test.py index f94f5a328c46..98f6176c22a0 100644 --- a/tests/roofline_test.py +++ b/tests/roofline_test.py @@ -18,7 +18,6 @@ from absl.testing import absltest import jax -from jax._src import mesh from jax._src import test_util as jtu from jax.experimental import roofline import jax.lax as lax @@ -465,11 +464,7 @@ def collective_matmul(a, b): ) def test_unary_ops(self, f, dtype): data = jnp.zeros((3, 8), dtype=dtype) - out, result = roofline.roofline( - f, - in_specs=(P()), - out_specs=P(), - )(data) + out, result = roofline.roofline(f)(data) with self.subTest("flops"): self.assertEqual(result.unfused_flops, 3 * 8) with self.subTest("hbm_bytes"): @@ -495,12 +490,9 @@ def test_binary_ops(self): lambda a, b: jnp.minimum(a, b), lambda a, b: jnp.maximum(a, b), ]: - out, result = roofline.roofline( - f, - mesh=mesh.AbstractMesh((), ()), - in_specs=(P(), P()), - out_specs=P(), - )(jnp.zeros((3, 8), dtype=int), jnp.ones((3, 8), dtype=int)) + out, result = roofline.roofline(f)( + jnp.zeros((3, 8), dtype=int), jnp.ones((3, 8), dtype=int) + ) self.assertEqual(result.unfused_flops, 3 * 8) self.assertEqual( result.unfused_hbm_bytes, @@ -515,12 +507,7 @@ def test_broadcast(self): (2.0, jnp.ones((3, 8))), (jnp.zeros((3, 8)), 2.0), ]: - _, result = roofline.roofline( - lambda a, b: a + b, - mesh=mesh.AbstractMesh((), ()), - in_specs=(P(), P()), - out_specs=P(), - )(left, right) + _, result = roofline.roofline(lambda a, b: a + b)(left, right) self.assertEqual(result.unfused_flops, 3 * 8) def test_nested(self): @@ -531,27 +518,21 @@ def g(x): return g(x) + g(y) - _, result = roofline.roofline( - f, - mesh=mesh.AbstractMesh((), ()), - in_specs=(P(), P()), - out_specs=P(), - )(jnp.zeros((11, 4), dtype=int), jnp.ones((11, 4), dtype=int)) + _, result = roofline.roofline(f)( + jnp.zeros((11, 4), dtype=int), jnp.ones((11, 4), dtype=int) + ) self.assertEqual(result.unfused_flops, 3 * (11 * 4)) def test_no_mesh(self): - _, result = roofline.roofline( - lambda a, b: a + b, - in_specs=(P(), P()), - out_specs=P(), - )(jnp.zeros((3, 8), dtype=int), jnp.ones((3, 8), dtype=int)) + _, result = roofline.roofline(lambda a, b: a + b)( + jnp.zeros((3, 8), dtype=int), jnp.ones((3, 8), dtype=int) + ) self.assertEqual(result.unfused_flops, 3 * 8) def test_no_specs(self): - _, result = roofline.roofline( - lambda a, b: a + b, - mesh=mesh.AbstractMesh((), ()), - )(jnp.zeros((3, 8), dtype=int), jnp.ones((3, 8), dtype=int)) + _, result = roofline.roofline(lambda a, b: a + b)( + jnp.zeros((3, 8), dtype=int), jnp.ones((3, 8), dtype=int) + ) self.assertEqual(result.unfused_flops, 3 * 8) def test_no_mesh_and_no_specs(self): @@ -561,12 +542,9 @@ def test_no_mesh_and_no_specs(self): self.assertEqual(result.unfused_flops, 3 * 8) def test_dot_general(self): - _, result = roofline.roofline( - lambda a, b: a @ b, - mesh=mesh.AbstractMesh((), ()), - in_specs=(P(), P()), - out_specs=P(), - )(jnp.zeros((3, 7), dtype=int), jnp.ones((7, 5), dtype=int)) + _, result = roofline.roofline(lambda a, b: a @ b)( + jnp.zeros((3, 7), dtype=int), jnp.ones((7, 5), dtype=int) + ) self.assertEqual(result.unfused_flops, 2 * 3 * 7 * 5) self.assertEqual( result.unfused_hbm_bytes, self._bytes_per_word * (3 * 7 + 7 * 5 + 3 * 5) @@ -631,12 +609,7 @@ def test_conv_general_dilated_unfused_hbm_bytes( feature_group_count=feature_group_count, ) - _, result = roofline.roofline( - conv, - mesh=mesh.AbstractMesh((), ()), - in_specs=(P(), P()), - out_specs=P(), - )(input_data, kernel_data) + _, result = roofline.roofline(conv)(input_data, kernel_data) expected_input_size = batch * num_input_channels * iw * ih expected_kernel_size = num_output_channels * num_input_features * kw * kh @@ -677,12 +650,7 @@ def test_conv_general_dilated_padding_string_unfused_hbm_bytes( lhs=a, rhs=b, window_strides=(1, 1), padding=padding ) - _, result = roofline.roofline( - conv, - mesh=mesh.AbstractMesh((), ()), - in_specs=(P(), P()), - out_specs=P(), - )(input_data, kernel_data) + _, result = roofline.roofline(conv)(input_data, kernel_data) expected_input_size = 1 * 1 * 10 * 20 expected_kernel_size = 1 * 1 * 3 * 3 @@ -702,12 +670,7 @@ def test_conv_general_dilated_padding_string_valid_unfused_hbm_bytes(self): lhs=a, rhs=b, window_strides=(1, 1), padding="VALID" ) - _, result = roofline.roofline( - conv, - mesh=mesh.AbstractMesh((), ()), - in_specs=(P(), P()), - out_specs=P(), - )(input_data, kernel_data) + _, result = roofline.roofline(conv)(input_data, kernel_data) expected_input_size = 1 * 1 * 10 * 20 expected_kernel_size = 1 * 1 * 3 * 3 @@ -725,12 +688,7 @@ def test_conv_general_dilated_padding_string_valid_unfused_hbm_bytes(self): self.assertEqual(result.unfused_hbm_bytes, expected_unfused_hbm_bytes) def test_reduce_sum_no_axis(self): - _, result = roofline.roofline( - lambda x: jnp.sum(x), - mesh=mesh.AbstractMesh((), ()), - in_specs=(P()), - out_specs=P(), - )(jnp.zeros((11, 4))) + _, result = roofline.roofline(lambda x: jnp.sum(x))(jnp.zeros((11, 4))) self.assertEqual(result.unfused_flops, 11 * 4 - 1) self.assertEqual( result.unfused_hbm_bytes, self._bytes_per_word * (11 * 4 + 1) @@ -743,12 +701,9 @@ def test_reduce_sum_with_axis(self): ([0, 1], 11 * 4 - 1, 11 * 4 + 1), ([], 0, 11 * 4 + 11 * 4), ]: - _, result = roofline.roofline( - lambda x: jnp.sum(x, axis=axis), - mesh=mesh.AbstractMesh((), ()), - in_specs=(P()), - out_specs=P(), - )(jnp.zeros((11, 4))) + _, result = roofline.roofline(lambda x: jnp.sum(x, axis=axis))( + jnp.zeros((11, 4)) + ) self.assertEqual(result.unfused_flops, expected_flops) self.assertEqual( result.unfused_hbm_bytes, self._bytes_per_word * expected_memory From 94846941e30c239e29b1f67ef0652567653b9ed3 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 24 Mar 2025 14:56:10 -0400 Subject: [PATCH 122/483] Fix mac wheel build. The xla_extension move introduced an incorrect path. --- jaxlib/tools/build_wheel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index 9967fc14b9f9..fcc811789c19 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -135,7 +135,7 @@ def verify_mac_libraries_dont_reference_chkstack(): if not _is_mac(): return nm = subprocess.run( - ["nm", "-g", r.Rlocation("__main/jaxlib/xla/xla_extension.so")], + ["nm", "-g", r.Rlocation("__main__/jaxlib/xla/xla_extension.so")], capture_output=True, text=True, check=False, From 7e42539653d33ec995487b683794c0bc86f7199b Mon Sep 17 00:00:00 2001 From: Zac Mustin Date: Mon, 24 Mar 2025 11:55:42 -0700 Subject: [PATCH 123/483] Create `_FMA_FLOPS_FACTOR` to be used in roofline `dot` (and later `conv)`. This change is a no-op made for convenience for follow-up changes. PiperOrigin-RevId: 740020625 --- jax/experimental/roofline/rooflines.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/jax/experimental/roofline/rooflines.py b/jax/experimental/roofline/rooflines.py index 1edd1e0649b1..bc8d65e966dd 100644 --- a/jax/experimental/roofline/rooflines.py +++ b/jax/experimental/roofline/rooflines.py @@ -36,6 +36,8 @@ from jax.experimental import shard_map +_FMA_FLOPS_FACTOR = 2 + for prim in it.chain( ad_util.__dict__.values(), ann.__dict__.values(), @@ -156,7 +158,7 @@ def _dot_general_roofline( (lhs_contract, _), (lhs_batch, _) = dimension_numbers flops = ( - 2 + _FMA_FLOPS_FACTOR * lhs.size * rhs.size / np.prod([lhs.shape[i] for i in lhs_contract]) From 13862ec10b0e3eaccb090822f103f1ae34b6e5b0 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 24 Mar 2025 09:13:49 -0400 Subject: [PATCH 124/483] Small cleanup to pretty-printer. Kidger's reimplementation of this code notes that the break mode and indent are unused in the _fits function (https://github.com/patrick-kidger/wadler_lindig/blob/851379b8f55e2bb98ea2c81905863f90f9606f0a/wadler_lindig/_wadler_lindig.py#L166). We can make the same optimization here. --- jax/_src/pretty_printer.py | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/jax/_src/pretty_printer.py b/jax/_src/pretty_printer.py index e8fdff497445..d02b6d9962e0 100644 --- a/jax/_src/pretty_printer.py +++ b/jax/_src/pretty_printer.py @@ -201,26 +201,20 @@ def __init__(self, child: Doc, *, foreground: Color | None = None, # non-recursive formulation using an explicit stack, necessary because Python # doesn't have a tail recursion optimization. -def _fits(doc: Doc, width: int, agenda: list[tuple[int, _BreakMode, Doc]] - ) -> bool: +def _fits(doc: Doc, width: int) -> bool: + agenda = [doc] while width >= 0 and len(agenda) > 0: - i, m, doc = agenda.pop() + doc = agenda.pop() if isinstance(doc, _NilDoc): pass elif isinstance(doc, _TextDoc): width -= len(doc.text) elif isinstance(doc, _ConcatDoc): - agenda.extend((i, m, d) for d in reversed(doc.children)) + agenda.extend(reversed(doc.children)) elif isinstance(doc, _BreakDoc): - if m == _BreakMode.BREAK: - return True width -= len(doc.text) - elif isinstance(doc, _NestDoc): - agenda.append((i + doc.n, m, doc.child)) - elif isinstance(doc, _GroupDoc): - agenda.append((i, _BreakMode.FLAT, doc.child)) - elif isinstance(doc, _ColorDoc) or isinstance(doc, _SourceMapDoc): - agenda.append((i, m, doc.child)) + elif isinstance(doc, (_NestDoc, _GroupDoc, _ColorDoc, _SourceMapDoc)): + agenda.append(doc.child) else: raise ValueError("Invalid document ", doc) @@ -372,8 +366,7 @@ def _format( elif isinstance(doc, _GroupDoc): # In Lindig's paper, _fits is passed the remainder of the document. # I'm pretty sure that's a bug and we care only if the current group fits! - if (_sparse(doc) - and _fits(doc, width - k, [(i, _BreakMode.FLAT, doc.child)])): + if (_sparse(doc) and _fits(doc, width - k)): agenda.append(_State(i, _BreakMode.FLAT, doc.child, color, source)) else: agenda.append(_State(i, _BreakMode.BREAK, doc.child, color, source)) From 7e235e3aee527d3a4c6f6cc0b633c175303e5c46 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 24 Mar 2025 12:43:28 -0700 Subject: [PATCH 125/483] jax.test_util: improve type annotations --- jax/_src/public_test_util.py | 13 ++++++++----- jax/_src/test_util.py | 19 +++++++++++-------- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/jax/_src/public_test_util.py b/jax/_src/public_test_util.py index 455a3b98cce2..59ddb73dc9e1 100644 --- a/jax/_src/public_test_util.py +++ b/jax/_src/public_test_util.py @@ -14,6 +14,7 @@ from functools import partial import operator +from typing import Any, TypeAlias from jax._src import api from jax._src import config @@ -32,7 +33,7 @@ EPS = 1e-4 -def _dtype(x): +def _dtype(x: Any) -> np.dtype: if hasattr(x, 'dtype'): return x.dtype elif type(x) in _dtypes.python_scalar_dtypes: @@ -40,8 +41,9 @@ def _dtype(x): else: return np.asarray(x).dtype +ToleranceDict: TypeAlias = dict[np.dtype, int | float] -_default_tolerance = { +_default_tolerance: ToleranceDict = { _dtypes.float0: 0, np.dtype(np.bool_): 0, np.dtype(_dtypes.int4): 0, @@ -76,7 +78,7 @@ def default_tolerance(): return _default_tolerance -default_gradient_tolerance = { +default_gradient_tolerance: ToleranceDict = { np.dtype(_dtypes.float8_e4m3b11fnuz): 1e-1, np.dtype(_dtypes.float8_e4m3fn): 1e-1, np.dtype(_dtypes.float8_e4m3fnuz): 1e-1, @@ -104,7 +106,7 @@ def default_tolerance(): _default_tolerance[np.dtype(_dtypes.float4_e2m1fn)] = 1e0 default_gradient_tolerance[np.dtype(_dtypes.float4_e2m1fn)] = 1e0 -def is_python_scalar(val): +def is_python_scalar(val: Any) -> bool: return not isinstance(val, np.generic) and isinstance(val, (bool, int, float, complex)) def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''): @@ -151,7 +153,8 @@ def maybe_upcast(x): # value errors. It should not do that. np.testing.assert_allclose(a, b, **kw, err_msg=err_msg) -def tolerance(dtype, tol=None): + +def tolerance(dtype: np.dtype, tol: int | float | ToleranceDict | None = None) -> int | float: tol = {} if tol is None else tol if not isinstance(tol, dict): return tol diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 3a18d12e9d4b..c3c4a934dd0e 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -51,6 +51,7 @@ from jax._src import lib as _jaxlib from jax._src import monitoring from jax._src import test_warning_util +from jax._src.typing import ArrayLike, DTypeLike from jax._src import xla_bridge from jax._src import util from jax._src import mesh as mesh_lib @@ -59,7 +60,7 @@ from jax._src.numpy.util import promote_dtypes, promote_dtypes_inexact from jax._src.public_test_util import ( # noqa: F401 _assert_numpy_allclose, _check_dtypes_match, _default_tolerance, _dtype, check_close, check_grads, - check_jvp, check_vjp, default_gradient_tolerance, default_tolerance, rand_like, tolerance) + check_jvp, check_vjp, default_gradient_tolerance, default_tolerance, rand_like, tolerance, ToleranceDict) from jax._src.util import unzip2 from jax.tree_util import tree_all, tree_flatten, tree_map, tree_unflatten import numpy as np @@ -131,10 +132,10 @@ def sanitize_test_name(s: str) -> str: return kSanitizeNameRE.sub("_", s) -def num_float_bits(dtype): +def num_float_bits(dtype: DTypeLike) -> int: return _dtypes.finfo(_dtypes.canonicalize_dtype(dtype)).bits -def to_default_dtype(arr): +def to_default_dtype(arr: ArrayLike) -> np.ndarray: """Convert a value to an array with JAX's default dtype. This is generally used for type conversions of values returned by numpy functions, @@ -145,7 +146,7 @@ def to_default_dtype(arr): dtype = _dtypes._default_types.get(arr.dtype.kind) return arr.astype(_dtypes.canonicalize_dtype(dtype)) if dtype else arr -def with_jax_dtype_defaults(func, use_defaults=True): +def with_jax_dtype_defaults(func: Callable[..., Any], use_defaults: bool = True): """Return a version of a function with outputs that match JAX's default dtypes. This is generally used to wrap numpy functions within tests, in order to make @@ -168,7 +169,7 @@ def wrapped(*args, **kwargs): return tree_map(f, result, use_defaults) return wrapped -def is_sequence(x): +def is_sequence(x: Any) -> bool: try: iter(x) except TypeError: @@ -176,14 +177,16 @@ def is_sequence(x): else: return True -def _normalize_tolerance(tol): +def _normalize_tolerance(tol: int | float | ToleranceDict | None) -> ToleranceDict: tol = tol or 0 if isinstance(tol, dict): return {np.dtype(k): v for k, v in tol.items()} else: return dict.fromkeys(_default_tolerance, tol) -def join_tolerance(tol1, tol2): +def join_tolerance( + tol1: int | float | ToleranceDict | None, + tol2: int | float | ToleranceDict | None) -> ToleranceDict: tol1 = _normalize_tolerance(tol1) tol2 = _normalize_tolerance(tol2) out = tol1 @@ -192,7 +195,7 @@ def join_tolerance(tol1, tol2): return out -def check_eq(xs, ys, err_msg=''): +def check_eq(xs: Any, ys: Any, err_msg: str = '') -> None: assert_close = partial(_assert_numpy_allclose, err_msg=err_msg) tree_all(tree_map(assert_close, xs, ys)) From f5a4d1a85c41a42ed8fb389259a241513970ff9a Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 24 Mar 2025 12:47:34 -0700 Subject: [PATCH 126/483] Enable `jax` wheel testing via Bazel. Remove jax dependencies from the Bazel test targets for `:build_jaxlib=false` and `:build_jaxlib=wheel`. `internal_test_util` is removed from the `jax` wheel. To use this package in Bazel py_test, we need to copy it to the unpacked wheel folder. This is done by adding `wheel_deps` value to `py_import` Jax targets. This change concludes ML Wheels design implementation in JAX and enables testing of all wheels via Bazel command. PiperOrigin-RevId: 740037952 --- BUILD.bazel | 44 ++++++++++++++++++++++++++ WORKSPACE | 1 + jax/BUILD | 41 ++++++++++++++++--------- jaxlib/jax.bzl | 83 +++++++++++++++++++++++++++++++++----------------- 4 files changed, 126 insertions(+), 43 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index eb43d7ec0fd8..5700fcef2e77 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -12,11 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +load( + "@xla//third_party/py:py_import.bzl", + "py_import", +) load("@xla//third_party/py:python_wheel.bzl", "collect_data_files", "transitive_py_deps") load( "//jaxlib:jax.bzl", "jax_source_package", "jax_wheel", + "py_deps", ) collect_data_files( @@ -98,3 +103,42 @@ jax_source_package( source_package_binary = ":build_wheel", source_package_name = "jax", ) + +genrule( + name = "internal_test_util_sources", + srcs = [ + "//jax:internal_export_back_compat_test_util", + "//jax:internal_test_harnesses", + "//jax:internal_test_util", + "//jax:internal_export_back_compat_test_data", + ], + outs = ["internal_test_util_sources.zip"], + cmd = "$(location @bazel_tools//tools/zip:zipper) c $@ $(SRCS)", + tools = ["@bazel_tools//tools/zip:zipper"], +) + +COMMON_DEPS = py_deps([ + "absl/testing", + "numpy", + "ml_dtypes", + "scipy", + "opt_einsum", + "hypothesis", + "cloudpickle", +]) + +py_import( + name = "jax_py_import", + wheel = ":jax_wheel", + wheel_deps = [":internal_test_util_sources"], + deps = COMMON_DEPS, +) + +# This target is used to add internal test util sources to the jax wheel. +# This is needed for the tests that depend on jax and use internal test util sources. +py_import( + name = "jax_wheel_with_internal_test_util", + wheel = "@pypi_jax//:whl", + wheel_deps = [":internal_test_util_sources"], + deps = COMMON_DEPS, +) diff --git a/WORKSPACE b/WORKSPACE index 129488281ea9..a6968446a1ec 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -16,6 +16,7 @@ python_init_repositories( "3.13-ft": "//build:requirements_lock_3_13_ft.txt", }, local_wheel_inclusion_list = [ + "jax-*", "jaxlib*", "jax_cuda*", "jax-cuda*", diff --git a/jax/BUILD b/jax/BUILD index 12eae4afdcf7..5d37a8987445 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -167,22 +167,30 @@ py_library( ], ), visibility = [":internal"], - deps = [ - ":jax", - ] + py_deps("numpy"), + deps = if_building_jaxlib( + if_building = [ + ":jax", + ], + if_not_building = [], + if_not_building_for_cpu = [], + ) + py_deps("numpy"), ) py_library( name = "internal_test_harnesses", srcs = ["_src/internal_test_util/test_harnesses.py"], visibility = [":internal"] + jax_internal_test_harnesses_visibility, - deps = [ - ":ad_util", - ":config", - ":jax", - ":test_util", - "//jax/_src/lib", - ] + py_deps("numpy"), + deps = if_building_jaxlib( + if_building = [ + ":ad_util", + ":config", + ":jax", + ":test_util", + "//jax/_src/lib", + ], + if_not_building = [], + if_not_building_for_cpu = [], + ) + py_deps("numpy"), ) py_library( @@ -191,15 +199,18 @@ py_library( visibility = [ ":internal", ] + jax_internal_export_back_compat_test_util_visibility, - deps = [ - ":jax", - ":test_util", - ] + py_deps("numpy"), + deps = if_building_jaxlib( + if_building = [ + ":jax", + ":test_util", + ], + if_not_building = [], + if_not_building_for_cpu = [], + ) + py_deps("numpy"), ) py_library( name = "internal_export_back_compat_test_data", - testonly = 1, srcs = glob([ "_src/internal_test_util/export_back_compat_test_data/*.py", "_src/internal_test_util/export_back_compat_test_data/pallas/*.py", diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index c6f55a86143f..9b8c861404c2 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -31,7 +31,6 @@ load("@xla//xla/tsl/platform:build_config_root.bzl", _tf_cuda_tests_tags = "tf_c cc_proto_library = _cc_proto_library cuda_library = _cuda_library rocm_library = _rocm_library -pytype_test = native.py_test nanobind_extension = _pybind_extension if_cuda_is_configured = _if_cuda_is_configured if_rocm_is_configured = _if_rocm_is_configured @@ -64,6 +63,18 @@ PLATFORM_TAGS_DICT = { ("Windows", "AMD64"): ("win", "amd64"), } +_GPU_PYPI_WHEEL_DEPS = [ + "//:jax_wheel_with_internal_test_util", + "@pypi_jaxlib//:pkg", + "@pypi_jax_cuda12_plugin//:pkg", + "@pypi_jax_cuda12_pjrt//:pkg", +] + +_CPU_PYPI_WHEEL_DEPS = [ + "//:jax_wheel_with_internal_test_util", + "@pypi_jaxlib//:pkg", +] + # TODO(vam): remove this once zstandard builds against Python 3.13 def get_zstandard(): if HERMETIC_PYTHON_VERSION == "3.13" or HERMETIC_PYTHON_VERSION == "3.13-ft": @@ -223,39 +234,50 @@ ALL_BACKENDS = ["cpu", "gpu", "tpu"] def if_building_jaxlib( if_building, - if_not_building = [ - "@pypi_jaxlib//:pkg", - "@pypi_jax_cuda12_plugin//:pkg", - "@pypi_jax_cuda12_pjrt//:pkg", - ], - if_not_building_for_cpu = ["@pypi_jaxlib//:pkg"], - if_py_import = [ - "//jaxlib/tools:jaxlib_py_import", - "//jaxlib/tools:jax_cuda_plugin_py_import", - "//jaxlib/tools:jax_cuda_pjrt_py_import", - ], - if_py_import_for_cpu = [ - "//jaxlib/tools:jaxlib_py_import", - ]): + if_not_building = _GPU_PYPI_WHEEL_DEPS, + if_not_building_for_cpu = _CPU_PYPI_WHEEL_DEPS): """Adds jaxlib and jaxlib cuda plugin wheels as dependencies instead of depending on sources. This allows us to test prebuilt versions of jaxlib wheels against the rest of the JAX codebase. Args: if_building: the source code targets to depend on in case we don't depend on the jaxlib wheels - if_not_building: the jaxlib wheels to depend on including gpu-specific plugins in case of + if_not_building: the wheels to depend on including gpu-specific plugins in case of gpu-enabled builds - if_not_building_for_cpu: the jaxlib wheels to depend on in case of cpu-only builds - if_py_import: the py_import targets to depend on in case of gpu-enabled builds - if_py_import_for_cpu: the py_import targets to depend on in case of cpu-only builds + if_not_building_for_cpu: the wheels to depend on in case of cpu-only builds """ return select({ "//jax:enable_jaxlib_build": if_building, "//jax_plugins/cuda:disable_jaxlib_for_cpu_build": if_not_building_for_cpu, "//jax_plugins/cuda:disable_jaxlib_for_cuda12_build": if_not_building, - "//jax_plugins/cuda:enable_py_import_for_cpu_build": if_py_import_for_cpu, - "//jax_plugins/cuda:enable_py_import_for_cuda12_build": if_py_import, + "//conditions:default": [], + }) + +def _get_test_deps(deps): + jaxlib_build_deps = [ + "//jaxlib/cuda:gpu_only_test_deps", + "//jaxlib/rocm:gpu_only_test_deps", + "//jax_plugins:gpu_plugin_only_test_deps", + ] + + gpu_py_imports = [ + "//:jax_py_import", + "//jaxlib/tools:jaxlib_py_import", + "//jaxlib/tools:jax_cuda_plugin_py_import", + "//jaxlib/tools:jax_cuda_pjrt_py_import", + ] + cpu_py_imports = [ + "//:jax_py_import", + "//jaxlib/tools:jaxlib_py_import", + ] + + return select({ + "//jax:enable_jaxlib_build": jaxlib_build_deps + deps, + "//jax_plugins/cuda:disable_jaxlib_for_cpu_build": _CPU_PYPI_WHEEL_DEPS, + "//jax_plugins/cuda:disable_jaxlib_for_cuda12_build": _GPU_PYPI_WHEEL_DEPS, + "//jax_plugins/cuda:enable_py_import_for_cpu_build": cpu_py_imports, + "//jax_plugins/cuda:enable_py_import_for_cuda12_build": gpu_py_imports, }) # buildifier: disable=function-docstring @@ -308,14 +330,10 @@ def jax_multiplatform_test( srcs = srcs, args = test_args, env = env, - deps = [ + deps = _get_test_deps([ "//jax", "//jax:test_util", - ] + deps + if_building_jaxlib([ - "//jaxlib/cuda:gpu_only_test_deps", - "//jaxlib/rocm:gpu_only_test_deps", - "//jax_plugins:gpu_plugin_only_test_deps", - ]), + ] + deps), data = data, shard_count = test_shards, tags = test_tags, @@ -609,7 +627,16 @@ def jax_py_test( env = dict(env) if "PYTHONWARNINGS" not in env: env["PYTHONWARNINGS"] = "error" - py_test(name = name, env = env, **kwargs) + deps = kwargs.get("deps", []) + kwargs.pop("deps") + test_deps = _get_test_deps(deps) + py_test(name = name, env = env, deps = test_deps, **kwargs) + +def pytype_test(name, **kwargs): + deps = kwargs.get("deps", []) + kwargs.pop("deps") + test_deps = _get_test_deps(deps) + native.py_test(name = name, deps = test_deps, **kwargs) def if_oss(oss_value, google_value = []): """Returns one of the arguments based on the non-configurable build env. From 13b6e01acf84f9ee3d314e77de819960c52e3faa Mon Sep 17 00:00:00 2001 From: Michael Whittaker Date: Mon, 24 Mar 2025 13:00:57 -0700 Subject: [PATCH 127/483] Increased tolerance in failing xla client tests. PiperOrigin-RevId: 740041921 --- jaxlib/xla/xla_client_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/jaxlib/xla/xla_client_test.py b/jaxlib/xla/xla_client_test.py index 5a2f3881f510..7de905d9ec41 100644 --- a/jaxlib/xla/xla_client_test.py +++ b/jaxlib/xla/xla_client_test.py @@ -1420,7 +1420,7 @@ def testDotGeneral(self): (([2], [1]), ([0], [0]))) ops.DotGeneral( ops.Constant(c, lhs), ops.Constant(c, rhs), dimension_numbers) - self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=1e-6) + self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=2e-6) def testDotGeneralWithDotDimensionNumbersProto(self): c = self._NewComputation() @@ -1436,7 +1436,7 @@ def testDotGeneralWithDotDimensionNumbersProto(self): ops.DotGeneral( ops.Constant(c, lhs), ops.Constant(c, rhs), dimension_numbers) - self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=1e-6) + self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=2e-6) def testDotGeneralWithPrecisionConfig(self): c = self._NewComputation() @@ -1453,7 +1453,7 @@ def testDotGeneralWithPrecisionConfig(self): ops.Constant(c, rhs), dimension_numbers, precision_config=config) - self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=1e-6) + self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=2e-6) def testConvGeneralDilatedF32(self): c = self._NewComputation() From 9f3eb3e232bdf9355f4cd02cf91592da9b065850 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 24 Mar 2025 13:06:22 -0700 Subject: [PATCH 128/483] Migrate more modules of xla/python to jax. PiperOrigin-RevId: 740043785 --- .../mlir/_mlir_libs/register_jax_dialects.cc | 8 +- jaxlib/xla/BUILD | 409 +++- jaxlib/xla/config.cc | 343 ++++ jaxlib/xla/config.h | 34 + jaxlib/xla/custom_call_sharding.cc | 343 ++++ jaxlib/xla/custom_call_sharding.h | 28 + jaxlib/xla/dlpack.cc | 699 +++++++ jaxlib/xla/dlpack.h | 57 + jaxlib/xla/jax_jit.cc | 495 +++++ jaxlib/xla/jax_jit.h | 265 +++ jaxlib/xla/mlir.cc | 251 +++ jaxlib/xla/mlir.h | 28 + jaxlib/xla/pjit.cc | 1402 ++++++++++++++ jaxlib/xla/pjit.h | 27 + jaxlib/xla/pmap_lib.cc | 1180 ++++++++++++ jaxlib/xla/pmap_lib.h | 37 + jaxlib/xla/sdy.cc | 143 ++ jaxlib/xla/sdy.h | 28 + jaxlib/xla/weakref_lru_cache.cc | 400 ++++ jaxlib/xla/weakref_lru_cache.h | 28 + jaxlib/xla/xla.cc | 20 +- jaxlib/xla/xla_compiler.cc | 1639 +++++++++++++++++ jaxlib/xla/xla_compiler.h | 28 + 23 files changed, 7868 insertions(+), 24 deletions(-) create mode 100644 jaxlib/xla/config.cc create mode 100644 jaxlib/xla/config.h create mode 100644 jaxlib/xla/custom_call_sharding.cc create mode 100644 jaxlib/xla/custom_call_sharding.h create mode 100644 jaxlib/xla/dlpack.cc create mode 100644 jaxlib/xla/dlpack.h create mode 100644 jaxlib/xla/jax_jit.cc create mode 100644 jaxlib/xla/jax_jit.h create mode 100644 jaxlib/xla/mlir.cc create mode 100644 jaxlib/xla/mlir.h create mode 100644 jaxlib/xla/pjit.cc create mode 100644 jaxlib/xla/pjit.h create mode 100644 jaxlib/xla/pmap_lib.cc create mode 100644 jaxlib/xla/pmap_lib.h create mode 100644 jaxlib/xla/sdy.cc create mode 100644 jaxlib/xla/sdy.h create mode 100644 jaxlib/xla/weakref_lru_cache.cc create mode 100644 jaxlib/xla/weakref_lru_cache.h create mode 100644 jaxlib/xla/xla_compiler.cc create mode 100644 jaxlib/xla/xla_compiler.h diff --git a/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc b/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc index 64f84965b8e2..1ba6fd9375df 100644 --- a/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc +++ b/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc @@ -2,7 +2,6 @@ // This module is called by mlir/__init__.py during initialization. #include -#include "shardy/integrations/c/passes.h" #include "mlir-c/Dialect/Arith.h" #include "mlir-c/Dialect/Func.h" #include "mlir-c/Dialect/GPU.h" @@ -15,13 +14,14 @@ #include "mlir-c/Dialect/Vector.h" #include "mlir-c/Transforms.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "shardy/integrations/c/passes.h" #include "jaxlib/mosaic/gpu/integrations/c/passes.h" namespace nb = nanobind; -#define REGISTER_DIALECT(name) \ - MlirDialectHandle name##_dialect = mlirGetDialectHandle__##name##__(); \ - mlirDialectHandleInsertDialect(name##_dialect, registry) +#define REGISTER_DIALECT(name) \ + MlirDialectHandle name##_dialect = mlirGetDialectHandle__##name##__(); \ + mlirDialectHandleInsertDialect(name##_dialect, registry) NB_MODULE(register_jax_dialects, m) { m.doc() = "Registers upstream MLIR dialects used by JAX."; diff --git a/jaxlib/xla/BUILD b/jaxlib/xla/BUILD index 3239ba703937..592d9d1c24f3 100644 --- a/jaxlib/xla/BUILD +++ b/jaxlib/xla/BUILD @@ -43,6 +43,16 @@ nanobind_extension( pytype_srcs = glob(["xla_extension/*.pyi"]), visibility = ["//visibility:public"], deps = [ + ":config", + ":custom_call_sharding", + ":dlpack", + ":jax_jit", + ":mlir", + ":pjit", + ":pmap_lib", + ":sdy", + ":weakref_lru_cache", + ":xla_compiler", "@com_google_absl//absl/base", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/hash", @@ -83,31 +93,21 @@ nanobind_extension( "@xla//xla/pjrt/distributed:service", "@xla//xla/pjrt/plugin/xla_cpu:cpu_client_options", "@xla//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client", - "@xla//xla/python:config", - "@xla//xla/python:custom_call_sharding", - "@xla//xla/python:dlpack", "@xla//xla/python:guard_lib", - "@xla//xla/python:jax_jit", "@xla//xla/python:logging", - "@xla//xla/python:mlir", "@xla//xla/python:nb_absl_flat_hash_map", "@xla//xla/python:nb_absl_span", "@xla//xla/python:nb_class_ptr", "@xla//xla/python:ops", - "@xla//xla/python:pjit", - "@xla//xla/python:pmap_lib", "@xla//xla/python:pprof_profile_builder", "@xla//xla/python:profiler", "@xla//xla/python:py_client", "@xla//xla/python:python_ref_manager", "@xla//xla/python:pytree", "@xla//xla/python:refine_polymorphic_shapes", - "@xla//xla/python:sdy", "@xla//xla/python:traceback", "@xla//xla/python:types", "@xla//xla/python:util", - "@xla//xla/python:weakref_lru_cache", - "@xla//xla/python:xla_compiler", "@xla//xla/python/ifrt", "@xla//xla/python/ifrt:plugin_program", "@xla//xla/python/ifrt:plugin_program_serdes", @@ -144,6 +144,395 @@ nanobind_extension( }), ) +cc_library( + name = "config", + srcs = ["config.cc"], + hdrs = ["config.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@nanobind", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla/python:python_ref_manager", + "@xla//xla/tsl/platform:logging", + ], +) + +cc_library( + name = "custom_call_sharding", + srcs = ["custom_call_sharding.cc"], + hdrs = ["custom_call_sharding.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@nanobind", + "@xla//xla:shape_util", + "@xla//xla:util", + "@xla//xla/hlo/ir:hlo", + "@xla//xla/hlo/utils:hlo_sharding_util", + "@xla//xla/pjrt:status_casters", + "@xla//xla/pjrt/c:pjrt_c_api_custom_partitioner_extension_hdrs", + "@xla//xla/pjrt/c:pjrt_c_api_hdrs", + "@xla//xla/pjrt/c:pjrt_c_api_helpers", + "@xla//xla/python:custom_call_batch_partitioner", + "@xla//xla/python:custom_partition_callback", + "@xla//xla/python:inspect_sharding", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/platform:statusor", + ], +) + +cc_library( + name = "dlpack", + srcs = ["dlpack.cc"], + hdrs = ["dlpack.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@dlpack", + "@llvm-project//llvm:Support", + "@nanobind", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla:shape_util", + "@xla//xla:status_macros", + "@xla//xla:util", + "@xla//xla/pjrt:exceptions", + "@xla//xla/pjrt:pjrt_client", + "@xla//xla/pjrt:pjrt_common", + "@xla//xla/pjrt:pjrt_compiler", + "@xla//xla/pjrt:pjrt_layout", + "@xla//xla/python:nb_class_ptr", + "@xla//xla/python:py_client", + "@xla//xla/python:python_ref_manager", + "@xla//xla/python:traceback", + "@xla//xla/python:types", + "@xla//xla/python:util", + "@xla//xla/python/ifrt", + "@xla//xla/python/pjrt_ifrt", + "@xla//xla/tsl/platform:errors", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/platform:statusor", + ], +) + +cc_library( + name = "jax_jit", + srcs = ["jax_jit.cc"], + hdrs = ["jax_jit.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@nanobind", + "@tsl//tsl/profiler/lib:traceme", + "@xla//third_party/python_runtime:headers", # build_cleaner: keep + "@xla//xla/pjrt:pjrt_client", + "@xla//xla/pjrt:pjrt_layout", + "@xla//xla/pjrt:status_casters", + "@xla//xla/python:nb_absl_inlined_vector", + "@xla//xla/python:nb_absl_span", + "@xla//xla/python:nb_helpers", + "@xla//xla/python:py_client", + "@xla//xla/python:python_ref_manager", + "@xla//xla/python:pytree", + "@xla//xla/python:types", + "@xla//xla/tsl/platform:logging", + ], +) + +cc_library( + name = "mlir", + srcs = ["mlir.cc"], + hdrs = ["mlir.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:BytecodeWriter", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FuncExtensions", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ReconcileUnrealizedCasts", + "@llvm-project//mlir:Support", + "@nanobind", + "@stablehlo//:stablehlo_serialization", + "@xla//xla/hlo/builder:xla_computation", + "@xla//xla/hlo/translate:stablehlo", + "@xla//xla/mlir_hlo:mhlo_passes", + "@xla//xla/pjrt:mlir_to_hlo", + "@xla//xla/pjrt:status_casters", + "@xla//xla/python:refine_polymorphic_shapes", + "@xla//xla/service/llvm_ir:llvm_util", + "@xla//xla/tsl/platform:errors", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/platform:statusor", + ], +) + +cc_library( + name = "pjit", + srcs = ["pjit.cc"], + hdrs = ["pjit.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":config", + ":jax_jit", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@nanobind", + "@tsl//tsl/profiler/lib:traceme", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla:shape_util", + "@xla//xla:util", + "@xla//xla/pjrt:exceptions", + "@xla//xla/pjrt:lru_cache", + "@xla//xla/python:guard_lib", + "@xla//xla/python:nb_class_ptr", + "@xla//xla/python:nb_helpers", + "@xla//xla/python:nb_numpy", + "@xla//xla/python:py_client", + "@xla//xla/python:python_ref_manager", + "@xla//xla/python:pytree", + "@xla//xla/python:traceback", + "@xla//xla/python/ifrt", + "@xla//xla/tsl/concurrency:ref_count", + "@xla//xla/tsl/platform:env", + "@xla//xla/tsl/platform:errors", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/platform:statusor", + ], +) + +cc_library( + name = "pmap_lib", + srcs = ["pmap_lib.cc"], + hdrs = ["pmap_lib.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":config", + ":jax_jit", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@nanobind", + "@tsl//tsl/profiler/lib:traceme", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla:status_macros", + "@xla//xla:util", + "@xla//xla:xla_data_proto_cc", + "@xla//xla/pjrt:exceptions", + "@xla//xla/pjrt:status_casters", + "@xla//xla/python:nb_class_ptr", + "@xla//xla/python:nb_helpers", + "@xla//xla/python:nb_numpy", + "@xla//xla/python:py_client", + "@xla//xla/python:python_ref_manager", + "@xla//xla/python:pytree", + "@xla//xla/python:traceback", + "@xla//xla/python:types", + "@xla//xla/python/ifrt", + "@xla//xla/tsl/concurrency:ref_count", + "@xla//xla/tsl/platform:env", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/platform:statusor", + "@xla//xla/tsl/python/lib/core:numpy", + ], +) + +cc_library( + name = "sdy", + srcs = ["sdy.cc"], + hdrs = ["sdy.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:BytecodeWriter", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@nanobind", + "@shardy//shardy/dialect/sdy/ir:dialect", + "@shardy//shardy/dialect/sdy/transforms/import:passes", + "@xla//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo", + "@xla//xla/mlir_hlo:all_passes", + "@xla//xla/pjrt:mlir_to_hlo", + "@xla//xla/pjrt:status_casters", + "@xla//xla/service/spmd/shardy:constants", + "@xla//xla/service/spmd/shardy:utils", + "@xla//xla/service/spmd/shardy/sdy_round_trip:import_shardy_attrs", + "@xla//xla/service/spmd/shardy/sdy_round_trip:pipelines", + "@xla//xla/tsl/framework/mlir:status_scoped_diagnostic_handler", + ], +) + +cc_library( + name = "weakref_lru_cache", + srcs = ["weakref_lru_cache.cc"], + hdrs = ["weakref_lru_cache.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@nanobind", + "@xla//third_party/python_runtime:headers", + "@xla//xla/pjrt:lru_cache", + "@xla//xla/tsl/platform:logging", + ], +) + +cc_library( + name = "xla_compiler", + srcs = ["xla_compiler.cc"], + hdrs = ["xla_compiler.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":dlpack", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@nanobind", + "@xla//xla:array", + "@xla//xla:debug_options_flags", + "@xla//xla:literal", + "@xla//xla:shape_util", + "@xla//xla:util", + "@xla//xla:xla_data_proto_cc", + "@xla//xla:xla_proto_cc", + "@xla//xla/client:executable_build_options", + "@xla//xla/ffi", + "@xla//xla/ffi:ffi_api", + "@xla//xla/ffi/api:c_api", + "@xla//xla/hlo/builder:xla_builder", + "@xla//xla/hlo/builder:xla_computation", + "@xla//xla/hlo/ir:hlo", + "@xla//xla/hlo/ir:hlo_module_group", + "@xla//xla/hlo/parser:hlo_parser", + "@xla//xla/hlo/pass:hlo_pass", + "@xla//xla/hlo/transforms/simplifiers:flatten_call_graph", + "@xla//xla/hlo/transforms/simplifiers:hlo_dce", + "@xla//xla/hlo/transforms/simplifiers:tuple_simplifier", + "@xla//xla/pjrt:compile_options_proto_cc", + "@xla//xla/pjrt:exceptions", + "@xla//xla/pjrt:pjrt_executable", + "@xla//xla/pjrt:status_casters", + "@xla//xla/python:nb_absl_span", + "@xla//xla/python:nb_helpers", + "@xla//xla/python:nb_numpy", + "@xla//xla/python:py_client", + "@xla//xla/python:types", + "@xla//xla/service:call_inliner", + "@xla//xla/service:computation_placer", + "@xla//xla/service:custom_call_target_registry", + "@xla//xla/service:hlo_graph_dumper", + "@xla//xla/service:hlo_module_config", + "@xla//xla/service:hlo_proto_cc", + "@xla//xla/service:name_uniquer", + "@xla//xla/tsl/lib/strings:proto_serialization", + "@xla//xla/tsl/platform:env", + "@xla//xla/tsl/platform:errors", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/platform:statusor", + ], +) + pytype_strict_library( name = "xla_client", srcs = ["xla_client.py"], diff --git a/jaxlib/xla/config.cc b/jaxlib/xla/config.cc new file mode 100644 index 000000000000..b5bc5830acbf --- /dev/null +++ b/jaxlib/xla/config.cc @@ -0,0 +1,343 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/config.h" + +#include + +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "xla/python/python_ref_manager.h" +#include "xla/tsl/platform/logging.h" + +namespace jax { + +namespace nb = nanobind; + +// Singleton object used to represet "value not set" in thread-local configs. +nb::object UnsetObject() { + return nb::steal(PyObject_CallObject( + reinterpret_cast(&PyBaseObject_Type), nullptr)); +} + +// Each configuration object has: +// * a global value, and +// * a thread-local value. +// When querying the state of a config, the thread-local value is used if it is +// set. Otherwise, the global value is used. + +// This class represents all of the thread-local configuration state for a +// thread. +class ThreadLocalConfigState { + public: + ThreadLocalConfigState(); + ~ThreadLocalConfigState(); + + static ThreadLocalConfigState& Instance() { + thread_local auto state = std::make_unique(); + return *state; + } + + nb::object Get(int key) { + DCHECK_GE(key, 0); + return key >= entries_.size() ? nb::object() : entries_[key]; + } + + void Set(int key, nb::object value); + + private: + friend class GlobalConfigState; + + // These values are accessed in one of two ways: + // * The owning thread reads or writes them, while holding the GIL, or, under + // free-threading, while the owning thread is in ATTACHED gc state. + // * Other threads may read or clear values while performing a garbarge + // collection. + // No locking is needed because a GC thread cannot run concurrently with other + // Python threads; even under free-threading Python uses a stop-the-world GC. + std::vector entries_; +}; + +// This class represents all of the global configuration state. +// TODO(phawkins): to support free-threading, we will need to add locking to +// this class. +class GlobalConfigState { + public: + static GlobalConfigState& Instance() { + static auto state = new GlobalConfigState(); + return *state; + } + + nb::object Get(int key) const; + void Set(int key, nb::object value); + + // Adds or removes a thread-local state from the set of thread-local states. + void AddThreadLocalState(ThreadLocalConfigState* state) { + absl::MutexLock lock(&mu_); + thread_local_states_.insert(state); + } + void RemoveThreadLocalState(ThreadLocalConfigState* state) { + absl::MutexLock lock(&mu_); + thread_local_states_.erase(state); + } + + // Python GC helpers. These are called from the tp_traverse and tp_clear + // methods of the Config class. + int tp_traverse(int key, PyObject* self, visitproc visit, void* arg); + int tp_clear(int key, PyObject* self); + + // Returns the singleton object representing "value not set". + const nb::object& unset() const { return unset_; } + + // Returns the set of keys that should be included in the jit key. + absl::Span include_in_jit_key() const { + return include_in_jit_key_; + } + + private: + friend class Config; + + // The set of thread-local states. This is used during garbarge collection to + // visit thread-local values. + absl::Mutex mu_; + absl::flat_hash_set thread_local_states_ + ABSL_GUARDED_BY(mu_); + std::vector entries_; + std::vector include_in_jit_key_; + nb::object unset_ = UnsetObject(); +}; + +ThreadLocalConfigState::ThreadLocalConfigState() { + GlobalConfigState::Instance().AddThreadLocalState(this); +} + +ThreadLocalConfigState::~ThreadLocalConfigState() { + // It's important that we remove the thread-local state before we access + // entries_. This ensures that accesses to entries_ are ordered with respect + // any garbage collection. + GlobalConfigState::Instance().RemoveThreadLocalState(this); + // We do not hold the GIL, so we must use deferred destruction. + xla::GlobalPyRefManager()->AddGarbage(absl::MakeSpan(entries_)); +} + +void ThreadLocalConfigState::Set(int key, nb::object value) { + DCHECK_GE(key, 0); + if (key >= entries_.size()) { + entries_.resize(key + 1); + } + std::swap(entries_[key], value); +} + +nb::object GlobalConfigState::Get(int key) const { + DCHECK_GE(key, 0); + DCHECK_LT(key, entries_.size()); + return entries_[key]; +} + +void GlobalConfigState::Set(int key, nb::object value) { + DCHECK_GE(key, 0); + DCHECK_LT(key, entries_.size()); + std::swap(entries_[key], value); +} + +int GlobalConfigState::tp_traverse(int key, PyObject* self, visitproc visit, + void* arg) { + DCHECK_GE(key, 0); + if (key < entries_.size()) { + PyObject* value = entries_[key].ptr(); + Py_VISIT(value); + } + absl::MutexLock lock(&mu_); + for (const auto* state : thread_local_states_) { + if (key < state->entries_.size()) { + PyObject* value = state->entries_[key].ptr(); + Py_VISIT(value); + } + } + return 0; +} + +int GlobalConfigState::tp_clear(int key, PyObject* self) { + if (key < entries_.size()) { + nb::object tmp; + std::swap(entries_[key], tmp); + } + // We destroy the python objects outside of the lock out of an abundance of + // caution. + std::vector to_destroy; + absl::MutexLock lock(&mu_); + to_destroy.reserve(thread_local_states_.size()); + for (auto* state : thread_local_states_) { + if (key < state->entries_.size()) { + nb::object tmp; + std::swap(state->entries_[key], tmp); + to_destroy.push_back(std::move(tmp)); + } + } + return 0; +} + +// A Config object represents a configurable object with both global and +// thread-local state. This class is wrapped using nanobind and exposed to +// Python. +class Config { + public: + Config(nb::object value, bool include_in_jit_key); + + // Returns the thread-local value if it is set, otherwise the global value. + nb::object Get(); + + // Returns the global value. + nb::object GetGlobal(); + + // Sets the global value. + void SetGlobal(nb::object value); + + // Returns the thread-local value. + nb::object GetLocal(); + + // Sets the thread-local value. May be `unset`. + void SetLocal(nb::object value); + + // Swaps the thread-local value with `value`. Returns the previous value. + // Either may be `unset`. + nb::object SwapLocal(nb::object value); + + // This class doesn't actually hold any data, but it's the only type + // known to Python. We pretend that this object owns both the global and any + // thread-local values corresponding to this key. + static int tp_traverse(PyObject* self, visitproc visit, void* arg); + static int tp_clear(PyObject* self); + static PyType_Slot slots_[]; + + private: + int key_; +}; + +Config::Config(nb::object value, bool include_in_jit_key) { + auto& instance = GlobalConfigState::Instance(); + key_ = instance.entries_.size(); + instance.entries_.push_back(std::move(value)); + if (include_in_jit_key) { + instance.include_in_jit_key_.push_back(key_); + } +} + +nb::object Config::GetLocal() { + nb::object result = ThreadLocalConfigState::Instance().Get(key_); + if (!result.is_valid()) { + return GlobalConfigState::Instance().unset(); + } + return result; +} + +nb::object Config::GetGlobal() { + return GlobalConfigState::Instance().Get(key_); +} + +nb::object Config::Get() { + nb::object local = ThreadLocalConfigState::Instance().Get(key_); + if (local.is_valid()) { + return local; + } + return GetGlobal(); +} + +void Config::SetLocal(nb::object value) { + const auto& instance = GlobalConfigState::Instance(); + if (value.ptr() == instance.unset().ptr()) { + value = nb::object(); + } + ThreadLocalConfigState::Instance().Set(key_, std::move(value)); +} + +nb::object Config::SwapLocal(nb::object value) { + const auto& global_instance = GlobalConfigState::Instance(); + auto& instance = ThreadLocalConfigState::Instance(); + auto result = instance.Get(key_); + if (value.ptr() == global_instance.unset().ptr()) { + value = nb::object(); + } + instance.Set(key_, std::move(value)); + if (!result.is_valid()) { + return global_instance.unset(); + } + return result; +} + +void Config::SetGlobal(nb::object value) { + GlobalConfigState::Instance().Set(key_, value); +} + +/* static */ int Config::tp_traverse(PyObject* self, visitproc visit, + void* arg) { + Config* c = nb::inst_ptr(self); + // For the purposes of GC, we pretend that this object owns both the global + // and any thread-local values corresponding to this key. + return GlobalConfigState::Instance().tp_traverse(c->key_, self, visit, arg); +} + +/* static */ int Config::tp_clear(PyObject* self) { + Config* c = nb::inst_ptr(self); + return GlobalConfigState::Instance().tp_clear(c->key_, self); +} + +PyType_Slot Config::slots_[] = { + {Py_tp_traverse, reinterpret_cast(Config::tp_traverse)}, + {Py_tp_clear, reinterpret_cast(Config::tp_clear)}, + {0, nullptr}, +}; + +void BuildConfigSubmodule(nanobind::module_& m) { + nb::module_ config_module = m.def_submodule("config", "Config library"); + + config_module.attr("unset") = GlobalConfigState::Instance().unset(); + + nb::class_ config(config_module, "Config", + nb::type_slots(Config::slots_), nb::is_generic()); + config.def(nb::init(), nb::arg("value").none(), + nb::arg("include_in_jit_key") = false); + config.def_prop_ro("value", &Config::Get); + config.def("get_local", &Config::GetLocal); + config.def("get_global", &Config::GetGlobal); + config.def("set_local", &Config::SetLocal, nb::arg("value").none()); + config.def("swap_local", &Config::SwapLocal, nb::arg("value").none()); + config.def("set_global", &Config::SetGlobal, nb::arg("value").none()); +} + +std::vector JitConfigs() { + auto& instance = GlobalConfigState::Instance(); + auto& thread_local_instance = ThreadLocalConfigState::Instance(); + std::vector result; + result.reserve(instance.include_in_jit_key().size()); + for (int i : instance.include_in_jit_key()) { + nb::object local = thread_local_instance.Get(i); + if (local.is_valid()) { + result.push_back(std::move(local)); + } else { + result.push_back(instance.Get(i)); + } + } + return result; +} + +} // namespace jax diff --git a/jaxlib/xla/config.h b/jaxlib/xla/config.h new file mode 100644 index 000000000000..40847bf4a370 --- /dev/null +++ b/jaxlib/xla/config.h @@ -0,0 +1,34 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_CONFIG_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_CONFIG_H_ + +#include + +// placeholder for index annotation headers +#include "nanobind/nanobind.h" + +namespace jax { + +// Returns the set of configuration values that should be included in the JIT +// cache key. +std::vector JitConfigs(); + +void BuildConfigSubmodule(nanobind::module_& m); + +} // namespace jax + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_CONFIG_H_ diff --git a/jaxlib/xla/custom_call_sharding.cc b/jaxlib/xla/custom_call_sharding.cc new file mode 100644 index 000000000000..f88bc93e3af3 --- /dev/null +++ b/jaxlib/xla/custom_call_sharding.cc @@ -0,0 +1,343 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "jaxlib/xla/custom_call_sharding.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/tuple.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/utils/hlo_sharding_util.h" +#include "xla/pjrt/c/pjrt_c_api.h" +#include "xla/pjrt/c/pjrt_c_api_custom_partitioner_extension.h" +#include "xla/pjrt/c/pjrt_c_api_helpers.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/custom_call_batch_partitioner.h" +#include "xla/python/custom_partition_callback.h" +#include "xla/python/inspect_sharding.h" +#include "xla/shape.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" + +namespace xla { + +namespace nb = ::nanobind; + +class PyCustomCallPartitionerCallbacks { + public: + PyCustomCallPartitionerCallbacks(nb::object prop_user_sharding, + nb::object partition, + nb::object infer_sharding_from_operands) + : prop_user_sharding_(prop_user_sharding), + partition_(partition), + infer_sharding_from_operands_(infer_sharding_from_operands) { + callbacks_.version = 0; + callbacks_.private_data = this; + callbacks_.dtor = +[](JAX_CustomCallPartitioner_Callbacks* self) { + delete GetSelfPtr(self); + }; + callbacks_.partition = +[](JAX_CustomCallPartitioner_Callbacks* self, + JAX_CustomCallPartitioner_Partition_Args* args) { + jax::PopulateResults(GetSelfPtr(self)->CallPartition(args), args); + }; + callbacks_.infer_sharding = + +[](JAX_CustomCallPartitioner_Callbacks* self, + JAX_CustomCallPartitioner_InferShardingFromOperands_Args* args) { + jax::PopulateResults( + GetSelfPtr(self)->CallInferShardingFromOperands(args), args); + }; + callbacks_.propagate_user_sharding = + +[](JAX_CustomCallPartitioner_Callbacks* self, + JAX_CustomCallPartitioner_PropagateUserSharding_Args* args) { + jax::PopulateResults( + GetSelfPtr(self)->CallPropagateUserSharding(args), args); + }; + } + + absl::StatusOr< + std::tuple, xla::HloSharding>> + CallPartition(JAX_CustomCallPartitioner_Partition_Args* args) const { + if (args->header.api_version != 0) { + return absl::InternalError("API version mismatch."); + } + TF_ASSIGN_OR_RETURN(auto args_tuple, jax::ReadArgs(args)); + std::vector shapes = std::move(std::get<0>(args_tuple)); + std::vector> shardings = + std::move(std::get<1>(args_tuple)); + xla::Shape result_shape = std::move(std::get<2>(args_tuple)); + std::optional result_sharding = + std::move(std::get<3>(args_tuple)); + absl::string_view backend_config = std::move(std::get<4>(args_tuple)); + + { + nb::gil_scoped_acquire gil; + try { + auto py_result = + partition_(shapes, shardings, result_shape, result_sharding, + nb::bytes(backend_config.data(), backend_config.size())); + try { + auto [ir, arg_shardings, result_sharding] = nb::cast< + std::tuple, HloSharding>>( + py_result); + if (arg_shardings.size() != args->num_args) { + return xla::Internal( + "Shardings returned from partitioning: lengths must match: %d " + "vs %d", + arg_shardings.size(), args->num_args); + } + return std::make_tuple(std::string(ir.c_str(), ir.size()), + std::move(arg_shardings), + std::move(result_sharding)); + } catch (const nb::cast_error& e) { + return xla::Internal( + "Shardings returned from partitioning: expected " + "Tuple[bytes, List[HloSharding], HloSharding] got: %s", + nb::cast(nb::repr(py_result))); + } + } catch (const nb::python_error& e) { + return xla::Internal("custom_partitioner: %s", e.what()); + } + } + } + + absl::StatusOr> CallInferShardingFromOperands( + JAX_CustomCallPartitioner_InferShardingFromOperands_Args* args) const { + if (args->header.api_version != 0) { + return absl::InternalError("API version mismatch."); + } + TF_ASSIGN_OR_RETURN(auto args_tuple, jax::ReadArgs(args)); + std::vector arg_shapes = std::move(std::get<0>(args_tuple)); + std::vector> arg_shardings = + std::move(std::get<1>(args_tuple)); + xla::Shape result_shape = std::move(std::get<2>(args_tuple)); + absl::string_view backend_config = std::move(std::get<3>(args_tuple)); + + std::optional result; + nb::gil_scoped_acquire gil; + try { + auto py_result = infer_sharding_from_operands_( + arg_shapes, arg_shardings, result_shape, + nb::bytes(backend_config.data(), backend_config.size())); + if (py_result.is_none()) { + return std::nullopt; + } + return nb::cast(py_result); + } catch (const nb::python_error& e) { + return xla::Internal("custom_partitioner: %s", e.what()); + } + } + + absl::StatusOr CallPropagateUserSharding( + JAX_CustomCallPartitioner_PropagateUserSharding_Args* args) const { + if (args->header.api_version != 0) { + return absl::InternalError("API version mismatch."); + } + TF_ASSIGN_OR_RETURN(auto args_tuple, jax::ReadArgs(args)); + xla::HloSharding result_sharding = std::move(std::get<0>(args_tuple)); + xla::Shape result_shape = std::move(std::get<1>(args_tuple)); + absl::string_view backend_config = std::move(std::get<2>(args_tuple)); + + nb::gil_scoped_acquire gil; + try { + // TODO(parkers): expand this API to handle the `user` sharding. + // The user is used when the custom call returns a Tuple and + // the user is a get-tuple-element. In this case we must update only + // part of the sharding spec. + auto result = nb::cast(prop_user_sharding_( + result_sharding, result_shape, + nb::bytes(backend_config.data(), backend_config.size()))); + return result; + } catch (const nb::python_error& e) { + return xla::Internal("custom_partitioner: %s", e.what()); + } + } + + JAX_CustomCallPartitioner_Callbacks* callbacks() { return &callbacks_; } + + private: + static PyCustomCallPartitionerCallbacks* GetSelfPtr( + JAX_CustomCallPartitioner_Callbacks* callbacks) { + return reinterpret_cast( + callbacks->private_data); + } + + JAX_CustomCallPartitioner_Callbacks callbacks_; + nb::object prop_user_sharding_; + nb::object partition_; + nb::object infer_sharding_from_operands_; +}; + +namespace { + +void CallInspectSharding(void* obj, JAX_InspectSharding_Callback_Args* args) { + std::optional arg = jax::InspectShardingReadArgs(args); + if (!arg.has_value()) { + return; + } + try { + nb::gil_scoped_acquire gil; + nb::handle(reinterpret_cast(obj))(*std::move(arg)); + } catch (const nb::python_error& e) { + jax::InspectShardingSetError(args, std::string(e.what())); + } +} + +} // namespace + +void BuildCustomCallShardingPybindAPI(nb::module_& m) { + m.def( + "register_custom_call_partitioner", + [](std::string name, nb::object prop_user_sharding, nb::object partition, + nb::object infer_sharding_from_operands, + bool can_side_effecting_have_replicated_sharding, + std::optional c_api) { + auto* c_fns = + (new PyCustomCallPartitionerCallbacks(prop_user_sharding, partition, + infer_sharding_from_operands)) + ->callbacks(); + c_fns->can_side_effecting_have_replicated_sharding = + can_side_effecting_have_replicated_sharding; + if (!c_api.has_value()) { + RegisterCustomCallPartitioner( + name, jax::CreateCApiCustomCallPartitioner(c_fns)); + return; + } + + if (absl::string_view(c_api->name()) != "pjrt_c_api") { + throw absl::InvalidArgumentError( + "Argument to register_custom_call_partitioner was not a " + "pjrt_c_api capsule."); + } + auto* c_api_value = static_cast(c_api->data()); + PJRT_Custom_Partitioner_Extension* extension = + pjrt::FindExtension( + c_api_value, + PJRT_Extension_Type::PJRT_Extension_Type_Custom_Partitioner); + if (extension == nullptr) { + return; + } + PJRT_Register_Custom_Partitioner_Args args; + args.struct_size = PJRT_Register_Custom_Partitioner_Args_STRUCT_SIZE; + args.name = name.c_str(); + args.name_size = name.size(); + args.callbacks = c_fns; + PJRT_Error* error = + reinterpret_cast( + extension) + ->register_custom_partitioner(&args); + std::unique_ptr error_ptr( + error, pjrt::MakeErrorDeleter(c_api_value)); + ThrowIfError(pjrt::PjrtErrorToStatus(error_ptr.get(), c_api_value)); + }, + R"(Registers a partitioner for a custom-call operation. + +Args: + name: custom_call_target to match. + prop_user_sharding: Custom backwards sharding propagation rule. + Takes result sharding and returns the instruction sharding. + partition: Lowering rule. Takes operand and result shardings and returns + a generated HLO and sharding specs. The spmd lowerer first reshards + to match the returned sharding specs and then inserts the generated hlo. + infer_sharding_from_operands: Custom forwards sharding propagation rule. + Takes operand sharding and returns the instruction sharding. + can_side_effecting_have_replicated_sharding: Side effecting ops are not + allowed to have replicated sharding. Pass true to disable this check. + c_api: Optional `PJRT_Api*` if it is called with a plugin. This is safe to + call on plugins that do not implement the custom partitioner extension +)", + nb::arg("name"), nb::arg("prop_user_sharding"), nb::arg("partition"), + nb::arg("infer_sharding_from_operands"), + nb::arg("can_side_effecting_have_replicated_sharding") = false, + nb::arg("c_api").none() = std::nullopt); + m.def("encode_inspect_sharding_callback", + [](nb::object handler) -> nb::bytes { + JAX_InspectSharding_Callback cb; + cb.call = &CallInspectSharding; + cb.data = handler.ptr(); + char bytes[sizeof(JAX_InspectSharding_Callback)]; + std::memcpy(&bytes, &cb, sizeof(JAX_InspectSharding_Callback)); + return nb::bytes(bytes, sizeof(JAX_InspectSharding_Callback)); + }); + + nb::module_ hlo_sharding_util_m = m.def_submodule( + "hlo_sharding_util", "Utilities for manipulating HloSharding."); + hlo_sharding_util_m.def( + "PartiallyReplicateTiledShardingOnDims", + [](const HloSharding& sharding, std::vector dims) { + return hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( + sharding, dims); + }); + + m.def( + "register_custom_call_as_batch_partitionable", + [](std::string target_name, std::optional c_api) { + if (!c_api.has_value()) { + RegisterCustomCallPartitioner( + target_name, std::make_unique()); + return; + } + if (absl::string_view(c_api->name()) != "pjrt_c_api") { + throw absl::InvalidArgumentError( + "Argument to register_custom_call_partitioner was not a " + "pjrt_c_api capsule."); + } + auto* c_api_value = static_cast(c_api->data()); + PJRT_Custom_Partitioner_Extension* extension = + pjrt::FindExtension( + c_api_value, + PJRT_Extension_Type::PJRT_Extension_Type_Custom_Partitioner); + if (extension == nullptr) { + return; + } + PJRT_Register_Batch_Partitionable_Args args; + args.struct_size = PJRT_Register_Batch_Partitionable_Args_STRUCT_SIZE; + args.name = target_name.c_str(); + args.name_size = target_name.size(); + PJRT_Error* error = extension->register_batch_partitionable(&args); + std::unique_ptr error_ptr( + error, pjrt::MakeErrorDeleter(c_api_value)); + ThrowIfError(pjrt::PjrtErrorToStatus(error_ptr.get(), c_api_value)); + }, + R"(Registers a custom call as batch partitionable. + +If a custom call is "batch partitionable", it means that it can be trivially +partitioned on some number of (leading) dimensions, with the same call being +executed independently on each shard of data. If the data are sharded on +non-batch dimensions, partitioning will re-shard the data to be replicated on +the non-batch dimensions. + +Args: + target_name: the target name of the batch partitionable custom call. + c_api: optional `PJRT_Api*` to support registration via a PJRT plugin. +)", + nb::arg("target_name"), nb::arg("c_api").none() = std::nullopt); +} + +} // namespace xla diff --git a/jaxlib/xla/custom_call_sharding.h b/jaxlib/xla/custom_call_sharding.h new file mode 100644 index 000000000000..c3470901f53e --- /dev/null +++ b/jaxlib/xla/custom_call_sharding.h @@ -0,0 +1,28 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_CUSTOM_CALL_SHARDING_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_CUSTOM_CALL_SHARDING_H_ + +// placeholder for index annotation headers +#include "nanobind/nanobind.h" + +namespace xla { + +void BuildCustomCallShardingPybindAPI(nanobind::module_& m); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_CUSTOM_CALL_SHARDING_H_ diff --git a/jaxlib/xla/dlpack.cc b/jaxlib/xla/dlpack.cc new file mode 100644 index 000000000000..f6605a36f02b --- /dev/null +++ b/jaxlib/xla/dlpack.cc @@ -0,0 +1,699 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/dlpack.h" + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" +#include "include/dlpack/dlpack.h" +#include "llvm/Support/Casting.h" +#include "nanobind/nanobind.h" +#include "nanobind/ndarray.h" +#include "xla/layout.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_common.h" +#include "xla/pjrt/pjrt_compiler.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "xla/python/pjrt_ifrt/pjrt_device.h" +#include "xla/python/py_array.h" +#include "xla/python/py_client.h" +#include "xla/python/python_ref_manager.h" +#include "xla/python/traceback.h" +#include "xla/python/types.h" +#include "xla/python/util.h" +#include "xla/shape_util.h" +#include "xla/status_macros.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" + +namespace nb = nanobind; + +namespace xla { +namespace { + +const char* const kDlTensorCapsuleName = "dltensor"; + +struct DLPackTensor { + ~DLPackTensor(); + + // `buffer_reference` is populated if we have shared (read-only) access. + nb::object buffer_reference; + + // `external_reference` is always populated. + std::unique_ptr external_reference; + + std::vector shape; + std::vector strides; + DLManagedTensor tensor; +}; + +DLPackTensor::~DLPackTensor() { + if (buffer_reference) { + GlobalPyRefManager()->AddGarbage( + absl::MakeSpan(&buffer_reference, /*size=*/1)); + } +} + +void DLPackTensorDeleter(DLManagedTensor* t) { + if (t) { + delete static_cast(t->manager_ctx); + } +} + +absl::StatusOr PrimitiveTypeToDLDataType(PrimitiveType type) { + switch (type) { + case S8: + return DLDataType{kDLInt, 8, 1}; + case S16: + return DLDataType{kDLInt, 16, 1}; + case S32: + return DLDataType{kDLInt, 32, 1}; + case S64: + return DLDataType{kDLInt, 64, 1}; + case U8: + return DLDataType{kDLUInt, 8, 1}; + case U16: + return DLDataType{kDLUInt, 16, 1}; + case U32: + return DLDataType{kDLUInt, 32, 1}; + case U64: + return DLDataType{kDLUInt, 64, 1}; + case F4E2M1FN: + return DLDataType{kDLFloat4_e2m1fn, 4, 1}; + case F8E3M4: + return DLDataType{kDLFloat8_e3m4, 8, 1}; + case F8E4M3: + return DLDataType{kDLFloat8_e4m3, 8, 1}; + case F8E4M3B11FNUZ: + return DLDataType{kDLFloat8_e4m3b11fnuz, 8, 1}; + case F8E4M3FN: + return DLDataType{kDLFloat8_e4m3fn, 8, 1}; + case F8E4M3FNUZ: + return DLDataType{kDLFloat8_e4m3fnuz, 8, 1}; + case F8E5M2: + return DLDataType{kDLFloat8_e5m2, 8, 1}; + case F8E5M2FNUZ: + return DLDataType{kDLFloat8_e5m2fnuz, 8, 1}; + case F8E8M0FNU: + return DLDataType{kDLFloat8_e8m0fnu, 8, 1}; + case BF16: + return DLDataType{kDLBfloat, 16, 1}; + case F16: + return DLDataType{kDLFloat, 16, 1}; + case F32: + return DLDataType{kDLFloat, 32, 1}; + case F64: + return DLDataType{kDLFloat, 64, 1}; + case PRED: + return DLDataType{kDLBool, 8, 1}; + case C64: + return DLDataType{kDLComplex, 64, 1}; + case C128: + return DLDataType{kDLComplex, 128, 1}; + default: + return Unimplemented("XLA type %s has no DLPack equivalent", + PrimitiveType_Name(type)); + } +} + +absl::StatusOr DLDataTypeToPrimitiveType(DLDataType type) { + if (type.lanes != 1) { + return Unimplemented("DLPack types with lanes != 1 not implemented, got %d", + type.lanes); + } + switch (type.code) { + case kDLBool: + switch (type.bits) { + case 8: + return PRED; + default: + return Unimplemented( + "Only 8-bit DLPack booleans are supported, got %d bits", + type.bits); + } + case kDLInt: + switch (type.bits) { + case 8: + return S8; + case 16: + return S16; + case 32: + return S32; + case 64: + return S64; + default: + return Unimplemented( + "Invalid or unsupported DLPack integer width: %d bits", + type.bits); + } + case kDLUInt: + switch (type.bits) { + case 8: + return U8; + case 16: + return U16; + case 32: + return U32; + case 64: + return U64; + default: + return Unimplemented( + "Invalid or unsupported DLPack unsigned integer width: %d bits", + type.bits); + } + case kDLFloat4_e2m1fn: + if (type.bits == 4) { + return F4E2M1FN; + } + return Unimplemented( + "Invalid or unsupported DLPack float4_e2m1fn width: %d bits", + type.bits); + case kDLFloat8_e3m4: + if (type.bits == 8) { + return F8E3M4; + } + return Unimplemented( + "Invalid or unsupported DLPack float8_e3m4 width: %d bits", + type.bits); + case kDLFloat8_e4m3: + if (type.bits == 8) { + return F8E4M3; + } + return Unimplemented( + "Invalid or unsupported DLPack float8_e4m3 width: %d bits", + type.bits); + case kDLFloat8_e4m3b11fnuz: + if (type.bits == 8) { + return F8E4M3B11FNUZ; + } + return Unimplemented( + "Invalid or unsupported DLPack float8_e4m3b11fnuz width: %d bits", + type.bits); + case kDLFloat8_e4m3fn: + if (type.bits == 8) { + return F8E4M3FN; + } + return Unimplemented( + "Invalid or unsupported DLPack float8_e4m3fn width: %d bits", + type.bits); + case kDLFloat8_e4m3fnuz: + if (type.bits == 8) { + return F8E4M3FNUZ; + } + return Unimplemented( + "Invalid or unsupported DLPack float8_e4m3fnuz width: %d bits", + type.bits); + case kDLFloat8_e5m2: + if (type.bits == 8) { + return F8E5M2; + } + return Unimplemented( + "Invalid or unsupported DLPack float8_e5m2 width: %d bits", + type.bits); + case kDLFloat8_e5m2fnuz: + if (type.bits == 8) { + return F8E5M2FNUZ; + } + return Unimplemented( + "Invalid or unsupported DLPack float8_e5m2fnuz width: %d bits", + type.bits); + case kDLFloat8_e8m0fnu: + if (type.bits == 8) { + return F8E8M0FNU; + } + return Unimplemented( + "Invalid or unsupported DLPack float8_e8m0fnu width: %d bits", + type.bits); + case kDLBfloat: + if (type.bits == 16) { + return BF16; + } + return Unimplemented( + "Invalid or unsupported DLPack bfloat width: %d bits", type.bits); + case kDLFloat: + switch (type.bits) { + case 16: + return F16; + case 32: + return F32; + case 64: + return F64; + default: + return Unimplemented( + "Invalid or unsupported DLPack float width: %d bits", type.bits); + } + case kDLComplex: + switch (type.bits) { + case 64: + return C64; + case 128: + return C128; + default: + return Unimplemented( + "Invalid or unsupported DLPack complex width: %d bits", + type.bits); + } + default: + return Unimplemented("Unknown or invalid DLPack type code %d", type.code); + } +} + +absl::StatusOr> StridesToLayout( + absl::Span dims, absl::Span strides) { + CHECK_EQ(dims.size(), strides.size()); + std::vector minor_to_major(dims.size()); + std::iota(minor_to_major.begin(), minor_to_major.end(), 0); + absl::c_sort(minor_to_major, [&](int a, int b) { + if (strides[a] < strides[b]) { + return true; + } + if (strides[a] > strides[b]) { + return false; + } + // If two dimensions have the same stride, prefer the major-to-minor + // interpretation of the ordering, since that's what JAX wants. + return b < a; + }); + int64_t stride = 1; + for (int64_t d : minor_to_major) { + if (dims[d] > 1 && strides[d] != stride) { + return Unimplemented( + "Only DLPack tensors with trivial (compact) striding are supported; " + "i.e., tensors whose striding represents a transposition of the " + "underlying buffer but not broadcasting. Dimensions were: [%s], " + "strides were [%s].", + absl::StrJoin(dims, ","), absl::StrJoin(strides, ",")); + } + stride *= dims[d]; + } + return minor_to_major; +} + +absl::StatusOr DLDeviceTypeForDevice(const PjRtDevice& device) { + if (device.client()->platform_id() == CpuId()) { + return kDLCPU; + } else if (device.client()->platform_id() == CudaId()) { + return kDLCUDA; + } else if (device.client()->platform_id() == RocmId()) { + return kDLROCM; + } + return InvalidArgument("Device %s cannot be used as a DLPack device.", + device.DebugString()); +} + +absl::StatusOr DLDeviceForDevice(const PjRtDevice& device) { + DLDevice context; + TF_ASSIGN_OR_RETURN(context.device_type, DLDeviceTypeForDevice(device)); + context.device_id = device.local_hardware_id().value(); + return context; +} + +absl::StatusOr DeviceForDLDevice(const PjRtClient* cpu_client, + const PjRtClient* gpu_client, + const DLDevice& context) { + switch (context.device_type) { + case kDLCPU: + if (cpu_client == nullptr) { + return InvalidArgument( + "DLPack tensor is on CPU, but no CPU backend was provided."); + } + TF_RET_CHECK(cpu_client->platform_id() == CpuId()); + return cpu_client->LookupAddressableDevice( + xla::PjRtLocalDeviceId(context.device_id)); + case kDLCUDA: + if (gpu_client == nullptr) { + return InvalidArgument( + "DLPack tensor is on GPU, but no GPU backend was provided."); + } + TF_RET_CHECK(gpu_client->platform_id() == CudaId()); + return gpu_client->LookupAddressableDevice( + xla::PjRtLocalDeviceId(context.device_id)); + case kDLROCM: + if (gpu_client == nullptr) { + return InvalidArgument( + "DLPack tensor is on GPU, but no GPU backend was provided."); + } + TF_RET_CHECK(gpu_client->platform_id() == RocmId()); + return gpu_client->LookupAddressableDevice( + xla::PjRtLocalDeviceId(context.device_id)); + default: + return InvalidArgument("Unknown/unsupported DLPack device type %d", + context.device_type); + } +} + +absl::Status VerifyDType(const DLTensor& dl_tensor) { + if (dl_tensor.dtype.bits % 8 != 0) { + return InvalidArgument( + "Unsupported DLPack tensor dtype: bits should be a multiple of 8, got " + "%d", + dl_tensor.dtype.bits); + } + + if (dl_tensor.dtype.lanes != 1) { + return InvalidArgument( + "Unsupported DLPack tensor dtype: lanes should be equal to 1, got %d", + dl_tensor.dtype.lanes); + } + + return absl::OkStatus(); +} + +absl::StatusOr> GetByteStrides(const DLTensor& dl_tensor) { + TF_RETURN_IF_ERROR(VerifyDType(dl_tensor)); + + // Convert element strides from the number of elements to the number of bytes. + std::vector strides; + strides.reserve(dl_tensor.ndim); + for (int i = 0; i < dl_tensor.ndim; ++i) { + strides.push_back(dl_tensor.strides[i] * dl_tensor.dtype.bits / 8); + } + return strides; +} + +absl::StatusOr> MakePjrtBuffer( + PjRtDevice& device, ::DLManagedTensor* dlmt, const Shape& shape, + PrimitiveType element_type, absl::Span dimensions, + std::optional stream = std::nullopt) { + std::function on_delete_callback; + if (dlmt->deleter) { + on_delete_callback = [dlmt]() { dlmt->deleter(dlmt); }; + } + + // First try to create a view. + void* data = + static_cast(dlmt->dl_tensor.data) + dlmt->dl_tensor.byte_offset; + auto result = device.client()->CreateViewOfDeviceBuffer( + data, shape, *device.default_memory_space(), on_delete_callback, stream); + + // If that fails with invalid argument, it's possibly because of the incorrect + // alignment. If we're on CPU, we can create a copy of buffer. + if (result.status().code() == absl::StatusCode::kInvalidArgument && + dlmt->dl_tensor.device.device_type == kDLCPU) { + LOG(WARNING) << "DLPack buffer is not aligned (data at: " << data + << "). Creating a copy."; + + // Convert tensor strides (expressed in number of elements) to byte strides. + std::optional> byte_strides; + if (dlmt->dl_tensor.strides) { + TF_ASSIGN_OR_RETURN(byte_strides, GetByteStrides(dlmt->dl_tensor)); + } + + TF_ASSIGN_OR_RETURN(auto* memory_space, device.default_memory_space()); + + // Create a copy. + result = device.client()->BufferFromHostBuffer( + data, element_type, dimensions, byte_strides, + PjRtClient::HostBufferSemantics::kMutableZeroCopy, on_delete_callback, + memory_space, /*device_layout=*/nullptr); + } + return result; +} + +} // namespace + +absl::StatusOr BufferToDLPackManagedTensor( + nb::handle py_buffer, std::optional stream) { + ifrt::Array* ifrt_array = nb::cast(py_buffer).ifrt_array(); + if (ifrt_array == nullptr) { + return Unimplemented( + "BufferToDLPackManagedTensor called on deleted array."); + } + auto* arr = llvm::dyn_cast_or_null(ifrt_array); + if (arr == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + PjRtBuffer* pjrt_buffer = arr->pjrt_buffers().front().get(); + + if (pjrt_buffer->IsTuple()) { + return Unimplemented( + "BufferToDLPackManagedTensor is not implemented for tuple " + "buffers."); + } + if (pjrt_buffer->has_dynamic_dimensions()) { + return Unimplemented("DynamicShape is not implemented in DLPack."); + } + + auto pack = std::make_unique(); + DLTensor& dt = pack->tensor.dl_tensor; + { + // AcquireExternalReference may block; there are no API guarantees. + GlobalPyRefManager()->CollectGarbage(); + nb::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN(pack->external_reference, + pjrt_buffer->AcquireExternalReference()); + if (stream) { + TF_RETURN_IF_ERROR( + pack->external_reference->WaitUntilBufferReadyOnStream(*stream)); + } else { + TF_RETURN_IF_ERROR( + AwaitBuffersReady(absl::MakeConstSpan(&ifrt_array, 1))); + } + } + pack->buffer_reference = nb::borrow(py_buffer); + + dt.data = pack->external_reference->OpaqueDeviceMemoryDataPointer(); + pack->tensor.manager_ctx = pack.get(); + pack->tensor.deleter = DLPackTensorDeleter; + TF_ASSIGN_OR_RETURN(dt.device, DLDeviceForDevice(*pjrt_buffer->device())); + dt.device.device_id = pjrt_buffer->device()->local_hardware_id().value(); + dt.ndim = pjrt_buffer->dimensions().size(); + TF_ASSIGN_OR_RETURN(dt.dtype, + PrimitiveTypeToDLDataType(pjrt_buffer->element_type())); + + pack->shape = std::vector(pjrt_buffer->dimensions().begin(), + pjrt_buffer->dimensions().end()); + + // TODO(b/327524065): use PjRtLayout directly instead of xla::Layout + Layout xla_layout = pjrt_buffer->layout()->xla_layout(); + pack->strides = StridesForShape(pjrt_buffer->element_type(), + pjrt_buffer->dimensions(), xla_layout); + + dt.shape = reinterpret_cast(pack->shape.data()); + dt.strides = reinterpret_cast(pack->strides.data()); + dt.byte_offset = 0; + + // We cannot use nanobind's capsule object constructor because we need to + // detect if the capsule name has been changed in the deleter, but nanobind + // hides the underlying Python object from the deleter. + nb::capsule capsule = nb::steal( + PyCapsule_New(&pack.release()->tensor, kDlTensorCapsuleName, + [](PyObject* obj) noexcept { + DLManagedTensor* dlmt = static_cast( + PyCapsule_GetPointer(obj, kDlTensorCapsuleName)); + if (dlmt) { + DLPackTensorDeleter(dlmt); + } else { + // The tensor has been deleted. Clear any error from + // PyCapsule_GetPointer. + PyErr_Clear(); + } + })); + if (!capsule.ptr()) { + throw nb::python_error(); + } + return capsule; +} + +absl::StatusOr DLPackManagedTensorToBuffer( + const nb::capsule& tensor, std::optional> cpu_client, + std::optional> gpu_client) { + // TODO(hyeontaek): This is a potential target for an IFRT client to multiplex + // multiple PjRt clients. Devices from these PjRt clients could be expressed + // as a unified set of IFRT devices. + auto* cpu_pjrt_client = cpu_client ? (*cpu_client)->pjrt_client() : nullptr; + auto* gpu_pjrt_client = gpu_client ? (*gpu_client)->pjrt_client() : nullptr; + + if (absl::string_view(tensor.name()) != kDlTensorCapsuleName) { + return InvalidArgument( + "DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". " + "Note that a DLPack tensor may be consumed at most once.", + absl::string_view(tensor.name())); + } + DLManagedTensor* dlmt = static_cast(tensor.data()); + if (dlmt->dl_tensor.ndim < 0) { + return InvalidArgument( + "Number of dimensions in DLManagedTensor must be nonnegative, got %d", + dlmt->dl_tensor.ndim); + } + TF_ASSIGN_OR_RETURN(PjRtDevice * device, + DeviceForDLDevice(cpu_client ? cpu_pjrt_client : nullptr, + gpu_client ? gpu_pjrt_client : nullptr, + dlmt->dl_tensor.device)); + absl::Span dimensions( + reinterpret_cast(dlmt->dl_tensor.shape), dlmt->dl_tensor.ndim); + TF_ASSIGN_OR_RETURN(PrimitiveType element_type, + DLDataTypeToPrimitiveType(dlmt->dl_tensor.dtype)); + + std::vector minor_to_major; + if (dlmt->dl_tensor.strides && + absl::c_find(dimensions, 0) == dimensions.end()) { + absl::Span strides( + reinterpret_cast(dlmt->dl_tensor.strides), + dlmt->dl_tensor.ndim); + TF_ASSIGN_OR_RETURN(minor_to_major, StridesToLayout(dimensions, strides)); + } else { + minor_to_major.resize(dlmt->dl_tensor.ndim); + std::iota(minor_to_major.rbegin(), minor_to_major.rend(), 0); + } + Shape shape = ShapeUtil::MakeShapeWithDenseLayout(element_type, dimensions, + minor_to_major); + + // Raise an error if the resulting PjRtBuffer would have a non-default layout. + // TODO(skyewm): we do this because JAX doesn't currently have good support + // for non-default layouts, and will return wrong results if a non-default + // layout is passed to a computation expecting default layouts. Remove this + // special case when non-default layouts are better supported by JAX. + TF_ASSIGN_OR_RETURN(Layout default_layout, device->client()->GetDefaultLayout( + element_type, dimensions)); + if (shape.layout() != default_layout) { + return Unimplemented( + "from_dlpack got array with non-default layout with minor-to-major " + "dimensions (%s), expected (%s)", + absl::StrJoin(shape.layout().minor_to_major(), ","), + absl::StrJoin(default_layout.minor_to_major(), ",")); + } + + std::function on_delete_callback; + if (dlmt->deleter) { + on_delete_callback = [dlmt]() { dlmt->deleter(dlmt); }; + } + + TF_ASSIGN_OR_RETURN( + auto pjrt_buffer, + MakePjrtBuffer(*device, dlmt, shape, element_type, dimensions)); + + // We have taken ownership of the array inside the capsule; make sure the + // capsule it cannot be used again. + PyCapsule_SetName(tensor.ptr(), "used_dltensor"); + PyCapsule_SetDestructor(tensor.ptr(), nullptr); + // TODO(phawkins): simplify the expression below once we know cpu_client is + // always non-null. + auto client = (cpu_client && device->client() == cpu_pjrt_client) + ? std::move(*cpu_client) + : std::move(*gpu_client); + auto* ifrt_client = + llvm::dyn_cast_or_null(client->ifrt_client()); + if (ifrt_client == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + TF_ASSIGN_OR_RETURN(auto ifrt_array, + ifrt_client->CreatePjRtArray(std::move(pjrt_buffer))); + return PyArray::MakeFromSingleDeviceArray(std::move(client), Traceback::Get(), + std::move(ifrt_array), false, true); +} + +absl::StatusOr DLPackManagedTensorToBuffer( + const nb::capsule& tensor, ifrt::Device* ifrt_device, + nb_class_ptr client, std::optional stream) { + ifrt::PjRtDevice* device = + llvm::dyn_cast_or_null(ifrt_device); + if (device == nullptr) { + throw XlaRuntimeError( + "DLPack is supported for PjRt-compatible backends only."); + } + if (!device->IsAddressable()) { + throw XlaRuntimeError( + "DLPack is only supported for devices addressable by the current " + "process."); + } + if (absl::string_view(tensor.name()) != kDlTensorCapsuleName) { + return InvalidArgument( + "DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". " + "Note that a DLPack tensor may be consumed at most once.", + absl::string_view(tensor.name())); + } + DLManagedTensor* dlmt = static_cast(tensor.data()); + if (dlmt->dl_tensor.ndim < 0) { + return InvalidArgument( + "Number of dimensions in DLManagedTensor must be nonnegative, got %d", + dlmt->dl_tensor.ndim); + } + absl::Span dimensions( + reinterpret_cast(dlmt->dl_tensor.shape), dlmt->dl_tensor.ndim); + TF_ASSIGN_OR_RETURN(PrimitiveType element_type, + DLDataTypeToPrimitiveType(dlmt->dl_tensor.dtype)); + + std::vector minor_to_major; + if (dlmt->dl_tensor.strides && + absl::c_find(dimensions, 0) == dimensions.end()) { + absl::Span strides( + reinterpret_cast(dlmt->dl_tensor.strides), + dlmt->dl_tensor.ndim); + TF_ASSIGN_OR_RETURN(minor_to_major, StridesToLayout(dimensions, strides)); + } else { + minor_to_major.resize(dlmt->dl_tensor.ndim); + std::iota(minor_to_major.rbegin(), minor_to_major.rend(), 0); + } + Shape shape = ShapeUtil::MakeShapeWithDenseLayout(element_type, dimensions, + minor_to_major); + + TF_ASSIGN_OR_RETURN(auto pjrt_buffer, + MakePjrtBuffer(*device->pjrt_device(), dlmt, shape, + element_type, dimensions, stream)); + + // We have taken ownership of the array inside the capsule; make sure the + // capsule it cannot be used again. + PyCapsule_SetName(tensor.ptr(), "used_dltensor"); + PyCapsule_SetDestructor(tensor.ptr(), nullptr); + + auto* ifrt_client = + llvm::dyn_cast_or_null(client->ifrt_client()); + if (ifrt_client == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + TF_ASSIGN_OR_RETURN(auto ifrt_array, + ifrt_client->CreatePjRtArray(std::move(pjrt_buffer))); + return PyArray::MakeFromSingleDeviceArray(std::move(client), Traceback::Get(), + std::move(ifrt_array), false, true); +} + +absl::StatusOr PrimitiveTypeToNbDLDataType( + PrimitiveType type) { + TF_ASSIGN_OR_RETURN(DLDataType dl_type, PrimitiveTypeToDLDataType(type)); + + nanobind::dlpack::dtype nb_type; + nb_type.lanes = dl_type.lanes; + nb_type.bits = dl_type.bits; + nb_type.code = dl_type.code; + + return nb_type; +} + +} // namespace xla diff --git a/jaxlib/xla/dlpack.h b/jaxlib/xla/dlpack.h new file mode 100644 index 000000000000..5d7fd7c10bf8 --- /dev/null +++ b/jaxlib/xla/dlpack.h @@ -0,0 +1,57 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DLPACK_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_DLPACK_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "nanobind/nanobind.h" +#include "nanobind/ndarray.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/python/py_client.h" + +namespace xla { + +// If take_ownership is true, ownership of the buffer is handed to DLPack, and +// the receiver may mutate the buffer as they see fit. Otherwise PjRt retains +// ownership of the buffer and it should be immutable. +// +// stream, if set, is a GPU stream, e.g. cudaStream_t for CUDA GPUs, that should +// be synchronized to the buffer as per +// https://dmlc.github.io/dlpack/latest/python_spec.html#python-specification-for-dlpack. +absl::StatusOr BufferToDLPackManagedTensor( + nanobind::handle buffer, std::optional stream); + +absl::StatusOr DLPackManagedTensorToBuffer( + const nanobind::capsule& tensor, + std::optional> cpu_client, + std::optional> gpu_client); + +absl::StatusOr DLPackManagedTensorToBuffer( + const nanobind::capsule& tensor, ifrt::Device* device, + nb_class_ptr client, std::optional stream); + +// Converts a PrimitiveType to the nanobind specific implementation of +// DLDataType. +absl::StatusOr PrimitiveTypeToNbDLDataType( + PrimitiveType type); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DLPACK_H_ diff --git a/jaxlib/xla/jax_jit.cc b/jaxlib/xla/jax_jit.cc new file mode 100644 index 000000000000..754272a078ed --- /dev/null +++ b/jaxlib/xla/jax_jit.cc @@ -0,0 +1,495 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This files implements the `jax.jit` dispatch and just-in-time feature. +// +// In a nutshell, `Jit(f)` returns a callable that will dispatch (i.e. forward +// based on passed arguments dtypes/shapes/identity) the execution to a +// just-in-time compiled XLA Executable. All of that is done in C++ for +// performance reasons. +// +// This file contains the utilities to: +// (a) inspect arguments and describe their structure, dtype/shapes, etc. +// (b) keep a mapping from function signatures to compiled XLA Executables. + +#include "jaxlib/xla/jax_jit.h" + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/attributes.h" +#include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/pair.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/nb_absl_inlined_vector.h" // IWYU pragma: keep +#include "xla/python/nb_absl_span.h" // IWYU pragma: keep +#include "xla/python/py_values.h" +#include "xla/python/pytree.h" +#include "xla/python/sharding.h" +#include "xla/python/types.h" +#include "xla/tsl/platform/logging.h" +#include "tsl/profiler/lib/traceme.h" + +namespace jax { + +namespace nb = nanobind; + +// TODO(phawkins): Add support for Tracers. +// TODO(jblespiau): Use absl absl::Status. + +namespace { + +// `thread_local_state.extra_jit_context` is set from Python. It's done when +// loading the Python jax modules on the main-thread. For other threads, we +// need to initialize the field the first time we access `thread_local_state`. +nb::object& initialize_local_state = *new nb::object(); + +} // namespace + +JitState& GlobalJitState() { + // Protected by the GIL. + static JitState& global_state = *new JitState(); + return global_state; +} + +JitState& ThreadLocalJitState() { + // TODO(phawkins): Google style guide forbids thread-local values with + // non-trivial destructors. + ABSL_CONST_INIT thread_local JitState thread_local_state; // NOLINT + DCHECK(PyGILState_Check()); + if (thread_local_state.extra_jit_context == std::nullopt) { + CHECK(initialize_local_state.ptr() != nullptr); + // Avoids reentrant calls to the initialization function. + thread_local_state.extra_jit_context = nb::none(); + initialize_local_state(); + } + return thread_local_state; +} + +bool GetDisableJit() { + auto& global_state = GlobalJitState(); + auto& thread_local_state = ThreadLocalJitState(); + CHECK(global_state.disable_jit.has_value()); + return thread_local_state.disable_jit.value_or(*global_state.disable_jit); +} + +bool GetEnableX64() { + auto& global_state = GlobalJitState(); + auto& thread_local_state = ThreadLocalJitState(); + CHECK(global_state.enable_x64.has_value()); + return thread_local_state.enable_x64.value_or(*global_state.enable_x64); +} + +std::optional GetDefaultDevice() { + auto& global_state = GlobalJitState(); + auto& thread_local_state = ThreadLocalJitState(); + return thread_local_state.default_device.has_value() + ? thread_local_state.default_device + : global_state.default_device; +} + +std::optional GetPostHook() { + auto& global_state = GlobalJitState(); + auto& thread_local_state = ThreadLocalJitState(); + return thread_local_state.post_hook.has_value() ? thread_local_state.post_hook + : global_state.post_hook; +} + +static std::string OptionalDebugString( + const std::optional optional) { + if (optional.has_value()) { + return nb::cast(nb::str(optional.value())); + } else { + return "None"; + } +} + +std::string ArgumentSignature::DebugString() const { + auto py_object_formatter = [](std::string* out, const nb::object& o) { + out->append(nb::cast(nb::str(o))); + }; + auto treedef_formatter = [](std::string* out, const xla::PyTreeDef& d) { + out->append(d.ToString()); + }; + return absl::StrFormat( + "static args (positional + keyword): [%s], " + "static arg keyword names: [%s], " + "dynamic arg signatures (positional + keyword): [%s]" + "dynamic arg shardings: [%s]", + absl::StrJoin(static_args, ",", py_object_formatter), + absl::StrJoin(static_arg_names, ",", py_object_formatter), + absl::StrJoin(dynamic_arg_names, ",", py_object_formatter), + absl::StrJoin(dynamic_arg_treedefs, "| ", treedef_formatter)); +} + +bool ArgumentSignature::operator==(const ArgumentSignature& other) const { + if (dynamic_arg_treedefs != other.dynamic_arg_treedefs) { + return false; + } + auto object_ptr_equality = [](nb::handle a, nb::handle b) { + return a.ptr() == b.ptr(); + }; + if (!absl::c_equal(dynamic_arg_names, other.dynamic_arg_names, + object_ptr_equality)) { + return false; + } + if (!absl::c_equal(static_arg_names, other.static_arg_names, + object_ptr_equality)) { + return false; + } + return absl::c_equal( + static_args, other.static_args, + [](const nb::object& a, const nb::object& b) { + try { + return a.type().ptr() == b.type().ptr() && a.equal(b); + } catch (const nb::python_error& e) { + throw std::invalid_argument(absl::StrCat( + "static arguments should be comparable using __eq__." + "The following error was raised when comparing two objects of " + "types ", + nb::cast(nb::str(a.type())), " and ", + nb::cast(nb::str(b.type())), + ". The error was:\n", e.what())); + } + }); +} + +std::string CallSignature::DebugString() const { + auto py_object_formatter = [](std::string* out, const nb::object& o) { + out->append(nb::cast(nb::str(o))); + }; + auto signature_formatter = [](std::string* out, + const xla::PyArgSignature& s) { + out->append(s.DebugString()); + }; + auto layout_formatter = [](std::string* out, + const std::shared_ptr& l) { + if (l != nullptr) { + out->append(l->ToString()); + } else { + out->append("None"); + } + }; + auto bool_formatter = [](std::string* out, bool o) { + out->append(o ? "true" : "false"); + }; + return absl::StrFormat( + "arg signature: %s\n" + "dynamic arg signatures (positional + keyword): %s\n" + "dynamic arg shardings: %s\n" + "dynamic arg layouts: %s\n" + "committed args: %s\n" + "device: %s\n" + "default_device: %s\n" + "jax_enable_x64: %d\n" + "global_extra_jit_context: %s\n" + "thread_local_extra_jit_context: %s\n" + "configs: %s\n", + arg_signature.DebugString(), + absl::StrJoin(dynamic_arg_signatures, ", ", signature_formatter), + absl::StrJoin(dynamic_arg_shardings, ", ", py_object_formatter), + absl::StrJoin(dynamic_arg_layouts, ", ", layout_formatter), + absl::StrJoin(committed_args, ",", bool_formatter), + device != nullptr ? device->DebugString() : "nullptr", + OptionalDebugString(default_device), jax_enable_x64, + OptionalDebugString(global_extra_jit_context), + OptionalDebugString(thread_local_extra_jit_context), + absl::StrJoin(configs, ", ", py_object_formatter)); +} + +bool CallSignature::operator==(const CallSignature& other) const { + if (arg_signature != other.arg_signature) { + return false; + } + if (dynamic_arg_signatures != other.dynamic_arg_signatures) { + return false; + } + if (device != other.device) { + return false; + } + if (jax_enable_x64 != other.jax_enable_x64) { + return false; + } + if (committed_args != other.committed_args) { + return false; + } + return + // `==` on py:objects is the Python `is`. We need equal. + absl::c_equal(dynamic_arg_shardings, other.dynamic_arg_shardings, + ShardingEqual) && + absl::c_equal(dynamic_arg_layouts, other.dynamic_arg_layouts, + [](const std::shared_ptr& a, + const std::shared_ptr& b) { + return (a && b) ? *a == *b : a == b; + }) && + (global_extra_jit_context.has_value() == + other.global_extra_jit_context.has_value()) && + (!global_extra_jit_context.has_value() || + global_extra_jit_context->equal(*other.global_extra_jit_context)) && + (default_device.has_value() == other.default_device.has_value()) && + (!default_device.has_value() || + default_device->equal(*other.default_device)) && + (thread_local_extra_jit_context.has_value() == + other.thread_local_extra_jit_context.has_value()) && + (!thread_local_extra_jit_context.has_value() || + thread_local_extra_jit_context->equal( + *other.thread_local_extra_jit_context)) && + configs.size() == other.configs.size() && + absl::c_equal( + configs, other.configs, + [](const nb::object& a, const nb::object& b) { return a.equal(b); }); +} + +// Filter out static arguments, flatten and concatenate other arguments (i.e. +// dynamic positional and keyword arguments), filling `arguments` in place. +absl::Status ParseArguments( + absl::Span positional_args, + absl::Span keyword_args, nb::handle kwnames, + absl::Span static_argnums, + absl::Span static_argnames, + xla::PyTreeRegistry* pytree_registry, ArgumentSignature& signature, + absl::InlinedVector& flat_dynamic_args) { + tsl::profiler::TraceMe traceme("ParseArguments"); + + DCHECK(absl::c_all_of(static_argnames, [](const nb::str& name) { + return PyUnicode_CHECK_INTERNED(name.ptr()); + })); + + flat_dynamic_args.reserve(positional_args.size() + keyword_args.size()); + if (static_argnums.empty()) { + signature.dynamic_arg_treedefs.reserve(positional_args.size()); + + // Positional arguments. + for (int i = 0; i < positional_args.size(); ++i) { + signature.dynamic_arg_treedefs.emplace_back(pytree_registry); + xla::PyTreeDef& pytree_def = signature.dynamic_arg_treedefs.back(); + pytree_def.Flatten(nb::handle(positional_args[i]), flat_dynamic_args); + } + } else { + signature.dynamic_arg_treedefs.reserve(positional_args.size()); + + // Positional arguments. + int num_positional_args = positional_args.size(); + for (int i = 0; i < positional_args.size(); ++i) { + if (std::find_if(static_argnums.begin(), static_argnums.end(), + [i, num_positional_args](int t) { + return t >= 0 ? i == t : i == t + num_positional_args; + }) == static_argnums.end()) { + signature.dynamic_arg_treedefs.emplace_back(pytree_registry); + xla::PyTreeDef& pytree_def = signature.dynamic_arg_treedefs.back(); + pytree_def.Flatten(positional_args[i], flat_dynamic_args); + } else { + signature.static_args.emplace_back( + nb::borrow(positional_args[i])); + } + } + } + + // Keyword arguments. + if (!keyword_args.empty()) { + std::vector> kwargs(keyword_args.size()); + // We first intern the keys, then sort them (by name, as in the Python path) + // (see also xla::PyTreeDef::Flatten) and then create the signatures. + // TODO(jblespiau): We should be able to sort the keys by interned-key + // pointers, but this requires the Python compilation to do the same. + for (int i = 0; i < keyword_args.size(); ++i) { + // Intern the key if not already interned. + PyObject* key = PyTuple_GET_ITEM(kwnames.ptr(), i); + Py_INCREF(key); + if (!PyUnicode_CHECK_INTERNED(key)) { + PyUnicode_InternInPlace(&key); + } + kwargs[i].first = key; + kwargs[i].second = keyword_args[i]; + } + + std::sort(kwargs.begin(), kwargs.end(), + [](const std::pair& a, + const std::pair& b) { + return a.first < b.first; + }); + auto kwarg_is_static = [&](nb::handle name) { + for (const auto& kw : static_argnames) { + if (kw.ptr() == name.ptr()) return true; + } + return false; + }; + + signature.dynamic_arg_names.reserve(keyword_args.size()); + for (int i = 0; i < keyword_args.size(); ++i) { + if (kwarg_is_static(kwargs[i].first)) { + signature.static_arg_names.push_back( + nb::steal(kwargs[i].first)); + signature.static_args.push_back( + nb::borrow(kwargs[i].second)); + } else { + signature.dynamic_arg_names.push_back( + nb::steal(kwargs[i].first)); + signature.dynamic_arg_treedefs.emplace_back(pytree_registry); + xla::PyTreeDef& pytree_def = signature.dynamic_arg_treedefs.back(); + pytree_def.Flatten(nb::handle(kwargs[i].second.ptr()), + flat_dynamic_args); + } + } + } + return absl::OkStatus(); +} + +void BuildJaxjitSubmodule(nb::module_& m) { + nb::module_ jitlib = m.def_submodule("jax_jit", "Jax C++ jit library"); + + nb::class_ jit_state_(jitlib, "JitState"); + jit_state_.def_rw("disable_jit", &JitState::disable_jit, nb::arg().none()); + jit_state_.def_rw("enable_x64", &JitState::enable_x64, nb::arg().none()); + jit_state_.def_rw("default_device", &JitState::default_device, + nb::arg().none()); + jit_state_.def_rw("extra_jit_context", &JitState::extra_jit_context, + nb::arg().none()); + jit_state_.def_rw("post_hook", &JitState::post_hook, nb::arg().none()); + + jitlib.def( + "global_state", [&]() { return &GlobalJitState(); }, + nb::rv_policy::reference); + jitlib.def( + "thread_local_state", [&]() { return &ThreadLocalJitState(); }, + nb::rv_policy::reference); + + jitlib.def( + "swap_thread_local_state_disable_jit", + [&](std::optional value) -> std::optional { + auto tls = &ThreadLocalJitState(); + auto result = tls->disable_jit; + tls->disable_jit = value; + return result; + }, + nb::arg("value").none(), nb::rv_policy::reference); + + jitlib.def("get_enable_x64", &GetEnableX64); + jitlib.def("set_thread_local_state_initialization_callback", + [](nb::object f) { initialize_local_state = f; }); + + nb::class_ arg_signature(jitlib, "PyArgSignature"); + arg_signature + .def_prop_ro( + "dtype", + [](const xla::PyArgSignature& sig) { + return xla::ValueOrThrow(xla::PrimitiveTypeToNbDtype(sig.dtype)); + }) + .def_prop_ro("shape", + [](const xla::PyArgSignature& sig) { + return xla::SpanToNbTuple(absl::MakeConstSpan(sig.shape)); + }) + .def_ro("weak_type", &xla::PyArgSignature::weak_type); + jitlib.def("_ArgSignatureOfValue", + xla::ValueOrThrowWrapper(xla::PyArgSignatureOfValue)); + + jitlib.def("_is_float0", &xla::IsFloat0); + + nb::class_ argument_signature(jitlib, "ArgumentSignature"); + argument_signature.def_ro("static_args", &ArgumentSignature::static_args) + .def_ro("static_arg_names", &ArgumentSignature::static_arg_names) + .def_ro("dynamic_arg_names", &ArgumentSignature::dynamic_arg_names) + .def_ro("dynamic_arg_treedefs", &ArgumentSignature::dynamic_arg_treedefs) + .def("__repr__", &ArgumentSignature::DebugString) + .def("__str__", &ArgumentSignature::DebugString) + .def("__hash__", + [](const ArgumentSignature& s) { return absl::HashOf(s); }) + .def("__eq__", [](const ArgumentSignature& a, + const ArgumentSignature& b) { return a == b; }) + .def("__ne__", [](const ArgumentSignature& a, + const ArgumentSignature& b) { return a != b; }); + + jitlib.def( + "parse_arguments", + [](nb::sequence positional_args, nb::sequence keyword_args, + nb::tuple kwnames, absl::Span static_argnums, + absl::Span static_argnames, + xla::PyTreeRegistry* pytree_registry) { + ArgumentSignature signature; + absl::InlinedVector flat_dynamic_args; + nb::object positional_args_seq = nb::steal(PySequence_Fast( + positional_args.ptr(), "positional_args must be a list or tuple")); + if (!positional_args_seq.ptr()) { + throw nb::python_error(); + } + nb::object keyword_args_seq = nb::steal(PySequence_Fast( + keyword_args.ptr(), "keyword_args must be a list or tuple")); + if (!keyword_args_seq.ptr()) { + throw nb::python_error(); + } + absl::Span positional_args_span = + absl::MakeSpan(PySequence_Fast_ITEMS(positional_args_seq.ptr()), + PySequence_Fast_GET_SIZE(positional_args_seq.ptr())); + absl::Span keyword_args_span = + absl::MakeSpan(PySequence_Fast_ITEMS(keyword_args_seq.ptr()), + PySequence_Fast_GET_SIZE(keyword_args_seq.ptr())); + + // Intern the static argument names. + std::vector static_argnames_interned; + static_argnames_interned.reserve(static_argnames.size()); + for (const nb::str& name : static_argnames) { + PyObject* s = name.inc_ref().ptr(); + PyUnicode_InternInPlace(&s); + static_argnames_interned.push_back(nb::steal(s)); + } + + xla::ThrowIfError( + ParseArguments(positional_args_span, keyword_args_span, kwnames, + static_argnums, static_argnames_interned, + pytree_registry, signature, flat_dynamic_args)); + return std::make_pair(std::move(signature), + std::move(flat_dynamic_args)); + }, + nb::arg("positional_args"), nb::arg("keyword_args"), nb::arg("kwnames"), + nb::arg("static_argnums"), nb::arg("static_argnames"), + nb::arg("pytree_registry"), + R"doc(Parses the arguments to a function as jax.jit would. + +Returns a ArgumentSignature and the flattened dynamic arguments. + +Args: + positional_args: The positional arguments. + keyword_args: The keyword arguments. + kwnames: The keyword names. + static_argnums: The static argument numbers. + static_argnames: The static argument names. + pytree_registry: The pytree registry. +)doc"); +} + +} // namespace jax diff --git a/jaxlib/xla/jax_jit.h b/jaxlib/xla/jax_jit.h new file mode 100644 index 000000000000..303d7e69414d --- /dev/null +++ b/jaxlib/xla/jax_jit.h @@ -0,0 +1,265 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_JAX_JIT_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_JAX_JIT_H_ + +#include + +#include +#include +#include +#include +#include +#include +#include + +// placeholder for index annotation headers +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/python/nb_helpers.h" +#include "xla/python/py_values.h" +#include "xla/python/python_ref_manager.h" +#include "xla/python/pytree.h" +#include "xla/python/sharding.h" +#include "xla/tsl/platform/logging.h" + +namespace jax { + +// Flags, such as JIT disable and the x64 mode, are controlled by: +// - a global flag value, e.g., associated to --jax_enable_x64 +// - possibly a thread-local value, which initially is std::nullopt and +// overrides the global value if set. The thread-local state is +// used to implement context managers that locally override the global state. +struct JitState { + ~JitState() { + if (extra_jit_context) { + // We likely do not hold the GIL if this JitState is thread-local, so we + // hand the Python object to the global reference manager to destroy. + nanobind::object o = std::move(*extra_jit_context); + xla::GlobalPyRefManager()->AddGarbage(absl::MakeSpan(&o, 1)); + extra_jit_context = std::nullopt; + } + } + + std::optional disable_jit; + std::optional enable_x64; + + // Used to manually set the default device jax should use. May be unset even + // in global state, indicating there is no manual override. + // TODO(skyewm): make this a C++ type when all JAX backends support a single + // C++ device interface + std::optional default_device; + + // Extra context that should be included in the JIT cache key. Must be + // hashable and have an equality defined. + std::optional extra_jit_context; + + // A callback that, if present, is called when a JITted function is executed + // from cache. May be unset even in global state. + std::optional post_hook; +}; + +JitState& GlobalJitState(); + +// Requires the GIL. +JitState& ThreadLocalJitState(); + +// Getters for JitState fields that first look in thread-local state, then +// fallback to global state. +bool GetDisableJit(); +bool GetEnableX64(); + +// TODO(skyewm): return a C++ type when all JAX backends support a single C++ +// device interface +std::optional GetDefaultDevice(); +std::optional GetPostHook(); + +// An ArgumentSignature describes the static arguments to a function call, and +// how the dynamic arguments are related to the arguments. Together with the +// values of the dynamic arguments, this fully describes the arguments. +struct ArgumentSignature { + // A PyTreeDef for each dynamic argument, positional arguments first + // followed by keyword arguments. Keyword arguments are in the order given + // by dynamic_arg_names. + absl::InlinedVector dynamic_arg_treedefs; + + // Dynamic keyword argument names. Interned, and sorted by the keyword + // name. Interned values are safe to compare by pointer. + std::vector dynamic_arg_names; + + // Static arguments. Contains the positional arguments sorted in argument + // order, followed by static keyword arguments in the order given by + // `static_arg_names`. + std::vector static_args; + + // Static keyword argument names. Interned, and sorted by keyword name. + std::vector static_arg_names; + + bool operator==(const ArgumentSignature& other) const; + bool operator!=(const ArgumentSignature& other) const { + return !(*this == other); + } + + std::string DebugString() const; +}; + +template +H AbslHashValue(H h, const ArgumentSignature& s) { + h = H::combine(std::move(h), s.dynamic_arg_treedefs, + s.dynamic_arg_names.size(), s.static_args.size(), + s.static_arg_names.size()); + + for (const auto& name : s.dynamic_arg_names) { + h = H::combine(std::move(h), name.ptr()); + } + for (size_t i = 0; i < s.static_args.size(); ++i) { + const auto& static_arg = s.static_args[i]; + Py_hash_t hash; + try { + hash = nanobind::hash(static_arg); + } catch (const nanobind::python_error& e) { + if (!e.matches(PyExc_TypeError)) throw; + throw std::invalid_argument(absl::StrCat( + "Non-hashable static arguments are not supported. An error occurred " + "while trying to hash an object of type ", + nanobind::cast(nanobind::str(static_arg.type())), + ", ", nanobind::cast(nanobind::str(static_arg)), + ". The error was:\n", e.what(), "\n")); + } + h = H::combine(std::move(h), hash); + } + for (const auto& name : s.static_arg_names) { + h = H::combine(std::move(h), name.ptr()); + } + return h; +} + +// Filter out static arguments, flatten and concatenate other arguments (i.e. +// dynamic positional and keyword arguments), filling `arguments` in place. +// Args: +// positional_args: positional arguments +// keyword_args: the values of the keyword arguments +// kwnames: either None or a tuple containing the keyword argument names +// static_argnums: the indices of the static arguments in the positional +// arguments +// static_argnames: the names of the static arguments, which must be interned. +// pytree_registry: the registry to use to convert the arguments to pytrees +// signature: output; describes the static arguments and the identities of the +// dynamic arguments. +// flat_dynamic_args: output; the concatenation of the dynamic positional +// arguments and sorted keyword arguments. +absl::Status ParseArguments( + absl::Span positional_args, + absl::Span keyword_args, nanobind::handle kwnames, + absl::Span static_argnums, + absl::Span static_argnames, + xla::PyTreeRegistry* pytree_registry, ArgumentSignature& signature, + absl::InlinedVector& flat_dynamic_args); + +// The signature of Python jitted function call, partitioned into: +// - dynamic positional arguments (i.e. positional args which are not static) +// - static positional arguments (i.e. the args associated to static_argnums) +// - keyword arguments +// The CallSignature should unambiguously identify a function call, thus, +// equality is based on: +// (a) Same PyTree for all dynamic positional arguments and keyword arguments +// (a) equality of the arguments and keyword arguments ArgSignature +// (a) equality (delegated to Python) of the static arguments. +struct CallSignature { + // Not part of the signature, but we need it for error messages. + absl::string_view function_name; + + ArgumentSignature arg_signature; + + // Shape and dtype for both the dynamic positional arguments and the keyword + // arguments (sorted by keyword name). + absl::InlinedVector dynamic_arg_signatures; + + // The sharding of the jax.Array arguments. + std::vector dynamic_arg_shardings; + + // The layout of the jax.Array arguments. + std::vector> dynamic_arg_layouts; + + absl::InlinedVector committed_args; + + // For JIT, we need this in the key because computation follows the data, so + // we may have multiple executables depending on the devices the data is on. + // This is not the case for PMAP, and is set to `nullptr`. + xla::PjRtDevice* device = nullptr; + bool jax_enable_x64; + + // For JIT on PJIT, we need to fallback to python whenever default_device + // changes. + std::optional default_device; + + // Opaque additional context that should be included as part of the cache key. + std::optional global_extra_jit_context; + std::optional thread_local_extra_jit_context; + + std::vector configs; + + bool operator==(const CallSignature& other) const; + bool operator!=(const CallSignature& other) const { + return !(*this == other); + } + + std::string DebugString() const; +}; + +template +H AbslHashValue(H h, const CallSignature& s) { + h = H::combine(std::move(h), s.arg_signature, s.dynamic_arg_signatures); + + DCHECK(s.dynamic_arg_shardings.empty() || + s.dynamic_arg_shardings.size() == s.dynamic_arg_signatures.size()); + + DCHECK(s.dynamic_arg_layouts.empty() || + s.dynamic_arg_layouts.size() == s.dynamic_arg_signatures.size()); + + // TODO(chky): For now, we are only hashing the pointer of shardings to avoid + // slow python hashing function. Consider implementing hashing function and + // equality checks in C++ in jax::Sharding and use those here. + for (const auto& sharding : s.dynamic_arg_shardings) { + h = H::combine(std::move(h), ShardingHash(sharding)); + } + + for (const auto& layout : s.dynamic_arg_layouts) { + if (layout != nullptr) { + h = H::combine(std::move(h), *layout); + } + } + + h = H::combine(std::move(h), s.committed_args, s.device, s.jax_enable_x64); + + // We do not hash the extra_jit_context fields since calling Python hash + // functions is expensive (~300ns) and we don't expect a large number of + // different contexts. + return h; +} + +// The function to call in `xla.cc` to add the bindings for this module. +void BuildJaxjitSubmodule(nanobind::module_& m); + +} // namespace jax + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_JAX_JIT_H_ diff --git a/jaxlib/xla/mlir.cc b/jaxlib/xla/mlir.cc new file mode 100644 index 000000000000..5905c6c6ec8d --- /dev/null +++ b/jaxlib/xla/mlir.cc @@ -0,0 +1,251 @@ +/* Copyright 2021 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/mlir.h" + +#include + +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Bytecode/BytecodeWriter.h" +#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" +#include "mlir/Dialect/Func/Extensions/AllExtensions.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/OwningOpRef.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "stablehlo/dialect/Serialization.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/translate/stablehlo.h" +#include "xla/mlir_hlo/mhlo/transforms/passes.h" +#include "xla/pjrt/mlir_to_hlo.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/refine_polymorphic_shapes.h" +#include "xla/service/llvm_ir/llvm_util.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" + +namespace nb = nanobind; + +namespace xla { +namespace { + +std::string PrintModule(mlir::ModuleOp module) { + std::string s; + llvm::raw_string_ostream os(s); + mlir::OpPrintingFlags flags; + flags.enableDebugInfo(); + module->print(os, flags); + return s; +} + +absl::StatusOr SerializeUsingBytecode(mlir::ModuleOp module) { + std::string bytecode; + llvm::raw_string_ostream os(bytecode); + mlir::BytecodeWriterConfig config; + if (mlir::failed(mlir::writeBytecodeToFile(module, os, config))) { + return absl::InvalidArgumentError("mlir::writeBytecodeToFile failed"); + } + return bytecode; +} + +void EnablePrintBeforeAndAfter(mlir::PassManager& pm) { + auto print_before = [](mlir::Pass*, mlir::Operation*) { return true; }; + auto print_after = [](mlir::Pass*, mlir::Operation*) { return true; }; + pm.enableIRPrinting(print_before, print_after); +} + +// Converts an XlaComputation to a StableHLO mlir::Module string. +// Exists for backwards compatibility. +// TODO(phawkins): port remaining users of XlaComputations to use mlir::Modules +// instead and delete this function. +absl::StatusOr PyXlaComputationToMlirModule( + const XlaComputation& computation) { + mlir::MLIRContext context; + if (VLOG_IS_ON(3)) context.disableMultithreading(); + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ConvertHloToStablehlo(context, &computation.proto())); + return PrintModule(*module); +} + +absl::StatusOr PyMlirModuleToXlaComputation( + absl::string_view mlir_module, bool use_tuple_args, bool return_tuple) { + mlir::MLIRContext context; + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ParseMlirModuleString(mlir_module, context)); + XlaComputation computation; + // SDY dialect may be part of the module which XLA doesn't know about. + TF_RETURN_IF_ERROR(ExportShardyForHloRoundTrip(*module)); + TF_RETURN_IF_ERROR(MlirToXlaComputation(*module, computation, use_tuple_args, + return_tuple, + /*use_shardy=*/false)); + return computation; +} + +absl::StatusOr PyMhloToStablehlo(absl::string_view mlir_module) { + mlir::MLIRContext context; + if (VLOG_IS_ON(3)) context.disableMultithreading(); + // JAX can be customized in a way that involves operations from custom + // dialects showing up in JAX IR. + // `ParseMlirModuleString` won't know about these dialects, but that's fine + // since we just want to convert MHLO ops to StableHLO ops here and leave + // everything else unchanged. + // In order to achieve that, we're allowing unregistered dialects here. + context.allowUnregisteredDialects(true); + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ParseMlirModuleString(mlir_module, context)); + mlir::PassManager pm(&context); + if (VLOG_IS_ON(3)) EnablePrintBeforeAndAfter(pm); + pm.addPass(mlir::mhlo::createHloLegalizeToStablehloPass()); + if (!mlir::succeeded(pm.run(*module))) { + return tsl::errors::InvalidArgument("MHLO => StableHLO failed"); + } + // Use bytecode, passing unregistered dialects with properties causes issues + // when using textual assembly. + TF_ASSIGN_OR_RETURN(std::string bytecode, SerializeUsingBytecode(*module)); + return nb::bytes(bytecode.data(), bytecode.size()); +} + +absl::StatusOr PyStablehloToMhlo(const nb::bytes& mlir_module) { + mlir::MLIRContext context; + if (VLOG_IS_ON(3)) context.disableMultithreading(); + // See PyMhloToStablehlo for an explanation of why we're allowing unregistered + // dialects here. + context.allowUnregisteredDialects(true); + TF_ASSIGN_OR_RETURN( + mlir::OwningOpRef module, + ParseMlirModuleString( + absl::string_view(mlir_module.c_str(), mlir_module.size()), context)); + mlir::PassManager pm(&context); + if (VLOG_IS_ON(3)) EnablePrintBeforeAndAfter(pm); + pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); + if (!mlir::succeeded(pm.run(*module))) { + return tsl::errors::InvalidArgument("StableHLO => MHLO failed"); + } + + // Use bytecode, passing unregistered dialects with properties causes issues + // when using textual assembly. + TF_ASSIGN_OR_RETURN(std::string bytecode, SerializeUsingBytecode(*module)); + return nb::bytes(bytecode.data(), bytecode.size()); +} + +absl::StatusOr PySerializePortableArtifact( + absl::string_view mlir_module, absl::string_view target) { + mlir::MLIRContext context; + if (VLOG_IS_ON(3)) context.disableMultithreading(); + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ParseMlirModuleString(mlir_module, context)); + + // Serialize portable artifact + TF_ASSIGN_OR_RETURN( + std::string bytecode, + SerializeUsingVersionedStablehlo(*module, target, /*inplace=*/true)); + return nb::bytes(bytecode.data(), bytecode.size()); +} + +absl::StatusOr PyDeserializePortableArtifact( + const nb::bytes& bytecode_str) { + mlir::MLIRContext context; + mlir::OwningOpRef module = + mlir::stablehlo::deserializePortableArtifact( + absl::string_view(bytecode_str.c_str(), bytecode_str.size()), + &context); + if (!module) + return tsl::errors::InvalidArgument("Failed to deserialize StableHLO"); + return PrintModule(*module); +} + +} // namespace + +void BuildMlirSubmodule(nb::module_& m) { + nb::module_ mlir_module = m.def_submodule("mlir", "MLIR/XLA integration"); + + mlir_module.def("xla_computation_to_mlir_module", + xla::ValueOrThrowWrapper(PyXlaComputationToMlirModule), + nb::arg("computation")); + mlir_module.def( + "mlir_module_to_xla_computation", + [](const nb::bytes& bytecode, bool use_tuple_args, bool return_tuple) { + return xla::ValueOrThrow(PyMlirModuleToXlaComputation( + absl::string_view(bytecode.c_str(), bytecode.size()), + use_tuple_args, return_tuple)); + }, + nb::arg("mlir_module"), nb::arg("use_tuple_args") = false, + nb::arg("return_tuple") = false); + mlir_module.def("mlir_module_to_xla_computation", + xla::ValueOrThrowWrapper(PyMlirModuleToXlaComputation), + nb::arg("mlir_module"), nb::arg("use_tuple_args") = false, + nb::arg("return_tuple") = false); + mlir_module.def( + "mhlo_to_stablehlo", + [](const nb::bytes& bytecode) { + return xla::ValueOrThrow(PyMhloToStablehlo( + absl::string_view(bytecode.c_str(), bytecode.size()))); + }, + nb::arg("mlir_module")); + mlir_module.def("mhlo_to_stablehlo", + xla::ValueOrThrowWrapper(PyMhloToStablehlo), + nb::arg("mlir_module")); + mlir_module.def("stablehlo_to_mhlo", + xla::ValueOrThrowWrapper(PyStablehloToMhlo), + nb::arg("mlir_module")); + mlir_module.def( + "serialize_portable_artifact", + [](const nb::bytes& bytecode, absl::string_view target) { + return xla::ValueOrThrow(PySerializePortableArtifact( + absl::string_view(bytecode.c_str(), bytecode.size()), target)); + }, + nb::arg("mlir_module"), nb::arg("target")); + mlir_module.def("serialize_portable_artifact", + xla::ValueOrThrowWrapper(PySerializePortableArtifact), + nb::arg("mlir_module"), nb::arg("target")); + mlir_module.def("deserialize_portable_artifact", + xla::ValueOrThrowWrapper(PyDeserializePortableArtifact), + nb::arg("mlir_module")); + mlir_module.def( + "refine_polymorphic_shapes", + [](nb::bytes bytecode, bool enable_shape_assertions, + bool validate_static_shapes, bool enable_shardy) -> nb::bytes { + std::string buffer; + llvm::raw_string_ostream os(buffer); + xla::ThrowIfError(RefinePolymorphicShapes( + absl::string_view(bytecode.c_str(), bytecode.size()), os, + enable_shape_assertions, validate_static_shapes, enable_shardy)); + return nb::bytes(buffer.data(), buffer.size()); + }, + nb::arg("mlir_module"), nb::arg("enable_shape_assertions") = true, + nb::arg("validate_static_shapes") = true, + nb::arg("enable_shardy") = false, + R"(Refines the dynamic shapes for a module. + The "main" function must have static shapes and all the + intermediate dynamic shapes depend only on the input static + shapes. Optionally, also validates that the resulting module has + only static shapes. + )"); +} + +} // namespace xla diff --git a/jaxlib/xla/mlir.h b/jaxlib/xla/mlir.h new file mode 100644 index 000000000000..f0bfd69bca6b --- /dev/null +++ b/jaxlib/xla/mlir.h @@ -0,0 +1,28 @@ +/* Copyright 2021 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_MLIR_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_MLIR_H_ + +// placeholder for index annotation headers +#include "nanobind/nanobind.h" + +namespace xla { + +void BuildMlirSubmodule(nanobind::module_& m); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_MLIR_H_ diff --git a/jaxlib/xla/pjit.cc b/jaxlib/xla/pjit.cc new file mode 100644 index 000000000000..96056708c2fb --- /dev/null +++ b/jaxlib/xla/pjit.cc @@ -0,0 +1,1402 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/pjit.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // NOLINT +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/cleanup/cleanup.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/synchronization/notification.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/config.h" +#include "jaxlib/xla/jax_jit.h" +#include "xla/layout.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/lru_cache.h" +#include "xla/python/guard_lib.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/python/nb_helpers.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/py_array.h" +#include "xla/python/py_executable.h" +#include "xla/python/py_values.h" +#include "xla/python/python_ref_manager.h" +#include "xla/python/pytree.h" +#include "xla/python/sharding.h" +#include "xla/python/traceback.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" +#include "tsl/profiler/lib/traceme.h" + +namespace jax { +namespace { + +namespace nb = nanobind; + +struct PjitCacheEntry { + explicit PjitCacheEntry(xla::PyTreeRegistry* registry) + : out_pytree_def(registry) {} + std::shared_ptr executable; + std::vector in_shardings; + std::vector out_avals; + std::vector out_dtypes; + std::vector> out_shapes; + std::vector out_weak_types; + std::vector out_shardings; + std::vector out_committed; + xla::PyTreeDef out_pytree_def; + // Bitvector of kept arguments from Jaxpr DCE pass. Used to drop some `args` + // in PjitFunction::Call before calling into compiled computation. + std::vector kept_var_bitvec; + std::vector in_device_local_layouts; + + // Ensures a single thread performs the compilation for a given executable. + // + // The first thread (holding the GIL) will create the CacheEntry associated to + // a signature and if the object has been inserted already, other threads + // will wait for the notification. + absl::Notification compilation_complete; + + std::thread::id thread_id = std::this_thread::get_id(); + + bool fall_back_to_python = false; +}; + +// A PjitFunctionCache represents a cache of compiled functions that can be +// shared between one or more PjitFunction objects. It serves two goals: +// - reduce the number of lru caches (hash map) across multiple JITs. +// - make the cache global to increase cache hits (e.g. calling jit(f)(3) twice) +// keeping entries alive as long as the underlying function f is alive. +// Assume the cache is protected by the GIL. +class PjitFunctionCache { + public: + static constexpr int kDefaultCapacity = 4096; + explicit PjitFunctionCache(int capacity); + + // Cache entries are shared_ptr<>s because it's possible the cache entry + // might be evicted before we finish tracing/compiling. + typedef xla::LRUCache> Cache; + + // We include as part of the cache key `global_cache_key` (and any other + // fields that aren't subsumed by the CallSignature we compute for each call). + static std::shared_ptr Lookup( + xla::nb_class_ptr self, nb::handle function, + nb::object global_cache_key); + std::shared_ptr DefaultCache(); + + // These methods require the GIL or the object's lock in no-GIL mode. + int Size() const { return lru_list_.Size(); } + int Capacity() const { return lru_list_.Capacity(); } + void Clear() { + lru_list_.Clear(); + functions_.clear(); + } + + private: + struct Key { + nb::handle function; // Does not hold a reference. + + // Other fields that are part of the arguments to `jit`, but are not + // otherwise part of CallSignature. + nb::object global_cache_key; + + size_t cached_hash; + + bool operator==(const Key& other) const { + bool global_cache_eq; + try { + global_cache_eq = global_cache_key.equal(other.global_cache_key); + } catch (const nanobind::python_error& e) { + throw std::invalid_argument( + absl::StrCat("Equality of global cache key lead to an exception. " + "The error was:\n", + e.what(), "\n")); + } + return function.ptr() == other.function.ptr() && global_cache_eq; + } + + struct Hash { + size_t operator()(const Key& key) const { return key.cached_hash; } + }; + }; + + template + friend H AbslHashValue(H h, const Key& key) { + h = H::combine(std::move(h), key.function.ptr()); + Py_hash_t hash; + try { + hash = nb::hash(key.global_cache_key); + } catch (const nanobind::python_error& e) { + if (!e.matches(PyExc_TypeError)) throw; + throw std::invalid_argument(absl::StrCat( + "Hashing global cache key lead to an exception. The error was:\n", + e.what(), "\n")); + } + h = H::combine(std::move(h), hash); + return h; + } + + struct Value { + explicit Value(std::shared_ptr cache) : cache(std::move(cache)) {} + std::shared_ptr cache; + + // A weak reference to the key function. We use the weak reference to + // register a callback that is triggered when the key function is destroyed. + // We use a weak pointer because we want to allow caching across multiple + // calls to `pjit(f)` if `f` remains alive, but we do not want the cache + // to keep `f` alive if all other references are dropped. + std::optional weakref; + }; + + // lru_list_ and functions_ are protected by the GIL in GIL mode, and by the + // self object lock in freethreading mode. + Cache::LRUList lru_list_; + // We use std::unordered_map because ABSL containers are not exception safe: + std::unordered_map, Key::Hash> functions_; + // mu_ prevents concurrent insertions into functions_ if the gil or critical + // section lock is released during insertion. + absl::Mutex mu_; +}; + +PjitFunctionCache::PjitFunctionCache(int capacity) : lru_list_(capacity) {} + +std::shared_ptr PjitFunctionCache::DefaultCache() { + return std::make_shared(&lru_list_); +} + +/*static*/ std::shared_ptr PjitFunctionCache::Lookup( + xla::nb_class_ptr self, nb::handle function, + nb::object global_cache_key) ABSL_NO_THREAD_SAFETY_ANALYSIS { + // In no-GIL mode, a critical section on self plays the same role that + // the GIL plays in GIL mode. + nb::ft_object_guard lock(self); + { + // Because the gil (or the critical section lock) can be released during + // cache insertion, this forces the lock order to be mu_ then gil so we + // must release the gil first. + nb::gil_scoped_release release; + // Acquire a mutex to avoid problems where the gil is released during + // cache insertion and then a second thread invalidates the cache order. + self->mu_.Lock(); + } + absl::Cleanup unlock = [&self]() ABSL_UNLOCK_FUNCTION(self->mu_) { + self->mu_.Unlock(); + }; + Key key; + key.function = function; + key.global_cache_key = global_cache_key; + key.cached_hash = absl::HashOf(key); + auto insert = self->functions_.emplace(key, nullptr); + if (!insert.second) { + return insert.first->second->cache; + } + std::shared_ptr cache = std::make_shared(&self->lru_list_); + auto callback = + nb::cpp_function([self, key{std::move(key)}](nb::handle weakref) { + nb::ft_object_guard lock(self); + auto it = self->functions_.find(key); + if (it != self->functions_.end()) { + self->functions_.erase(it); + } + }); + PyObject* weakref = PyWeakref_NewRef(function.ptr(), callback.ptr()); + if (weakref) { + std::unique_ptr& entry = insert.first->second; + entry = std::make_unique(cache); + entry->weakref = nb::steal(weakref); + } else { + PyErr_Clear(); + // `function` is not weak-referenceable. Don't bother adding it to the + // shared cache in that case; the `jit` object will hold the only shared + // reference to the cache entry. + self->functions_.erase(insert.first); + } + return cache; +} + +class PjitFunction { + public: + PjitFunction(std::string function_name, std::optional fun, + nb::callable cache_miss, std::vector static_argnums, + std::vector static_argnames, + nb::object global_cache_key, + xla::nb_class_ptr pytree_registry, + nb::callable shard_arg_fallback, + xla::nb_class_ptr cache); + ~PjitFunction(); + + PjitFunction(const PjitFunction&) = delete; + PjitFunction& operator=(const PjitFunction&) = delete; + PjitFunction(PjitFunction&&) = default; + PjitFunction& operator=(PjitFunction&&) = default; + + // nb::object typed subclass for PjitFunction objects. + class pyobject : public nb::object { + public: + NB_OBJECT(pyobject, nb::object, "PjitFunction", + PjitFunction::IsPjitFunction); + pyobject() = default; + PjitFunction* func() const { + return PjitFunction::AsPjitFunctionUnchecked(*this); + } + }; + // Alias as ::object; outside the scope above we won't confuse nanobind's + // macros. + using object = pyobject; + + // Returns true if `h` is a PjitFunction. + static bool IsPjitFunction(nb::handle handle); + // Converts `handle` to a PjitFunction*. Does not do any checking. + static PjitFunction* AsPjitFunctionUnchecked(nb::handle handle); + + absl::StatusOr Call(nb::handle callable, PyObject* const* args, + size_t nargs, PyObject* kwnames); + + void InitExecutables(); + + void ClearPythonReferences(); + + const std::string& function_name() const { return function_name_; } + const std::optional& fun() const { return fun_; } + const nb::callable& cache_miss() const { return cache_miss_; } + const xla::nb_class_ptr& pytree_registry() const { + return pytree_registry_; + } + const nb::callable& shard_arg_fallback() const { return shard_arg_fallback_; } + + const std::vector& static_argnums() const { return static_argnums_; } + const std::vector& static_argnames() const { + return static_argnames_; + } + const nb::object& global_cache_key() const { return global_cache_key_; } + const xla::nb_class_ptr& cache() const { return cache_; } + + int cache_capacity() const { + nb::ft_object_guard lock(cache_); + return executables_->Size(); + } + + void ClearCache() { + nb::ft_object_guard lock(cache_); + executables_->Clear(); + } + + std::shared_ptr executables() { + nb::ft_object_guard lock(cache_); + return executables_; + } + + nb::object PythonSignature() { + if (!fun_.has_value()) { + throw nb::value_error( + absl::StrFormat( + "Calling __signature__ on PjitFunction(%s) not supported.", + function_name_) + .c_str()); + } + static const auto* inspect = + new nb::module_(nb::module_::import_("inspect")); + return inspect->attr("signature")(*fun_); + } + + private: + absl::Status ComputeCallSignature( + absl::Span flat_dynamic_args, + CallSignature& call_signature); + + void PopulateCacheEntry(PjitCacheEntry& cache_entry, + const nb::tuple& out_and_fastpath_data); + + std::string function_name_; + std::optional fun_; + nb::callable cache_miss_; + std::vector static_argnums_; + std::vector static_argnames_; + nb::object global_cache_key_; + + xla::nb_class_ptr pytree_registry_; + nb::callable shard_arg_fallback_; + xla::nb_class_ptr cache_; + + // In no-GIL mode executables_ is protected by the object lock on cache_, + // because it shared an LRU list with cache_. + std::shared_ptr executables_; +}; + +PjitFunction::PjitFunction( + std::string function_name, std::optional fun, + nb::callable cache_miss, std::vector static_argnums, + std::vector static_argnames, nb::object global_cache_key, + xla::nb_class_ptr pytree_registry, + nb::callable shard_arg_fallback, xla::nb_class_ptr cache) + : function_name_(std::move(function_name)), + fun_(std::move(fun)), + cache_miss_(std::move(cache_miss)), + static_argnums_(std::move(static_argnums)), + global_cache_key_(std::move(global_cache_key)), + pytree_registry_(std::move(pytree_registry)), + shard_arg_fallback_(std::move(shard_arg_fallback)), + cache_(std::move(cache)) { + std::sort(static_argnums_.begin(), static_argnums_.end()); + static_argnames_.reserve(static_argnames.size()); + for (nb::str& name : static_argnames) { + PyObject* s = name.inc_ref().ptr(); + PyUnicode_InternInPlace(&s); + static_argnames_.push_back(nb::steal(s)); + } +} + +void PjitFunction::InitExecutables() { + // Construction of the object hasn't completed yet, so we don't need to hold + // the cache lock to mutate executables_. + if (!fun_.has_value()) { + executables_ = cache_->DefaultCache(); + } else { + executables_ = cache_->Lookup(cache_, fun_.value(), global_cache_key_); + } +} + +PjitFunction::~PjitFunction() { + nb::ft_object_guard lock(cache_); + executables_ = nullptr; +} + +void CallShardArgFallback( + nb::handle arg, nb::handle sharding, nb::handle layout, + const nb::callable& fallback, + std::vector>& num_args_arrays, + std::vector& keep_alive_objects) { + tsl::profiler::TraceMe traceme("cpp_pjit_shard_arg_fallback"); + auto py_array_or_bufs = fallback(arg, sharding, layout); + auto py_array = nb::cast(py_array_or_bufs); + num_args_arrays.push_back(tsl::FormRef(py_array.ifrt_array())); + keep_alive_objects.push_back(std::move(py_array_or_bufs)); +} + +// Prepares the input PjRtBuffers from the python arguments. This is equivalent +// to shard_args() in pxla.py but for only a few supported cases. +absl::StatusOr>> +PrepareIfrtInputs(const xla::PyLoadedExecutable& executable, + absl::Span flat_dynamic_args, + bool enable_x64, const std::vector& kept_args, + const std::vector& in_shardings, + const std::vector& in_device_local_layouts, + const nb::callable& shard_arg_fallback, + std::vector& keep_alive_objects) { + const auto& addressable_devices = + executable.ifrt_loaded_executable()->addressable_devices(); + const auto& num_global_devices = + executable.ifrt_loaded_executable()->num_devices(); + int num_args = flat_dynamic_args.size(); + + std::vector> num_args_arrays; + num_args_arrays.reserve(num_args); + + struct CopyGroup { + std::vector indices; + std::vector> arrays; + }; + absl::flat_hash_map, + CopyGroup> + copy_groups; + + xla::DevicePutOptions options; + options.squash_64bit_types = !enable_x64; + options.allow_zero_copy = true; + xla::ifrt::Device* data_device = nullptr; + if (executable.ifrt_loaded_executable()->num_devices() == 1) { + data_device = executable.ifrt_loaded_executable()->addressable_devices()[0]; + } + int dce_i = 0; + for (int i = 0; i < num_args; ++i) { + if (!kept_args[i]) { + continue; + } + int dce_index = dce_i; + ++dce_i; + + const nb::object& arg = flat_dynamic_args[i]; + const nb::object& in_device_local_layout = + in_device_local_layouts[dce_index]; + + auto transfer_guard_formatter = [] { return std::string(""); }; + + if (arg.type().ptr() != xla::PyArray::type().ptr()) { + if (data_device != nullptr && in_device_local_layout.is_none()) { + TF_RETURN_IF_ERROR( + jax::ApplyTransferGuardToHostToDevice(transfer_guard_formatter)); + TF_ASSIGN_OR_RETURN( + auto on_device_fn, + DevicePut(arg, executable.ifrt_loaded_executable()->client(), + data_device, options, xla::ifrt::MemoryKind())); + TF_ASSIGN_OR_RETURN(xla::DevicePutResult on_device, [&]() { + // Must release the GIL before calling IFRT because backends may + // decide to block/sleep for device buffer allocation. + nb::gil_scoped_release gil_release; + return std::move(on_device_fn)(); + }()); + + num_args_arrays.push_back(std::move(on_device.ifrt_array)); + if (on_device.owning_pybuffer) { + keep_alive_objects.push_back(std::move(on_device.owning_pybuffer)); + } + continue; + } else { + CallShardArgFallback(arg, in_shardings[dce_index], + in_device_local_layout, shard_arg_fallback, + num_args_arrays, keep_alive_objects); + continue; + } + } + + xla::PyArray py_array = nb::borrow(arg); + const auto& sharding = py_array.sharding(); + int sharding_num_devices = jax::Sharding::SafeNumDevices(sharding); + + // Currently only committed PyArray inputs or uncommitted PyArray on a + // single device inputs are allowed. This is checked previously in the entry + // point of PjitFunction::Call(). + DCHECK(py_array.committed() || + (!py_array.committed() && sharding_num_devices == 1)); + + if (!in_device_local_layout.is_none()) { + TF_ASSIGN_OR_RETURN(auto arr_layout, py_array.ifrt_array()->layout()); + xla::Layout in_xc_layout = nb::cast( + in_device_local_layout.attr("_to_xla_layout")(py_array.dtype())); + if (in_xc_layout != arr_layout->xla_layout()) { + CallShardArgFallback(arg, in_shardings[dce_index], + in_device_local_layout, shard_arg_fallback, + num_args_arrays, keep_alive_objects); + continue; + } + } + + if (sharding.type().ptr() == jax::PmapSharding::type().ptr()) { + CallShardArgFallback(arg, in_shardings[dce_index], in_device_local_layout, + shard_arg_fallback, num_args_arrays, + keep_alive_objects); + continue; + } + + if (sharding_num_devices != num_global_devices) { + CallShardArgFallback(arg, in_shardings[dce_index], in_device_local_layout, + shard_arg_fallback, num_args_arrays, + keep_alive_objects); + continue; + } + + xla::ifrt::Array* ifrt_array = py_array.ifrt_array(); + // PyArray inputs should have already been checked in + // `xla::PyArgSignatureOfValue()` called by + // `PjitFunction::ComputeCallSignature()`. + DCHECK(ifrt_array != nullptr) << "PyArray has been unexpectedly deleted."; + + const auto& ifrt_sharding = ifrt_array->sharding(); + if (sharding_num_devices == 1 && + ifrt_sharding.devices()->devices().front() != addressable_devices[0]) { + auto& copy_group = + copy_groups[std::make_pair(ifrt_sharding.devices()->devices().front(), + ifrt_sharding.memory_kind())]; + copy_group.indices.push_back(num_args_arrays.size()); + copy_group.arrays.push_back(tsl::FormRef(ifrt_array)); + num_args_arrays.push_back({}); + } else { + num_args_arrays.push_back(tsl::FormRef(ifrt_array)); + } + + keep_alive_objects.push_back(arg); + } + + if (!copy_groups.empty()) { + xla::ifrt::Client* const ifrt_client = + executable.ifrt_loaded_executable()->client(); + xla::ifrt::DeviceListRef ifrt_devices = + ifrt_client->MakeDeviceList({addressable_devices[0]}); + for (auto& [key, group] : copy_groups) { + TF_ASSIGN_OR_RETURN( + auto copied_ifrt_arrays, + ifrt_client->CopyArrays(absl::MakeSpan(group.arrays), ifrt_devices, + /*memory_kind=*/std::nullopt, + xla::ifrt::ArrayCopySemantics::kReuseInput)); + for (int i = 0; i < copied_ifrt_arrays.size(); ++i) { + num_args_arrays[group.indices[i]] = std::move(copied_ifrt_arrays[i]); + } + } + } + + return num_args_arrays; +} + +absl::StatusOr PjitFunction::Call(nb::handle callable, + PyObject* const* args, + size_t nargs, PyObject* kwnames) { + tsl::profiler::TraceMe traceme( + [&] { return absl::StrCat("PjitFunction(", function_name_, ")"); }); + + // Make sure we trigger a garbage collection on JIT function calls. Otherwise + // code like + // f = jit(...) + // while True: + // f(x) + // may never free temporary buffers for copies of arguments. + xla::GlobalPyRefManager()->MaybeCollectGarbage(); + + if (GetDisableJit()) { + if (!fun_.has_value()) { + throw nb::value_error( + absl::StrFormat("Disable jit is not supported in the AOT path since " + "the function is not available for (%s)", + function_name_) + .c_str()); + } + return nb::steal( + PyObject_Vectorcall(fun_.value().ptr(), args, nargs, kwnames)); + } + + // Calls the cache_miss_ function. This just calls the Python function; it may + // return nullptr value if a Python exception is thrown. + auto cache_miss = [&]() -> nb::tuple { + return nb::steal( + PyObject_Vectorcall(cache_miss_.ptr(), args, nargs, kwnames)); + }; + + // Call the cache_miss() function, extracting the output data and ignoring + // the fastpath data. If the cache miss returns a Python error, returns + // nullptr and leaves the Python error set. + auto fallback_to_cache_miss = [&]() { + nb::tuple cache_miss_output = cache_miss(); + if (!cache_miss_output.ptr()) { + return nb::object(); + } + return nb::object(cache_miss_output[0]); + }; + + size_t num_positional_args = PyVectorcall_NARGS(nargs); + size_t num_keyword_args = kwnames ? PyTuple_GET_SIZE(kwnames) : 0; + absl::Span positional_args(args, num_positional_args); + absl::Span keyword_args(args + num_positional_args, + num_keyword_args); + + CallSignature call_signature; + std::vector keep_alive_objects; + absl::InlinedVector flat_dynamic_args; + auto status = ParseArguments( + positional_args, keyword_args, kwnames, static_argnums_, static_argnames_, + pytree_registry_.get(), call_signature.arg_signature, flat_dynamic_args); + if (!status.ok()) { + VLOG(2) << "ParseArguments failed: " << status; + return fallback_to_cache_miss(); + } + + // Perform a few checks for the arguments. Currently we are only allowing + // committed PyArray inputs. For other cases, e.g. Tracers or ShapedArray, it + // will fallback to python. For jit, numpy arrays and scalars are also + // allowed, which we will check later. + for (const auto& arg : flat_dynamic_args) { + if (arg.type().ptr() != xla::PyArray::type().ptr()) { + continue; + } + + xla::PyArray py_array = nb::borrow(arg); + + // Only allow committed PyArray in cpp pjit for now as the logic on handling + // sharding for uncommitted PyArray is complicated and still under + // development. + // + // TODO(chky): Consider support uncommitted PyArray in cpp when the python + // side stablizes. + if (!py_array.committed() && + jax::Sharding::SafeNumDevices(py_array.sharding()) > 1) { + VLOG(2) << "PyArray argument is not committed and number of global " + "devices is more than 1; fallback to python."; + return fallback_to_cache_miss(); + } + } + + status = ComputeCallSignature(flat_dynamic_args, call_signature); + if (!status.ok()) { + VLOG(2) << "ComputeCallSignature failed: " << status; + return fallback_to_cache_miss(); + } + + VLOG(2) << "CallSignature:\n" << call_signature.DebugString(); + bool inserted = false; + std::shared_ptr cache_entry; + { + nb::ft_object_guard lock(cache_); + cache_entry = executables_->GetOrCreateIfAbsent( + call_signature, [this, &inserted](const CallSignature& unused) { + inserted = true; + return std::make_shared(pytree_registry_.get()); + }); + } + + if (!cache_entry->compilation_complete.HasBeenNotified()) { + // In case of several threads attempting to compile the executable, only + // the one that inserted the item will perform the compilation. + if (inserted) { + nb::object out_and_fastpath_data; + nb::tuple out_tuple; + VLOG(2) << "Cache miss for " << call_signature.DebugString(); + bool remove_cache = false; + try { + // Calls Python and may release the GIL. May also throw if + // compilation/tracing fails. + out_and_fastpath_data = cache_miss(); + if (!out_and_fastpath_data.ptr()) { + throw nb::python_error(); + } + out_tuple = nb::cast(out_and_fastpath_data); + + PopulateCacheEntry(*cache_entry, out_tuple); + + if (out_tuple.size() > 2 && out_tuple[2].is_valid()) { + remove_cache = nb::cast(out_tuple[2]); + } + } catch (const std::exception& e) { + VLOG(2) << "cache miss fail: " << e.what(); + cache_entry->fall_back_to_python = true; + cache_entry->compilation_complete.Notify(); + throw; + } + cache_entry->compilation_complete.Notify(); + + if (remove_cache) { + nb::ft_object_guard lock(cache_); + executables_->Remove(call_signature); + } + + // We have already computed the result in the miss path so we can return + // it. We are even *required* to do so if there are donated arguments, + // because any donated buffers will now be invalid. + return nb::object(out_tuple[0]); + } else { + if (cache_entry->thread_id == std::this_thread::get_id()) { + auto error_string = absl::StrCat("Recursively calling jit: ", + call_signature.DebugString()); + PyErr_SetString(PyExc_RecursionError, error_string.c_str()); + throw nb::python_error(); + } + // Release the GIL while we wait, making sure the compile thread can + // lock it. + nb::gil_scoped_release release; + cache_entry->compilation_complete.WaitForNotification(); + } + } + + if (cache_entry->fall_back_to_python) { + VLOG(2) << "cpp pjit fallback to python."; + return fallback_to_cache_miss(); + } + + // A vector of [num_inputs]. + auto num_args_arrays = PrepareIfrtInputs( + *cache_entry->executable, flat_dynamic_args, + call_signature.jax_enable_x64, cache_entry->kept_var_bitvec, + cache_entry->in_shardings, cache_entry->in_device_local_layouts, + shard_arg_fallback_, keep_alive_objects); + + if (!num_args_arrays.ok()) { + VLOG(2) << "Failed to prepare IFRT inputs: " << num_args_arrays.status(); + return fallback_to_cache_miss(); + } + + xla::ifrt::ExecuteOptions execute_options = + cache_entry->executable->options(); + execute_options.launch_id = cache_entry->executable->GetNextLaunchId(); + execute_options.execution_stream_id = + tsl::Env::Default()->GetCurrentThreadId(); + + // A vector of [num_outputs]. + std::vector> output_arrays; + { + nb::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN(auto result, + cache_entry->executable->ifrt_executable()->Execute( + absl::MakeSpan(*num_args_arrays), execute_options, + /*devices=*/std::nullopt)); + output_arrays = std::move(result.outputs); + } + + auto traceback = xla::Traceback::Get(); + + // Convert the ifrt::Array objects to PyArray. + int num_outputs = output_arrays.size(); + absl::InlinedVector outputs; + outputs.reserve(num_outputs); + for (int i = 0; i < num_outputs; ++i) { + // Creating the PyArray result. In addition to the IFRT arrays, the metadata + // like `aval` and `sharding` are retrieved from the cache for this + // function, which are produced by the python path in `cache_miss`. + xla::PyArray py_array( + cache_entry->out_avals[i], cache_entry->out_weak_types[i], + cache_entry->out_dtypes[i], cache_entry->out_shapes[i], + cache_entry->out_shardings[i], cache_entry->executable->client(), + traceback, std::move(output_arrays[i]), + /*committed=*/cache_entry->out_committed.at(i), /*skip_checks=*/true); + + outputs.push_back(std::move(py_array)); + } + + nb::object out = nb::steal( + cache_entry->out_pytree_def.Unflatten(outputs).release().ptr()); + + // If there is a post-hook function, call it with the inputs and the outputs. + std::optional post_hook = GetPostHook(); + if (post_hook) { + nb::tuple args_tuple = + nb::steal(PyTuple_New(num_positional_args)); + for (size_t i = 0; i < num_positional_args; ++i) { + Py_INCREF(args[i]); + PyTuple_SET_ITEM(args_tuple.ptr(), i, args[i]); + } + nb::dict kwargs; + if (kwnames) { + for (size_t i = 0; i < num_keyword_args; ++i) { + kwargs[nb::handle(PyTuple_GET_ITEM(kwnames, i))] = + nb::borrow(args[num_positional_args + i]); + } + } + (*post_hook)(nb::handle(callable.ptr()), args_tuple, kwargs, + nb::handle(out.ptr())); + } + + return out; +} + +absl::Status PjitFunction::ComputeCallSignature( + absl::Span flat_dynamic_args, CallSignature& signature) { + signature.function_name = function_name_; + + // Get dynamic argument signatures. + JitState& global_state = jax::GlobalJitState(); + JitState& tls = jax::ThreadLocalJitState(); + bool jax_enable_x64 = GetEnableX64(); + + signature.default_device = GetDefaultDevice(); + signature.jax_enable_x64 = jax_enable_x64; + + auto& dynamic_arg_signatures = signature.dynamic_arg_signatures; + dynamic_arg_signatures.reserve(flat_dynamic_args.size()); + auto& dynamic_arg_shardings = signature.dynamic_arg_shardings; + dynamic_arg_shardings.reserve(flat_dynamic_args.size()); + auto& dynamic_arg_layouts = signature.dynamic_arg_layouts; + dynamic_arg_layouts.reserve(flat_dynamic_args.size()); + + for (nb::handle arg : flat_dynamic_args) { + TF_ASSIGN_OR_RETURN(auto arg_signature, + xla::PyArgSignatureOfValue(arg, jax_enable_x64)); + signature.dynamic_arg_signatures.push_back(std::move(arg_signature)); + + // It should be already checked previously in the entry point of + // PjitFunction::Call(). + if (arg.type().ptr() == xla::PyArray::type().ptr()) { + auto py_array = nb::borrow(arg); + signature.dynamic_arg_shardings.push_back(py_array.sharding()); + auto layout = py_array.layout(); + if (absl::IsUnimplemented(layout.status())) { + signature.dynamic_arg_layouts.push_back(nullptr); + } else { + signature.dynamic_arg_layouts.push_back(*std::move(layout)); + } + signature.committed_args.push_back(py_array.committed()); + } else { + signature.dynamic_arg_shardings.push_back(nb::none()); + signature.dynamic_arg_layouts.push_back(nullptr); + signature.committed_args.push_back(false); + } + } + + signature.thread_local_extra_jit_context = tls.extra_jit_context; + signature.global_extra_jit_context = global_state.extra_jit_context; + signature.configs = JitConfigs(); + + return absl::OkStatus(); +} + +void PjitFunction::PopulateCacheEntry(PjitCacheEntry& cache_entry, + const nb::tuple& out_and_fastpath_data) { + DCHECK_GE(out_and_fastpath_data.size(), 2); + + if (out_and_fastpath_data[1].is_none()) { + VLOG(2) << "fastpath_data is none"; + cache_entry.fall_back_to_python = true; + return; + } + + nb::tuple fastpath_data = nb::cast(out_and_fastpath_data[1]); + + cache_entry.executable = nb::cast>( + fastpath_data.attr("xla_executable")); + + nb::sequence in_shardings = fastpath_data.attr("in_shardings"); + cache_entry.in_shardings.reserve(nb::len(in_shardings)); + for (nb::handle sharding : in_shardings) { + cache_entry.in_shardings.push_back(nb::borrow(sharding)); + } + + nb::sequence out_shardings = fastpath_data.attr("out_shardings"); + cache_entry.out_shardings.reserve(nb::len(out_shardings)); + for (nb::handle sharding : out_shardings) { + cache_entry.out_shardings.push_back(nb::borrow(sharding)); + } + + nb::sequence out_committed = fastpath_data.attr("out_committed"); + cache_entry.out_committed.reserve(nb::len(out_committed)); + for (nb::handle c : out_committed) { + cache_entry.out_committed.push_back(nb::cast(c)); + } + + nb::sequence out_avals = fastpath_data.attr("out_avals"); + cache_entry.out_avals.reserve(nb::len(out_avals)); + cache_entry.out_dtypes.reserve(nb::len(out_avals)); + cache_entry.out_shapes.reserve(nb::len(out_avals)); + cache_entry.out_weak_types.reserve(nb::len(out_avals)); + for (nb::handle aval : out_avals) { + cache_entry.out_avals.push_back(nb::borrow(aval)); + cache_entry.out_dtypes.push_back(aval.attr("dtype")); + cache_entry.out_shapes.push_back( + nb::cast>(aval.attr("shape"))); + cache_entry.out_weak_types.push_back( + nb::cast(aval.attr("weak_type"))); + } + + cache_entry.out_pytree_def = nb::cast( + nb::handle(fastpath_data.attr("out_pytree_def").ptr())); + + nb::sequence kept_var_bitvec = fastpath_data.attr("kept_var_bitvec"); + cache_entry.kept_var_bitvec.reserve(nb::len(kept_var_bitvec)); + for (nb::handle k : kept_var_bitvec) { + cache_entry.kept_var_bitvec.push_back(nb::cast(k)); + } + + nb::sequence in_device_local_layouts = + fastpath_data.attr("in_device_local_layouts"); + cache_entry.in_device_local_layouts.reserve(nb::len(in_device_local_layouts)); + for (nb::handle dll : in_device_local_layouts) { + cache_entry.in_device_local_layouts.push_back(nb::borrow(dll)); + } +} + +// Helper function used by the tp_clear GC method. +void PjitFunction::ClearPythonReferences() { + // TODO(mattjj): phawkins@ observed that the xla::PyTreeRegistry + // pytree_registry_ attribute of PjitFunction could in principle also have + // python references to clear + nb::callable cache_miss; + std::optional fun; + nb::callable shard_arg_fallback; + // Swap values for nulls before they are destroyed. See the Python + // Py_CLEAR() documentation for a discussion of this topic. + std::swap(cache_miss_, cache_miss); + std::swap(fun_, fun); + std::swap(shard_arg_fallback_, shard_arg_fallback); +} + +struct PjitFunctionObject { + PyObject_HEAD; +#if PY_VERSION_HEX < 0x030C0000 + PyObject* dict; // Dictionary for __dict__ + PyObject* weakrefs; // Weak references; for use by the Python interpreter. +#endif // PY_VERSION_HEX < 0x030C0000 + vectorcallfunc vectorcall; + PjitFunction fun; + + // Doubly-linked list of PjitFunctionObjects, protected by + // PjitFunctionStore::mu_ or the GIL in GIL mode. + PjitFunctionObject* next; + PjitFunctionObject* prev; +}; + +// Contains a list of all PjitFunctionObjects. +// Thread-safe. +class PjitFunctionStore { + public: + void Insert(PjitFunctionObject* o) { + nb::ft_lock_guard lock(mu_); + o->next = compiled_functions_; + o->prev = nullptr; + if (o->next) { + o->next->prev = o; + } + compiled_functions_ = o; + } + + void Remove(PjitFunctionObject* o) { + nb::ft_lock_guard lock(mu_); + if (o->next) { + o->next->prev = o->prev; + } + if (o->prev) { + o->prev->next = o->next; + } else { + compiled_functions_ = o->next; + } + } + + void ClearCaches() { + std::vector< + std::pair>> + caches; + { + nb::ft_lock_guard lock(mu_); + for (PjitFunctionObject* fn = compiled_functions_; fn != nullptr; + fn = fn->next) { + caches.emplace_back(fn->fun.cache(), fn->fun.executables()); + } + } + for (auto& [cache, executables] : caches) { + nb::ft_object_guard lock(cache); + executables->Clear(); + } + }; + + private: + // Protected by the GIL in GIL mode, and by mu_ in freethreading mode. + nb::ft_mutex mu_; + PjitFunctionObject* compiled_functions_; +}; + +PjitFunctionStore pjit_function_store; + +PyObject* PjitFunction_Type = nullptr; + +bool PjitFunction::IsPjitFunction(nb::handle handle) { + return handle.type().ptr() == PjitFunction_Type; +} + +PjitFunction* PjitFunction::AsPjitFunctionUnchecked(nb::handle handle) { + return &(reinterpret_cast(handle.ptr())->fun); +} + +PjitFunction* AsPjitFunction(nb::handle handle) { + if (!PjitFunction::IsPjitFunction(handle)) { + throw xla::XlaRuntimeError(xla::InvalidArgument("Expected a PjitFunction")); + } + return PjitFunction::AsPjitFunctionUnchecked(handle); +} + +extern "C" { + +PyObject* PjitFunction_tp_vectorcall(PyObject* callable, PyObject* const* args, + size_t nargs, PyObject* kwnames) { + PjitFunctionObject* o = reinterpret_cast(callable); + tsl::profiler::TraceMe traceme([&] { + return absl::StrCat("PjitFunction(", o->fun.function_name(), ")"); + }); + try { + absl::StatusOr out = + o->fun.Call(callable, args, nargs, kwnames); + if (!out.ok()) { + PyErr_SetString(PyExc_ValueError, out.status().ToString().c_str()); + return nullptr; + } + return out.value().release().ptr(); + } catch (nb::python_error& e) { + e.restore(); + return nullptr; + } catch (nb::cast_error& e) { + PyErr_SetString(PyExc_ValueError, e.what()); + return nullptr; + } catch (std::invalid_argument& e) { + PyErr_SetString(PyExc_ValueError, e.what()); + return nullptr; + } catch (std::runtime_error& e) { + PyErr_SetString(PyExc_ValueError, e.what()); + return nullptr; + } +} + +PyObject* PjitFunction_tp_new(PyTypeObject* subtype, PyObject* args, + PyObject* kwds) { + PjitFunctionObject* self = + reinterpret_cast(subtype->tp_alloc(subtype, 0)); + if (!self) return nullptr; +#if PY_VERSION_HEX < 0x030C0000 + self->dict = nullptr; + self->weakrefs = nullptr; +#endif // PY_VERSION_HEX < 0x030C0000 + self->vectorcall = PjitFunction_tp_vectorcall; + return reinterpret_cast(self); +} + +void PjitFunction_tp_dealloc(PyObject* self) { + PyObject_GC_UnTrack(self); + PyTypeObject* tp = Py_TYPE(self); + PjitFunctionObject* o = reinterpret_cast(self); + pjit_function_store.Remove(o); + PyObject_ClearWeakRefs(self); +#if PY_VERSION_HEX < 0x030C0000 + Py_CLEAR(o->dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_ClearManagedDict(self); +#else + PyObject_ClearManagedDict(self); +#endif // PY_VERSION_HEX < 0x030C0000 + o->fun.~PjitFunction(); + tp->tp_free(self); + Py_DECREF(tp); +} + +int PjitFunction_tp_traverse(PyObject* self, visitproc visit, void* arg) { + // TODO(mattjj): phawkins@ observed that the xla::PyTreeRegistry + // pytree_registry_ attribute of PjitFunction could in principle also have + // python references to visit + PjitFunctionObject* o = reinterpret_cast(self); + // https://docs.python.org/3/c-api/typeobj.html#c.PyTypeObject.tp_traverse + Py_VISIT(Py_TYPE(self)); +#if PY_VERSION_HEX < 0x030C0000 + Py_VISIT(o->dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_VisitManagedDict(self, visit, arg); +#else + PyObject_VisitManagedDict(self, visit, arg); +#endif // PY_VERSION_HEX < 0x030C0000 + Py_VISIT(o->fun.cache_miss().ptr()); + Py_VISIT(o->fun.shard_arg_fallback().ptr()); + if (o->fun.fun()) { + Py_VISIT(o->fun.fun()->ptr()); + } + return 0; +} + +int PjitFunction_tp_clear(PyObject* self) { + PjitFunctionObject* o = reinterpret_cast(self); +#if PY_VERSION_HEX < 0x030C0000 + Py_CLEAR(o->dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_ClearManagedDict(self); +#else + PyObject_ClearManagedDict(self); +#endif // PY_VERSION_HEX < 0x030C0000 + o->fun.ClearPythonReferences(); + return 0; +} + +// Implements the Python descriptor protocol so JIT-compiled functions can be +// used as bound methods. See: +// https://docs.python.org/3/howto/descriptor.html#functions-and-methods +PyObject* PjitFunction_tp_descr_get(PyObject* self, PyObject* obj, + PyObject* type) { + if (obj == nullptr || obj == Py_None) { + Py_INCREF(self); + return self; + } + return PyMethod_New(self, obj); +} + +static PyGetSetDef PjitFunction_tp_getset[] = { + // Having a __dict__ seems necessary to allow !functool.wraps to override + // __doc__. + {const_cast("__dict__"), PyObject_GenericGetDict, + PyObject_GenericSetDict, nullptr, nullptr}, + {nullptr, nullptr, nullptr, nullptr, nullptr}}; + +PyObject* PjitFunction_tp_repr(PyObject* self) { + try { + const std::string& repr = absl::StrFormat( + "", nb::cast(nb::repr( + nb::getattr(self, "__wrapped__")))); + return PyUnicode_FromString(repr.c_str()); + } catch (...) { + // Ignore all errors when accessing a repr. + return PyUnicode_FromString(""); + } +} + +} // extern "C" + +void InitializePjitFunction( + PjitFunctionObject* fn_obj, std::string function_name, + std::optional fun, nb::callable cache_miss, + std::vector static_argnums, std::vector static_argnames, + nb::object global_cache_key, + xla::nb_class_ptr pytree_registry, + nb::callable shard_arg_fallback, + xla::nb_class_ptr cache) { + fn_obj->next = fn_obj->prev = nullptr; + if (nb::isinstance(global_cache_key)) { + global_cache_key = nb::tuple(global_cache_key); + } + new (&fn_obj->fun) PjitFunction( + std::move(function_name), std::move(fun), std::move(cache_miss), + std::move(static_argnums), std::move(static_argnames), + std::move(global_cache_key), std::move(pytree_registry), + std::move(shard_arg_fallback), std::move(cache)); + // Handled separately because it is not exception safe to call this + // in the constructor because it leaves the object improperly constructed. + fn_obj->fun.InitExecutables(); + + // Only add the executable to the store after executables_ has been + // initialized. We want only fully constructed executables in the store. + pjit_function_store.Insert(fn_obj); +} + +nb::object MakePjitFunction( + std::string function_name, std::optional fun, + nb::callable cache_miss, std::vector static_argnums, + std::vector static_argnames, nb::object global_cache_key, + xla::nb_class_ptr pytree_registry, + nb::callable shard_arg_fallback, + std::optional> cache) { + nb::object obj = nb::steal(PjitFunction_tp_new( + reinterpret_cast(PjitFunction_Type), nullptr, nullptr)); + PjitFunctionObject* fn_obj = reinterpret_cast(obj.ptr()); + if (!cache) { + cache = xla::make_nb_class( + PjitFunctionCache::kDefaultCapacity); + } + InitializePjitFunction( + fn_obj, std::move(function_name), std::move(fun), std::move(cache_miss), + std::move(static_argnums), std::move(static_argnames), + std::move(global_cache_key), std::move(pytree_registry), + std::move(shard_arg_fallback), std::move(*cache)); + return obj; +} + +// Version numbers for the pickled representations of +// PjitFunction. Increment these if changing them. +const int kPjitFunctionPickleVersion = 1; + +PyMemberDef PjitFunction_members[] = { + {"__vectorcalloffset__", T_PYSSIZET, + static_cast(offsetof(PjitFunctionObject, vectorcall)), + READONLY, nullptr}, +#if PY_VERSION_HEX < 0x030C0000 + {"__dictoffset__", T_PYSSIZET, + static_cast(offsetof(PjitFunctionObject, dict)), READONLY, + nullptr}, + {"__weaklistoffset__", T_PYSSIZET, + static_cast(offsetof(PjitFunctionObject, weakrefs)), READONLY, + nullptr}, +#endif // PY_VERSION_HEX < 0x030C0000 + {nullptr, 0, 0, 0, nullptr}, +}; + +PyType_Slot PjitFunction_slots[] = { + {Py_tp_new, reinterpret_cast(PjitFunction_tp_new)}, + {Py_tp_dealloc, reinterpret_cast(PjitFunction_tp_dealloc)}, + {Py_tp_traverse, reinterpret_cast(PjitFunction_tp_traverse)}, + {Py_tp_clear, reinterpret_cast(PjitFunction_tp_clear)}, + {Py_tp_getset, reinterpret_cast(PjitFunction_tp_getset)}, + {Py_tp_descr_get, reinterpret_cast(PjitFunction_tp_descr_get)}, + {Py_tp_call, reinterpret_cast(PyVectorcall_Call)}, + {Py_tp_repr, reinterpret_cast(PjitFunction_tp_repr)}, + {Py_tp_members, reinterpret_cast(PjitFunction_members)}, + {0, nullptr}, +}; + +} // namespace + +void BuildPjitSubmodule(nb::module_& m) { + nb::class_ cache(m, "PjitFunctionCache"); + cache.def(nb::init(), + nb::arg("capacity") = PjitFunctionCache::kDefaultCapacity); + cache.def("size", &PjitFunctionCache::Size, nb::lock_self()); + cache.def("capacity", &PjitFunctionCache::Capacity, nb::lock_self()); + cache.def("clear", &PjitFunctionCache::Clear, nb::lock_self()); + cache.def_static("clear_all", []() { pjit_function_store.ClearCaches(); }); + cache.def( + "__getstate__", + // Pickles as an empty cache; the client can repopulate as needed. + [](const PjitFunctionCache& cache) { + nb::dict pickle; + pickle["version"] = kPjitFunctionPickleVersion; + pickle["capacity"] = cache.Capacity(); + return pickle; + }, + nb::lock_self()); + cache.def("__setstate__", + [](PjitFunctionCache* cache, const nb::dict& pickle) { + int version = nb::cast(pickle["version"]); + if (version != kPjitFunctionPickleVersion) { + throw std::invalid_argument(absl::StrFormat( + "Invalid PjitFunction pickle version, got %d, expected %d", + version, kPjitFunctionPickleVersion)); + } + int capacity = nb::cast(pickle["capacity"]); + new (cache) PjitFunctionCache(capacity); + }); + + // We need to use heap-allocated type objects because we want to add + // additional methods dynamically. + std::string name = + absl::StrCat(nb::cast(m.attr("__name__")), ".PjitFunction"); + PyType_Spec PjitFunction_spec = { +#if PY_VERSION_HEX < 0x030B0000 + // Work around for https://github.com/python/cpython/issues/89478 + // CPython 3.10 and earlier assume that the .name value remains alive + // forever. + /*.name=*/strdup(name.c_str()), +#else + /*.name=*/name.c_str(), +#endif // PY_VERSION_HEX < 0x030B0000 + /*.basicsize=*/static_cast(sizeof(PjitFunctionObject)), + /*.itemsize=*/0, +#if PY_VERSION_HEX < 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | + Py_TPFLAGS_HAVE_VECTORCALL, +#else // PY_VERSION_HEX < 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | + Py_TPFLAGS_HAVE_VECTORCALL | Py_TPFLAGS_MANAGED_DICT | + Py_TPFLAGS_MANAGED_WEAKREF, +#endif // PY_VERSION_HEX < 0x030C0000 + /*.slots=*/PjitFunction_slots, + }; + PjitFunction_Type = PyType_FromSpec(&PjitFunction_spec); + if (!PjitFunction_Type) { + throw nb::python_error(); + } + nb::object cfun = nb::borrow(PjitFunction_Type); + + // Add PjitFunction to the xla_extension module so it can be pickled. + m.attr("PjitFunction") = cfun; + cfun.attr("__getstate__") = nb::cpp_function( + [](const PjitFunction::object& self) { + PjitFunction* fn = self.func(); + nb::dict pickle; + pickle["version"] = kPjitFunctionPickleVersion; + pickle["function_name"] = fn->function_name(); + if (fn->fun().has_value()) { + pickle["fun"] = *fn->fun(); + } + pickle["cache_miss"] = fn->cache_miss(); + pickle["static_argnums"] = fn->static_argnums(); + pickle["static_argnames"] = nb::cast(fn->static_argnames()); + pickle["global_cache_key"] = fn->global_cache_key(); + pickle["pytree_registry"] = nb::cast(fn->pytree_registry()); + pickle["shard_arg_fallback"] = fn->shard_arg_fallback(); + pickle["cache"] = fn->cache(); + return pickle; + }, + nb::is_method()); + cfun.attr("__setstate__") = nb::cpp_function( + [](nb::object& self, const nb::dict& pickle) { + int version = nb::cast(pickle["version"]); + if (version != kPjitFunctionPickleVersion) { + throw std::invalid_argument(absl::StrFormat( + "Invalid PjitFunction pickle version, got %d, expected %d. " + "Pickling/Unpickling jitted functions using different JAX " + "versions is not supported.", + version, kPjitFunctionPickleVersion)); + } + std::string function_name = + nb::cast(pickle["function_name"]); + std::optional fun; + if (pickle.contains("fun")) { + fun = nb::cast(pickle["fun"]); + } + nb::callable cache_miss = nb::cast(pickle["cache_miss"]); + std::vector static_argnums = + nb::cast>(pickle["static_argnums"]); + std::vector static_argnames = + nb::cast>(pickle["static_argnames"]); + nb::object global_cache_key = pickle["global_cache_key"]; + xla::nb_class_ptr pytree_registry = + nb::cast>( + nb::handle(pickle["pytree_registry"].ptr())); + nb::callable shard_arg_fallback = + nb::cast(pickle["shard_arg_fallback"]); + xla::nb_class_ptr cache = + nb::cast>(pickle["cache"]); + InitializePjitFunction( + reinterpret_cast(self.ptr()), + std::move(function_name), std::move(fun), std::move(cache_miss), + std::move(static_argnums), std::move(static_argnames), + std::move(global_cache_key), std::move(pytree_registry), + std::move(shard_arg_fallback), std::move(cache)); + }, + nb::is_method()); + cfun.attr("__signature__") = + xla::nb_property_readonly([](nb::handle self) -> nb::object { + return AsPjitFunction(self)->PythonSignature(); + }); + cfun.attr("_cache_miss") = + xla::nb_property_readonly([](nb::handle self) -> nb::object { + return AsPjitFunction(self)->cache_miss(); + }); + // All private members are only for testing/debugging purposes + cfun.attr("_cache_size") = nb::cpp_function( + [](nb::handle self) -> int { + return AsPjitFunction(self)->cache_capacity(); + }, + nb::is_method()); + cfun.attr("_clear_cache") = nb::cpp_function( + [](nb::handle self) { AsPjitFunction(self)->ClearCache(); }, + nb::is_method()); + + m.def( + "pjit", + [](std::string function_name, std::optional fun, + nb::callable cache_miss, std::vector static_argnums, + std::vector static_argnames, nb::object global_cache_key, + nb::object pytree_registry, nb::callable shard_arg_fallback, + std::optional> cache) { + xla::nb_class_ptr registry = + nb::cast>( + nb::handle(pytree_registry.ptr())); + return MakePjitFunction( + std::move(function_name), std::move(fun), std::move(cache_miss), + std::move(static_argnums), std::move(static_argnames), + std::move(global_cache_key), std::move(registry), + std::move(shard_arg_fallback), std::move(cache)); + }, + nb::arg("function_name"), nb::arg("fun").none(), nb::arg("cache_miss"), + nb::arg("static_argnums"), nb::arg("static_argnames"), + nb::arg("global_cache_key"), nb::arg("pytree_registry"), + nb::arg("shard_arg_fallback"), nb::arg("cache").none() = nb::none()); +} + +} // namespace jax diff --git a/jaxlib/xla/pjit.h b/jaxlib/xla/pjit.h new file mode 100644 index 000000000000..545fb2307783 --- /dev/null +++ b/jaxlib/xla/pjit.h @@ -0,0 +1,27 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_PJIT_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_PJIT_H_ + +// placeholder for index annotation headers +#include "nanobind/nanobind.h" + +namespace jax { + +void BuildPjitSubmodule(nanobind::module_& m); +} + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_PJIT_H_ diff --git a/jaxlib/xla/pmap_lib.cc b/jaxlib/xla/pmap_lib.cc new file mode 100644 index 000000000000..5582eccf4f8b --- /dev/null +++ b/jaxlib/xla/pmap_lib.cc @@ -0,0 +1,1180 @@ +/* Copyright 2021 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/pmap_lib.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/synchronization/notification.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/config.h" +#include "jaxlib/xla/jax_jit.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/python/nb_helpers.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/py_array.h" +#include "xla/python/py_client.h" +#include "xla/python/py_device.h" +#include "xla/python/py_executable.h" +#include "xla/python/py_values.h" +#include "xla/python/python_ref_manager.h" +#include "xla/python/pytree.h" +#include "xla/python/sharded_device_array.h" +#include "xla/python/sharding.h" +#include "xla/python/to_ifrt_sharding.h" +#include "xla/python/traceback.h" +#include "xla/python/types.h" +#include "xla/status_macros.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/python/lib/core/numpy.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/profiler/lib/traceme.h" + +namespace jax { + +namespace nb = nanobind; + +namespace { + +// Specifies how to shard the inputs. Even though everything could be computed +// from `sharding_specs` and the argument shape, we cache derived computations +// for performance. +struct InputSpec { + InputSpec(nb::object indices, nb::object array_sharding) + : indices(std::move(indices)), + array_sharding(std::move(array_sharding)) {} + nb::object indices; + nb::object array_sharding; +}; + +// An object containing the arguments to create Array from the +// output buffers. +struct ResultSpec { + public: + explicit ResultSpec(nb::object aval) + : out_aval(std::move(aval)), + weak_type(nb::cast(out_aval.attr("weak_type"))) {} + nb::object out_aval; + bool weak_type; +}; + +// The result of `ShardArg`. +struct ShardArgResult { + // Points to the on-device array. + // ifrt_array->sharding().num_shards() == `num_devices`. + tsl::RCReference ifrt_array; + // The Python argument will be always be copied to `owning_sda`. + nb::object owning_sda; +}; + +// Shards a single argument over devices. +// +// We currently only support fully in C++, C++ Array. For all +// other usages, we call a Python function returning C++ Array +// that will be casted back to the C++ objects. +// +// This function is not usable for JAX extensions that do not comply with the +// PjRt interfaces. +// +// Arguments: +// `arg`: The object to shard across `devices`. If a `Array`, +// a fast-path will be executed if it's already correctly sharded. +// +// Returns a failure absl::Status when an unrecoverable error occurred, so we +// don't need to fallback to Python. +// +// Both `devices` and `sharding_spec` has the same length. +absl::StatusOr ShardArg( + nb::handle arg, absl::Span devices, + const InputSpec& input_spec, nb::handle py_devices, + const nb::callable& python_fallback) { + if (arg.type().ptr() == xla::PyArray::type().ptr()) { + auto py_array = nb::borrow(arg); + if (py_array.sharding().type().ptr() == + input_spec.array_sharding.type().ptr()) { + auto* pmap_sharding = nb::cast(py_array.sharding()); + auto* cached_pmap_sharding = + nb::cast(input_spec.array_sharding); + + if (pmap_sharding->sharding_spec() == + cached_pmap_sharding->sharding_spec()) { + ShardArgResult result; + result.owning_sda = nb::borrow(arg); + result.ifrt_array = tsl::FormRef(py_array.ifrt_array()); + if (result.ifrt_array == nullptr) { + return xla::InvalidArgument("Array has been deleted."); + } + if (result.ifrt_array->sharding().devices()->devices() != devices) { + absl::InlinedVector ifrt_devices; + ifrt_devices.reserve(devices.size()); + ifrt_devices.insert(ifrt_devices.end(), devices.begin(), + devices.end()); + // pmap does not support memory_kind for now. + auto* ifrt_client = result.ifrt_array->client(); + TF_ASSIGN_OR_RETURN(auto copied_ifrt_arrays, + ifrt_client->CopyArrays( + absl::MakeSpan(&result.ifrt_array, 1), + ifrt_client->MakeDeviceList(ifrt_devices), + xla::ifrt::MemoryKind(), + xla::ifrt::ArrayCopySemantics::kReuseInput)); + result.ifrt_array = std::move(copied_ifrt_arrays.front()); + } + return result; + } + } + } + + auto ndarray = xla::nb_numpy_ndarray::ensure(arg); + if (ndarray && PyArray_CheckExact(arg.ptr()) && + xla::DtypeToPrimitiveType(ndarray.dtype()).status().ok()) { + tsl::profiler::TraceMe traceme("ndarray pmap ShardArg"); + nb::list indices = nb::list(input_spec.indices); + nb::list py_devices_list = nb::cast(py_devices); + auto n_devices = py_devices_list.size(); + if (indices.size() != n_devices) { + return xla::InvalidArgument("indices vs devices mismatch: %d vs %d", + indices.size(), n_devices); + } + + std::vector> per_device_arrays; + per_device_arrays.reserve(n_devices); + absl::InlinedVector devices; + devices.reserve(n_devices); + // TODO(hyeontaek): The created array will never be disassembled. We should + // omit collecting shapes and make the OpaqueSharding non-disassemblable? + std::vector shapes; + shapes.reserve(n_devices); + + nb::list owning_pylist; + ShardArgResult result; + result.owning_sda = owning_pylist; + const bool jax_enable_x64 = GetEnableX64(); + + std::vector device_put_fns; + device_put_fns.reserve(n_devices); + xla::DevicePutOptions options; + options.squash_64bit_types = !jax_enable_x64; + options.allow_zero_copy = true; + for (size_t i = 0; i < n_devices; ++i) { + auto to_device = nb::cast(py_devices_list[i]); + if (to_device->client().get() == nullptr) { + return xla::InvalidArgument("Cannot copy to unattached devices."); + } + + TF_ASSIGN_OR_RETURN( + device_put_fns.emplace_back(), + DevicePut(arg[indices[i]], to_device->client()->ifrt_client(), + to_device->device(), options, xla::ifrt::MemoryKind())); + } + std::vector device_puts; + device_puts.reserve(n_devices); + { + nb::gil_scoped_release gil_release; + for (auto& device_put_fn : device_put_fns) { + TF_ASSIGN_OR_RETURN(auto device_put, std::move(device_put_fn)()); + device_puts.push_back(std::move(device_put)); + } + } + for (auto& device_put : device_puts) { + per_device_arrays.push_back(std::move(device_put.ifrt_array)); + devices.push_back( + per_device_arrays.back()->sharding().devices()->devices().front()); + shapes.push_back(per_device_arrays.back()->shape()); + if (device_put.owning_pybuffer) { + owning_pylist.append(device_put.owning_pybuffer); + } + } + + if (per_device_arrays.empty()) { + return xla::InvalidArgument("Per-device arrays must not be empty."); + } + // TODO(hyeontaek): The logical shape here is inaccurate. We + // may want to avoid creating a new Array or specialize Array + // to disallow access to the logical shape. + xla::ifrt::Shape shape = per_device_arrays.front()->shape(); + TF_ASSIGN_OR_RETURN( + auto ifrt_sharding, + xla::GetIfrtConcreteSharding(input_spec.array_sharding, shape, shapes)); + TF_ASSIGN_OR_RETURN( + result.ifrt_array, + per_device_arrays.front() + ->client() + ->AssembleArrayFromSingleDeviceArrays( + std::move(shape), std::move(ifrt_sharding), + absl::MakeSpan(per_device_arrays), + xla::ifrt::ArrayCopySemantics::kReuseInput, + xla::ifrt::SingleDeviceShardSemantics::kAddressableShards)); + return result; + } + tsl::profiler::TraceMe traceme("pmap_lib_shard_arg_python_fallback"); + auto py_array_or_bufs = python_fallback(arg, input_spec.array_sharding); + + auto py_array = nb::cast(py_array_or_bufs); + ShardArgResult result; + result.owning_sda = nb::borrow(py_array_or_bufs); + result.ifrt_array = tsl::FormRef(py_array.ifrt_array()); + return result; +} + +struct PmapCacheEntry { + explicit PmapCacheEntry(xla::PyTreeRegistry* registry) + : out_pytree_def(registry) {} + std::shared_ptr executable; + // The value `backend.local_devices()`. + nb::object py_devices; // To pass back to Python. + std::vector devices; + std::vector input_specs; + xla::PyTreeDef out_pytree_def; + // Objects necessary to build the out Array objects. + std::vector out_result_specs; + + std::vector out_array_shardings; + std::vector out_dtypes; + std::vector> out_shapes; + std::vector out_committed; + + // Ensures a single thread performs the compilation for a given executable. + // + // The first thread (holding the GIL) will create the CacheEntry associated to + // a signature and if the object has been inserted already, other threads + // will wait for the notification. + absl::Notification compilation_complete; + + bool fall_back_to_python = false; +}; + +} // namespace + +// A `PmapFunction` is associated to a `jax.pmap(f)` and takes care of the +// bookkeeping of the different signatures used and the dispatch of calls to +// the correct underlying `PyLoadedExecutable`. This class is thread-safe. +class PmapFunction { + public: + PmapFunction(nb::callable fun, nb::callable cache_miss, + std::vector static_argnums, + nb::callable python_shard_arg_fallback, + xla::nb_class_ptr pytree_registry) + : fun_(std::move(fun)), + cache_miss_(std::move(cache_miss)), + static_argnums_(std::move(static_argnums)), + pytree_registry_(std::move(pytree_registry)), + python_shard_arg_fallback_(std::move(python_shard_arg_fallback)) { + std::sort(static_argnums_.begin(), static_argnums_.end()); + + function_name_ = + nb::cast(nb::str(nb::getattr(fun_, "__name__", fun_))); + } + PmapFunction(const PmapFunction&) = delete; + PmapFunction& operator=(const PmapFunction& other) = delete; + PmapFunction(PmapFunction&&) = default; + PmapFunction& operator=(PmapFunction&&) = default; + + // This function will: + // (a) flatten the inputs using pytree + // (b) get buffer objects from the arguments + // (c) call the executable + // (d) construct `Array` objects from the outputs + // (e) reconstruct the `PyTree`. + absl::StatusOr Call(nb::handle callable, PyObject* const* args, + size_t nargs, PyObject* kwnames); + + nb::object PythonSignature() { + static const auto* inspect = + new nb::module_(nb::module_::import_("inspect")); + return inspect->attr("signature")(fun_); + } + + int cache_size() { + nb::ft_lock_guard lock(mu_); + return executables_.size(); + } + void cache_clear() { + nb::ft_lock_guard lock(mu_); + return executables_.clear(); + } + const nb::callable& fun() const { return fun_; } + const nb::callable& cache_miss() const { return cache_miss_; } + const std::string& function_name() const { return function_name_; } + const xla::nb_class_ptr& pytree_registry() const { + return pytree_registry_; + } + const nb::callable& python_shard_arg_fallback() const { + return python_shard_arg_fallback_; + } + const std::vector& static_argnums() const { return static_argnums_; } + + // nb::object typed subclass for PmapFunction objects. + class pyobject : public nb::object { + public: + NB_OBJECT(pyobject, nb::object, "PmapFunction", + PmapFunction::IsPmapFunction); + pyobject() = default; + PmapFunction* func() const { + return PmapFunction::AsPmapFunctionUnchecked(*this); + } + }; + // Alias as ::object; outside the scope above we won't confuse nanobind's + // macros. + using object = pyobject; + + // Returns true if `h` is a PmapFunction. + static bool IsPmapFunction(nb::handle handle); + // Converts `handle` to a PmapFunction*. Does not do any checking. + static PmapFunction* AsPmapFunctionUnchecked(nb::handle handle); + + // Helper function used by the tp_clear GC method. + void ClearPythonReferences() { + nb::callable fun, cache_miss, python_shard_arg_fallback; + // Swap values for nulls before they are destroyed. See the Python + // Py_CLEAR() documentation for a discussion of this topic. + std::swap(fun_, fun); + std::swap(cache_miss_, cache_miss); + std::swap(python_shard_arg_fallback_, python_shard_arg_fallback); + } + + // Updates the signature of arguments for a pmapped function. + // + // It deals with the arguments signatures and also of the global and + // thread-local jit context. + absl::Status ComputeCallSignature( + absl::Span flat_dynamic_args, + CallSignature& signature) { + signature.function_name = function_name_; + + // Get dynamic argument signatures. + JitState& global_state = jax::GlobalJitState(); + JitState& tls = jax::ThreadLocalJitState(); + const bool jax_enable_x64 = GetEnableX64(); + signature.jax_enable_x64 = jax_enable_x64; + for (nb::handle arg : flat_dynamic_args) { + auto signature_or_error = xla::PyArgSignatureOfValue(arg, jax_enable_x64); + if (!signature_or_error.ok()) { + VLOG(2) << "PyArgSignatureOfValue failed: " + << signature_or_error.status(); + return signature_or_error.status(); + } + signature.dynamic_arg_signatures.push_back( + std::move(signature_or_error).value()); + } + signature.thread_local_extra_jit_context = tls.extra_jit_context; + signature.global_extra_jit_context = global_state.extra_jit_context; + signature.configs = JitConfigs(); + return absl::Status(); + } + + // Returns, for debugging purposes (e.g. finding why some call misses the + // cache and recompiles), the list of the string representations of the keys. + // + // The format can change at any time. + std::string DebugCacheKeys() { + nb::ft_lock_guard lock(mu_); + std::vector key_strings = { + absl::StrCat("The cache contains ", executables_.size(), " elements:")}; + // We will be able to use auto& [key, _] when TF uses C++ 17. + for (auto& pair : executables_) { + key_strings.push_back(pair.first.DebugString()); + } + return absl::StrJoin(key_strings, "\n\n"); + } + + private: + // Mutates `cache_entry` in place. + void PopulateCacheEntry(PmapCacheEntry& cache_entry, + const nb::tuple& out_and_fastpath_data); + + bool always_fallback_to_python_ = false; + + nb::callable fun_; // The Python function to pmap. + std::string function_name_; + // See JAX _cpp_pmap in api.py for documentation. + nb::callable cache_miss_; + + // We need to know the static arguments to remove them from the arguments + // passed to the underlying PyLoadedExecutable. In sorted order. + std::vector static_argnums_; + xla::nb_class_ptr pytree_registry_; + // We need a `shared_ptr` here to ensure value pointer stability, and to + // ensure that the cache entry remains alive in the presence of concurrent + // removals. + absl::flat_hash_map> + executables_; + + // The fallback function to use with `ShardArgs`. + // TODO(jblespiau): Add support for more types from C++. + nb::callable python_shard_arg_fallback_; + + // Protect methods in FT: + nb::ft_mutex mu_; +}; + +void PmapFunction::PopulateCacheEntry(PmapCacheEntry& cache_entry, + const nb::tuple& out_and_fastpath_data) { + CHECK_EQ(out_and_fastpath_data.size(), 2); + if (out_and_fastpath_data[1].is_none()) { + cache_entry.fall_back_to_python = true; + return; + } + + nb::tuple pmap_data = nb::cast(out_and_fastpath_data[1]); + if (nb::cast(pmap_data.attr("version")) != 1) { + throw xla::XlaRuntimeError(absl::StrCat( + "The versions of jaxlib and Jax are incompatible (pmap cpp version 1 " + "expected, but got ", + nb::cast(pmap_data.attr("version")), + "Upgrade jaxlib and jax. Provided data was:", + nb::cast(nb::str(nb::repr(pmap_data))))); + } + // See api.nb::_PmapFastpathData in the JAX code base for the expected + // namedtuple. + std::shared_ptr executable; + try { + executable = nb::cast>( + pmap_data.attr("xla_executable")); + } catch (const nb::cast_error& e) { + // Backends that don't implement the C++ PjRt APIs + cache_entry.fall_back_to_python = true; + always_fallback_to_python_ = true; + return; + } + cache_entry.executable = std::move(executable); + const std::vector>& devices = + cache_entry.executable->AddressableDevices(); + cache_entry.devices.reserve(devices.size()); + for (auto& device : devices) { + cache_entry.devices.push_back(device->device()); + } + + // Inputs shard args details. + nb::list input_indices = pmap_data.attr("input_indices"); + + cache_entry.py_devices = pmap_data.attr("input_devices"); + auto input_devices = nb::cast>>( + pmap_data.attr("input_devices")); + + nb::list input_array_shardings = pmap_data.attr("input_array_shardings"); + + cache_entry.input_specs.reserve(input_array_shardings.size()); + + for (int i = 0; i < input_array_shardings.size(); ++i) { + cache_entry.input_specs.emplace_back(input_indices[i], + input_array_shardings[i]); + } + + // Outputs specs. + auto out_tree = nb::cast(pmap_data.attr("out_pytree_def")); + cache_entry.out_pytree_def = std::move(out_tree); + nb::list out_avals = pmap_data.attr("out_avals"); + + cache_entry.out_result_specs.reserve(out_avals.size()); + cache_entry.out_dtypes.reserve(out_avals.size()); + cache_entry.out_shapes.reserve(out_avals.size()); + + for (int i = 0; i < out_avals.size(); ++i) { + cache_entry.out_dtypes.push_back(out_avals[i].attr("dtype")); + cache_entry.out_shapes.push_back( + nb::cast>(out_avals[i].attr("shape"))); + cache_entry.out_result_specs.emplace_back(out_avals[i]); + } + + nb::list out_array_shardings = pmap_data.attr("out_array_shardings"); + + DCHECK(out_array_shardings.size() == 0 || + out_avals.size() == out_array_shardings.size()); + + cache_entry.out_array_shardings.reserve(out_array_shardings.size()); + for (nb::handle out_array_sharding : out_array_shardings) { + cache_entry.out_array_shardings.push_back( + nb::borrow(out_array_sharding)); + } + + nb::list out_committed = pmap_data.attr("out_committed"); + + DCHECK(out_committed.size() == 0 || out_avals.size() == out_committed.size()); + + cache_entry.out_committed.reserve(out_committed.size()); + for (nb::handle c : out_committed) { + cache_entry.out_committed.push_back(nb::cast(c)); + } +} + +absl::StatusOr PmapFunction::Call(nb::handle callable, + PyObject* const* args, + size_t nargs, PyObject* kwnames) { + xla::GlobalPyRefManager()->MaybeCollectGarbage(); + + // Calls the cache_miss_ function. This just calls the Python function; it may + // return nullptr value if a Python exception is thrown. + auto cache_miss = [&]() -> nb::tuple { + return nb::steal( + PyObject_Vectorcall(cache_miss_.ptr(), args, nargs, kwnames)); + }; + + // Call the cache_miss() function, extracting the output data and ignoring + // the fastpath data. If the cache miss returns a Python error, returns + // nullptr and leaves the Python error set. + auto fallback_to_cache_miss = [&]() { + nb::tuple cache_miss_output = cache_miss(); + if (!cache_miss_output.ptr()) { + return nb::object(); + } + return nb::object(cache_miss_output[0]); + }; + + if (always_fallback_to_python_) { + return fallback_to_cache_miss(); + } + + size_t num_positional_args = PyVectorcall_NARGS(nargs); + size_t num_keyword_args = kwnames ? PyTuple_GET_SIZE(kwnames) : 0; + absl::Span positional_args(args, num_positional_args); + absl::Span keyword_args(args + num_positional_args, + num_keyword_args); + CallSignature call_signature; + absl::InlinedVector flat_dynamic_args; + std::vector keep_alive_objects; + absl::Status status = + ParseArguments(positional_args, keyword_args, kwnames, static_argnums_, + /*static_argnames=*/{}, pytree_registry_.get(), + call_signature.arg_signature, flat_dynamic_args); + if (!status.ok()) { + VLOG(2) << "ParseArguments failed: " << status; + return fallback_to_cache_miss(); + } + + status = ComputeCallSignature(flat_dynamic_args, call_signature); + if (!status.ok()) { + return fallback_to_cache_miss(); + } + + // Retrieve/Maybe add the executable to the cache. + bool inserted = false; + std::shared_ptr cache_entry_ptr; + { + nb::ft_lock_guard lock(mu_); + std::shared_ptr& entry_ref = executables_[call_signature]; + if (!entry_ref) { + inserted = true; + entry_ref = std::make_shared(pytree_registry_.get()); + } + cache_entry_ptr = entry_ref; + } + PmapCacheEntry& cache_entry = *cache_entry_ptr; + + if (!cache_entry.compilation_complete.HasBeenNotified()) { + // In case of several threads attempting to compile the executable, only + // the one that inserted the item will perform the compilation. + if (inserted) { + nb::object out_and_fastpath_data; + nb::tuple out_tuple; + VLOG(2) << "Cache miss for " << call_signature.DebugString(); + try { + // Calls Python and may release the GIL. May also throw if + // compilation/tracing fails. + out_and_fastpath_data = cache_miss(); + if (!out_and_fastpath_data.ptr()) { + throw nb::python_error(); + } + out_tuple = nb::cast(out_and_fastpath_data); + + PopulateCacheEntry(cache_entry, out_tuple); + } catch (const std::exception& e) { + cache_entry.fall_back_to_python = true; + cache_entry.compilation_complete.Notify(); + throw; + } + cache_entry.compilation_complete.Notify(); + + // We have already computed the result in the miss path so we can return + // it. We are even *required* to do so if there are donated arguments, + // because any donated buffers will now be invalid. + return nb::object(out_tuple[0]); + } else { + // Release the GIL while we wait, making sure the compile thread can + // lock it. + nb::gil_scoped_release release; + cache_entry.compilation_complete.WaitForNotification(); + } + } + if (cache_entry.fall_back_to_python) { + return fallback_to_cache_miss(); + } + + // 1. Parse arguments. + std::vector& input_devices = cache_entry.devices; + std::vector& input_specs = cache_entry.input_specs; + const int num_args = flat_dynamic_args.size(); + + // We need [num_args] for the `Execute` call below. + std::vector> num_args_arrays(num_args); + for (int i = 0; i < num_args; ++i) { + TF_ASSIGN_OR_RETURN( + ShardArgResult sharded_arg, + ShardArg(flat_dynamic_args[i], input_devices, input_specs[i], + cache_entry.py_devices, python_shard_arg_fallback_)); + + num_args_arrays[i] = std::move(sharded_arg.ifrt_array); + if (sharded_arg.owning_sda) { + keep_alive_objects.push_back(std::move(sharded_arg.owning_sda)); + } + } + + xla::ifrt::ExecuteOptions execute_options = cache_entry.executable->options(); + execute_options.launch_id = cache_entry.executable->GetNextLaunchId(); + execute_options.execution_stream_id = + tsl::Env::Default()->GetCurrentThreadId(); + + // A vector of [num_outputs]. + std::vector> output_arrays; + { + nb::gil_scoped_release gil_release; + auto ifrt_executable = cache_entry.executable->ifrt_executable(); + TF_ASSIGN_OR_RETURN( + auto result, ifrt_executable->Execute(absl::MakeSpan(num_args_arrays), + execute_options, + /*devices=*/std::nullopt)); + output_arrays = std::move(result.outputs); + } + + // TODO(jblespiau): We don't need to create the PyBuffer objects. + // Having a C++ `Array`, keeping internally the PjRtBuffer + // objects is sufficient, and we can lazily create the `PyBuffer` only if + // we access them from Python. + auto traceback = xla::Traceback::Get(); + // TODO(jblespiau): Change the `client` function to return a reference. + xla::nb_class_ptr client = cache_entry.executable->client(); + + // Convert the PjRtBuffer objects to PyBuffer, and invert the order from + // [num_devices, num_args] to [num_args, num_devices]. + const int num_outputs = output_arrays.size(); + std::vector flat_sharded_device_arrays; + flat_sharded_device_arrays.reserve(num_outputs); + + const auto& output_specs = cache_entry.out_result_specs; + + TF_RET_CHECK(cache_entry.out_array_shardings.size() == num_outputs); + for (int i = 0; i < num_outputs; ++i) { + const ResultSpec& result_spec = output_specs[i]; + xla::PyArray py_array( + result_spec.out_aval, result_spec.weak_type, cache_entry.out_dtypes[i], + cache_entry.out_shapes[i], cache_entry.out_array_shardings[i], client, + traceback, std::move(output_arrays[i]), cache_entry.out_committed[i], + /*skip_checks=*/true); + + flat_sharded_device_arrays.push_back(std::move(py_array)); + } + + nb::object out = + cache_entry.out_pytree_def.Unflatten(flat_sharded_device_arrays); + + // If there is a post-hook function, call it with the inputs and the outputs. + std::optional post_hook = GetPostHook(); + if (post_hook) { + nb::tuple args_tuple = + nb::steal(PyTuple_New(num_positional_args)); + for (size_t i = 0; i < num_positional_args; ++i) { + Py_INCREF(args[i]); + PyTuple_SET_ITEM(args_tuple.ptr(), i, args[i]); + } + nb::dict kwargs; + if (kwnames) { + for (size_t i = 0; i < num_keyword_args; ++i) { + kwargs[nb::handle(PyTuple_GET_ITEM(kwnames, i))] = + nb::borrow(args[num_positional_args + i]); + } + } + + (*post_hook)(callable, args_tuple, kwargs, out); + } + + return out; +} + +struct JaxPmapFunctionObject { + PyObject_HEAD; +#if PY_VERSION_HEX < 0x030C0000 + PyObject* dict; // Dictionary for __dict__ + PyObject* weakrefs; // Weak references; for use by the Python interpreter. +#endif // PY_VERSION_HEX < 0x030C0000 + vectorcallfunc vectorcall; + PmapFunction fun; +}; + +PyObject* JaxPmapFunction_Type = nullptr; + +bool PmapFunction::IsPmapFunction(nb::handle handle) { + return handle.type().ptr() == JaxPmapFunction_Type; +} + +PmapFunction* PmapFunction::AsPmapFunctionUnchecked(nb::handle handle) { + return &(reinterpret_cast(handle.ptr())->fun); +} + +absl::StatusOr AsPmapFunction(nb::handle handle) { + if (!PmapFunction::IsPmapFunction(handle)) { + return xla::InvalidArgument("Expected a PmapFunction"); + } + return PmapFunction::AsPmapFunctionUnchecked(handle); +} + +namespace { + +extern "C" { + +PyObject* JaxPmapFunction_tp_vectorcall(PyObject* callable, + PyObject* const* args, size_t nargs, + PyObject* kwnames) { + JaxPmapFunctionObject* o = reinterpret_cast(callable); + tsl::profiler::TraceMe traceme([&] { + return absl::StrCat("JaxPmapFunction(", o->fun.function_name(), ")"); + }); + try { + absl::StatusOr out = + o->fun.Call(callable, args, nargs, kwnames); + if (!out.ok()) { + PyErr_SetString(PyExc_ValueError, out.status().ToString().c_str()); + return nullptr; + } + return out.value().release().ptr(); + } catch (nb::python_error& e) { + e.restore(); + return nullptr; + } catch (nb::cast_error& e) { + PyErr_SetString(PyExc_ValueError, e.what()); + return nullptr; + } catch (std::invalid_argument& e) { + PyErr_SetString(PyExc_ValueError, e.what()); + return nullptr; + } +} + +PyObject* JaxPmapFunction_tp_new(PyTypeObject* subtype, PyObject* args, + PyObject* kwds) { + JaxPmapFunctionObject* self = + reinterpret_cast(subtype->tp_alloc(subtype, 0)); + if (!self) return nullptr; +#if PY_VERSION_HEX < 0x030C0000 + self->dict = nullptr; + self->weakrefs = nullptr; +#endif // PY_VERSION_HEX < 0x030C0000 + self->vectorcall = JaxPmapFunction_tp_vectorcall; + return reinterpret_cast(self); +} + +void JaxPmapFunction_tp_dealloc(PyObject* self) { + PyObject_GC_UnTrack(self); + PyTypeObject* tp = Py_TYPE(self); + JaxPmapFunctionObject* o = reinterpret_cast(self); + PyObject_ClearWeakRefs(self); +#if PY_VERSION_HEX < 0x030C0000 + Py_CLEAR(o->dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_ClearManagedDict(self); +#else + PyObject_ClearManagedDict(self); +#endif // PY_VERSION_HEX < 0x030C0000 + o->fun.~PmapFunction(); + tp->tp_free(self); + Py_DECREF(tp); +} + +int JaxPmapFunction_tp_traverse(PyObject* self, visitproc visit, void* arg) { + JaxPmapFunctionObject* o = reinterpret_cast(self); + // https://docs.python.org/3/c-api/typeobj.html#c.PyTypeObject.tp_traverse + Py_VISIT(Py_TYPE(self)); +#if PY_VERSION_HEX < 0x030C0000 + Py_VISIT(o->dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_VisitManagedDict(self, visit, arg); +#else + PyObject_VisitManagedDict(self, visit, arg); +#endif // PY_VERSION_HEX < 0x030C0000 + Py_VISIT(o->fun.fun().ptr()); + Py_VISIT(o->fun.cache_miss().ptr()); + return 0; +} + +int JaxPmapFunction_tp_clear(PyObject* self) { + JaxPmapFunctionObject* o = reinterpret_cast(self); +#if PY_VERSION_HEX < 0x030C0000 + Py_CLEAR(o->dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_ClearManagedDict(self); +#else + PyObject_ClearManagedDict(self); +#endif // PY_VERSION_HEX < 0x030C0000 + o->fun.ClearPythonReferences(); + return 0; +} + +// Implements the Python descriptor protocol so PMAP-compiled functions can be +// used as bound methods. See: +// https://docs.python.org/3/howto/descriptor.html#functions-and-methods +PyObject* JaxPmapFunction_tp_descr_get(PyObject* self, PyObject* obj, + PyObject* type) { + if (obj == nullptr || obj == Py_None) { + Py_INCREF(self); + return self; + } + return PyMethod_New(self, obj); +} + +static PyGetSetDef JaxPmapFunction_tp_getset[] = { + // Having a __dict__ seems necessary to allow !functool.wraps to override + // __doc__. + {const_cast("__dict__"), PyObject_GenericGetDict, + PyObject_GenericSetDict, nullptr, nullptr}, + {nullptr, nullptr, nullptr, nullptr, nullptr}}; + +PyMemberDef JaxPmapFunction_members[] = { + {"__vectorcalloffset__", T_PYSSIZET, + static_cast(offsetof(JaxPmapFunctionObject, vectorcall)), + READONLY, nullptr}, +#if PY_VERSION_HEX < 0x030C0000 + {"__dictoffset__", T_PYSSIZET, + static_cast(offsetof(JaxPmapFunctionObject, dict)), READONLY, + nullptr}, + {"__weaklistoffset__", T_PYSSIZET, + static_cast(offsetof(JaxPmapFunctionObject, weakrefs)), + READONLY, nullptr}, +#endif // PY_VERSION_HEX < 0x030C0000 + {nullptr, 0, 0, 0, nullptr}, +}; + +PyType_Slot JaxPmapFunction_slots[] = { + {Py_tp_new, reinterpret_cast(JaxPmapFunction_tp_new)}, + {Py_tp_dealloc, reinterpret_cast(JaxPmapFunction_tp_dealloc)}, + {Py_tp_traverse, reinterpret_cast(JaxPmapFunction_tp_traverse)}, + {Py_tp_clear, reinterpret_cast(JaxPmapFunction_tp_clear)}, + {Py_tp_getset, reinterpret_cast(JaxPmapFunction_tp_getset)}, + {Py_tp_descr_get, reinterpret_cast(JaxPmapFunction_tp_descr_get)}, + {Py_tp_call, reinterpret_cast(PyVectorcall_Call)}, + {Py_tp_members, reinterpret_cast(JaxPmapFunction_members)}, + {0, nullptr}, +}; + +} // extern "C" + +nb::object MakePmapFunction( + nb::callable fun, nb::callable cache_miss, std::vector static_argnums, + nb::callable python_shard_arg_fallback, + xla::nb_class_ptr pytree_registry) { + nb::object obj = nb::steal(JaxPmapFunction_tp_new( + reinterpret_cast(JaxPmapFunction_Type), nullptr, nullptr)); + JaxPmapFunctionObject* buf = + reinterpret_cast(obj.ptr()); + new (&buf->fun) PmapFunction( + std::move(fun), std::move(cache_miss), std::move(static_argnums), + std::move(python_shard_arg_fallback), std::move(pytree_registry)); + return obj; +} + +// Version numbers for the pickled representations. +// Increment these if changing them. +const int kPmapFunctionPickleVersion = 1; + +} // namespace + +void BuildPmapSubmodule(nb::module_& m) { + nb::module_ pmap_lib = m.def_submodule("pmap_lib", "Jax C++ pmap library"); + + nb::class_ no_sharding(pmap_lib, "NoSharding"); + no_sharding.def(nb::init<>()) + .def("__getstate__", + [](const NoSharding& self) { return nb::make_tuple(); }) + .def("__setstate__", + [](NoSharding& self, nb::tuple t) { new (&self) NoSharding(); }) + .def("__repr__", + [](const NoSharding& chuncked) { return "NoSharding()"; }) + .def("__eq__", + [](const NoSharding& self, nb::object obj) { + return nb::isinstance(obj); + }) + .def("__hash__", [](const NoSharding& self) { + const size_t hash = absl::HashOf(self); + return nb::int_(hash); + }); + + nb::class_ chunked(pmap_lib, "Chunked"); + chunked.def(nb::init>()) + .def("__getstate__", + [](const Chunked& self) { return nb::make_tuple(self.chunks); }) + .def("__setstate__", + [](Chunked& self, nb::tuple t) { + new (&self) Chunked{nb::cast>(t[0])}; + }) + .def_ro("chunks", &Chunked::chunks) + .def("__repr__", + [](const Chunked& chuncked) { + return absl::StrCat("Chunked(", + absl::StrJoin(chuncked.chunks, ","), ")"); + }) + .def("__eq__", [](const Chunked& self, nb::object other) { + if (!nb::isinstance(other)) { + return false; + } + return self == nb::cast(other); + }); + + nb::class_ unstacked(pmap_lib, "Unstacked"); + unstacked.def(nb::init()) + .def("__getstate__", + [](const Unstacked& self) { return nb::make_tuple(self.size); }) + .def("__setstate__", + [](Unstacked& self, nb::tuple t) { + new (&self) Unstacked{nb::cast(t[0])}; + }) + .def_ro("size", &Unstacked::size) + .def("__repr__", + [](const Unstacked& x) { + return absl::StrCat("Unstacked(", x.size, ")"); + }) + .def("__eq__", [](const Unstacked& self, nb::object other) { + if (!nb::isinstance(other)) { + return false; + } + return self == nb::cast(other); + }); + + nb::class_ sharded_axis(pmap_lib, "ShardedAxis"); + sharded_axis.def(nb::init()) + .def("__getstate__", + [](const ShardedAxis& self) { return nb::make_tuple(self.axis); }) + .def("__setstate__", + [](ShardedAxis& self, nb::tuple t) { + new (&self) ShardedAxis{nb::cast(t[0])}; + }) + .def_ro("axis", &ShardedAxis::axis) + .def("__repr__", + [](const ShardedAxis& x) { + return absl::StrCat("ShardedAxis(axis=", x.axis, ")"); + }) + .def("__eq__", [](const ShardedAxis& self, const ShardedAxis& other) { + return self == other; + }); + + nb::class_ replicated(pmap_lib, "Replicated"); + replicated.def(nb::init()) + .def("__getstate__", + [](const Replicated& self) { return nb::make_tuple(self.replicas); }) + .def("__setstate__", + [](Replicated& self, nb::tuple t) { + new (&self) Replicated{nb::cast(t[0])}; + }) + .def_ro("replicas", &Replicated::replicas) + .def("__repr__", + [](const Replicated& x) { + return absl::StrCat("Replicated(replicas=", x.replicas, ")"); + }) + .def("__eq__", [](const Replicated& self, const Replicated& other) { + return self == other; + }); + + nb::class_ sharding_spec(pmap_lib, "ShardingSpec"); + sharding_spec + .def(nb::init(), nb::arg("sharding"), + nb::arg("mesh_mapping")) + .def("__getstate__", + [](const ShardingSpec& self) { + auto sharding = + xla::SpanToNbTuple(absl::MakeConstSpan(self.GetSharding())); + auto mesh_mapping = + xla::SpanToNbTuple(absl::MakeConstSpan(self.GetMeshMapping())); + return nb::make_tuple(sharding, mesh_mapping); + }) + .def("__setstate__", + [](ShardingSpec& self, nb::tuple t) { + new (&self) + ShardingSpec{nb::cast>(t[0]), + nb::cast>(t[1])}; + }) + .def_prop_ro( + "sharding", + [](const ShardingSpec& self) { + return xla::SpanToNbTuple(absl::MakeConstSpan(self.GetSharding())); + }) + .def_prop_ro("mesh_mapping", + [](const ShardingSpec& self) { + return xla::SpanToNbTuple( + absl::MakeConstSpan(self.GetMeshMapping())); + }) + .def("__eq__", [](const ShardingSpec& self, + const ShardingSpec& other) { return self == other; }) + .def("__hash__", [](const ShardingSpec& self) { + const size_t hash = absl::HashOf(self); + return nb::int_(hash); + }); + + // We need to use heap-allocated type objects because we want to add + // additional methods dynamically. + + std::string name = + absl::StrCat(nb::cast(m.attr("__name__")), ".PmapFunction"); + PyType_Spec pmap_function_spec = { +#if PY_VERSION_HEX < 0x030B0000 + // Work around for https://github.com/python/cpython/issues/89478 + // CPython 3.10 and earlier assume that the .name value remains alive + // forever. + /*.name=*/strdup(name.c_str()), +#else + /*.name=*/name.c_str(), +#endif // PY_VERSION_HEX < 0x030B0000 + /*.basicsize=*/static_cast(sizeof(JaxPmapFunctionObject)), + /*.itemsize=*/0, +#if PY_VERSION_HEX < 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | + Py_TPFLAGS_HAVE_VECTORCALL, +#else // PY_VERSION_HEX >= 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | + Py_TPFLAGS_HAVE_VECTORCALL | Py_TPFLAGS_MANAGED_DICT | + Py_TPFLAGS_MANAGED_WEAKREF, +#endif // PY_VERSION_HEX >= 0x030C0000 + /*.slots=*/JaxPmapFunction_slots, + }; + + JaxPmapFunction_Type = PyType_FromSpec(&pmap_function_spec); + if (!JaxPmapFunction_Type) { + throw nb::python_error(); + } + nb::object cfun = nb::borrow(JaxPmapFunction_Type); + + // Add PmapFunction to the xla_extension module so it can be pickled. + m.attr("PmapFunction") = cfun; + + cfun.attr("__signature__") = + xla::nb_property_readonly([](nb::handle self) -> nb::object { + PmapFunction* fun = xla::ValueOrThrow(AsPmapFunction(self)); + return fun->PythonSignature(); + }); + // Required by `post_hook`. + cfun.attr("_cache_miss") = + xla::nb_property_readonly([](nb::handle self) -> nb::object { + PmapFunction* fun = xla::ValueOrThrow(AsPmapFunction(self)); + return fun->cache_miss(); + }); + cfun.attr("__getstate__") = nb::cpp_function( + [](const PmapFunction::object& self) { + PmapFunction* fn = self.func(); + nb::dict pickle; + pickle["version"] = kPmapFunctionPickleVersion; + pickle["fun"] = fn->fun(); + pickle["cache_miss"] = fn->cache_miss(); + pickle["static_argnums"] = fn->static_argnums(); + pickle["python_shard_arg_fallback"] = fn->python_shard_arg_fallback(); + pickle["pytree_registry"] = nb::cast(fn->pytree_registry()); + return pickle; + }, + nb::is_method()); + cfun.attr("__setstate__") = nb::cpp_function( + [](PmapFunction::object& self, const nb::dict& pickle) { + int version = nb::cast(pickle["version"]); + if (version != kPmapFunctionPickleVersion) { + throw std::invalid_argument(absl::StrFormat( + "Invalid PmapFunction pickle version, got %d, expected %d. " + "Pickling/Unpickling jitted functions using different JAX " + "versions is not supported.", + version, kPmapFunctionPickleVersion)); + } + nb::callable fun = nb::cast(pickle["fun"]); + nb::callable cache_miss = nb::cast(pickle["cache_miss"]); + std::vector static_argnums = + nb::cast>(pickle["static_argnums"]); + nb::callable python_shard_arg_fallback = + nb::cast(pickle["python_shard_arg_fallback"]); + xla::nb_class_ptr pytree_registry = + nb::cast>( + pickle["pytree_registry"]); + new (&(reinterpret_cast(self.ptr())->fun)) + PmapFunction(std::move(fun), std::move(cache_miss), + std::move(static_argnums), + std::move(python_shard_arg_fallback), + std::move(pytree_registry)); + }, + nb::is_method()); + + // This is only for testing/debugging purposes. + cfun.attr("_cache_size") = + xla::nb_property_readonly([](nb::handle self) -> nb::object { + PmapFunction* fun = xla::ValueOrThrow(AsPmapFunction(self)); + return nb::cast(fun->cache_size()); + }); + + cfun.attr("_cache_clear") = nb::cpp_function( + [](nb::handle self) { + PmapFunction* fun = xla::ValueOrThrow(AsPmapFunction(self)); + fun->cache_clear(); + }, + nb::is_method()); + + cfun.attr("_debug_cache_keys") = nb::cpp_function( + [](nb::handle self) -> std::string { + PmapFunction* fun = xla::ValueOrThrow(AsPmapFunction(self)); + return fun->DebugCacheKeys(); + }, + nb::is_method()); + + pmap_lib.def( + "pmap", + [](nb::callable fun, nb::callable cache_miss, + std::vector static_argnums, nb::callable shard_arg_fallback, + nb::object pytree_registry) -> nb::object { + xla::nb_class_ptr registry = + nb::cast>(pytree_registry); + return MakePmapFunction( + std::move(fun), std::move(cache_miss), std::move(static_argnums), + std::move(shard_arg_fallback), std::move(registry)); + }, + nb::arg("fun"), nb::arg("cache_miss"), nb::arg("static_argnums"), + nb::arg("shard_arg_fallback"), nb::arg("pytree_registry")); +} + +} // namespace jax diff --git a/jaxlib/xla/pmap_lib.h b/jaxlib/xla/pmap_lib.h new file mode 100644 index 000000000000..9ad60a03daf6 --- /dev/null +++ b/jaxlib/xla/pmap_lib.h @@ -0,0 +1,37 @@ +/* Copyright 2021 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_PMAP_LIB_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_PMAP_LIB_H_ + +#include +#include +#include + +// placeholder for index annotation headers +#include "nanobind/nanobind.h" + +// TODO(jblespiau): The current implementation moves the Python logic to C++, +// as a preliminary step to executing the `pmap` execution path from C++. +// It implements the current Python behavior (thus, it may not be optimal, and +// we will be able to modify it later). + +namespace jax { + +void BuildPmapSubmodule(nanobind::module_& m); + +} // namespace jax + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_PMAP_LIB_H_ diff --git a/jaxlib/xla/sdy.cc b/jaxlib/xla/sdy.cc new file mode 100644 index 000000000000..c6d1145517d8 --- /dev/null +++ b/jaxlib/xla/sdy.cc @@ -0,0 +1,143 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/sdy.h" + +#include +#include + +#include "mhlo/transforms/passes.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Bytecode/BytecodeWriter.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/tuple.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "shardy/dialect/sdy/ir/dialect.h" +#include "shardy/dialect/sdy/ir/utils.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" +#include "xla/mlir_hlo/mhlo/transforms/passes.h" +#include "xla/pjrt/mlir_to_hlo.h" +#include "xla/pjrt/status_casters.h" +#include "xla/service/spmd/shardy/constants.h" +#include "xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.h" +#include "xla/service/spmd/shardy/sdy_round_trip/pipelines.h" +#include "xla/service/spmd/shardy/utils.h" +#include "xla/tsl/framework/mlir/status_scoped_diagnostic_handler.h" + +namespace nb = nanobind; + +namespace xla { + +namespace { + +absl::StatusOr SerializeUsingBytecode(mlir::ModuleOp module) { + std::string bytecode; + llvm::raw_string_ostream os(bytecode); + mlir::BytecodeWriterConfig config; + if (mlir::failed(mlir::writeBytecodeToFile(module, os, config))) { + return absl::InvalidArgumentError("mlir::writeBytecodeToFile failed"); + } + return bytecode; +} + +} // namespace + +void BuildSdySubmodule(nb::module_& m) { + nb::module_ mlir_module = m.def_submodule("sdy", "Shardy/XLA integration"); + + mlir_module + // TODO(b/707574930): define a C API for the XLA pipelines. + .def( + "sdy_round_trip_export_pipeline", + [](const nb::bytes& bytecode) -> nb::bytes { + mlir::MLIRContext context; + mlir::OwningOpRef module = + xla::ValueOrThrow(ParseMlirModuleString( + absl::string_view(bytecode.c_str(), bytecode.size()), + context)); + mlir::PassManager pm(&context); + sdy::addSdyRoundTripExportPipeline(pm); + tsl::StatusScopedDiagnosticHandler diagnosticHandler(&context); + ThrowIfError(diagnosticHandler.consumeStatus(pm.run(module.get()))); + std::string module_str = + xla::ValueOrThrow(SerializeUsingBytecode(module.get())); + return nb::bytes(module_str.data(), module_str.size()); + }, + nb::arg("module")) + .def( + "sdy_round_trip_import_shardings", + [](const nb::bytes& bytecode) -> nb::bytes { + mlir::MLIRContext context; + mlir::OwningOpRef module = + xla::ValueOrThrow(ParseMlirModuleString( + absl::string_view(bytecode.c_str(), bytecode.size()), + context)); + mlir::PassManager pm(&context); + pm.addPass(xla::sdy::createSdyRoundTripImportShardyAttrsPass()); + tsl::StatusScopedDiagnosticHandler diagnosticHandler(&context); + ThrowIfError(diagnosticHandler.consumeStatus(pm.run(module.get()))); + std::string module_str = + xla::ValueOrThrow(SerializeUsingBytecode(module.get())); + return nb::bytes(module_str.data(), module_str.size()); + }, + nb::arg("module")) + .def("lowered_with_shardy", + [](const nb::bytes& bytecode) -> bool { + mlir::MLIRContext context; + mlir::OwningOpRef module = + xla::ValueOrThrow(ParseMlirModuleString( + absl::string_view(bytecode.c_str(), bytecode.size()), + context)); + return mlir::sdy::getMeshAttr(module.get(), "mesh") || + sdy::tryGetFrontendAttr( + module.get(), sdy::kMeshesRoundTripAttr) + .has_value(); + }) + // TODO(bartchr): delete this and all uses of it once I have JAX export + // support multiple meshes. + .def("get_mesh", [](const nb::bytes& bytecode) -> nb::list { + mlir::MLIRContext context; + mlir::OwningOpRef module = + xla::ValueOrThrow(ParseMlirModuleString( + absl::string_view(bytecode.c_str(), bytecode.size()), context)); + auto mesh_op = + mlir::SymbolTable::lookupNearestSymbolFrom( + module.get(), mlir::StringAttr::get(&context, "mesh")); + if (!mesh_op) { + return {}; + } + nb::list mesh_shape; + for (auto axis : mesh_op.getMeshAttr().getAxes()) { + mesh_shape.append( + nb::make_tuple(axis.getName().str(), axis.getSize())); + } + return mesh_shape; + }); +} + +} // namespace xla diff --git a/jaxlib/xla/sdy.h b/jaxlib/xla/sdy.h new file mode 100644 index 000000000000..5d8c8c2eb7dd --- /dev/null +++ b/jaxlib/xla/sdy.h @@ -0,0 +1,28 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_SDY_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_SDY_H_ + +// placeholder for index annotation headers +#include "nanobind/nanobind.h" + +namespace xla { + +void BuildSdySubmodule(nanobind::module_& m); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_SDY_H_ diff --git a/jaxlib/xla/weakref_lru_cache.cc b/jaxlib/xla/weakref_lru_cache.cc new file mode 100644 index 000000000000..80498f30aaef --- /dev/null +++ b/jaxlib/xla/weakref_lru_cache.cc @@ -0,0 +1,400 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/weakref_lru_cache.h" + +#include + +#include +#include +#include +#include +#include +#include +#include // NOLINT +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/cleanup/cleanup.h" +#include "absl/hash/hash.h" +#include "absl/strings/str_cat.h" +#include "absl/synchronization/mutex.h" +#include "absl/synchronization/notification.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "xla/pjrt/lru_cache.h" +#include "xla/tsl/platform/logging.h" + +namespace nb = nanobind; + +namespace jax { +namespace { + +// Minimal wrapper to expose a nb::dict_iterator's value as something +// hashable with Abseil. +class HashablePyDictEntry { + public: + explicit HashablePyDictEntry(std::pair entry) + : entry_(entry) {} + + template + friend H AbslHashValue(H h, const HashablePyDictEntry& v) { + return H::combine(std::move(h), nb::hash(v.entry_.first), + nb::hash(v.entry_.second)); + } + + std::pair entry_; +}; + +// Similarly, a minimalist adaptor around the nb::detail::dict_iterator +// itself. Note that the iterator "is" also a Value. Does not meet the full +// standard iterator requirements, only enough to support H::combine_unordered. +class HashablePyDictIter { + public: + using iterator_category = std::input_iterator_tag; + + explicit HashablePyDictIter(nb::detail::dict_iterator& iter) : iter_(iter) {} + + // Minimal set of iterator operations. + HashablePyDictEntry operator*() const { return HashablePyDictEntry(*iter_); } + bool operator!=(const HashablePyDictIter& rhs) const { + return iter_ != rhs.iter_; + } + void operator++() { ++iter_; } + + private: + nb::detail::dict_iterator& iter_; +}; + +struct HashableKey { + nb::object context; + nb::args args; + nb::kwargs kwargs; + + template + friend H AbslHashValue(H h, const HashableKey& key) { + // Note: Despite the fact this is an ABSL hash function, it's safe to call + // functions that may throw exceptions such as nb::hash(), because it is + // used by an LRUCache, which uses a std::unordered_map, which is + // exception-safe. + h = H::combine(std::move(h), nb::hash(key.context), nb::hash(key.args)); + nb::detail::dict_iterator begin = key.kwargs.begin(); + nb::detail::dict_iterator end = key.kwargs.end(); + h = H::combine_unordered(std::move(h), HashablePyDictIter(begin), + HashablePyDictIter(end)); + h = H::combine(std::move(h), key.kwargs.size()); + return h; + } +}; + +} // namespace + +class WeakrefLRUCache : public std::enable_shared_from_this { + public: + class Key { + public: + Key(nb::object context, nb::args args, nb::kwargs kwargs) + : context_(std::move(context)), + args_(std::move(args)), + kwargs_(std::move(kwargs)), + cached_hash_(absl::HashOf(HashableKey{context_, args_, kwargs_})) {} + + bool operator==(const Key& other) const { + return context_.equal(other.context_) && args_.equal(other.args_) && + kwargs_.equal(other.kwargs_); + } + + template + friend H AbslHashValue(H h, const Key& key) { + return H::combine(std::move(h), key.cached_hash_); + } + + nb::object context() const { return context_; } + nb::args args() const { return args_; } + nb::kwargs kwargs() const { return kwargs_; } + + int tp_traverse(visitproc visit, void* arg) const { + Py_VISIT(context_.ptr()); + Py_VISIT(args_.ptr()); + Py_VISIT(kwargs_.ptr()); + return 0; + } + + private: + nb::object context_; + nb::args args_; + nb::kwargs kwargs_; + size_t cached_hash_; + }; + + struct CacheEntry { + bool has_result = false; + nb::object result; + absl::Notification completed; + std::thread::id thread_id = std::this_thread::get_id(); + + int tp_traverse(visitproc visit, void* arg) const { + Py_VISIT(result.ptr()); + return 0; + } + }; + + struct CacheInfo { + int64_t hits; + int64_t misses; + int64_t maxsize; + int64_t currsize; + }; + + struct WeakrefCacheKey { + nb::weakref ref; + size_t cached_hash; + }; + + using Cache = xla::LRUCache>; + + struct WeakrefCacheValue { + std::shared_ptr cache; + }; + + struct WeakrefKeyHash { + size_t operator()(const WeakrefCacheKey& v) const { return v.cached_hash; } + }; + + struct WeakrefKeyEq { + bool operator()(const WeakrefCacheKey& lhs, + const WeakrefCacheKey& rhs) const { + return lhs.ref.equal(rhs.ref); + } + }; + + WeakrefLRUCache(nb::callable cache_context_fn, nb::callable fn, + int64_t maxsize) + : cache_context_fn_(cache_context_fn), fn_(fn), lru_list_(maxsize) {} + + std::shared_ptr GetCache(WeakrefCacheKey key) { + WeakrefCacheValue& value = entries_[key]; + if (!value.cache) { + value.cache = std::make_shared(&lru_list_); + } + return value.cache; + } + + nb::object Call(nb::object weakref_key, nb::args args, + nb::kwargs kwargs) ABSL_NO_THREAD_SAFETY_ANALYSIS { + nb::object context = cache_context_fn_(); + + // We precompute all of the hash values needed by the various maps rather + // than computing them during the std::unordered_map insertions. At the very + // least, MSVC's std::unordered_map has undefined behavior if the hash + // function throws an exception + // (https://learn.microsoft.com/en-us/cpp/standard-library/unordered-map-class?view=msvc-170#emplace). + Key key(context, args, kwargs); + size_t wrcache_hash = static_cast(nb::hash(weakref_key)); + + // No hash computations after this point. + + auto weakref_gc_callback = nb::cpp_function( + [this_weak = weak_from_this(), wrcache_hash](nb::handle weakref) { + auto cache = this_weak.lock(); + if (cache == nullptr) { + return; + } + // Set up PyCriticalSection for cache python associated object; + auto py_cache = nb::find(cache); + // This should never happen as python cache should always be found + CHECK(py_cache.ptr() != nullptr); + nb::ft_object_guard lock(py_cache); + + // The object the reference referred to is now in the process of being + // destroyed, so we cannot refer to its contents. Python weakref + // objects compare based on identity if the object they refer to is + // gone, so the hash lookup will work fine. + auto it = cache->entries_.find( + WeakrefCacheKey{nb::borrow(weakref), wrcache_hash}); + if (it == cache->entries_.end()) { + return; + } + // Create temp-var to avoid re-entrant erase. + auto tmp = std::move(it->second); + cache->entries_.erase(it); + }); + nb::weakref weakref = nb::weakref(weakref_key, weakref_gc_callback); + WeakrefCacheKey wrcache_key{weakref, wrcache_hash}; + std::shared_ptr cache_ptr = GetCache(wrcache_key); + Cache& cache = *cache_ptr; + ++total_queries_; + + bool inserted = false; + std::shared_ptr entry; + { + // Because the gil can be released during cache insertion, this forces + // the lock order to be mu_ then gil so we must release the gil first. + nb::gil_scoped_release release; + // Acquire a mutex to avoid problems where the gil is released during + // cache insertion and then a second thread invalidates the cache order. + mu_.Lock(); + } + { + // GetOrCreateIfAbsent calls into Python hash and equality functions, + // which may throw exceptions. The use of absl::Cleanup ensures mu_ is + // released if that happens. + absl::Cleanup unlock = [this]() + ABSL_UNLOCK_FUNCTION(mu_) { mu_.Unlock(); }; + entry = cache.GetOrCreateIfAbsent(key, [&inserted](const Key& key) { + inserted = true; + return std::make_shared(); + }); + } + if (!entry->completed.HasBeenNotified()) { + if (inserted) { + ++misses_; + absl::Cleanup notify = [&] { entry->completed.Notify(); }; + entry->result = fn_(weakref_key, *args, **kwargs); + entry->has_result = true; + } else { + if (entry->thread_id == std::this_thread::get_id()) { + auto error_string = + absl::StrCat("Recursively calling ", + nb::cast(nb::repr(weakref_key)), + nb::cast(nb::repr(args))); + PyErr_SetString(PyExc_RecursionError, error_string.c_str()); + throw nb::python_error(); + } + nb::gil_scoped_release release; + entry->completed.WaitForNotification(); + } + } + + if (entry->has_result) { + return entry->result; + } else { + ++misses_; + return fn_(weakref_key, *args, **kwargs); + } + } + std::vector GetKeys() { + std::vector results; + mu_.Lock(); + for (const auto& wr_entry : entries_) { + for (const auto& rest : *wr_entry.second.cache) { + nb::tuple result = + nb::make_tuple(*wr_entry.first.ref, rest.first.context(), + rest.first.args(), rest.first.kwargs()); + results.push_back(std::move(result)); + } + } + mu_.Unlock(); + return results; + } + CacheInfo GetCacheInfo() const { + CacheInfo result; + result.hits = total_queries_ - misses_; + result.misses = misses_; + result.maxsize = lru_list_.Capacity(); + result.currsize = lru_list_.Size(); + return result; + } + void Clear() { + total_queries_ = misses_ = 0; + std::vector> deferred_deletes; + deferred_deletes.reserve(entries_.size()); + for (auto& entry : entries_) { + deferred_deletes.push_back(std::move(entry.second.cache)); + } + entries_.clear(); + deferred_deletes.clear(); + } + + nb::callable cache_context_fn_; + nb::callable fn_; + Cache::LRUList lru_list_; + std::unordered_map + entries_; + int64_t misses_ = 0; + int64_t total_queries_ = 0; + absl::Mutex mu_; + + static int tp_traverse(PyObject* self, visitproc visit, void* arg) { + WeakrefLRUCache* cache = nb::inst_ptr(self); + Py_VISIT(Py_TYPE(self)); + Py_VISIT(cache->cache_context_fn_.ptr()); + Py_VISIT(cache->fn_.ptr()); + for (const auto& [wr_key, wr_value] : cache->entries_) { + Py_VISIT(wr_key.ref.ptr()); + for (const auto& [key, cache_value] : *wr_value.cache) { + int rval = key.tp_traverse(visit, arg); + if (rval != 0) { + return rval; + } + if (cache_value.value.has_value()) { + cache_value.value->get()->tp_traverse(visit, arg); + } + } + } + return 0; + } + + static int tp_clear(PyObject* self) { + WeakrefLRUCache* cache = nb::inst_ptr(self); + cache->Clear(); + cache->cache_context_fn_.reset(); + cache->fn_.reset(); + return 0; + } + + static PyType_Slot slots_[]; +}; + +/* static */ PyType_Slot WeakrefLRUCache::slots_[] = { + {Py_tp_traverse, (void*)WeakrefLRUCache::tp_traverse}, + {Py_tp_clear, (void*)WeakrefLRUCache::tp_clear}, + {0, nullptr}, +}; + +void BuildWeakrefLRUCacheAPI(nb::module_& m) { + auto weakref_lru_cache = + nb::class_(m, "WeakrefLRUCache", + nb::is_weak_referenceable(), + nb::type_slots(WeakrefLRUCache::slots_)) + .def("__call__", &WeakrefLRUCache::Call, nb::lock_self()) + .def("cache_keys", &WeakrefLRUCache::GetKeys, nb::lock_self()) + .def("cache_info", &WeakrefLRUCache::GetCacheInfo, nb::lock_self()) + .def("cache_clear", &WeakrefLRUCache::Clear, nb::lock_self()); + nb::class_(weakref_lru_cache, + "WeakrefLRUCacheInfo") + .def_ro("hits", &WeakrefLRUCache::CacheInfo::hits) + .def_ro("misses", &WeakrefLRUCache::CacheInfo::misses) + .def_ro("maxsize", &WeakrefLRUCache::CacheInfo::maxsize) + .def_ro("currsize", &WeakrefLRUCache::CacheInfo::currsize) + .def("__repr__", [](WeakrefLRUCache::CacheInfo& info) { + return absl::StrCat( + "WeakrefLRUCache(hits=", info.hits, ", misses=", info.misses, + ", maxsize=", info.maxsize, ", currsize=", info.currsize, ")"); + }); + m.def( + "weakref_lru_cache", + [](nb::callable cache_context_fn, nb::callable fn, int64_t maxsize) { + return std::make_shared(cache_context_fn, fn, maxsize); + }, + nb::arg("cache_context_fn"), nb::arg("fn"), nb::arg("maxsize") = 2048); +} + +} // namespace jax diff --git a/jaxlib/xla/weakref_lru_cache.h b/jaxlib/xla/weakref_lru_cache.h new file mode 100644 index 000000000000..444e01cef575 --- /dev/null +++ b/jaxlib/xla/weakref_lru_cache.h @@ -0,0 +1,28 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_WEAKREF_LRU_CACHE_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_WEAKREF_LRU_CACHE_H_ + +// placeholder for index annotation headers +#include "nanobind/nanobind.h" + +namespace jax { + +void BuildWeakrefLRUCacheAPI(nanobind::module_& m); + +} // namespace jax + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_WEAKREF_LRU_CACHE_H_ diff --git a/jaxlib/xla/xla.cc b/jaxlib/xla/xla.cc index 5f39b9173b89..fdd4456b238c 100644 --- a/jaxlib/xla/xla.cc +++ b/jaxlib/xla/xla.cc @@ -46,6 +46,7 @@ limitations under the License. #include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep #include "nanobind/stl/variant.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/sdy.h" #include "xla/backends/cpu/collectives/cpu_collectives.h" #include "xla/pjrt/c/pjrt_c_api.h" #include "xla/pjrt/distributed/client.h" @@ -64,7 +65,6 @@ limitations under the License. #include "xla/python/pjrt_ifrt/pjrt_attribute_map_util.h" #include "xla/python/py_client.h" #include "xla/python/py_program.h" -#include "xla/python/sdy.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/tsl/python/lib/core/numpy.h" // NOLINT @@ -84,6 +84,15 @@ limitations under the License. #include "xla/backends/cpu/collectives/mpi_collectives.h" #endif // !_WIN32 && !PLATFORM_GOOGLE +#include "jaxlib/xla/config.h" +#include "jaxlib/xla/custom_call_sharding.h" +#include "jaxlib/xla/dlpack.h" +#include "jaxlib/xla/jax_jit.h" +#include "jaxlib/xla/mlir.h" +#include "jaxlib/xla/pjit.h" +#include "jaxlib/xla/pmap_lib.h" +#include "jaxlib/xla/weakref_lru_cache.h" +#include "jaxlib/xla/xla_compiler.h" #include "xla/pjrt/distributed/key_value_store_interface.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/pjrt_api.h" @@ -92,22 +101,15 @@ limitations under the License. #include "xla/pjrt/pjrt_common.h" #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/pjrt_layout.h" -#include "xla/python/config.h" -#include "xla/python/custom_call_sharding.h" -#include "xla/python/dlpack.h" #include "xla/python/guard_lib.h" -#include "xla/python/jax_jit.h" #include "xla/python/logging.h" // IWYU pragma: keep -#include "xla/python/mlir.h" #include "xla/python/nb_absl_flat_hash_map.h" // IWYU pragma: keep #include "xla/python/nb_absl_span.h" // IWYU pragma: keep #include "xla/python/nb_class_ptr.h" #include "xla/python/ops.h" -#include "xla/python/pjit.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" #include "xla/python/pjrt_ifrt/pjrt_executable.h" #include "xla/python/pjrt_ifrt/pjrt_topology.h" -#include "xla/python/pmap_lib.h" #include "xla/python/pprof_profile_builder.h" #include "xla/python/profiler.h" #include "xla/python/py_array.h" @@ -120,8 +122,6 @@ limitations under the License. #include "xla/python/pytree.h" #include "xla/python/sharding.h" #include "xla/python/traceback.h" -#include "xla/python/weakref_lru_cache.h" -#include "xla/python/xla_compiler.h" #include "xla/tsl/distributed_runtime/preemption/preemption_sync_manager.h" #include "xla/tsl/platform/status.h" #include "tsl/platform/platform.h" diff --git a/jaxlib/xla/xla_compiler.cc b/jaxlib/xla/xla_compiler.cc new file mode 100644 index 000000000000..f4719b450988 --- /dev/null +++ b/jaxlib/xla/xla_compiler.cc @@ -0,0 +1,1639 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/xla_compiler.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "nanobind/ndarray.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/pair.h" // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/dlpack.h" +#include "xla/array.h" +#include "xla/client/executable_build_options.h" +#include "xla/debug_options_flags.h" +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/ffi.h" +#include "xla/ffi/ffi_api.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_module_group.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/hlo/transforms/simplifiers/flatten_call_graph.h" +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" +#include "xla/hlo/transforms/simplifiers/tuple_simplifier.h" +#include "xla/layout.h" +#include "xla/layout_util.h" +#include "xla/literal.h" +#include "xla/pjrt/compile_options.pb.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/nb_absl_span.h" // IWYU pragma: keep +#include "xla/python/nb_helpers.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/py_client.h" +#include "xla/python/types.h" +#include "xla/service/call_inliner.h" +#include "xla/service/computation_placer.h" +#include "xla/service/custom_call_target_registry.h" +#include "xla/service/hlo.pb.h" +#include "xla/service/hlo_graph_dumper.h" +#include "xla/service/hlo_module_config.h" +#include "xla/service/name_uniquer.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/tsl/lib/strings/proto_serialization.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" +#include "xla/xla.pb.h" +#include "xla/xla_data.pb.h" + +namespace nanobind { +namespace detail { + +template <> +struct type_caster { + public: + NB_TYPE_CASTER_FROM_PYTHON_ONLY(xla::OpMetadata, + const_name("xla::OpMetadata")); + + bool from_python(handle h, uint8_t, cleanup_list*) noexcept { + handle op_type = getattr(h, "op_type"); + if (!op_type.is_none()) { + value.set_op_type(cast(op_type)); + } + handle op_name = getattr(h, "op_name"); + if (!op_name.is_none()) { + value.set_op_name(cast(op_name)); + } + handle source_file = getattr(h, "source_file"); + if (!source_file.is_none()) { + value.set_source_file(cast(source_file)); + } + handle source_line = getattr(h, "source_line"); + if (!source_line.is_none()) { + value.set_source_line(cast(source_line)); + } + return true; + } +}; + +} // namespace detail +} // namespace nanobind + +namespace xla { +namespace { + +namespace nb = nanobind; + +struct Uniquer { + absl::Mutex mu; + NameUniquer name_uniquer ABSL_GUARDED_BY(mu); +}; + +Uniquer* GetUniquer() { + static Uniquer* uniquer = new Uniquer; + return uniquer; +} + +static std::string UniquifyName(const std::string& name) { + Uniquer* uniquer = GetUniquer(); + absl::MutexLock lock(&uniquer->mu); + return uniquer->name_uniquer.GetUniqueName(name); +} + +// Converts a computation to a serialized HloModuleProto. +absl::StatusOr GetComputationSerializedProto( + const XlaComputation& computation) { + std::string result; + if (!tsl::SerializeToStringDeterministic(computation.proto(), &result)) { + return Unknown("Failed to serialize the HloModuleProto."); + } + return nb::bytes(result.data(), result.size()); +} + +// Converts a hlo module to a serialized HloModuleProto. +absl::StatusOr GetHloModuleSerializedProto(const HloModule& module) { + std::string result; + if (!tsl::SerializeToStringDeterministic(module.ToProto(), &result)) { + return Unknown("Failed to serialize the HloModuleProto."); + } + return nb::bytes(result.data(), result.size()); +} + +// Converts a serialized HloModuleProto into a HloModule. +absl::StatusOr> HloModuleFromSerializedProto( + const nb::bytes& bytes) { + HloModuleProto proto; + proto.ParseFromArray(bytes.c_str(), bytes.size()); + TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config, + HloModule::CreateModuleConfigFromProto( + proto, GetDebugOptionsFromFlags())); + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + HloModule::CreateFromProto(proto, module_config)); + return std::shared_ptr(std::move(module)); +} + +absl::StatusOr> GetHloModule( + const XlaComputation& computation) { + TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config, + HloModule::CreateModuleConfigFromProto( + computation.proto(), GetDebugOptionsFromFlags())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr module, + HloModule::CreateFromProto(computation.proto(), module_config)); + return std::shared_ptr(std::move(module)); +} + +// Converts a computation to textual HLO form. +absl::StatusOr GetComputationHloText( + const XlaComputation& computation, bool print_large_constants = false) { + TF_ASSIGN_OR_RETURN(std::shared_ptr hlo_module, + GetHloModule(computation)); + HloPrintOptions options; + options = HloPrintOptions::ShortParsable(); + options.set_print_large_constants(print_large_constants); + return hlo_module->ToString(options); +} + +// Converts a computation to HLO dot graph form. +absl::StatusOr GetComputationHloDotGraph( + const XlaComputation& computation) { + TF_ASSIGN_OR_RETURN(std::shared_ptr hlo_module, + GetHloModule(computation)); + return RenderGraph(*hlo_module->entry_computation(), /*label=*/"", + hlo_module->config().debug_options(), + RenderedGraphFormat::kDot); +} + +// Hashes the HLO module. +absl::StatusOr HashComputation(const XlaComputation& computation) { + TF_ASSIGN_OR_RETURN(std::shared_ptr hlo_module, + GetHloModule(computation)); + return absl::HashOf(*hlo_module); +} +// Safe version of ShapeUtil::MakeShapeWithDenseLayout that fails gracefully on +// invalid input. +absl::StatusOr MakeShapeWithDenseLayout( + PrimitiveType element_type, absl::Span dims, + std::optional> minor_to_major, + std::optional> dynamic_dimensions) { + Shape shape; + if (dynamic_dimensions) { + TF_ASSIGN_OR_RETURN( + shape, ShapeUtil::MakeValidatedShape(element_type, dims, + dynamic_dimensions.value())); + } else { + TF_ASSIGN_OR_RETURN(shape, + ShapeUtil::MakeValidatedShape(element_type, dims)); + } + if (minor_to_major) { + *shape.mutable_layout() = LayoutUtil::MakeLayout(*minor_to_major); + TF_RETURN_IF_ERROR( + LayoutUtil::ValidateLayoutForShape(shape.layout(), shape)); + } + + return shape; +} + +// Pybind function for HloSharding.iota_tile, which is a non-crashing factory +// that produces a HloSharding instance backed by tile assignment of a +// transposed and reshaped iota array of device ids. More specifically the tile +// assignment array is as if it is produced by the following numpy code: +// numpy.arange(math.prod(dims)).reshape(reshape_dims) +// .transpose(transpose_perm).reshape(math.prod(dims)) +// where: +// `dims`: is the dimensions of the tile assignment array, which corresponds to +// OpSharding.tile_assignment_dimensions. +// `reshape_dims`: is the dimensions the 1D iota array is reshaped to. +// `transpose_perm`: is the dimension permutation to transpose `reshape_dims`. +// `subgroup_types`: indicates the subgroups of the last `subgroup_types.size()` +// dimensions in `dims`. +// +// In practice, `reshape_dims` often maps to the axises of user defined device +// mesh, and `transpose_perm` often maps to the user specification of how a +// tensor is partitioned based on the axes defined in the mesh, e.g. for a mesh +// of size 4x2x2 as AxBxC: +// PartitionSpec('A', 'B', 'C') corresponds to reshape_dims=[4,2,2], +// transpose_perm=[0,1,2] (no transpose) +// PartitionSpec('B', 'A', 'C') corresponds to reshape_dims=[4,2,2], +// transpose_perm=[1,0,2] (swap A and B) +absl::StatusOr IotaTileHelper( + absl::Span dims, absl::Span reshape_dims, + absl::Span transpose_perm, + absl::Span subgroup_types) { + if (dims.empty()) { + return InvalidArgument("`dims` should not be empty."); + } + if (reshape_dims.size() != transpose_perm.size()) { + return InvalidArgument( + "`reshape_dims` and `transpose_perm` should have the same size, saw " + "[%s] v.s. [%s]", + absl::StrJoin(reshape_dims, ","), absl::StrJoin(transpose_perm, ",")); + } + if (!reshape_dims.empty() && Product(dims) != Product(reshape_dims)) { + return InvalidArgument( + "Cannot reshape from `dims` [%s] to `reshape_dims` [%s].", + absl::StrJoin(dims, ","), absl::StrJoin(reshape_dims, ",")); + } + if (subgroup_types.size() > dims.size()) { + return InvalidArgument( + "`subgroup_types`(%lld) should not have more dimensions than " + "`dims`(%lld).", + subgroup_types.size(), dims.size()); + } + if (reshape_dims.empty()) { + return subgroup_types.empty() + ? HloSharding::IotaTile(dims) + : HloSharding::Subgroup(TileAssignment(dims), subgroup_types); + } + return subgroup_types.empty() + ? HloSharding::IotaTile(dims, reshape_dims, transpose_perm) + : HloSharding::Subgroup( + TileAssignment(dims, reshape_dims, transpose_perm), + subgroup_types); +} + +// Registers a 'fn' as a custom call target. +// +// `fn` must be a custom call implementation function pointer (XLA_FFI_Handler* +// when implemented as FFI handler) encapsulated in a PyCapsule object or a +// a dictionary of function pointers (also encapsulated in a PyCapsule). +// +// See XLA_FFI_ExecutionStage documentation for more details about the +// custom execution stages. +absl::Status PyRegisterCustomCallTarget(const std::string& fn_name, + nb::object fn, + const std::string& platform, + int api_version, + XLA_FFI_Handler_Traits traits) { + // Register legacy custom call target (untyped void* API). + if (api_version == 0) { + if (traits != 0) { + return absl::InvalidArgumentError( + "Custom call target registration with traits is not supported for " + "api_version=0"); + } + + nb::capsule capsule; + if (!nb::try_cast(fn, capsule)) { + return absl::InvalidArgumentError( + "Custom call target registration with api_version=0 requires a " + "PyCapsule fn object"); + } + + CustomCallTargetRegistry::Global()->Register( + fn_name, static_cast(capsule.data()), platform); + return absl::OkStatus(); + } + + // Register XLA FFI handler (typed API with explicit function signatures). + if (api_version == 1) { + nb::capsule capsule; + if (nb::try_cast(fn, capsule)) { + return ffi::TakeStatus(ffi::Ffi::RegisterStaticHandler( + xla::ffi::GetXlaFfiApi(), fn_name, platform, + reinterpret_cast( + static_cast(capsule.data())))); + } + + nb::dict bundle; + if (nb::try_cast(fn, bundle)) { + auto handler = [&](const char* name) -> absl::StatusOr { + if (!bundle.contains(name)) return nullptr; + + nb::capsule capsule; + if (!nb::try_cast(bundle[name], capsule)) { + return absl::InvalidArgumentError( + "Custom call target registration with api_version=1 requires a " + "PyCapsule fn object for all dict keys"); + } + + return reinterpret_cast(capsule.data()); + }; + + XLA_FFI_Handler_Bundle bundle; + TF_ASSIGN_OR_RETURN(bundle.instantiate, handler("instantiate")); + TF_ASSIGN_OR_RETURN(bundle.prepare, handler("prepare")); + TF_ASSIGN_OR_RETURN(bundle.initialize, handler("initialize")); + TF_ASSIGN_OR_RETURN(bundle.execute, handler("execute")); + + return ffi::TakeStatus(ffi::Ffi::RegisterStaticHandler( + xla::ffi::GetXlaFfiApi(), fn_name, platform, bundle, traits)); + } + + return absl::InvalidArgumentError( + "Unsupported custom call target type for api_version=1"); + } + + return absl::UnimplementedError(absl::StrFormat( + "API version %d is not supported by RegisterCustomCallTarget. " + "Supported versions are 0 and 1.", + api_version)); +} + +absl::Status PyRegisterCustomTypeId(absl::string_view type_name, + nb::object type_id) { + nb::capsule capsule; + if (!nb::try_cast(type_id, capsule)) { + return absl::InvalidArgumentError( + "The type_id argument to register_custom_call_type_id must be a " + "PyCapsule object holding a pointer to a XLA_FFI_TypeId."); + } + XLA_FFI_TypeId* type_id_ptr = + reinterpret_cast(static_cast(capsule.data())); + return ffi::TakeStatus(ffi::Ffi::RegisterTypeId(xla::ffi::GetXlaFfiApi(), + type_name, type_id_ptr)); +} + +template +void DefRepeatedProperty(nb::class_& cls, const char* name, + Container* (T::*getter)()) { + cls.def_prop_rw( + name, + [getter](T& obj) { + Container* elems = (obj.*getter)(); + std::vector result; + result.reserve(elems->size()); + std::copy(elems->begin(), elems->end(), std::back_inserter(result)); + return result; + }, + [getter](T& obj, std::vector new_elems) { + Container* elems = (obj.*getter)(); + elems->Clear(); + elems->Reserve(new_elems.size()); + for (typename Container::value_type& e : new_elems) { + elems->Add(std::move(e)); + } + }); +} + +template +void DefRepeatedEnumProperty(nb::class_& cls, const char* name, + Container* (T::*getter)()) { + cls.def_prop_rw( + name, + [getter](T& obj) { + Container* elems = (obj.*getter)(); + std::vector result; + result.reserve(elems->size()); + std::copy(elems->begin(), elems->end(), std::back_inserter(result)); + return result; + }, + [getter](T& obj, nb::sequence new_elems) { + Container* elems = (obj.*getter)(); + elems->Clear(); + for (nb::handle e : new_elems) { + elems->Add(nb::cast(e.attr("value"))); + } + }); +} + +template +Array NDArrayToArray(nb::ndarray ndarray) { + std::vector shapes; + shapes.reserve(ndarray.ndim()); + for (int i = 0; i < ndarray.ndim(); ++i) { + shapes.push_back(ndarray.shape(i)); + } + xla::Array array(shapes); + array.Each([&](absl::Span indices, int64_t* val) { + int64_t offset = indices.back(); + int64_t multiplier = 1; + for (int i = ndarray.ndim() - 1; i > 0; --i) { + multiplier *= ndarray.shape(i); + offset += indices[i - 1] * multiplier; + } + *val = *(ndarray.data() + offset); + }); + return array; +} + +absl::StatusOr SubgroupWithTileAssignmentHelper( + nb::ndarray tile_assignment, + absl::Span subgroup_types) { + return HloSharding::Subgroup(NDArrayToArray(tile_assignment), subgroup_types); +} + +nb::ndarray<> LiteralToNdarray(Literal& obj) { + const Shape& shape = obj.shape(); + + if (!shape.has_layout()) { + throw XlaRuntimeError( + "Creating an array is only supported for Literals with a layout."); + } + + const Layout& layout = shape.layout(); + + if (!layout.tiles().empty()) { + throw XlaRuntimeError( + "Creating an array from a tiled Literal is not supported."); + } + + if (!LayoutUtil::IsDenseArray(shape)) { + throw XlaRuntimeError( + "Creating an array is only supported for dense Literals."); + } + + xla::PrimitiveType primitive_type = shape.element_type(); + nb::dlpack::dtype dtype = + ValueOrThrow(PrimitiveTypeToNbDLDataType(primitive_type)); + + absl::Span dimensions = shape.dimensions(); + std::vector unsigned_dimensions(dimensions.begin(), dimensions.end()); + auto strides = StridesForShape(primitive_type, dimensions, layout); + + return nb::ndarray<>(obj.untyped_data(), unsigned_dimensions.size(), + unsigned_dimensions.data(), {}, strides.data(), dtype, + nb::device::cpu::value, 0); +} + +} // namespace + +void BuildXlaCompilerSubmodule(nb::module_& m) { + // Shapes + nb::class_ layout_class(m, "Layout"); + layout_class.def(nb::init>()) + .def("__init__", + [](Layout* self, nb::sequence minor_to_major, nb::sequence tiling, + int64_t element_size_in_bits) { + std::vector xla_tiles; + xla_tiles.reserve(nb::len(tiling.ptr())); + for (auto tile : tiling) { + xla_tiles.push_back(Tile( + SequenceToVector(nb::cast(tile)))); + } + std::vector xla_minor_to_major = + SequenceToVector(minor_to_major); + new (self) + Layout(xla_minor_to_major, xla_tiles, element_size_in_bits); + }) + .def("minor_to_major", + [](Layout layout) { return SpanToNbTuple(layout.minor_to_major()); }) + .def("element_size_in_bits", &Layout::element_size_in_bits) + .def("tiling", + [](Layout layout) { + std::vector result; + result.reserve(layout.tiles().size()); + for (auto& t : layout.tiles()) { + result.push_back(SpanToNbTuple(t.dimensions())); + } + return result; + }) + .def("__eq__", [](const Layout& layout, + const Layout& other) { return layout == other; }) + .def("__ne__", [](const Layout& layout, + const Layout& other) { return layout != other; }) + .def("__str__", &Layout::ToString) + .def("__hash__", + [](const Layout& layout) { return absl::HashOf(layout); }) + .def("to_string", &Layout::ToString) + .def("__getstate__", + [](const Layout& self) -> nb::tuple { + auto proto = self.ToProto(); + std::string result; + if (!tsl::SerializeToStringDeterministic(proto, &result)) { + // throw converted by PyBind to a Python RuntimeError. + throw XlaRuntimeError( + absl::StrCat("Layout.py_pickle: ", + "SerializeToStringDeterministic failed")); + } + return nb::make_tuple(nb::bytes(result.data(), result.size())); + }) + .def("__setstate__", [](Layout* self, nb::tuple t) { + LayoutProto result; + nb::bytes serialized = nb::cast(t[0]); + result.ParseFromArray(serialized.c_str(), serialized.size()); + new (self) Layout(Layout::CreateFromProto(result)); + }); + + nb::class_ shape_class(m, "Shape"); + shape_class + .def("__init__", + [](Shape* self, const std::string& s) { + new (self) Shape(ValueOrThrow(ParseShape(s))); + }) + .def_static( + "tuple_shape", + [](std::vector shapes) -> Shape { + return ShapeUtil::MakeTupleShape(shapes); + }, + "Constructs a tuple shape.") + .def_static("array_shape", + xla::ValueOrThrowWrapper( + [](PrimitiveType type, nb::sequence dims_seq, + std::optional layout_seq, + std::optional> dynamic_dimensions) + -> absl::StatusOr { + std::vector dims = + SequenceToVector(dims_seq); + if (layout_seq) { + std::vector layout = + SequenceToVector(*layout_seq); + return MakeShapeWithDenseLayout(type, dims, layout, + dynamic_dimensions); + } else { + return MakeShapeWithDenseLayout( + type, dims, std::nullopt, dynamic_dimensions); + } + }), + "Constructs an array shape.", nb::arg("type"), + nb::arg("dims"), nb::arg("layout").none() = std::nullopt, + nb::arg("dynamic_dimensions").none() = std::nullopt) + .def_static( + "array_shape", + xla::ValueOrThrowWrapper( + [](nb_dtype dtype, nb::sequence dims_seq, + std::optional layout_seq, + std::optional> dynamic_dimensions) + -> absl::StatusOr { + PrimitiveType type = ValueOrThrow(DtypeToPrimitiveType(dtype)); + std::vector dims = SequenceToVector(dims_seq); + if (layout_seq) { + std::vector layout = + SequenceToVector(*layout_seq); + return MakeShapeWithDenseLayout(type, dims, layout, + dynamic_dimensions); + } else { + return MakeShapeWithDenseLayout(type, dims, std::nullopt, + dynamic_dimensions); + } + }), + "Constructs an array shape.", nb::arg("type"), nb::arg("dims"), + nb::arg("layout").none() = std::nullopt, + nb::arg("dynamic_dimensions").none() = std::nullopt) + .def_static("token_shape", []() { return ShapeUtil::MakeTokenShape(); }) + .def_static( + "scalar_shape", + [](PrimitiveType type) -> Shape { + return ShapeUtil::MakeScalarShape(type); + }, + "Constructs a scalar shape.", nb::arg("type")) + .def_static( + "scalar_shape", + [](nb_dtype dtype) -> Shape { + PrimitiveType type = xla::ValueOrThrow(DtypeToPrimitiveType(dtype)); + return ShapeUtil::MakeScalarShape(type); + }, + "Constructs a scalar shape.", nb::arg("type")) + .def("dimensions", + [](const Shape& shape) -> nb::tuple { + return SpanToNbTuple(shape.dimensions()); + }) + .def("layout", + [](const Shape& shape) -> Layout { return shape.layout(); }) + .def("xla_element_type", &Shape::element_type) + .def("element_type", + [](const Shape& shape) { + return xla::ValueOrThrow( + PrimitiveTypeToNbDtype(shape.element_type())); + }) + .def("numpy_dtype", + [](const Shape& shape) { + if (shape.IsTuple()) { + return nb_dtype("O"); + } + return xla::ValueOrThrow( + PrimitiveTypeToNbDtype(shape.element_type())); + }) + .def("is_tuple", &Shape::IsTuple) + .def("is_array", &Shape::IsArray) + .def("is_token", &Shape::IsToken) + .def("is_static", &Shape::is_static) + .def("is_dynamic", &Shape::is_dynamic) + .def("is_dynamic_dimension", &Shape::is_dynamic_dimension, + nb::arg("dimension")) + .def("set_dynamic_dimension", &Shape::set_dynamic_dimension, + nb::arg("dimension"), nb::arg("is_dynamic")) + .def("rank", &Shape::rank) + .def("to_serialized_proto", + [](const Shape& shape) { + ShapeProto proto = shape.ToProto(); + std::string s = proto.SerializeAsString(); + return nb::bytes(s.data(), s.size()); + }) + .def("tuple_shapes", + [](const Shape& shape) { + return std::vector(shape.tuple_shapes()); + }) + .def("leaf_count", + [](const Shape& shape) { return ShapeUtil::GetLeafCount(shape); }) + .def( + "with_major_to_minor_layout_if_absent", + [](const Shape& shape) { + Shape out = shape; + ShapeUtil::ForEachMutableSubshape( + &out, [](Shape* subshape, const ShapeIndex&) { + if (!subshape->has_layout()) { + LayoutUtil::SetToDefaultLayout(subshape); + } + }); + return out; + }, + "Returns a copy of a shape with missing layouts set to " + "major-to-minor.") + .def("__eq__", [](const Shape& shape, + const Shape& other) { return shape == other; }) + .def("__ne__", [](const Shape& shape, + const Shape& other) { return shape != other; }) + .def("__hash__", [](const Shape& shape) { return absl::HashOf(shape); }) + .def("__repr__", [](const Shape& shape) { + return shape.ToString(/*print_layout=*/true); + }); + + nb::class_(m, "ProgramShape") + .def( + "__init__", + [](ProgramShape* self, absl::Span params, Shape result) { + new (self) ProgramShape(); + for (const Shape& param : params) { + *self->add_parameters() = param; + } + *self->mutable_result() = result; + }) + .def("parameter_shapes", + static_cast& (ProgramShape::*)() const>( + &ProgramShape::parameters)) + .def("result_shape", &ProgramShape::result) + .def("__repr__", &ProgramShape::ToString); + + nb::class_(m, "ShapeIndex") + .def("__init__", + [](ShapeIndex* self, const std::vector& v) { + new (self) ShapeIndex(v.begin(), v.end()); + }) + .def("__repr__", &ShapeIndex::ToString) + .def("__eq__", [](const ShapeIndex& shape_ind, + const ShapeIndex& other) { return shape_ind == other; }) + .def("__ne__", [](const ShapeIndex& shape_ind, + const ShapeIndex& other) { return shape_ind != other; }) + .def("__hash__", + [](const ShapeIndex& shape_ind) { return absl::HashOf(shape_ind); }); + + // Literals + nb::class_(m, "Literal") + .def(nb::init()) + .def("__repr__", &Literal::ToString) + .def( + "__array__", + [](std::shared_ptr obj, std::optional dtype, + std::optional copy) { + // Provides the interface required by numpy to create a np.ndarray. + // Currently don't support the __dl_pack__ interface but can be + // added with very little effort it if needed. + + nb::ndarray np_array(LiteralToNdarray(*obj)); + + if (dtype.has_value()) { + throw XlaRuntimeError( + "Passing of dtype to __array__ not currently supported."); + } + + if (copy.has_value() && *copy) { + // when a copy is requested we _must_ return a copy: + // https://numpy.org/doc/2.1/reference/generated/numpy.ndarray.__array__.html + return np_array.cast(nb::rv_policy::copy); + } + + return np_array.cast(nb::rv_policy::reference_internal, + nb::cast(obj)); + }, + nb::arg("dtype").none() = nb::none(), + nb::arg("copy").none() = nb::none()) + .def("shape", &Literal::shape); + + nb::class_(m, "XlaComputation") + .def("__init__", + [](XlaComputation* self, + const nb::bytes& serialized_hlo_module_proto) { + HloModuleProto proto; + proto.ParseFromArray(serialized_hlo_module_proto.c_str(), + serialized_hlo_module_proto.size()); + new (self) XlaComputation(proto); + }) + .def("get_hlo_module", xla::ValueOrThrowWrapper(GetHloModule)) + .def("program_shape", + xla::ValueOrThrowWrapper(&XlaComputation::GetProgramShape)) + .def("name", &XlaComputation::name) + .def("as_serialized_hlo_module_proto", + xla::ValueOrThrowWrapper(GetComputationSerializedProto)) + .def("as_hlo_text", xla::ValueOrThrowWrapper(GetComputationHloText), + nb::arg("print_large_constants") = false) + .def("as_hlo_dot_graph", + xla::ValueOrThrowWrapper(GetComputationHloDotGraph)) + .def("hash", xla::ValueOrThrowWrapper(HashComputation)) + .def("as_hlo_module", xla::ValueOrThrowWrapper(GetHloModule)); + + nb::class_ hlo_print_options_class(m, "HloPrintOptions"); + hlo_print_options_class.def(nb::init<>()) + .def_static("short_parsable", &HloPrintOptions::ShortParsable) + .def_static("canonical", &HloPrintOptions::Canonical) + .def_static("fingerprint", &HloPrintOptions::Fingerprint) + .def_prop_rw("print_large_constants", + &HloPrintOptions::print_large_constants, + &HloPrintOptions::set_print_large_constants) + .def_prop_rw("print_metadata", &HloPrintOptions::print_metadata, + &HloPrintOptions::set_print_metadata) + .def_prop_rw("print_backend_config", + &HloPrintOptions::print_backend_config, + &HloPrintOptions::set_print_backend_config) + .def_prop_rw("print_result_shape", &HloPrintOptions::print_result_shape, + &HloPrintOptions::set_print_result_shape) + .def_prop_rw("print_operand_shape", &HloPrintOptions::print_operand_shape, + &HloPrintOptions::set_print_operand_shape) + .def_prop_rw("print_operand_names", &HloPrintOptions::print_operand_names, + &HloPrintOptions::set_print_operand_names) + .def_prop_rw("print_ids", &HloPrintOptions::print_ids, + &HloPrintOptions::set_print_ids) + .def_prop_rw("print_extra_attributes", + &HloPrintOptions::print_extra_attributes, + &HloPrintOptions::set_print_extra_attributes) + .def_prop_rw("print_program_shape", &HloPrintOptions::print_program_shape, + &HloPrintOptions::set_print_program_shape) + .def_prop_rw("print_percent", &HloPrintOptions::print_percent, + &HloPrintOptions::set_print_percent) + .def_prop_rw("print_control_dependencies", + &HloPrintOptions::print_control_dependencies, + &HloPrintOptions::set_print_control_dependencies) + .def_prop_rw("compact_operands", &HloPrintOptions::compact_operands, + &HloPrintOptions::set_compact_operands) + .def_prop_rw("include_layout_in_shapes", + &HloPrintOptions::include_layout_in_shapes, + &HloPrintOptions::set_include_layout_in_shapes) + .def_prop_rw("canonicalize_instruction_names", + &HloPrintOptions::canonicalize_instruction_names, + &HloPrintOptions::set_canonicalize_instruction_names) + .def_prop_rw("canonicalize_computations", + &HloPrintOptions::canonicalize_computations, + &HloPrintOptions::set_canonicalize_computations) + .def_prop_rw("indent_amount", &HloPrintOptions::indent_amount, + &HloPrintOptions::set_indent_amount) + .def_prop_rw("is_in_nested_computation", + &HloPrintOptions::is_in_nested_computation, + &HloPrintOptions::set_is_in_nested_computation); + + // HloModule.computations() returns raw pointers. + // pybind seems to prefer smart pointers. + // We give pybind a smart pointer to a wrapper around a raw pointer to satisfy + // pybind and avoid double frees. + class ComputationWrapper { + public: + ComputationWrapper(const HloComputation* comp, + const std::shared_ptr module) + : comp_(comp), module_(module) {} + absl::string_view name() const { return comp_->name(); } + void render_html(const std::string& filename) { + std::string html = xla::ValueOrThrow(RenderGraph( + *comp_, /*label=*/"", comp_->parent()->config().debug_options(), + RenderedGraphFormat::kHtml, HloRenderOptions())); + xla::ThrowIfError(tsl::WriteStringToFile( + tsl::Env::Default(), absl::StrCat(filename, ".html"), html)); + } + + private: + const HloComputation* comp_; + // The module owns the computations: if its destructor is called, the + // computations are freed. To prevent that from happening in cases where the + // module Python object goes out of scope and gets garbage collected before + // the computations, we keep a shared_ptr to the module that originated the + // computation. + const std::shared_ptr module_; + }; + + nb::class_ hlo_computation_class(m, "HloComputation"); + + hlo_computation_class.def_prop_ro("name", &ComputationWrapper::name) + .def("render_html", &ComputationWrapper::render_html); + + nb::class_ hlo_module_class(m, "HloModule"); + hlo_module_class.def_prop_ro("name", &HloModule::name) + .def( + "to_string", + static_cast( + &HloModule::ToString), + nb::arg("options") = HloPrintOptions()) + .def("as_serialized_hlo_module_proto", + xla::ValueOrThrowWrapper(GetHloModuleSerializedProto)) + .def("from_serialized_hlo_module_proto", + xla::ValueOrThrowWrapper(HloModuleFromSerializedProto)) + .def("computations", + [](const std::shared_ptr m) + -> std::vector> { + std::vector> computations; + for (HloComputation* comp : m->computations()) + computations.push_back( + std::make_shared(comp, m)); + return computations; + }) + .def_prop_ro("spmd_output_sharding", + [](const HloModule& m) -> std::optional { + if (!m.has_spmd_output_sharding()) return std::nullopt; + return m.spmd_output_sharding().ToProto(); + }) + .def_prop_ro("spmd_parameters_shardings", + [](const HloModule& m) + -> std::optional> { + if (!m.has_spmd_parameters_shardings()) + return std::nullopt; + std::vector param_shardings; + for (const auto& parameter_sharding : + m.spmd_parameters_shardings()) { + param_shardings.push_back(parameter_sharding.ToProto()); + } + return param_shardings; + }); + + nb::class_ hlo_module_group_class(m, "HloModuleGroup"); + hlo_module_group_class + .def("__init__", + [](HloModuleGroup* self, const std::string& name, + const std::vector>& hlo_modules) { + std::vector> modules; + modules.reserve(hlo_modules.size()); + for (const auto& m : hlo_modules) { + modules.push_back(m->Clone(/*suffix=*/"")); + } + new (self) HloModuleGroup(name, std::move(modules)); + }) + .def_prop_ro("name", &HloModuleGroup::name) + .def("to_string", &HloModuleGroup::ToString) + .def("to_modules", + [](HloModuleGroup& m) -> std::vector> { + std::vector> modules = + m.ConsumeModules(); + std::vector> shared_modules; + shared_modules.reserve(modules.size()); + for (auto& module : modules) { + shared_modules.push_back(std::move(module)); + } + return shared_modules; + }); + + m.def("hlo_module_to_dot_graph", + [](const HloModule& hlo_module) -> std::string { + return xla::ValueOrThrow(RenderGraph( + *hlo_module.entry_computation(), /*label=*/"", + hlo_module.config().debug_options(), RenderedGraphFormat::kDot)); + }); + m.def( + "hlo_module_cost_analysis", + xla::ValueOrThrowWrapper([](PyClient* client, const HloModule& module) + -> absl::StatusOr { + TF_ASSIGN_OR_RETURN(auto analysis, + client->pjrt_client()->GetHloCostAnalysis()); + TF_RETURN_IF_ERROR(module.entry_computation()->Accept(analysis.get())); + + // Convert from HloCostAnalysis::Properties to a standard map. + nb::dict ret; + analysis->properties().ForEach([&](absl::string_view key, float val) { + ret[nb::str(key.data(), key.size())] = nb::cast(val); + }); + return ret; + })); + m.def("hlo_module_from_text", + xla::ValueOrThrowWrapper( + [](const std::string& hlo_module_text) + -> absl::StatusOr> { + auto hlo_module = + xla::ParseAndReturnUnverifiedModule(hlo_module_text); + TF_RETURN_IF_ERROR(hlo_module.status()); + std::shared_ptr result(std::move(*hlo_module)); + return result; + })); + + nb::class_ xla_op_class(m, "XlaOp"); + + nb::class_(m, "XlaBuilder") + .def("__init__", + [](XlaBuilder* self, const std::string& name) { + new (self) XlaBuilder(UniquifyName(name)); + }) + // TODO(phawkins): delete capitalized names after updating callers. + .def("Build", + xla::ValueOrThrowWrapper( + [](XlaBuilder& builder, std::optional root) { + return root ? builder.Build(*root) : builder.Build(); + }), + "Builds a computation from the contents of the builder.", + nb::arg("root") = std::nullopt) + .def("GetShape", xla::ValueOrThrowWrapper(&XlaBuilder::GetShape)) + .def("build", + xla::ValueOrThrowWrapper( + [](XlaBuilder& builder, std::optional root) { + return root ? builder.Build(*root) : builder.Build(); + }), + "Builds a computation from the contents of the builder.", + nb::arg("root") = std::nullopt) + .def("clear_op_metadata", &XlaBuilder::ClearOpMetadata) + .def("get_shape", xla::ValueOrThrowWrapper(&XlaBuilder::GetShape)) + .def( + "get_program_shape", + [](const XlaBuilder& builder, + std::optional root) -> absl::StatusOr { + return root ? builder.GetProgramShape(*root) + : builder.GetProgramShape(); + }, + nb::arg("root") = std::nullopt) + .def("is_constant", xla::ValueOrThrowWrapper(&XlaBuilder::IsConstant)) + .def("set_op_metadata", &XlaBuilder::SetOpMetadata) + .def("set_sharding", &XlaBuilder::SetSharding) + .def("clear_sharding", &XlaBuilder::ClearSharding) + .def("set_frontend_attributes", &XlaBuilder::SetFrontendAttributes) + .def("clear_frontend_attributes", &XlaBuilder::ClearFrontendAttributes) + .def("setup_alias", + [](XlaBuilder& builder, const std::vector& output_index, + int64_t param_number, const std::vector& param_index) { + builder.SetUpAlias( + ShapeIndex(output_index.begin(), output_index.end()), + param_number, + ShapeIndex(param_index.begin(), param_index.end())); + }); + + // Device assignments + nb::class_(m, "DeviceAssignment") + .def_static( + "create", + xla::ValueOrThrowWrapper([](nb::ndarray> array) + -> absl::StatusOr { + if (array.ndim() != 2) { + return InvalidArgument( + "Argument to DeviceAssignment constructor must be a " + "2D array, received an %dD array.", + array.ndim()); + } + DeviceAssignment result(array.shape(0), array.shape(1)); + for (int i = 0; i < array.shape(0); ++i) { + for (int j = 0; j < array.shape(1); ++j) { + result(i, j) = array(i, j); + } + } + return result; + })) + .def("replica_count", &DeviceAssignment::replica_count) + .def("computation_count", &DeviceAssignment::computation_count) + .def("__repr__", &DeviceAssignment::ToString) + .def("serialize", + xla::ValueOrThrowWrapper( + [](const DeviceAssignment& da) -> absl::StatusOr { + DeviceAssignmentProto proto; + da.Serialize(&proto); + std::string result; + if (!tsl::SerializeToStringDeterministic(proto, &result)) { + return Unknown( + "Failed to serialize the DeviceAssignmentProto."); + } + return nb::bytes(result.data(), result.size()); + })); + + nb::class_ compile_options(m, "CompileOptions"); + compile_options + .def("__init__", + [](CompileOptions* self) { + new (self) CompileOptions(); + DebugOptions* debug_options = + self->executable_build_options.mutable_debug_options(); + // Sets fast-math-disabling default options expected by JAX. + debug_options->set_xla_cpu_enable_fast_min_max(false); + debug_options->set_xla_gpu_enable_fast_min_max(false); + }) + .def("__getstate__", + [](const CompileOptions& self) -> nb::tuple { + auto proto = ValueOrThrow(self.ToProto()); + std::string result; + if (!tsl::SerializeToStringDeterministic(proto, &result)) { + // throw converted by PyBind to a Python RuntimeError. + throw XlaRuntimeError( + absl::StrCat("CompileOptions.py_pickle: ", + "SerializeToStringDeterministic failed")); + } + return nb::make_tuple(nb::bytes(result.data(), result.size())); + }) + .def("__setstate__", + [](CompileOptions* self, nb::tuple t) { + CompileOptionsProto result; + nb::bytes serialized = nb::cast(t[0]); + result.ParseFromArray(serialized.c_str(), serialized.size()); + new (self) CompileOptions( + ValueOrThrow(CompileOptions::FromProto(result))); + }) + .def("SerializeAsString", + [](const CompileOptions& self) -> nb::bytes { + auto proto = ValueOrThrow(self.ToProto()); + std::string result; + if (!tsl::SerializeToStringDeterministic(proto, &result)) { + // throw converted by PyBind to a Python RuntimeError. + throw XlaRuntimeError( + absl::StrCat("CompileOptions.SerializeAsString: ", + "SerializeToStringDeterministic failed")); + } + return nb::bytes(result.data(), result.size()); + }) + .def_static("ParseFromString", + [](nb::bytes s) { + CompileOptionsProto result; + result.ParseFromArray(s.c_str(), s.size()); + return ValueOrThrow(CompileOptions::FromProto(result)); + }) + .def_rw("argument_layouts", &CompileOptions::argument_layouts) + .def_rw("parameter_is_tupled_arguments", + &CompileOptions::parameter_is_tupled_arguments) + .def_rw("compile_portable_executable", + &CompileOptions::compile_portable_executable) + .def_ro("executable_build_options", + &CompileOptions::executable_build_options) + .def_rw("env_option_overrides", &CompileOptions::env_option_overrides) + // TODO(phawkins): the following fields exist for backward compatibility. + // Remove them after JAX has been updated not to use them. + .def_rw("tuple_arguments", &CompileOptions::parameter_is_tupled_arguments) + .def_prop_rw( + "num_replicas", + [](const CompileOptions& options) { + return options.executable_build_options.num_replicas(); + }, + [](CompileOptions& options, int num_replicas) { + options.executable_build_options.set_num_replicas(num_replicas); + }) + .def_prop_rw( + "num_partitions", + [](const CompileOptions& options) { + return options.executable_build_options.num_partitions(); + }, + [](CompileOptions& options, int num_partitions) { + options.executable_build_options.set_num_partitions(num_partitions); + }) + .def_prop_rw( + "profile_version", + [](const CompileOptions& options) { return options.profile_version; }, + [](CompileOptions& options, int64_t profile_version) { + options.profile_version = profile_version; + }) + .def_prop_rw( + "device_assignment", + [](const CompileOptions& options) -> std::optional { + return options.executable_build_options.has_device_assignment() + ? std::optional( + options.executable_build_options + .device_assignment()) + : std::nullopt; + }, + [](CompileOptions& options, + const DeviceAssignment& device_assignment) { + options.executable_build_options.set_device_assignment( + device_assignment); + }); + + // Custom-call targets. + m.def( + "register_custom_call_target", + [](nb::object fn_name_py, nb::object fn, const std::string& platform, + int api_version, XLA_FFI_Handler_Traits traits) { + std::string fn_name; + if (!nb::try_cast(fn_name_py, fn_name)) { + nb::bytes bytes = nb::cast(fn_name_py); + fn_name = std::string(bytes.c_str(), bytes.size()); + } + xla::ThrowIfError(PyRegisterCustomCallTarget( + fn_name, std::move(fn), platform, api_version, traits)); + }, + nb::arg("fn_name"), nb::arg("fn"), nb::arg("platform"), + nb::arg("api_version") = 0, nb::arg("traits") = 0); + + m.def( + "custom_call_targets", + [](const std::string& platform) -> nb::dict { + nb::dict targets; + for (const auto& [name, target] : + CustomCallTargetRegistry::Global()->registered_symbols(platform)) { + targets[nb::str(name.data(), name.size())] = nb::capsule(target); + } + + auto ffi_handlers = ffi::StaticRegisteredHandlers(platform); + if (!ffi_handlers.ok()) return targets; + + for (const auto& [name, registration] : *ffi_handlers) { + nb::dict bundle; + auto export_handler = [&](absl::string_view name, + XLA_FFI_Handler* h) { + if (h != nullptr) { + bundle[nb::str(name.data(), name.size())] = + nb::capsule(reinterpret_cast(h)); + } + }; + export_handler("prepare", registration.bundle.prepare); + export_handler("initialize", registration.bundle.initialize); + export_handler("execute", registration.bundle.execute); + targets[nb::str(name.data(), name.size())] = std::move(bundle); + } + return targets; + }, + nb::arg("platform")); + + nb::enum_(m, "AutotuneCacheMode") + .value("UNSPECIFIED", DebugOptions::AUTOTUNE_CACHE_MODE_UNSPECIFIED) + .value("UPDATE", DebugOptions::AUTOTUNE_CACHE_MODE_UPDATE) + .value("READ", DebugOptions::AUTOTUNE_CACHE_MODE_READ); + + m.def( + "register_custom_type_id", + [](absl::string_view type_name, nb::object type_id) { + xla::ThrowIfError(PyRegisterCustomTypeId(type_name, type_id)); + }, + nb::arg("type_name"), nb::arg("type_id")); + + nb::class_(m, "DebugOptions") + .def("__repr__", &DebugOptions::DebugString) + .def_prop_rw("xla_backend_optimization_level", + &DebugOptions::xla_backend_optimization_level, + &DebugOptions::set_xla_backend_optimization_level) + .def_prop_rw("xla_cpu_enable_fast_math", + &DebugOptions::xla_cpu_enable_fast_math, + &DebugOptions::set_xla_cpu_enable_fast_math) + .def_prop_rw("xla_cpu_enable_xprof_traceme", + &DebugOptions::xla_cpu_enable_xprof_traceme, + &DebugOptions::set_xla_cpu_enable_xprof_traceme) + .def_prop_rw("xla_cpu_fast_math_honor_infs", + &DebugOptions::xla_cpu_fast_math_honor_infs, + &DebugOptions::set_xla_cpu_fast_math_honor_infs) + .def_prop_rw("xla_cpu_fast_math_honor_nans", + &DebugOptions::xla_cpu_fast_math_honor_nans, + &DebugOptions::set_xla_cpu_fast_math_honor_nans) + .def_prop_rw("xla_cpu_fast_math_honor_division", + &DebugOptions::xla_cpu_fast_math_honor_division, + &DebugOptions::set_xla_cpu_fast_math_honor_division) + .def_prop_rw("xla_cpu_fast_math_honor_functions", + &DebugOptions::xla_cpu_fast_math_honor_functions, + &DebugOptions::set_xla_cpu_fast_math_honor_functions) + .def_prop_rw("xla_detailed_logging", &DebugOptions::xla_detailed_logging, + &DebugOptions::set_xla_detailed_logging) + .def_prop_rw("xla_enable_dumping", &DebugOptions::xla_enable_dumping, + &DebugOptions::set_xla_enable_dumping) + .def_prop_rw("xla_gpu_enable_fast_min_max", + &DebugOptions::xla_gpu_enable_fast_min_max, + &DebugOptions::set_xla_gpu_enable_fast_min_max) + .def_prop_rw("xla_gpu_dump_autotune_results_to", + &DebugOptions::xla_gpu_dump_autotune_results_to, + [](DebugOptions* self, std::string value) { + self->set_xla_gpu_dump_autotune_results_to(value); + }) + .def_prop_rw("xla_gpu_load_autotune_results_from", + &DebugOptions::xla_gpu_load_autotune_results_from, + [](DebugOptions* self, std::string value) { + self->set_xla_gpu_load_autotune_results_from(value); + }) + .def_prop_rw("xla_gpu_cuda_data_dir", + &DebugOptions::xla_gpu_cuda_data_dir, + [](DebugOptions* self, std::string value) { + self->set_xla_gpu_cuda_data_dir(value); + }) + .def_prop_rw("xla_llvm_disable_expensive_passes", + &DebugOptions::xla_llvm_disable_expensive_passes, + &DebugOptions::set_xla_llvm_disable_expensive_passes) + .def_prop_rw( + "xla_disable_hlo_passes", + [](DebugOptions* self) { + return absl::StrJoin(self->xla_disable_hlo_passes(), ","); + }, + [](DebugOptions* self, std::string value) { + self->clear_xla_disable_hlo_passes(); + for (const auto& passname : + std::vector(absl::StrSplit(value, ','))) { + self->add_xla_disable_hlo_passes(passname); + } + }) + .def_prop_rw( + "xla_enable_hlo_passes_only", + [](DebugOptions* self) { + return absl::StrJoin(self->xla_enable_hlo_passes_only(), ","); + }, + [](DebugOptions* self, std::string value) { + self->clear_xla_enable_hlo_passes_only(); + for (const auto& passname : + std::vector(absl::StrSplit(value, ','))) { + self->add_xla_enable_hlo_passes_only(passname); + } + }) + .def_prop_rw("xla_test_all_input_layouts", + &DebugOptions::xla_test_all_input_layouts, + &DebugOptions::set_xla_test_all_input_layouts) + .def_prop_rw("xla_force_host_platform_device_count", + &DebugOptions::xla_force_host_platform_device_count, + &DebugOptions::set_xla_force_host_platform_device_count) + .def_prop_rw("xla_dump_to", &DebugOptions::xla_dump_to, + [](DebugOptions* self, std::string value) { + self->set_xla_dump_to(value); + }) + .def_prop_rw("xla_dump_hlo_module_re", + &DebugOptions::xla_dump_hlo_module_re, + [](DebugOptions* self, std::string value) { + self->set_xla_dump_hlo_module_re(value); + }) + .def_prop_rw("xla_dump_hlo_pass_re", &DebugOptions::xla_dump_hlo_pass_re, + [](DebugOptions* self, std::string value) { + self->set_xla_dump_hlo_pass_re(value); + }) + .def_prop_rw("xla_dump_hlo_as_text", &DebugOptions::xla_dump_hlo_as_text, + &DebugOptions::set_xla_dump_hlo_as_text) + .def_prop_rw("xla_dump_hlo_as_proto", + &DebugOptions::xla_dump_hlo_as_proto, + &DebugOptions::set_xla_dump_hlo_as_proto) + .def_prop_rw("xla_dump_hlo_as_dot", &DebugOptions::xla_dump_hlo_as_dot, + &DebugOptions::set_xla_dump_hlo_as_dot) + .def_prop_rw("xla_dump_hlo_as_url", &DebugOptions::xla_dump_hlo_as_url, + &DebugOptions::set_xla_dump_hlo_as_url) + .def_prop_rw("xla_dump_hlo_as_html", &DebugOptions::xla_dump_hlo_as_html, + &DebugOptions::set_xla_dump_hlo_as_html) + .def_prop_rw("xla_dump_fusion_visualization", + &DebugOptions::xla_dump_fusion_visualization, + &DebugOptions::set_xla_dump_fusion_visualization) + .def_prop_rw("xla_dump_hlo_snapshots", + &DebugOptions::xla_dump_hlo_snapshots, + &DebugOptions::set_xla_dump_hlo_snapshots) + .def_prop_rw("xla_dump_max_hlo_modules", + &DebugOptions::xla_dump_max_hlo_modules, + &DebugOptions::set_xla_dump_max_hlo_modules) + .def_prop_rw("xla_dump_module_metadata", + &DebugOptions::xla_dump_module_metadata, + &DebugOptions::set_xla_dump_module_metadata) + .def_prop_rw("xla_dump_compress_protos", + &DebugOptions::xla_dump_compress_protos, + &DebugOptions::set_xla_dump_compress_protos) + .def_prop_rw("xla_dump_hlo_as_long_text", + &DebugOptions::xla_dump_hlo_as_long_text, + &DebugOptions::set_xla_dump_hlo_as_long_text) + .def_prop_rw("xla_dump_disable_metadata", + &DebugOptions::xla_dump_disable_metadata, + &DebugOptions::set_xla_dump_disable_metadata) + .def_prop_rw("xla_dump_hlo_pipeline_re", + &DebugOptions::xla_dump_hlo_pipeline_re, + [](DebugOptions* self, std::string value) { + self->set_xla_dump_hlo_pipeline_re(value); + }) + .def_prop_rw("xla_gpu_dump_autotune_logs_to", + &DebugOptions::xla_gpu_dump_autotune_logs_to, + [](DebugOptions* self, std::string value) { + self->set_xla_gpu_dump_autotune_logs_to(value); + }) + .def_prop_rw("xla_gpu_kernel_cache_file", + &DebugOptions::xla_gpu_kernel_cache_file, + [](DebugOptions* self, std::string value) { + self->set_xla_gpu_kernel_cache_file(value); + }) + .def_prop_rw( + "xla_gpu_enable_llvm_module_compilation_parallelism", + &DebugOptions::xla_gpu_enable_llvm_module_compilation_parallelism, + &DebugOptions::set_xla_gpu_enable_llvm_module_compilation_parallelism) + .def_prop_rw("xla_gpu_per_fusion_autotune_cache_dir", + &DebugOptions::xla_gpu_per_fusion_autotune_cache_dir, + [](DebugOptions* self, std::string value) { + self->set_xla_gpu_per_fusion_autotune_cache_dir(value); + }) + .def_prop_rw("xla_gpu_experimental_autotune_cache_mode", + &DebugOptions::xla_gpu_experimental_autotune_cache_mode, + &DebugOptions::set_xla_gpu_experimental_autotune_cache_mode); + + nb::class_(m, "ExecutableBuildOptions") + .def(nb::init<>()) + .def("__repr__", &ExecutableBuildOptions::ToString) + .def_prop_rw( + "fdo_profile", + [](const ExecutableBuildOptions& options) { + return nb::bytes(options.fdo_profile().data(), + options.fdo_profile().size()); + }, + [](ExecutableBuildOptions& options, nb::bytes fdo_profile) { + options.set_fdo_profile( + std::string(fdo_profile.c_str(), fdo_profile.size())); + }) + .def_prop_rw( + "result_layout", + [](const ExecutableBuildOptions& options) -> std::optional { + return options.result_layout() + ? std::optional(*options.result_layout()) + : std::nullopt; + }, + &ExecutableBuildOptions::set_result_layout) + .def_prop_rw("num_replicas", &ExecutableBuildOptions::num_replicas, + &ExecutableBuildOptions::set_num_replicas) + .def_prop_rw("num_partitions", &ExecutableBuildOptions::num_partitions, + &ExecutableBuildOptions::set_num_partitions) + .def_prop_ro("debug_options", + &ExecutableBuildOptions::mutable_debug_options, + nb::rv_policy::reference, nb::keep_alive<1, 0>()) + .def_prop_rw( + "device_assignment", + [](const ExecutableBuildOptions& options) + -> std::optional { + return options.has_device_assignment() + ? std::optional( + options.device_assignment()) + : std::nullopt; + }, + &ExecutableBuildOptions::set_device_assignment) + .def("compilation_environments_from_serialized_proto", + [](ExecutableBuildOptions& options, + const nb::bytes& serialized_proto) { + xla::CompilationEnvironmentsProto env_proto; + env_proto.ParseFromArray(serialized_proto.c_str(), + serialized_proto.size()); + auto comp_envs = xla::ValueOrThrow( + xla::CompilationEnvironments::CreateFromProto(env_proto)); + *options.mutable_comp_envs() = std::move(*comp_envs); + }) + .def_prop_rw("exec_time_optimization_effort", + &ExecutableBuildOptions::exec_time_optimization_effort, + &ExecutableBuildOptions::set_exec_time_optimization_effort) + .def_prop_rw("memory_fitting_effort", + &ExecutableBuildOptions::memory_fitting_effort, + &ExecutableBuildOptions::set_memory_fitting_effort) + .def_prop_rw( + "optimization_level", &ExecutableBuildOptions::optimization_level, + [](ExecutableBuildOptions& options, int value) { + options.set_optimization_level( + static_cast(value)); + }) + .def_prop_rw( + "memory_fitting_level", &ExecutableBuildOptions::memory_fitting_level, + [](ExecutableBuildOptions& options, int value) { + options.set_memory_fitting_level( + static_cast(value)); + }) + .def_prop_rw("use_spmd_partitioning", + &ExecutableBuildOptions::use_spmd_partitioning, + &ExecutableBuildOptions::set_use_spmd_partitioning) + .def_prop_rw("use_auto_spmd_partitioning", + &ExecutableBuildOptions::use_auto_spmd_partitioning, + &ExecutableBuildOptions::set_use_auto_spmd_partitioning) + .def_prop_rw( + "auto_spmd_partitioning_mesh_shape", + &ExecutableBuildOptions::auto_spmd_partitioning_mesh_shape, + &ExecutableBuildOptions::set_auto_spmd_partitioning_mesh_shape) + .def_prop_rw("auto_spmd_partitioning_mesh_ids", + &ExecutableBuildOptions::auto_spmd_partitioning_mesh_ids, + &ExecutableBuildOptions::set_auto_spmd_partitioning_mesh_ids) + .def_prop_rw( + "allow_spmd_sharding_propagation_to_parameters", + [](const ExecutableBuildOptions& options) -> std::vector { + return std::vector( + options.allow_spmd_sharding_propagation_to_parameters().begin(), + options.allow_spmd_sharding_propagation_to_parameters().end()); + }, + [](ExecutableBuildOptions& options, std::vector values) { + absl::InlinedVector v(values.begin(), values.end()); + options.set_allow_spmd_sharding_propagation_to_parameters(v); + }) + .def_prop_rw( + "allow_spmd_sharding_propagation_to_output", + [](const ExecutableBuildOptions& options) -> std::vector { + return std::vector( + options.allow_spmd_sharding_propagation_to_output().begin(), + options.allow_spmd_sharding_propagation_to_output().end()); + }, + [](ExecutableBuildOptions& options, std::vector values) { + absl::InlinedVector v(values.begin(), values.end()); + options.set_allow_spmd_sharding_propagation_to_output(v); + }) + .def_prop_rw("use_shardy_partitioner", + &ExecutableBuildOptions::use_shardy_partitioner, + &ExecutableBuildOptions::set_use_shardy_partitioner); + + nb::enum_ op_sharding_type(m, "OpSharding_Type", + nb::is_arithmetic()); + op_sharding_type.value("REPLICATED", OpSharding::REPLICATED) + .value("MAXIMAL", OpSharding::MAXIMAL) + .value("MANUAL", OpSharding::MANUAL) + .value("TUPLE", OpSharding::TUPLE) + .value("OTHER", OpSharding::OTHER) + .value("UNKNOWN", OpSharding::UNKNOWN); + + nb::enum_ op_sharding_shard_group_type( + m, "OpSharding_ShardGroupType"); + op_sharding_shard_group_type.value("AS", OpSharding::AS) + .value("LIKE", OpSharding::LIKE); + + nb::class_ op_sharding(m, "OpSharding"); + op_sharding + .def_prop_ro_static( + "Type", + [op_sharding_type](const nb::object&) { return op_sharding_type; }) + .def_prop_ro_static("ShardGroupType", + [op_sharding_shard_group_type](const nb::object&) { + return op_sharding_shard_group_type; + }) + .def(nb::init<>()) + .def("__getstate__", + [](const OpSharding& self) { + std::string serialized = self.SerializeAsString(); + return nb::make_tuple( + nb::bytes(serialized.data(), serialized.size())); + }) + .def("__setstate__", + [](OpSharding* self, nb::tuple t) { + new (self) OpSharding(); + nb::bytes serialized = nb::cast(t[0]); + self->ParseFromArray(serialized.c_str(), serialized.size()); + }) + .def_prop_rw("type", &xla::OpSharding::type, &xla::OpSharding::set_type) + .def_prop_rw("replicate_on_last_tile_dim", + &xla::OpSharding::replicate_on_last_tile_dim, + &xla::OpSharding::set_replicate_on_last_tile_dim) + .def_prop_rw("is_shard_group", &xla::OpSharding::is_shard_group, + &xla::OpSharding::set_is_shard_group) + .def_prop_rw("shard_group_id", &xla::OpSharding::shard_group_id, + &xla::OpSharding::set_shard_group_id) + .def_prop_rw("shard_group_type", &xla::OpSharding::shard_group_type, + &xla::OpSharding::set_shard_group_type) + .def("__repr__", + [](const xla::OpSharding& self) { return self.DebugString(); }) + .def("ParseFromString", + [](OpSharding& sharding, const nb::bytes& s) { + sharding.ParseFromArray(s.c_str(), s.size()); + }) + .def("SerializeToString", + [](const OpSharding& sharding) { + std::string serialized = sharding.SerializeAsString(); + return nb::bytes(serialized.data(), serialized.size()); + }) + .def("clone", + [](const OpSharding& sharding) { return OpSharding(sharding); }); + DefRepeatedProperty(op_sharding, "tile_assignment_dimensions", + &xla::OpSharding::mutable_tile_assignment_dimensions); + DefRepeatedProperty(op_sharding, "tile_assignment_devices", + &xla::OpSharding::mutable_tile_assignment_devices); + DefRepeatedProperty(op_sharding, "iota_reshape_dims", + &xla::OpSharding::mutable_iota_reshape_dims); + DefRepeatedProperty(op_sharding, "iota_transpose_perm", + &xla::OpSharding::mutable_iota_transpose_perm); + DefRepeatedProperty(op_sharding, "tuple_shardings", + &xla::OpSharding::mutable_tuple_shardings); + DefRepeatedEnumProperty(op_sharding, "last_tile_dims", + &xla::OpSharding::mutable_last_tile_dims); + + nb::class_ hlo_sharding(m, "HloSharding"); + hlo_sharding + .def_static("from_proto", + xla::ValueOrThrowWrapper(xla::HloSharding::FromProto)) + .def_static("from_string", xla::ValueOrThrowWrapper(xla::ParseSharding)) + .def_static( + "tuple_sharding", + [](xla::Shape shape, + std::vector shardings) -> xla::HloSharding { + return HloSharding::Tuple(shape, shardings); + }, + "Constructs a tuple sharding.") + .def_static( + "iota_tile", xla::ValueOrThrowWrapper(IotaTileHelper), + nb::arg("dims"), + nb::arg("reshape_dims") = absl::Span(), + nb::arg("transpose_perm") = absl::Span(), + nb::arg("subgroup_types") = absl::Span()) + .def_static("manual", [] { return HloSharding::Manual(); }) + .def_static("replicate", [] { return HloSharding::Replicate(); }) + .def_static("unknown", [] { return HloSharding::Unknown(); }) + .def_static( + "subgroup_with_device_ordering", + xla::ValueOrThrowWrapper(SubgroupWithTileAssignmentHelper), + nb::arg("tile_assignment"), + nb::arg("subgroup_types") = absl::Span()) + .def("__eq__", [](const xla::HloSharding& a, + const xla::HloSharding& b) { return a == b; }) + .def("__hash__", + [](const xla::HloSharding& self) { return absl::HashOf(self); }) + .def("is_replicated", &xla::HloSharding::IsReplicated) + .def("is_manual", &xla::HloSharding::IsManual) + .def("is_unknown", &xla::HloSharding::IsUnknown) + .def("is_tiled", &xla::HloSharding::IsTiled) + .def("is_maximal", &xla::HloSharding::IsTileMaximal) + .def("tile", [](const xla::HloSharding& self, + xla::Shape shape) { return self.TileShape(shape); }) + // tile_assignment.array() is computed using an internal cache, + // which is why nb::lock_self() is required. It may be preferable to move + // this locking into the TileAssignment class if we find it to race with + // non-Python users of that class. + .def( + "tuple_elements", + [](const xla::HloSharding& self) { return self.tuple_elements(); }, + nb::lock_self()) + .def( + "num_devices", + [](const xla::HloSharding& self) { + return self.tile_assignment().num_elements(); + }, + nb::lock_self()) + .def( + "num_dimensions", + [](const xla::HloSharding& self) { + return self.tile_assignment().num_dimensions(); + }, + nb::lock_self()) + .def( + "tile_assignment_dimensions", + [](const xla::HloSharding& self) { + absl::Span span = + self.tile_assignment().dimensions(); + CHECK(span.data()); + return span; + }, + nb::lock_self()) + .def( + "tile_assignment_devices", + [](const xla::HloSharding& self) { + auto span = + absl::MakeConstSpan(self.tile_assignment().array().data(), + self.tile_assignment().num_elements()); + CHECK(span.data()); + return span; + }, + nb::lock_self()) + .def("replicate_on_last_tile_dim", + &xla::HloSharding::ReplicateOnLastTileDim) + .def("subgroup_types", &xla::HloSharding::subgroup_types) + .def("__repr__", + [](const xla::HloSharding& self) { return self.ToString(); }) + .def("to_proto", &xla::HloSharding::ToProto); + + nb::class_ frontend_attributes(m, "FrontendAttributes"); + frontend_attributes.def(nb::init<>()) + .def("__setitem__", + [](FrontendAttributes* attr, std::string key, std::string value) { + (*attr->mutable_map())[key] = value; + }); + + nb::enum_(m, "PrecisionConfig_Precision") + .value("DEFAULT", PrecisionConfig::DEFAULT) + .value("HIGH", PrecisionConfig::HIGH) + .value("HIGHEST", PrecisionConfig::HIGHEST); + + nb::enum_(m, "ResultAccuracy_Mode") + .value("DEFAULT", ResultAccuracy::DEFAULT) + .value("HIGHEST", ResultAccuracy::HIGHEST); + + nb::enum_(m, "FftType") + .value("FFT", FftType::FFT) + .value("IFFT", FftType::IFFT) + .value("RFFT", FftType::RFFT) + .value("IRFFT", FftType::IRFFT); + + // Hlo Module Passes + nb::class_ hlo_pass_interface(m, "HloPassInterface"); + hlo_pass_interface.def_prop_ro("name", &HloPassInterface::name) + .def("is_pass_pipeline", &HloPassInterface::IsPassPipeline) + .def("run", + [](HloPassInterface& pass, HloModule* module) -> bool { + return xla::ValueOrThrow(pass.Run(module)); + }) + .def("run_on_module_group", + [](HloPassInterface& pass, HloModuleGroup* module_group) -> bool { + return xla::ValueOrThrow(pass.RunOnModuleGroup(module_group)); + }); + + nb::class_(m, "HloDCE").def(nb::init<>()); + nb::class_(m, "CallInliner").def(nb::init<>()); + nb::class_(m, "FlattenCallGraph") + .def(nb::init<>()); + nb::class_(m, "TupleSimplifier") + .def(nb::init<>()); +} // NOLINT(readability/fn_size) +} // namespace xla diff --git a/jaxlib/xla/xla_compiler.h b/jaxlib/xla/xla_compiler.h new file mode 100644 index 000000000000..f3ffe5fe9440 --- /dev/null +++ b/jaxlib/xla/xla_compiler.h @@ -0,0 +1,28 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_XLA_COMPILER_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_XLA_COMPILER_H_ + +// placeholder for index annotation headers +#include "nanobind/nanobind.h" + +namespace xla { + +void BuildXlaCompilerSubmodule(nanobind::module_& m); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_XLA_COMPILER_H_ From ff718862f2b632a1f8e75c4b2c16f48e0f27724d Mon Sep 17 00:00:00 2001 From: Gleb Pobudzey Date: Mon, 24 Mar 2025 13:49:18 -0700 Subject: [PATCH 129/483] [Mosaic GPU] Adding a new layout WGMMAColFragLayout to be able to load a 1d array and broadcast it along the leading dimension to a 2d shape as an input to a wgmma. In this new layout the first 4 threads of a warp group hold 8 uniques values. These values are replicated in each (thread_idx % 4) group. PiperOrigin-RevId: 740058172 --- .../mosaic/gpu/fragmented_array.py | 95 ++++++++++++++++++- tests/mosaic/gpu_test.py | 34 +++++++ 2 files changed, 127 insertions(+), 2 deletions(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index c2b61c6d5bfe..b730e34e2ed0 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -420,6 +420,23 @@ def thread_idxs(self, shape): yield (row,) +@dataclasses.dataclass(frozen=True) +class WGMMAColFragLayout: + """[n] matrix, where n % 8 == 0.""" + + def thread_idxs(self, shape): + index = ir.IndexType.get() + assert len(shape) == 1 + assert shape[0] % 8 == 0 + + tid = arith.index_cast(ir.IndexType.get(), mgpu.thread_idx()) + lane_id = arith.remui(tid, c(WARP_SIZE, index)) + col_base = arith.muli(arith.remui(lane_id, c(4, index)), c(2, index)) + + for col_group in range(0, shape[0], 8): + col = arith.addi(col_base, c(col_group, index)) + yield (col,) + @dataclasses.dataclass(frozen=True) class WGSplatFragLayout: """A fragmented array where all the values are equal represented as a register per thread. @@ -530,10 +547,11 @@ def linear_thread_idxs(self): yield arith.addi(off, c(i * WARPGROUP_SIZE * self.vec_size, tidx.type)) -FragmentedLayout = WGSplatFragLayout | WGStridedFragLayout | WGMMARowFragLayout | TiledLayout +FragmentedLayout = WGSplatFragLayout | WGStridedFragLayout | WGMMARowFragLayout | WGMMAColFragLayout | TiledLayout WGMMA_ROW_LAYOUT = WGMMARowFragLayout() +WGMMA_COL_LAYOUT = WGMMAColFragLayout() # The tiled layout is equivalent to one described here in PTX documentation: # https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-d @@ -651,6 +669,12 @@ def __init__( if _registers.ndim != 2 or _registers.shape[-1] != 2: raise ValueError(f"Invalid register array shape: {_registers.shape}") + # Registers are [n_tiles] in WGMMA_COL layout + # Each element is a vector of size 2. + case WGMMAColFragLayout(): + if _registers.ndim != 1: + raise ValueError(f"Invalid register array shape: {_registers.shape}") + # Registers are flat case WGStridedFragLayout(shape): [reg_size] = ir.VectorType(_registers.flat[0].type).shape @@ -731,6 +755,36 @@ def load_wgmma_row( registers = np.array(registers).reshape(-1, 2) return cls(_registers=registers, _layout=layout, _is_signed=is_signed) + @classmethod + def load_wgmma_col( + cls, + ref: ir.Value, + *, + is_signed: bool | None = None, + ): + if not ir.MemRefType.isinstance(ref.type): + raise TypeError(ref.type) + + ref_ty = ir.MemRefType(ref.type) + shape = tuple(ref_ty.shape) + layout = WGMMAColFragLayout() + + if len(shape) != 1: + raise ValueError("WGMMAColFragLayout requires a 1D shape.") + + if shape[0] % 8: + raise ValueError( + f"WGMMAColFragLayout requires {shape[0]=} to be a multiple of 8." + ) + + vec_ty = ir.VectorType.get((2,), ref_ty.element_type) + new_regs = np.full((shape[0] // 8,), llvm.mlir_undef(vec_ty)) + + for col_tile, (idx,) in enumerate(layout.thread_idxs(shape)): + reg = vector.load(vec_ty, ref, [idx]) + new_regs[col_tile] = reg + + return cls(_registers=new_regs, _layout=layout, _is_signed=is_signed) @classmethod def splat(cls, value, shape, layout=None, *, is_signed: bool | None = None): @@ -755,6 +809,9 @@ def shape(self): case WGMMARowFragLayout(): row_tiles = self.registers.shape[0] return (row_tiles * 64,) + case WGMMAColFragLayout(): + col_tiles = self.registers.shape[0] + return (col_tiles * 8,) case WGStridedFragLayout(shape): return shape case WGSplatFragLayout(shape=shape): @@ -768,7 +825,7 @@ def shape(self): def mlir_dtype(self): reg_ty = self.registers.flat[0].type match self.layout: - case WGStridedFragLayout() | TiledLayout(): + case WGStridedFragLayout() | WGMMAColFragLayout() | TiledLayout(): return ir.VectorType(reg_ty).element_type case WGMMARowFragLayout() | WGSplatFragLayout(): return reg_ty @@ -1745,6 +1802,23 @@ def broadcast_minor(self, n): _registers=new_regs, _layout=WGMMA_LAYOUT, _is_signed=self.is_signed ) + def broadcast_major(self, m): + if not isinstance(self.layout, WGMMAColFragLayout): + raise NotImplementedError + + if m % 64: + raise ValueError("Number of rows must be divisible by 64") + + reg_shape = WGMMA_LAYOUT.registers_shape((m, self.shape[0])) + new_regs = np.empty(reg_shape, dtype=object) + for col_tile, reg in np.ndenumerate(self.registers): + tile = [slice(None)] * len(new_regs.shape) + tile[1] = col_tile + new_regs[tuple(tile)] = reg + return FragmentedArray( + _registers=new_regs, _layout=WGMMA_LAYOUT, _is_signed=self.is_signed + ) + def select(self, on_true, on_false): if ( not ir.IntegerType.isinstance(self.mlir_dtype) @@ -1802,6 +1876,8 @@ def vs_unsupported(): match self.layout: case WGMMARowFragLayout(): self._store_untiled_wgmma_row(ref) + case WGMMAColFragLayout(): + self._store_untiled_wgmma_col(ref) case WGSplatFragLayout(): vs_unsupported() self._store_untiled_splat(ref) @@ -1865,6 +1941,21 @@ def _store_untiled_wgmma_row(self, ref: ir.Value): ): memref.store(value, ref, [idx]) + def _store_untiled_wgmma_col(self, ref: ir.Value): + """Stores an array with a WGMMA col layout.""" + assert isinstance(self.layout, WGMMAColFragLayout) + index = ir.IndexType.get() + tid = arith.index_cast(ir.IndexType.get(), mgpu.thread_idx()) + tid_wg = arith.remui(tid, c(WARPGROUP_SIZE, index)) + + # Consecutive groups of 4 threads replicate the same data, so we only need to + # transfer data from one group. + is_first = arith.cmpi(arith.CmpIPredicate.ult, tid_wg, c(4, index)) + + with utils.when(is_first): + for (idx,), reg in zip(self.layout.thread_idxs(self.shape), self.registers): + vector.store(reg, ref, [idx]) + def _store_untiled_tiled(self, ref: ir.Value, *, vector_store: bool = True): """Stores an array with a tiled layout. Not optimized at the moment.""" if utils.bitwidth(self.mlir_dtype) < 8: diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 478064188750..d9f56ee1d454 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1982,6 +1982,40 @@ def kernel(ctx, *args): )(inp) np.testing.assert_array_equal(inp, result) + @parameterized.product( + in_shape=((128,), (64,)), dtype=[jnp.float16, jnp.float32] + ) + def test_wgmma_col_load_store_with_layout(self, in_shape, dtype): + def kernel(ctx, *args): + gmem_input, gmem_output, (smem_input, smem_output) = args + copy(gmem_input, smem_input) + t = mgpu.FragmentedArray.load_wgmma_col(smem_input) + t.store_untiled(smem_output) + copy(smem_output, gmem_output) + + inp = out = self.prng.uniform(-1, 1, in_shape).astype(dtype) + result = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (inp,), out, [inp, out], + )(inp) + np.testing.assert_array_equal(result, inp) + + @parameterized.parameters((128, 128), (128, 64), (64, 128)) + def test_broadcast_major(self, m, n): + def kernel(ctx, *args): + gmem_input, gmem_output, () = args + t = mgpu.FragmentedArray.load_wgmma_col(gmem_input) + t.broadcast_major(m).store_untiled(gmem_output) + + inp = self.prng.uniform(-1, 1, (n,)).astype(jnp.float16) + out_shape = jax.ShapeDtypeStruct((m, n), jnp.float16) + + result = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (inp,), out_shape, () + )(inp) + + out_ref = jax.lax.broadcast_in_dim(inp, (m, n), (1,)) + np.testing.assert_array_equal(result, out_ref) + def test_warp_tree_reduce(self): def kernel(ctx, out, *_): del ctx From e75226392384824fecbd8859b3f2a4da84f898ab Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 24 Mar 2025 13:49:48 -0700 Subject: [PATCH 130/483] [pallas] Index Pallas refs instead of using `pl.load` and `pl.store` Indexing is less verbose and is thus easier to read in most cases. The functional API is really only necessary for masked loads and stores. PiperOrigin-RevId: 740058341 --- .../pallas/ops/tpu/flash_attention.py | 116 ++++++++---------- .../splash_attention_kernel.py | 40 +++--- tests/pallas/ops_test.py | 2 +- tests/pallas/pallas_test.py | 26 ++-- tests/pallas/tpu_pallas_test.py | 17 ++- 5 files changed, 85 insertions(+), 116 deletions(-) diff --git a/jax/experimental/pallas/ops/tpu/flash_attention.py b/jax/experimental/pallas/ops/tpu/flash_attention.py index 0cb3d798d09e..ef8dd61abacb 100644 --- a/jax/experimental/pallas/ops/tpu/flash_attention.py +++ b/jax/experimental/pallas/ops/tpu/flash_attention.py @@ -391,9 +391,9 @@ def body(i, _): l_prev = l_scratch_ref[batch_idx] q = q_tile_ref[batch_idx] # [block_q, head_dim] start_k = i * block_k - k = pl.load( - k_tile_ref, (*batch_idx, pl.dslice(start_k, block_k), slice(None)) - ) # [block_k, head_dim] + k = k_tile_ref[ + (*batch_idx, pl.dslice(start_k, block_k), slice(None)) + ] # [block_k, head_dim] s = jax.lax.dot_general( q, k, TRANS_B_DIM_NUMBERS, preferred_element_type=jnp.float32 @@ -403,10 +403,9 @@ def body(i, _): # TODO(tanburn) Should the attention bias be added before or after # multiplication by sm_scale? if ab_tile_ref is not None: - ab = pl.load( - ab_tile_ref, + ab = ab_tile_ref[ (*batch_idx, pl.dslice(None), pl.dslice(start_k, block_k)) - ).astype(jnp.float32) + ].astype(jnp.float32) s += ab if sm_scale != 1.0: @@ -422,10 +421,9 @@ def body(i, _): q_segment_ids = pltpu.repeat( q_segment_ids_tile_ref[batch_idx[0]], repeats, axis=1 ) # [block_q, block_k]. - kv_segment_ids = pl.load( - kv_segment_ids_tile_ref, - (batch_idx[0], pl.dslice(1), pl.dslice(start_k, block_k)), - ) # [1, block_k]. + kv_segment_ids = kv_segment_ids_tile_ref[ + batch_idx[0], :1, pl.dslice(start_k, block_k) + ] # [1, block_k]. mask = jnp.equal(q_segment_ids, kv_segment_ids).astype(jnp.bool_) if causal: @@ -471,9 +469,7 @@ def body(i, _): l_next_inv_safe = jnp.where(l_next == 0.0, 1.0, 1.0 / l_next) acc_scratch_ref[batch_idx] *= l_broadcast(l_corr * l_next_inv_safe) - v = pl.load( - v_tile_ref, (*batch_idx, pl.dslice(start_k, block_k), slice(None)) - ) + v = v_tile_ref[(*batch_idx, pl.dslice(start_k, block_k), slice(None))] o_curr = jax.lax.dot( p.astype(v.dtype), v, preferred_element_type=jnp.float32 ) @@ -529,15 +525,13 @@ def _flash_attention_kernel_single_batch_single_step( raise NotImplementedError( f"kv block size must be a multiple of {NUM_LANES}" ) - q_segment_ids = pl.load( - q_segment_ids_tile_ref, (batch_idx[0],) - ) # [block_q, NUM_LANES]. + q_segment_ids = q_segment_ids_tile_ref[ + batch_idx[0] + ] # [block_q, NUM_LANES]. q_segment_ids = pltpu.repeat( q_segment_ids, repeats, axis=1 ) # [block_q, block_k]. - kv_segment_ids = pl.load( - kv_segment_ids_tile_ref, (batch_idx[0], pl.dslice(1)) - ) # [1, block_k]. + kv_segment_ids = kv_segment_ids_tile_ref[batch_idx[0], :1] # [1, block_k]. mask = jnp.equal(q_segment_ids, kv_segment_ids).astype(jnp.bool_) if causal: @@ -840,33 +834,27 @@ def q_body(j, _): start_q = j * block_q def k_body(i, _): start_k = i * block_k - k = pl.load(k_tile_ref, (0, 0, pl.ds(start_k, block_k), slice(None))) - v = pl.load(v_tile_ref, (0, 0, pl.ds(start_k, block_k), slice(None))) - q = pl.load(q_tile_ref, (0, 0, pl.ds(start_q, block_q), slice(None)) - ) # [block_q, head_dim] - l = pl.load(l_tile_ref, (0, 0, pl.ds(start_q, block_q), slice(None)) - ) # [block_q, 128] - m = pl.load(m_tile_ref, (0, 0, pl.ds(start_q, block_q), slice(None)) - ) # [block_q, 128] - do = pl.load(do_tile_ref, (0, 0, pl.ds(start_q, block_q), slice(None)) - ) # [block_q, 128] - di = pl.load(di_tile_ref, (0, 0, pl.ds(start_q, block_q), slice(None)) - ).astype(jnp.float32) # [block_q, 128] + k = k_tile_ref[0, 0, pl.ds(start_k, block_k), :] + v = v_tile_ref[0, 0, pl.ds(start_k, block_k), :] + q = q_tile_ref[0, 0, pl.ds(start_q, block_q), :] # [block_q, head_dim] + l = l_tile_ref[0, 0, pl.ds(start_q, block_q), :] # [block_q, 128] + m = m_tile_ref[0, 0, pl.ds(start_q, block_q), :] # [block_q, 128] + do = do_tile_ref[0, 0, pl.ds(start_q, block_q), :] # [block_q, 128] + di = di_tile_ref[0, 0, pl.ds(start_q, block_q), :].astype( + jnp.float32 + ) # [block_q, 128] capped_logits = lax.dot_general( q, k, TRANS_B_DIM_NUMBERS, preferred_element_type=jnp.float32 ) # [block_q_major, block_k] if ab_tile_ref is not None: - ab = pl.load( - ab_tile_ref, - ( - 0, - 0, - pl.dslice(j * block_q, block_q), - pl.dslice(i * block_k, block_k), - ), - ).astype(jnp.float32) + ab = ab_tile_ref[ + 0, + 0, + pl.dslice(j * block_q, block_q), + pl.dslice(i * block_k, block_k), + ].astype(jnp.float32) capped_logits += ab if sm_scale != 1.0: @@ -878,15 +866,15 @@ def k_body(i, _): if rem: raise NotImplementedError( ) - q_segment_ids = pl.load( - q_segment_ids_tile_ref, (0, pl.ds(start_q, block_q), slice(None)) - ) # [block_q, NUM_LANES]. + q_segment_ids = q_segment_ids_tile_ref[ + 0, pl.ds(start_q, block_q), : + ] # [block_q, NUM_LANES]. q_segment_ids = pltpu.repeat( q_segment_ids, repeats, axis=1 ) # [block_q, block_k]. - kv_segment_ids = pl.load( - kv_segment_ids_tile_ref, (slice(None), 0, pl.ds(start_k, block_k)) - ) # [1, block_k]. + kv_segment_ids = kv_segment_ids_tile_ref[ + :, 0, pl.ds(start_k, block_k) + ] # [1, block_k]. mask = jnp.equal(q_segment_ids, kv_segment_ids).astype(jnp.bool_) if causal: @@ -913,9 +901,9 @@ def k_body(i, _): 1 / l, block_k // MIN_BLOCK_SIZE, axis=1 ) # [block_q_major, block_k_major] dv = lax.dot(p.T.astype(do.dtype), do, preferred_element_type=jnp.float32) - pl.store(dv_scratch_ref, (pl.ds(start_k, block_k), slice(None)), - pl.load(dv_scratch_ref, (pl.ds(start_k, block_k), slice(None))) - + dv.astype(dv_scratch_ref.dtype)) + dv_scratch_ref[pl.ds(start_k, block_k), :] += dv.astype( + dv_scratch_ref.dtype + ) # di: [block_q, 128] # do: [block_q, head_dim] @@ -931,9 +919,9 @@ def k_body(i, _): # ds: [block_q_major, block_k_major] # q: [block_q_major, head_dim] dk = lax.dot(ds.T.astype(do.dtype), q, preferred_element_type=jnp.float32) - pl.store(dk_scratch_ref, (pl.ds(start_k, block_k), slice(None)), - pl.load(dk_scratch_ref, (pl.ds(start_k, block_k), slice(None))) - + dk.astype(dk_scratch_ref.dtype)) + dk_scratch_ref[pl.ds(start_k, block_k), :] += dk.astype( + dk_scratch_ref.dtype + ) lax.fori_loop(0, block_k_major // block_k, k_body, None, unroll=True) if causal: @@ -1192,12 +1180,8 @@ def start_new_sequence(): def body(i, _): k_slice = pl.ds(i * block_k, block_k) q = q_tile_ref[0, 0, :, :] - k = pl.load( - k_tile_ref, (0, 0, k_slice, slice(None)), - ) # [block_k, head_dim] - v = pl.load( - v_tile_ref, (0, 0, k_slice, slice(None)), - ) # [block_k, head_dim] + k = k_tile_ref[0, 0, k_slice, :] # [block_k, head_dim] + v = v_tile_ref[0, 0, k_slice, :] # [block_k, head_dim] l = l_tile_ref[0, 0, :, :] # [block_q_major, 128] m = m_tile_ref[0, 0, :, :] # [block_q_major, 128] do = do_tile_ref[0, 0, :, :] # [block_q_major, head_dim] @@ -1208,9 +1192,9 @@ def body(i, _): ) if ab_tile_ref is not None: - ab = pl.load( - ab_tile_ref, (0, 0, pl.dslice(None), pl.dslice(i * block_k, block_k)) - ).astype(jnp.float32) + ab = ab_tile_ref[0, 0, :, pl.dslice(i * block_k, block_k)].astype( + jnp.float32 + ) capped_logits += ab if sm_scale != 1.0: @@ -1226,9 +1210,7 @@ def body(i, _): q_segment_ids = pltpu.repeat( q_segment_ids_tile_ref[0], repeats, axis=1 ) # [block_q, block_k]. - kv_segment_ids = pl.load( - kv_segment_ids_tile_ref, (slice(None), 0, k_slice) - ) # [1, block_k]. + kv_segment_ids = kv_segment_ids_tile_ref[:, 0, k_slice] # [1, block_k]. mask = jnp.equal(q_segment_ids, kv_segment_ids).astype(jnp.bool_) if causal: @@ -1269,10 +1251,8 @@ def body(i, _): ds = ds * sm_scale if ds_tile_ref is not None: - pl.store( - ds_tile_ref, - (0, 0, pl.dslice(None), pl.dslice(i * block_k, block_k)), - ds.astype(ds_tile_ref.dtype), + ds_tile_ref[0, 0, :, pl.dslice(i * block_k, block_k)] = ds.astype( + ds_tile_ref.dtype ) # dp: [block_q_major, block_k] diff --git a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py index d0fb6f2f9670..b69b0e36f177 100644 --- a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py +++ b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py @@ -599,9 +599,9 @@ def _apply_mask_and_soft_cap( masks = [] if mask_ref is not None: if k_in_lanes: - mask = pl.load(mask_ref, (slice(None), k_slice)) + mask = mask_ref[:, k_slice] else: - mask = pl.load(mask_ref, (k_slice, slice(None))) + mask = mask_ref[k_slice, :] masks.append( jnp.bitwise_or(mask, jnp.broadcast_to(should_not_mask, mask.shape)) @@ -630,7 +630,7 @@ def _apply_mask_and_soft_cap( k_sequence = k_offset + jax.lax.broadcasted_iota( jnp.int32, (k_slice.size, bq), 0 ) - q_sequence = pl.load(q_sequence_ref, (pl.ds(1), slice(None))) # [1, bq] + q_sequence = q_sequence_ref[:1, :] # [1, bq] q_sequence = jnp.broadcast_to(q_sequence, (k_slice.size, bq)) assert q_sequence.shape == k_sequence.shape @@ -644,7 +644,7 @@ def _apply_mask_and_soft_cap( if q_segment_ids_ref is not None: if k_in_lanes: - kv_ids = pl.load(kv_segment_ids_ref, (pl.ds(1), k_slice)) # [1, k_slice] + kv_ids = kv_segment_ids_ref[:1, k_slice] # [1, k_slice] repeats, rem = divmod(kv_ids.shape[1], NUM_LANES) if rem: raise NotImplementedError(f"block_kv must be a multiple of {NUM_LANES}") @@ -655,9 +655,9 @@ def _apply_mask_and_soft_cap( if rem: raise NotImplementedError(f"block_q must be a multiple of {NUM_LANES}") kv_ids = pltpu.repeat( - pl.load(kv_segment_ids_ref, (k_slice, slice(None))), repeats, axis=1 + kv_segment_ids_ref[k_slice, :], repeats, axis=1 ) # [k_slice, bq] - q_ids = pl.load(q_segment_ids_ref, (pl.ds(1), slice(None))) # [1, bq] + q_ids = q_segment_ids_ref[:1, :] # [1, bq] masks.append(q_ids == kv_ids) def cap_logits(logits): @@ -743,9 +743,9 @@ def body(kv_compute_index, _): q = q_ref[...] if q_layout == HEAD_DIM_MINOR else q_ref[...].T qk_dims = NT_DIM_NUMBERS if k_layout == HEAD_DIM_MINOR else NN_DIM_NUMBERS if k_layout == HEAD_DIM_MINOR: - k = pl.load(k_ref, (slice_k, slice(None))) + k = k_ref[slice_k, :] else: - k = pl.load(k_ref, (slice(None), slice_k)) + k = k_ref[:, slice_k] qk = lax.dot_general(q, k, qk_dims, preferred_element_type=float32) assert qk.shape == (bq, bkv_compute) @@ -794,9 +794,9 @@ def body(kv_compute_index, _): sv_dims = NN_DIM_NUMBERS if v_layout == HEAD_DIM_MINOR else NT_DIM_NUMBERS if v_layout == HEAD_DIM_MINOR: - v = pl.load(v_ref, (slice_k, slice(None))) + v = v_ref[slice_k, :] else: - v = pl.load(v_ref, (slice(None), slice_k)) + v = v_ref[:, slice_k] v = v.astype(float32) o_curr = lax.dot_general(s_curr, v, sv_dims) @@ -1688,13 +1688,13 @@ def body(i, _): q = q_ref[...] # We keep q potentially transposed, since it's always RHS def _load_kv(ref, layout): if layout == HEAD_DIM_MINOR: - return pl.load(ref, (slice_k, slice(None))) - return pl.load(ref, (slice(None), slice_k)).T + return ref[slice_k, :] + return ref[:, slice_k].T k = _load_kv(k_ref, k_layout) v = _load_kv(v_ref, v_layout) - logsumexp = pl.load(logsumexp_ref, (pl.ds(1), slice(None))) + logsumexp = logsumexp_ref[:1, :] do = do_ref[...] - di = pl.load(di_ref, (pl.ds(1), slice(None))) + di = di_ref[:1, :] qk_dims = NT_DIM_NUMBERS if q_layout == HEAD_DIM_MINOR else NN_DIM_NUMBERS qk_uncapped = lax.dot_general( @@ -1718,10 +1718,8 @@ def _load_kv(ref, layout): ) p = jnp.exp(qk - logsumexp) dv = lax.dot(p.astype(do.dtype), do, preferred_element_type=jnp.float32) - dv = dv.astype(dv_scratch_ref.dtype) + pl.load( - dv_scratch_ref, (slice_k, slice(None)) - ) - pl.store(dv_scratch_ref, (slice_k, slice(None)), dv) + dv = dv.astype(dv_scratch_ref.dtype) + dv_scratch_ref[slice_k, :] + dv_scratch_ref[slice_k, :] = dv dp = lax.dot_general( v, do, NT_DIM_NUMBERS, @@ -1737,10 +1735,8 @@ def _load_kv(ref, layout): dk = lax.dot_general( ds.astype(do.dtype), q, dk_dims, preferred_element_type=jnp.float32 ) - dk = dk.astype(dk_scratch_ref.dtype) + pl.load( - dk_scratch_ref, (slice_k, slice(None)) - ) - pl.store(dk_scratch_ref, (slice_k, slice(None)), dk) + dk = dk.astype(dk_scratch_ref.dtype) + dk_scratch_ref[slice_k, :] + dk_scratch_ref[slice_k, :] = dk if dq_scratch_ref is not None or dq_ref is not None: dq = lax.dot_general( ds.T.astype(k.dtype), k, NN_DIM_NUMBERS, diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 38426747d85d..8d5dc471e847 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -1937,7 +1937,7 @@ def test_masked_oob_load_store_slice(self): def masked_oob_load_store_slice(x_ref, mask_ref, start_idx_ref, o_ref): x = pl.load(x_ref, (pl.dslice(start_idx_ref[()], n)), mask=mask_ref[:], other=-1.) - pl.store(o_ref, (pl.dslice(None),), x) + o_ref[...] = x x = random.normal(random.key(0), (n,)) slice_start = random.randint(random.key(2), (), 1, n) diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 9e5130b8f449..781934ecd682 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -128,8 +128,8 @@ def matmul_block_spec(x, y, *, bm, bn, bk, interpret, debug=False): def matmul_kernel(x_ref, y_ref, o_ref): acc = jnp.zeros(o_ref.shape, dtype=jnp.float32) def body(i, acc_ref): - x_block = pl.load(x_ref, (slice(None), pl.ds(i * bk, bk))) - y_block = pl.load(y_ref, (pl.ds(i * bk, bk), slice(None))) + x_block = x_ref[:, pl.ds(i * bk, bk)] + y_block = y_ref[pl.ds(i * bk, bk), :] acc_ref[:, :] += pl.dot(x_block, y_block) acc = for_loop(k // bk, body, acc).astype(o_ref.dtype) o_ref[:, :] = acc @@ -624,8 +624,9 @@ def test_unused_ref(self): out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32), ) def dummy(_, o_ref): - pl.store(o_ref, (jnp.arange(m)[:, None], jnp.arange(n)[None, :]), - jnp.ones_like(o_ref)) + o_ref[jnp.arange(m)[:, None], jnp.arange(n)[None, :]] = jnp.ones_like( + o_ref + ) key = random.key(0) x = random.normal(key, (m, n)) @@ -667,8 +668,7 @@ def test_using_pallas_slice(self): out_shape=out_shape, ) def slice_kernel(x_ref, y_ref): - x = pl.load(x_ref, (pl.dslice(0, 4), pl.dslice(0, 4))) - pl.store(y_ref, (pl.dslice(4), pl.dslice(4)), x) + y_ref[:4, :4] = x_ref[:4, :4] x = random.normal(random.key(0), (m, n)) y = slice_kernel(x) y_ref = x[:4] @@ -1733,7 +1733,7 @@ def test_range_while_loop(self): def kernel(x_ref, r_ref): @pl.when(pl.program_id(0) == 0) def _(): - pl.store(r_ref, (0, 0), 0) + r_ref[0, 0] = 0 def cond(carry): i, j = carry @@ -1745,8 +1745,7 @@ def body(carry): sl = jax.lax.div(i, 128) l = jax.lax.rem(i, 128) v = x_ref[0, sl, l] - s = pl.load(r_ref, (0, 0)) - pl.store(r_ref, (0, 0), s + v) + r_ref[0, 0] += v return io + 1, j i = 128 @@ -1798,7 +1797,7 @@ def test_non_range_while_loop(self): def kernel(x_ref, r_ref): @pl.when(pl.program_id(0) == 0) def _(): - pl.store(r_ref, (0, 0), 0) + r_ref[0, 0] = 0 def cond(state): i, s = state @@ -1808,14 +1807,11 @@ def body(state): i, s = state sl = jax.lax.div(i, jnp.astype(128, i.dtype)) l = jax.lax.rem(i, jnp.astype(128, i.dtype)) - v = pl.load(x_ref, (0, sl, l)) + v = x_ref[0, sl, l] return i + 1, s + v i = jnp.int32(0) - s = pl.load(r_ref, (0, 0)) - - i, s = jax.lax.while_loop(cond, body, (i, s)) - pl.store(r_ref, (0, 0), s) + _, r_ref[0, 0] = jax.lax.while_loop(cond, body, (i, r_ref[0, 0])) x = jnp.arange(4096) x = jnp.reshape(x, [4, 8, 128]) diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index 55831ff6af1d..128fe50687a0 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -145,8 +145,7 @@ def body(_, x_ref, o_ref): x = jnp.arange(8 * 8 * 128, dtype=jnp.int32).reshape((8 * 8, 128)) def _x_transform(i, s_ref): - s = pl.load(s_ref, (i,)) - return (s, 0) + return (s_ref[i], 0) out = self.pallas_call( body, @@ -225,7 +224,7 @@ def kernel(s_refs, src, to_store, dst, *scratch_refs): assert s2.shape == (3,) assert s3 is None store_idx = s_ref[pl.program_id(0)] - pl.store(dst, (pl.dslice(store_idx, 1), slice(None)), to_store[...]) + dst[pl.dslice(store_idx, 1), :] = to_store[...] # Pass a pytree of scalar return kernel((s, np.arange(3, dtype=np.int32), None), x, to_store) @@ -281,7 +280,7 @@ def body(_, x_ref, o_ref): x = jnp.arange(2 * 8 * 8 * 128, dtype=jnp.int32).reshape((2, 8 * 8, 128)) def _x_transform(i, s_ref): - s = pl.load(s_ref, (i,)) + s = s_ref[i] return (s, 0) def f(x): @@ -423,7 +422,7 @@ def body(_, x_ref, o_ref): x = jnp.arange(8 * 8 * 128, dtype=jnp.int32).reshape((8 * 8, 128)) def _x_transform(i, s_ref): - s = pl.load(s_ref, (i,)) + s = s_ref[i] return (s, 0) s = s[None] @@ -457,7 +456,7 @@ def body(_, x_ref, o_ref): x = jnp.arange(2 * 8 * 8 * 128, dtype=jnp.int32).reshape((2, 8 * 8, 128)) def _x_transform(i, s_ref): - s = pl.load(s_ref, (i,)) + s = s_ref[i] return (s, 0) s = jnp.tile(s[None], [2, 1]) @@ -1139,8 +1138,7 @@ def kernel(x_hbm_ref, y_hbm_ref, sem_val_ref, dma_sem): def test_hbm_hbm_dma(self): def kernel(x_hbm_ref, y_hbm_ref): def body(sem): - pltpu.async_copy(x_hbm_ref.at[pl.ds(8), :], y_hbm_ref.at[:, pl.ds(128)], - sem).wait() + pltpu.async_copy(x_hbm_ref.at[:8, :], y_hbm_ref.at[:, :128], sem).wait() pl.run_scoped(body, pltpu.SemaphoreType.DMA) x = jnp.arange(8 * 128.).reshape((8, 128)) y = self.pallas_call( @@ -2570,8 +2568,7 @@ def body(scalar_ref, x_ref, o_ref): x = jnp.arange(8 * 8 * 128, dtype=jnp.int32).reshape((8 * 8, 128)) def _x_transform(i, s_ref): - s = pl.load(s_ref, (i,)) - return (s, 0) + return (s_ref[i], 0) pallas_call = self.pallas_call( body, From 777d8f27408ba519579fbc8307cb4bef0572fb9f Mon Sep 17 00:00:00 2001 From: Gleb Pobudzey Date: Mon, 24 Mar 2025 14:16:26 -0700 Subject: [PATCH 131/483] [Mosaic GPU] Adding pallas bindings to broadcast over the leading dimension and load a ref into WGMMAColFragLayout format. PiperOrigin-RevId: 740068368 --- jax/_src/pallas/mosaic_gpu/lowering.py | 6 ++++++ jax/_src/pallas/mosaic_gpu/primitives.py | 9 ++++++++ jax/experimental/mosaic/gpu/__init__.py | 2 ++ tests/pallas/mosaic_gpu_test.py | 27 ++++++++++++++++++------ 4 files changed, 38 insertions(+), 6 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 004a6e7f2760..607b3028f93b 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1262,6 +1262,12 @@ def _broadcast_in_dim_lowering_rule( and x.layout == mgpu.WGMMA_ROW_LAYOUT ): return x.broadcast_minor(y_aval.shape[-1]) + if ( + broadcast_dimensions == (1,) + and y_aval.ndim == x_aval.ndim + 1 + and x.layout == mgpu.WGMMA_COL_LAYOUT + ): + return x.broadcast_major(y_aval.shape[-2]) if broadcast_dimensions: raise NotImplementedError return x.broadcast(shape) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index a27137964349..9665f14254f8 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -110,6 +110,10 @@ def _load_p_lowering_rule( return mgpu.FragmentedArray.load_wgmma_row( x_ref, is_signed=mgpu_utils.is_signed(x_aval.dtype) ) + case mgpu.WGMMAColFragLayout(): + return mgpu.FragmentedArray.load_wgmma_col( + x_ref, is_signed=mgpu_utils.is_signed(x_aval.dtype) + ) case mgpu.WGStridedFragLayout(shape=shape, vec_size=vec_size): ref_ty = ir.MemRefType(x_ref.type) if shape != tuple(ref_ty.shape): @@ -878,6 +882,8 @@ class Layout(enum.Enum): WGMMA = enum.auto() #: [m] matrix, where m % 64 == 0. WGMMA_ROW = enum.auto() + #: [n] matrix, where n % 8 == 0. + WGMMA_COL = enum.auto() WG_SPLAT = enum.auto() WG_STRIDED = enum.auto() @@ -897,6 +903,9 @@ def check_no_args(): case Layout.WGMMA_ROW: check_no_args() return mgpu.WGMMA_ROW_LAYOUT + case Layout.WGMMA_COL: + check_no_args() + return mgpu.WGMMA_COL_LAYOUT case Layout.WG_SPLAT: return mgpu.WGSplatFragLayout(*args, **kwargs) # pytype: disable=missing-parameter case Layout.WG_STRIDED: diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index d004c7deb3df..867fd84b8b3c 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -54,7 +54,9 @@ FragmentedLayout as FragmentedLayout, WGMMA_LAYOUT as WGMMA_LAYOUT, WGMMA_ROW_LAYOUT as WGMMA_ROW_LAYOUT, + WGMMA_COL_LAYOUT as WGMMA_COL_LAYOUT, WGMMARowFragLayout as WGMMARowFragLayout, + WGMMAColFragLayout as WGMMAColFragLayout, WGSplatFragLayout as WGSplatFragLayout, WGStridedFragLayout as WGStridedFragLayout, optimization_barrier as optimization_barrier, diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index d31d1c9d41b2..b33857df40b6 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -638,6 +638,7 @@ def kernel(x_ref, o_ref, barrier_ref): src_memory_space=[plgpu.SMEM, plgpu.GMEM], layout=[ plgpu.Layout.WGMMA_ROW, + plgpu.Layout.WGMMA_COL, plgpu.Layout.WG_STRIDED((128,), vec_size=1), None, ], @@ -661,15 +662,27 @@ def kernel(x_ref, o_ref): x = jnp.arange(2 * 128, dtype=jnp.float32).reshape(2, 128) np.testing.assert_array_equal(f(x), x) - @parameterized.product(src_memory_space=[plgpu.SMEM, plgpu.GMEM]) - def test_load_row_input_to_wgmma_with_transforms(self, src_memory_space): + @parameterized.product(src_memory_space=[plgpu.SMEM], + layout=[ + plgpu.Layout.WGMMA_ROW, + plgpu.Layout.WGMMA_COL, + ],) + def test_load_row_input_to_wgmma_with_transforms(self, src_memory_space, layout): m, k, n = 64, 128, 192 key1, key2 = jax.random.split(jax.random.key(42), 2) - a = jax.random.uniform(key1, shape=(m,), dtype=jnp.float16) + if layout == plgpu.Layout.WGMMA_ROW: + input_shape = (m,) + broadcast_dim = 0 + expand_dim = 1 + else: + input_shape = (k,) + broadcast_dim = 1 + expand_dim = 0 + a = jax.random.uniform(key1, shape=input_shape, dtype=jnp.float16) b = jax.random.uniform(key2, shape=(k, n), dtype=jnp.float16) def kernel(x_ref, y_ref, o_ref): - x = plgpu.load(x_ref, (), layout=plgpu.Layout.WGMMA_ROW) - x = lax.broadcast_in_dim(x, (m, k), [0]) + x = plgpu.load(x_ref, (), layout=layout) + x = lax.broadcast_in_dim(x, (m, k), [broadcast_dim]) def compute(acc_ref): plgpu.wgmma(acc_ref, x, y_ref) @@ -697,7 +710,9 @@ def compute(acc_ref): out_specs=out_spec, ) - out_ref = jnp.broadcast_to(a[:, None], (m, k)) @ b + out_ref = ( + jnp.broadcast_to(jnp.expand_dims(a, axis=expand_dim), (m, k)) @ b + ) np.testing.assert_allclose(f(a, b), out_ref, rtol=1e-3) def test_indexing_before_transpose(self): From 60b3e5156aed252132f54ab6ea337935dfaa2804 Mon Sep 17 00:00:00 2001 From: Michael Whittaker Date: Mon, 24 Mar 2025 15:02:44 -0700 Subject: [PATCH 132/483] Reduced sharding in various tests. PiperOrigin-RevId: 740084295 --- tests/BUILD | 10 +++++----- tests/pallas/BUILD | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/BUILD b/tests/BUILD index 4fab173b4e15..0a8fb9459044 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -596,9 +596,9 @@ jax_multiplatform_test( name = "lax_scipy_spectral_dac_test", srcs = ["lax_scipy_spectral_dac_test.py"], shard_count = { - "cpu": 40, - "gpu": 40, - "tpu": 40, + "cpu": 10, + "gpu": 10, + "tpu": 10, }, deps = [ "//jax:internal_test_util", @@ -1432,8 +1432,8 @@ jax_multiplatform_test( "gpu_p100x2_shardy", ], shard_count = { - "gpu": 10, - "tpu": 10, + "gpu": 5, + "tpu": 5, }, tags = [ "multiaccelerator", diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 7581ff78802b..1ea05c700938 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -328,7 +328,7 @@ jax_multiplatform_test( "tpu_gmm_test.py", ], enable_backends = ["tpu"], - shard_count = 50, + shard_count = 5, tags = [ "noasan", # Times out. "nomsan", # Times out. From 5f1ab2ee6713934b37b5d272e67494a1816cbdba Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 24 Mar 2025 15:39:22 -0700 Subject: [PATCH 133/483] Skip checking of manylinux compliance for `jax` wheel. If `auditwheel show` is executed on `jax` wheel, the following message is printed: ``` INFO:auditwheel.main_show:This does not look like a platform wheel, no ELF executable or shared library file (including compiled Python C extension) found in the wheel archive ``` PiperOrigin-RevId: 740096302 --- ci/utilities/run_auditwheel.sh | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ci/utilities/run_auditwheel.sh b/ci/utilities/run_auditwheel.sh index 30b6a3b51865..b8f80c3e6778 100755 --- a/ci/utilities/run_auditwheel.sh +++ b/ci/utilities/run_auditwheel.sh @@ -26,6 +26,10 @@ if [[ -z "$WHEELS" ]]; then fi for wheel in $WHEELS; do + # Skip checking manylinux compliance for jax wheel. + if [[ "$wheel" =~ 'jax-' ]]; then + continue + fi printf "\nRunning auditwheel on the following wheel:" ls $wheel OUTPUT_FULL=$(python -m auditwheel show $wheel) From c1904dc7eb6e74c85daae57b4d79dfaf353f850f Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 24 Mar 2025 16:05:45 -0700 Subject: [PATCH 134/483] Update the docstring to mesh to use computation follows data and jax.jit APIs. Fixes https://github.com/jax-ml/jax/issues/27390 PiperOrigin-RevId: 740104692 --- jax/_src/mesh.py | 40 +++++++++------------------------------- 1 file changed, 9 insertions(+), 31 deletions(-) diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index b490febf7b0c..a8003e693459 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -194,16 +194,9 @@ def _name_to_type(self): class Mesh(_BaseMesh, contextlib.ContextDecorator): """Declare the hardware resources available in the scope of this manager. - In particular, all ``axis_names`` become valid resource names inside the - managed block and can be used e.g. in the ``in_axis_resources`` argument of - :py:func:`jax.experimental.pjit.pjit`. Also see JAX's multi-process programming - model (https://jax.readthedocs.io/en/latest/multi_process.html) - and the Distributed arrays and automatic parallelization tutorial + See the Distributed arrays and automatic parallelization tutorial (https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) - - If you are compiling in multiple threads, make sure that the - ``with Mesh`` context manager is inside the function that the threads will - execute. + and Explicit sharding tutorial (https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html) Args: devices: A NumPy ndarray object containing JAX device objects (as @@ -214,32 +207,17 @@ class Mesh(_BaseMesh, contextlib.ContextDecorator): Examples: - >>> from jax.experimental.pjit import pjit >>> from jax.sharding import Mesh - >>> from jax.sharding import PartitionSpec as P + >>> from jax.sharding import PartitionSpec as P, NamedSharding >>> import numpy as np ... - >>> inp = np.arange(16).reshape((8, 2)) - >>> devices = np.array(jax.devices()).reshape(4, 2) - ... >>> # Declare a 2D mesh with axes `x` and `y`. - >>> global_mesh = Mesh(devices, ('x', 'y')) - >>> # Use the mesh object directly as a context manager. - >>> with global_mesh: - ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp) - - >>> # Initialize the Mesh and use the mesh as the context manager. - >>> with Mesh(devices, ('x', 'y')) as global_mesh: - ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp) - - >>> # Also you can use it as `with ... as ...`. - >>> global_mesh = Mesh(devices, ('x', 'y')) - >>> with global_mesh as m: - ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp) - - >>> # You can also use it as `with Mesh(...)`. - >>> with Mesh(devices, ('x', 'y')): - ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp) + >>> devices = np.array(jax.devices()).reshape(4, 2) + >>> mesh = Mesh(devices, ('x', 'y')) + >>> inp = np.arange(16).reshape(8, 2) + >>> arr = jax.device_put(inp, NamedSharding(mesh, P('x', 'y'))) + >>> out = jax.jit(lambda x: x * 2)(arr) + >>> assert out.sharding == NamedSharding(mesh, P('x', 'y')) """ devices: np.ndarray From 49aad1b97fa937583ff21df194fae4cd50be20eb Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 24 Mar 2025 16:39:35 -0700 Subject: [PATCH 135/483] Add the missing `flatbuffers` dependency for the tests that run under `:build_jaxlib=false`. PiperOrigin-RevId: 740115575 --- BUILD.bazel | 1 + 1 file changed, 1 insertion(+) diff --git a/BUILD.bazel b/BUILD.bazel index 5700fcef2e77..ebf852a60924 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -125,6 +125,7 @@ COMMON_DEPS = py_deps([ "opt_einsum", "hypothesis", "cloudpickle", + "flatbuffers", ]) py_import( From 51560bf3f55e50469377f08151a94d36bef0b655 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 24 Mar 2025 18:24:22 -0700 Subject: [PATCH 136/483] [JAX] [XLA:Python] Migrate pytree module to JAX. PiperOrigin-RevId: 740142231 --- jaxlib/xla/BUILD | 49 +- jaxlib/xla/jax_jit.cc | 2 +- jaxlib/xla/jax_jit.h | 2 +- jaxlib/xla/pjit.cc | 2 +- jaxlib/xla/pmap_lib.cc | 2 +- jaxlib/xla/pytree.cc | 1825 +++++++++++++++++++++++++++++++++++++++ jaxlib/xla/pytree.h | 408 +++++++++ jaxlib/xla/pytree.proto | 32 + jaxlib/xla/xla.cc | 2 +- 9 files changed, 2315 insertions(+), 9 deletions(-) create mode 100644 jaxlib/xla/pytree.cc create mode 100644 jaxlib/xla/pytree.h create mode 100644 jaxlib/xla/pytree.proto diff --git a/jaxlib/xla/BUILD b/jaxlib/xla/BUILD index 592d9d1c24f3..2edc183bc49b 100644 --- a/jaxlib/xla/BUILD +++ b/jaxlib/xla/BUILD @@ -14,13 +14,16 @@ load( "//jaxlib:jax.bzl", + "cc_proto_library", "if_oss", + "jax_visibility", "nanobind_extension", "py_deps", "py_strict_library", "py_strict_test", "pytype_strict_library", ) +# Placeholder: load proto_library licenses(["notice"]) @@ -50,6 +53,7 @@ nanobind_extension( ":mlir", ":pjit", ":pmap_lib", + ":pytree", ":sdy", ":weakref_lru_cache", ":xla_compiler", @@ -103,7 +107,6 @@ nanobind_extension( "@xla//xla/python:profiler", "@xla//xla/python:py_client", "@xla//xla/python:python_ref_manager", - "@xla//xla/python:pytree", "@xla//xla/python:refine_polymorphic_shapes", "@xla//xla/python:traceback", "@xla//xla/python:types", @@ -250,6 +253,7 @@ cc_library( ], features = ["-use_header_modules"], deps = [ + ":pytree", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", @@ -270,7 +274,6 @@ cc_library( "@xla//xla/python:nb_helpers", "@xla//xla/python:py_client", "@xla//xla/python:python_ref_manager", - "@xla//xla/python:pytree", "@xla//xla/python:types", "@xla//xla/tsl/platform:logging", ], @@ -328,6 +331,7 @@ cc_library( deps = [ ":config", ":jax_jit", + ":pytree", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", @@ -352,7 +356,6 @@ cc_library( "@xla//xla/python:nb_numpy", "@xla//xla/python:py_client", "@xla//xla/python:python_ref_manager", - "@xla//xla/python:pytree", "@xla//xla/python:traceback", "@xla//xla/python/ifrt", "@xla//xla/tsl/concurrency:ref_count", @@ -376,6 +379,7 @@ cc_library( deps = [ ":config", ":jax_jit", + ":pytree", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/hash", @@ -399,7 +403,6 @@ cc_library( "@xla//xla/python:nb_numpy", "@xla//xla/python:py_client", "@xla//xla/python:python_ref_manager", - "@xla//xla/python:pytree", "@xla//xla/python:traceback", "@xla//xla/python:types", "@xla//xla/python/ifrt", @@ -411,6 +414,44 @@ cc_library( ], ) +proto_library( + name = "pytree_proto", + srcs = ["pytree.proto"], +) + +cc_proto_library( + name = "pytree_cc_proto", + deps = [":pytree_proto"], +) + +cc_library( + name = "pytree", + srcs = ["pytree.cc"], + hdrs = ["pytree.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + visibility = jax_visibility("jaxlib/xla/pytree"), + deps = [ + ":pytree_cc_proto", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@nanobind", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla/pjrt:exceptions", + "@xla//xla/python:nb_class_ptr", + "@xla//xla/tsl/platform:logging", + ], +) + cc_library( name = "sdy", srcs = ["sdy.cc"], diff --git a/jaxlib/xla/jax_jit.cc b/jaxlib/xla/jax_jit.cc index 754272a078ed..23abe9a8404a 100644 --- a/jaxlib/xla/jax_jit.cc +++ b/jaxlib/xla/jax_jit.cc @@ -53,13 +53,13 @@ limitations under the License. #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/pytree.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_layout.h" #include "xla/pjrt/status_casters.h" #include "xla/python/nb_absl_inlined_vector.h" // IWYU pragma: keep #include "xla/python/nb_absl_span.h" // IWYU pragma: keep #include "xla/python/py_values.h" -#include "xla/python/pytree.h" #include "xla/python/sharding.h" #include "xla/python/types.h" #include "xla/tsl/platform/logging.h" diff --git a/jaxlib/xla/jax_jit.h b/jaxlib/xla/jax_jit.h index 303d7e69414d..a000ef6773b2 100644 --- a/jaxlib/xla/jax_jit.h +++ b/jaxlib/xla/jax_jit.h @@ -34,11 +34,11 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "nanobind/nanobind.h" +#include "jaxlib/xla/pytree.h" #include "xla/pjrt/pjrt_client.h" #include "xla/python/nb_helpers.h" #include "xla/python/py_values.h" #include "xla/python/python_ref_manager.h" -#include "xla/python/pytree.h" #include "xla/python/sharding.h" #include "xla/tsl/platform/logging.h" diff --git a/jaxlib/xla/pjit.cc b/jaxlib/xla/pjit.cc index 96056708c2fb..a13d3f0b52e3 100644 --- a/jaxlib/xla/pjit.cc +++ b/jaxlib/xla/pjit.cc @@ -52,6 +52,7 @@ limitations under the License. #include "nanobind/stl/vector.h" // IWYU pragma: keep #include "jaxlib/xla/config.h" #include "jaxlib/xla/jax_jit.h" +#include "jaxlib/xla/pytree.h" #include "xla/layout.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/lru_cache.h" @@ -69,7 +70,6 @@ limitations under the License. #include "xla/python/py_executable.h" #include "xla/python/py_values.h" #include "xla/python/python_ref_manager.h" -#include "xla/python/pytree.h" #include "xla/python/sharding.h" #include "xla/python/traceback.h" #include "xla/tsl/concurrency/ref_count.h" diff --git a/jaxlib/xla/pmap_lib.cc b/jaxlib/xla/pmap_lib.cc index 5582eccf4f8b..c6849f8c25fd 100644 --- a/jaxlib/xla/pmap_lib.cc +++ b/jaxlib/xla/pmap_lib.cc @@ -46,6 +46,7 @@ limitations under the License. #include "nanobind/stl/vector.h" // IWYU pragma: keep #include "jaxlib/xla/config.h" #include "jaxlib/xla/jax_jit.h" +#include "jaxlib/xla/pytree.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/status_casters.h" #include "xla/python/ifrt/array.h" @@ -64,7 +65,6 @@ limitations under the License. #include "xla/python/py_executable.h" #include "xla/python/py_values.h" #include "xla/python/python_ref_manager.h" -#include "xla/python/pytree.h" #include "xla/python/sharded_device_array.h" #include "xla/python/sharding.h" #include "xla/python/to_ifrt_sharding.h" diff --git a/jaxlib/xla/pytree.cc b/jaxlib/xla/pytree.cc new file mode 100644 index 000000000000..dd5a0bd9cf69 --- /dev/null +++ b/jaxlib/xla/pytree.cc @@ -0,0 +1,1825 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Caution: this code uses exceptions. The exception use is local to the +// binding code and the idiomatic way to emit Python exceptions. + +#include "jaxlib/xla/pytree.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/pair.h" // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/tuple.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/pytree.pb.h" +#include "xla/pjrt/exceptions.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/tsl/platform/logging.h" + +namespace xla { + +namespace nb = nanobind; + +constexpr int kSequenceKeyHashSalt = 1; +constexpr int kFlattenedIndexKeyHashSalt = 42; + +PyTreeRegistry::PyTreeRegistry(bool enable_none, bool enable_tuple, + bool enable_namedtuple, bool enable_list, + bool enable_dict) { + auto add_builtin_type = [&](PyTypeObject* type_obj, PyTreeKind kind) { + nb::object type = + nb::borrow(reinterpret_cast(type_obj)); + auto registration = std::make_unique(); + registration->kind = kind; + registration->type = type; + CHECK(registrations_.emplace(type, std::move(registration)).second); + }; + if (enable_none) { + add_builtin_type(Py_TYPE(Py_None), PyTreeKind::kNone); + } + if (enable_tuple) { + add_builtin_type(&PyTuple_Type, PyTreeKind::kTuple); + } + enable_namedtuple_ = enable_namedtuple; + if (enable_list) { + add_builtin_type(&PyList_Type, PyTreeKind::kList); + } + if (enable_dict) { + add_builtin_type(&PyDict_Type, PyTreeKind::kDict); + } +} + +void PyTreeRegistry::Register( + nb::object type, nb::callable to_iterable, nb::callable from_iterable, + std::optional to_iterable_with_keys) { + auto registration = std::make_unique(); + registration->kind = PyTreeKind::kCustom; + registration->type = type; + registration->to_iterable = std::move(to_iterable); + registration->from_iterable = std::move(from_iterable); + registration->to_iterable_with_keys = std::move(to_iterable_with_keys); + nb::ft_lock_guard lock(mu_); + auto it = registrations_.emplace(type, std::move(registration)); + if (!it.second) { + throw std::invalid_argument( + absl::StrFormat("Duplicate custom PyTreeDef type registration for %s.", + nb::cast(nb::repr(type)))); + } +} + +void PyTreeRegistry::RegisterDataclass(nb::object type, + std::vector data_fields, + std::vector meta_fields) { + auto registration = std::make_unique(); + registration->kind = PyTreeKind::kDataclass; + registration->type = type; + registration->data_fields = std::move(data_fields); + registration->meta_fields = std::move(meta_fields); + nb::ft_lock_guard lock(mu_); + auto it = registrations_.emplace(type, std::move(registration)); + if (!it.second) { + throw std::invalid_argument(absl::StrFormat( + "Duplicate custom dataclass PyTreeDef type registration for %s.", + nb::cast(nb::repr(std::move(type))))); + } +} + +std::pair +PyTreeRegistry::Registration::ToIterable(nanobind::handle o) const { + nb::object out = to_iterable(o); + nb::tuple leaves_and_aux_data; + if (!nb::try_cast(out, leaves_and_aux_data) || + leaves_and_aux_data.size() != 2) { + throw std::invalid_argument(absl::StrCat( + "The to_iterable function for a custom PyTree node should return " + "a (children, aux_data) tuple, got ", + nb::cast(nb::repr(out)))); + } + nb::iterable leaves; + if (!nb::try_cast(leaves_and_aux_data[0], leaves)) { + throw std::invalid_argument(absl::StrCat( + "The to_iterable function for a custom PyTree node should return " + "a (children, aux_data) tuple where 'children' is iterable, " + "got ", + nb::cast(nb::repr(out)))); + } + return std::make_pair(std::move(leaves), nb::object(leaves_and_aux_data[1])); +} + +std::pair>, nb::object> +PyTreeRegistry::Registration::ToIterableWithKeys(nb::handle o) const { + // Backwards compatibility case: return dummy FlattenedIndexKey for each leaf. + std::vector> result; + if (!to_iterable_with_keys.has_value()) { + auto [leaves, aux_data] = ToIterable(o); + for (nb::handle leaf : leaves) { + result.push_back(std::make_pair( + make_nb_class(result.size()), nb::borrow(leaf))); + } + return std::make_pair(std::move(result), std::move(aux_data)); + } + nb::object out = to_iterable_with_keys.value()(o); + nb::tuple leaves_and_aux_data; + if (!nb::try_cast(out, leaves_and_aux_data) || + leaves_and_aux_data.size() != 2) { + throw std::invalid_argument(absl::StrCat( + "The to_iterable_with_keys function for a custom PyTree " + "node should return a (key_leaf_pairs, aux_data) tuple, got ", + nb::cast(nb::repr(out)))); + } + nb::iterable key_leaf_pairs; + if (!nb::try_cast(leaves_and_aux_data[0], key_leaf_pairs)) { + throw std::invalid_argument(absl::StrCat( + "The to_iterable_with_keys function for a custom PyTree node should " + "return a (key_leaf_pairs, aux_data) tuple where 'key_leaf_pairs' is " + "iterable, got ", + nb::cast(nb::repr(leaves_and_aux_data)))); + } + for (nb::handle key_leaf_pair : key_leaf_pairs) { + nb::tuple key_leaf_pair_tuple; + if (!nb::try_cast(key_leaf_pair, key_leaf_pair_tuple) || + key_leaf_pair_tuple.size() != 2) { + throw std::invalid_argument(absl::StrCat( + "The to_iterable_with_keys function for a custom PyTree node should " + "return a (key_leaf_pairs, aux_data) tuple where 'child", + nb::cast(nb::repr(key_leaf_pair)))); + } + result.push_back(std::make_pair(nb::borrow(key_leaf_pair_tuple[0]), + nb::borrow(key_leaf_pair_tuple[1]))); + } + return std::make_pair(std::move(result), nb::object(leaves_and_aux_data[1])); +} + +int PyTreeRegistry::Registration::tp_traverse(visitproc visit, void* arg) { + Py_VISIT(type.ptr()); + Py_VISIT(to_iterable.ptr()); + Py_VISIT(from_iterable.ptr()); + for (const auto& field : data_fields) { + Py_VISIT(field.ptr()); + } + for (const auto& field : meta_fields) { + Py_VISIT(field.ptr()); + } + return 0; +} + +// Computes the node kind of a given Python object. +PyTreeKind PyTreeRegistry::KindOfObject( + nb::handle obj, PyTreeRegistry::Registration const** custom) const { + const PyTreeRegistry::Registration* registration = Lookup(obj.type()); + if (registration) { + if (registration->kind == PyTreeKind::kCustom || + registration->kind == PyTreeKind::kDataclass) { + *custom = registration; + } else { + *custom = nullptr; + } + return registration->kind; + } else if (nb::isinstance(obj) && nb::hasattr(obj, "_fields")) { + // We can only identify namedtuples heuristically, here by the presence of + // a _fields attribute. + return PyTreeKind::kNamedTuple; + } else { + return PyTreeKind::kLeaf; + } +} + +/*static*/ const PyTreeRegistry::Registration* PyTreeRegistry::Lookup( + nb::handle type) const { + nb::ft_lock_guard lock(mu_); + auto it = registrations_.find(type); + return it == registrations_.end() ? nullptr : it->second.get(); +} + +/*static*/ std::vector GetSortedPyDictKeys(PyObject* py_dict) { + std::vector keys; + keys.reserve(PyDict_Size(py_dict)); + PyObject* key; + Py_ssize_t pos = 0; + while (PyDict_Next(py_dict, &pos, &key, /*value=*/nullptr)) { + keys.push_back(nb::borrow(key)); + } + + try { + std::stable_sort( + keys.begin(), keys.end(), [](const nb::object& a, const nb::object& b) { + int cmp = PyObject_RichCompareBool(a.ptr(), b.ptr(), Py_LT); + if (cmp == -1) { + throw nb::python_error(); + } + return cmp; + }); + } catch (nb::python_error& e) { + nb::raise_from(e, PyExc_ValueError, + "Comparator raised exception while sorting pytree " + "dictionary keys."); + } + return keys; +} + +/*static*/ bool IsSortedPyDictKeysEqual(absl::Span lhs, + absl::Span rhs) { + if (lhs.size() != rhs.size()) { + return false; + } + for (int i = 0; i < lhs.size(); ++i) { + if (lhs[i].not_equal(rhs[i])) { + return false; + } + } + return true; +} + +bool PyTreeDef::operator==(const PyTreeDef& other) const { + if (traversal_.size() != other.traversal_.size()) { + return false; + } + for (size_t i = 0; i < traversal_.size(); ++i) { + const Node& a = traversal_[i]; + const Node& b = other.traversal_[i]; + if (a.kind != b.kind || a.arity != b.arity || + (a.node_data.ptr() == nullptr) != (b.node_data.ptr() == nullptr) || + (a.sorted_dict_keys.size() != b.sorted_dict_keys.size()) || + a.custom != b.custom) { + return false; + } + if (a.node_data && a.node_data.not_equal(b.node_data)) { + return false; + } + if (!IsSortedPyDictKeysEqual(a.sorted_dict_keys, b.sorted_dict_keys)) { + return false; + } + // We don't need to test equality of num_leaves and num_nodes since they + // are derivable from the other node data. + } + return true; +} + +nb::object PyTreeRegistry::FlattenOneLevel(nb::handle x) const { + return FlattenOneLevelImpl(x, /*with_keys=*/false); +} + +nb::object PyTreeRegistry::FlattenOneLevelWithKeys(nb::handle x) const { + return FlattenOneLevelImpl(x, /*with_keys=*/true); +} + +nb::object PyTreeRegistry::FlattenOneLevelImpl(nb::handle x, + bool with_keys) const { + PyTreeRegistry::Registration const* custom; + PyTreeKind kind = KindOfObject(x, &custom); + switch (kind) { + case PyTreeKind::kNone: + return nb::make_tuple(nb::make_tuple(), nb::none()); + case PyTreeKind::kTuple: { + if (with_keys) { + auto size = PyTuple_GET_SIZE(x.ptr()); + nb::object key_leaves = nb::steal(PyTuple_New(size)); + for (int i = 0; i < size; ++i) { + nb::object key = make_nb_class(i); + nb::object value = + nb::borrow(PyTuple_GET_ITEM(x.ptr(), i)); + PyTuple_SET_ITEM(key_leaves.ptr(), i, + nb::make_tuple(key, value).release().ptr()); + } + return nb::make_tuple(std::move(key_leaves), nb::none()); + } + return nb::make_tuple(nb::borrow(x), nb::none()); + } + case PyTreeKind::kList: { + if (with_keys) { + auto size = PyList_GET_SIZE(x.ptr()); + nb::object key_leaves = nb::steal(PyTuple_New(size)); + for (int i = 0; i < size; ++i) { + nb::object key = make_nb_class(i); + nb::object value = + nb::borrow(PyList_GET_ITEM(x.ptr(), i)); + PyTuple_SET_ITEM(key_leaves.ptr(), i, + nb::make_tuple(key, value).release().ptr()); + } + return nb::make_tuple(std::move(key_leaves), nb::none()); + } + return nb::make_tuple(nb::borrow(x), nb::none()); + } + case PyTreeKind::kDict: { + nb::dict dict = nb::borrow(x); + std::vector sorted_keys = GetSortedPyDictKeys(dict.ptr()); + nb::tuple keys = nb::steal(PyTuple_New(sorted_keys.size())); + nb::tuple values = nb::steal(PyTuple_New(sorted_keys.size())); + for (size_t i = 0; i < sorted_keys.size(); ++i) { + nb::object& key = sorted_keys[i]; + nb::object value = nb::object(dict[key]); + if (with_keys) { + value = nb::make_tuple(make_nb_class(key), value); + } + PyTuple_SET_ITEM(values.ptr(), i, value.release().ptr()); + PyTuple_SET_ITEM(keys.ptr(), i, sorted_keys[i].release().ptr()); + } + return nb::make_tuple(std::move(values), std::move(keys)); + } + case PyTreeKind::kNamedTuple: { + nb::tuple in = nb::borrow(x); + nb::list out; + if (with_keys) { + // Get key names from NamedTuple fields. + nb::tuple fields; + if (!nb::try_cast(nb::getattr(in, "_fields"), fields) || + in.size() != fields.size()) { + throw std::invalid_argument( + "A namedtuple's _fields attribute should have the same size as " + "the tuple."); + } + auto field_iter = fields.begin(); + for (nb::handle entry : in) { + out.append(nb::make_tuple( + make_nb_class(nb::str(*field_iter)), entry)); + } + return nb::make_tuple(std::move(out), x.type()); + } + for (size_t i = 0; i < in.size(); ++i) { + out.append(in[i]); + } + return nb::make_tuple(std::move(out), x.type()); + } + case PyTreeKind::kCustom: { + if (with_keys) { + auto [leaves, aux_data] = custom->ToIterableWithKeys(x); + return nb::make_tuple(std::move(leaves), std::move(aux_data)); + } + auto [leaves, aux_data] = custom->ToIterable(x); + return nb::make_tuple(std::move(leaves), std::move(aux_data)); + } + case PyTreeKind::kDataclass: { + auto data_size = custom->data_fields.size(); + nb::list leaves = nb::steal(PyList_New(data_size)); + for (int leaf = 0; leaf < data_size; ++leaf) { + nb::object value = nb::getattr(x, custom->data_fields[leaf]); + if (with_keys) { + value = nb::make_tuple( + make_nb_class(custom->data_fields[leaf]), value); + } + PyList_SET_ITEM(leaves.ptr(), leaf, value.release().ptr()); + } + auto meta_size = custom->meta_fields.size(); + nb::object aux_data = nb::steal(PyTuple_New(meta_size)); + for (int meta_leaf = 0; meta_leaf < meta_size; ++meta_leaf) { + PyTuple_SET_ITEM( + aux_data.ptr(), meta_leaf, + nb::getattr(x, custom->meta_fields[meta_leaf]).release().ptr()); + } + return nb::make_tuple(std::move(leaves), std::move(aux_data)); + } + default: + DCHECK(kind == PyTreeKind::kLeaf); + return nb::none(); + } +} + +/* static */ PyType_Slot PyTreeRegistry::slots_[] = { + {Py_tp_traverse, (void*)PyTreeRegistry::tp_traverse}, + {Py_tp_clear, (void*)PyTreeRegistry::tp_clear}, + {0, nullptr}, +}; + +/* static */ int PyTreeRegistry::tp_traverse(PyObject* self, visitproc visit, + void* arg) { + PyTreeRegistry* registry = nb::inst_ptr(self); + Py_VISIT(Py_TYPE(self)); + nb::ft_lock_guard lock(registry->mu_); + for (const auto& [key, value] : registry->registrations_) { + Py_VISIT(key.ptr()); + int rval = value->tp_traverse(visit, arg); + if (rval != 0) { + return rval; + } + } + return 0; +} + +/* static */ int PyTreeRegistry::tp_clear(PyObject* self) { + PyTreeRegistry* registry = nb::inst_ptr(self); + nb::ft_lock_guard lock(registry->mu_); + registry->registrations_.clear(); + return 0; +} + +/* static */ PyType_Slot DictKey::slots_[] = { + {Py_tp_traverse, (void*)DictKey::tp_traverse}, + {Py_tp_clear, (void*)DictKey::tp_clear}, + {0, nullptr}, +}; + +/* static */ int DictKey::tp_traverse(PyObject* self, visitproc visit, + void* arg) { + DictKey* key = nb::inst_ptr(self); + Py_VISIT(key->key_.ptr()); + return 0; +} + +/* static */ int DictKey::tp_clear(PyObject* self) { + DictKey* dictkey = nb::inst_ptr(self); + nb::object tmp; + std::swap(tmp, dictkey->key_); + return 0; +} + +std::string SequenceKey::ToString() const { + return absl::StrFormat("[%d]", idx_); +} + +std::string SequenceKey::ToReprString() const { + return absl::StrFormat("SequenceKey(idx=%d)", idx_); +} + +std::string DictKey::ToString() const { + return absl::StrFormat("[%s]", nb::cast(nb::repr(key_))); +} + +std::string DictKey::ToReprString() const { + return absl::StrFormat("DictKey(key=%s)", + nb::cast(nb::repr(key_))); +} + +std::string GetAttrKey::ToString() const { + return absl::StrFormat(".%s", nb::cast(name_)); +} + +std::string GetAttrKey::ToReprString() const { + return absl::StrFormat("GetAttrKey(name='%s')", + nb::cast(name_)); +} + +std::string FlattenedIndexKey::ToString() const { + return absl::StrFormat("[]", key_); +} + +std::string FlattenedIndexKey::ToReprString() const { + return absl::StrFormat("FlattenedIndexKey(key=%d)", key_); +} + +bool SequenceKey::Equals(const nb::object& other) { + SequenceKey other_key(0); + if (!nb::try_cast(other, other_key)) return false; + return idx_ == other_key.idx(); +} + +bool DictKey::Equals(const nb::object& other) { + DictKey other_key(nb::none()); + if (!nb::try_cast(other, other_key)) return false; + return key_.equal(other_key.key()); +} + +bool GetAttrKey::Equals(const nb::object& other) { + GetAttrKey other_key(nb::str("")); + if (!nb::try_cast(other, other_key)) return false; + return name_.equal(other_key.name()); +} + +bool FlattenedIndexKey::Equals(const nb::object& other) { + FlattenedIndexKey other_key(0); + if (!nb::try_cast(other, other_key)) return false; + return key_ == other_key.key(); +} + +nanobind::tuple SequenceKey::MatchArgs(nanobind::handle unused) { + return nanobind::make_tuple("idx"); +}; + +nanobind::tuple DictKey::MatchArgs(nanobind::handle unused) { + return nanobind::make_tuple("key"); +}; + +nanobind::tuple GetAttrKey::MatchArgs(nanobind::handle unused) { + return nanobind::make_tuple("name"); +}; + +nanobind::tuple FlattenedIndexKey::MatchArgs(nanobind::handle unused) { + return nanobind::make_tuple("key"); +}; + +template +void PyTreeDef::FlattenImpl(nb::handle handle, T& leaves, + const std::optional& leaf_predicate, + std::optional>& keypath) { + Node node; + const int start_num_nodes = traversal_.size(); + const int start_num_leaves = leaves.size(); + bool is_known_leaf = false; + if (leaf_predicate) { + nb::object o = (*leaf_predicate)(handle); + // Historically we accepted "truthy" values from leaf predicates. Accept + // None here to keep existing clients happy. + if (o.is_none()) { + is_known_leaf = false; + } else if (!nb::try_cast(o, is_known_leaf)) { + throw std::invalid_argument(absl::StrCat( + "is_leaf predicate returned a non-boolean value ", + nb::cast(nb::repr(o)), "; expected a boolean")); + } + } + if (is_known_leaf) { + nb::object value = nb::borrow(handle); + if (keypath.has_value()) { + const std::vector& frozen_keypath = keypath.value(); + nb::object kp_tuple = nb::steal(PyTuple_New(frozen_keypath.size())); + for (int i = 0; i < frozen_keypath.size(); ++i) { + PyTuple_SET_ITEM(kp_tuple.ptr(), i, + nb::object(frozen_keypath[i]).release().ptr()); + } + value = nb::make_tuple(std::move(kp_tuple), std::move(value)); + } + if constexpr (std::is_same_v) { + leaves.append(std::move(value)); + } else { + leaves.push_back(std::move(value)); + } + } else { + node.kind = registry_->KindOfObject(handle, &node.custom); + auto recurse = [this, &leaf_predicate, &leaves]( + nb::handle child, + std::optional>& keypath) { + if (Py_EnterRecursiveCall( + " in flatten; PyTree may have cyclical node references.")) { + return; + } + FlattenImpl(child, leaves, leaf_predicate, keypath); + Py_LeaveRecursiveCall(); + }; + switch (node.kind) { + case PyTreeKind::kNone: + // Nothing to do. + break; + case PyTreeKind::kTuple: { + node.arity = PyTuple_GET_SIZE(handle.ptr()); + for (int i = 0; i < node.arity; ++i) { + if (keypath.has_value()) { + keypath->push_back(make_nb_class(i)); + } + recurse(PyTuple_GET_ITEM(handle.ptr(), i), keypath); + if (keypath.has_value()) { + keypath->pop_back(); + } + } + break; + } + case PyTreeKind::kList: { + node.arity = PyList_GET_SIZE(handle.ptr()); + for (int i = 0; i < node.arity; ++i) { + if (keypath.has_value()) { + keypath->push_back(make_nb_class(i)); + } + recurse(PyList_GET_ITEM(handle.ptr(), i), keypath); + if (keypath.has_value()) { + keypath->pop_back(); + } + } + break; + } + case PyTreeKind::kDict: { + nb::dict dict = nb::borrow(handle); + + std::vector keys = GetSortedPyDictKeys(dict.ptr()); + for (nb::object& key : keys) { + if (keypath.has_value()) { + keypath->push_back(make_nb_class(key)); + } + recurse(dict[key], keypath); + if (keypath.has_value()) { + keypath->pop_back(); + } + } + node.arity = dict.size(); + node.sorted_dict_keys = std::move(keys); + break; + } + case PyTreeKind::kCustom: { + if (keypath.has_value()) { + auto [leaves, aux_data] = node.custom->ToIterableWithKeys(handle); + node.node_data = std::move(aux_data); + node.arity = 0; + for (auto& [key, leaf] : leaves) { + keypath->push_back(key); + ++node.arity; + recurse(leaf, keypath); + keypath->pop_back(); + } + } else { + auto [leaves, aux_data] = node.custom->ToIterable(handle); + node.node_data = std::move(aux_data); + node.arity = 0; + for (nb::handle entry : leaves) { + ++node.arity; + recurse(entry, keypath); + } + } + break; + } + case PyTreeKind::kDataclass: { + auto meta_size = node.custom->meta_fields.size(); + nb::object aux_data = nb::steal(PyTuple_New(meta_size)); + for (int meta_leaf = 0; meta_leaf < meta_size; ++meta_leaf) { + PyTuple_SET_ITEM( + aux_data.ptr(), meta_leaf, + nb::getattr(handle, node.custom->meta_fields[meta_leaf]) + .release() + .ptr()); + } + node.node_data = std::move(aux_data); + auto data_size = node.custom->data_fields.size(); + node.arity = data_size; + for (int leaf = 0; leaf < data_size; ++leaf) { + if (keypath.has_value()) { + keypath->push_back( + make_nb_class(node.custom->data_fields[leaf])); + } + recurse(nb::getattr(handle, node.custom->data_fields[leaf]), keypath); + if (keypath.has_value()) { + keypath->pop_back(); + } + } + break; + } + case PyTreeKind::kNamedTuple: { + nb::tuple tuple = nb::borrow(handle); + node.arity = tuple.size(); + node.node_data = nb::borrow(tuple.type()); + if (keypath.has_value()) { + // Get key names from NamedTuple fields. + nb::tuple fields; + if (!nb::try_cast(nb::getattr(tuple, "_fields"), fields) || + tuple.size() != fields.size()) { + throw std::invalid_argument( + "A namedtuple's _fields attribute should have the same size as " + "the tuple."); + } + auto field_iter = fields.begin(); + for (nb::handle entry : tuple) { + keypath->push_back(make_nb_class(nb::str(*field_iter))); + field_iter++; + recurse(entry, keypath); + keypath->pop_back(); + } + } else { + for (nb::handle entry : tuple) { + recurse(entry, keypath); + } + } + break; + } + default: + DCHECK(node.kind == PyTreeKind::kLeaf); + auto value = nb::borrow(handle); + if (keypath.has_value()) { + const std::vector& frozen_keypath = keypath.value(); + nb::object kp_tuple = nb::steal(PyTuple_New(frozen_keypath.size())); + for (int i = 0; i < frozen_keypath.size(); ++i) { + PyTuple_SET_ITEM(kp_tuple.ptr(), i, + nb::object(frozen_keypath[i]).release().ptr()); + } + value = nb::make_tuple(std::move(kp_tuple), std::move(value)); + } + if constexpr (std::is_same_v) { + leaves.append(std::move(value)); + } else { + leaves.push_back(std::move(value)); + } + } + } + node.num_nodes = traversal_.size() - start_num_nodes + 1; + node.num_leaves = leaves.size() - start_num_leaves; + traversal_.push_back(std::move(node)); +} + +void PyTreeDef::Flatten(nb::handle handle, + absl::InlinedVector& leaves, + std::optional leaf_predicate) { + std::optional> keypath = std::nullopt; + FlattenImpl(handle, leaves, leaf_predicate, keypath); +} + +void PyTreeDef::Flatten(nb::handle handle, std::vector& leaves, + std::optional leaf_predicate) { + std::optional> keypath = std::nullopt; + FlattenImpl(handle, leaves, leaf_predicate, keypath); +} + +void PyTreeDef::Flatten(nb::handle handle, nb::list& leaves, + std::optional leaf_predicate) { + std::optional> keypath = std::nullopt; + FlattenImpl(handle, leaves, leaf_predicate, keypath); +} + +/*static*/ std::pair, nb_class_ptr> +PyTreeDef::Flatten(nb::handle x, nb_class_ptr registry, + std::optional leaf_predicate) { + auto def = make_nb_class(registry); + std::vector leaves; + def->Flatten(x, leaves, leaf_predicate); + return std::make_pair(std::move(leaves), std::move(def)); +} + +void PyTreeDef::FlattenWithPath(nb::handle handle, nanobind::list& leaves, + std::optional leaf_predicate) { + std::optional> keypath = std::vector(); + FlattenImpl(handle, leaves, leaf_predicate, keypath); +} + +/*static*/ bool PyTreeDef::AllLeaves(PyTreeRegistry* registry, + const nb::iterable& x) { + const PyTreeRegistry::Registration* custom; + for (const nb::handle& h : x) { + if (registry->KindOfObject(h, &custom) != PyTreeKind::kLeaf) return false; + } + return true; +} + +template +nb::object PyTreeDef::UnflattenImpl(T leaves) const { + absl::InlinedVector agenda; + auto it = leaves.begin(); + int leaf_count = 0; + for (const Node& node : traversal_) { + if (agenda.size() < node.arity) { + throw std::logic_error("Too few elements for TreeDef node."); + } + switch (node.kind) { + case PyTreeKind::kLeaf: + if (it == leaves.end()) { + throw std::invalid_argument(absl::StrFormat( + "Too few leaves for PyTreeDef; expected %d, got %d", num_leaves(), + leaf_count)); + } + agenda.push_back(nb::borrow(*it)); + ++it; + ++leaf_count; + break; + + case PyTreeKind::kNone: + case PyTreeKind::kTuple: + case PyTreeKind::kNamedTuple: + case PyTreeKind::kList: + case PyTreeKind::kDict: + case PyTreeKind::kCustom: + case PyTreeKind::kDataclass: { + const int size = agenda.size(); + absl::Span span; + if (node.arity > 0) { + span = absl::Span(&agenda[size - node.arity], node.arity); + } + nb::object o = MakeNode(node, span); + agenda.resize(size - node.arity); + agenda.push_back(o); + break; + } + } + } + if (it != leaves.end()) { + throw std::invalid_argument(absl::StrFormat( + "Too many leaves for PyTreeDef; expected %d.", num_leaves())); + } + if (agenda.size() != 1) { + throw std::logic_error("PyTreeDef traversal did not yield a singleton."); + } + return std::move(agenda.back()); +} + +nb::object PyTreeDef::Unflatten(nb::iterable leaves) const { + return UnflattenImpl(leaves); +} + +nb::object PyTreeDef::Unflatten(absl::Span leaves) const { + return UnflattenImpl(leaves); +} + +/*static*/ nb::object PyTreeDef::MakeNode(const PyTreeDef::Node& node, + absl::Span children) { + if (children.size() != node.arity) { + throw std::logic_error("Node arity mismatch."); + } + switch (node.kind) { + case PyTreeKind::kLeaf: + throw std::logic_error("MakeNode not implemented for leaves."); + + case PyTreeKind::kNone: + return nb::none(); + + case PyTreeKind::kTuple: + case PyTreeKind::kNamedTuple: { + nb::object tuple = nb::steal(PyTuple_New(node.arity)); + for (int i = 0; i < node.arity; ++i) { + PyTuple_SET_ITEM(tuple.ptr(), i, children[i].release().ptr()); + } + if (node.kind == PyTreeKind::kNamedTuple) { + return node.node_data(*tuple); + } else { + return tuple; + } + } + + case PyTreeKind::kList: { + nb::object list = nb::steal(PyList_New(node.arity)); + for (int i = 0; i < node.arity; ++i) { + PyList_SET_ITEM(list.ptr(), i, children[i].release().ptr()); + } + return list; + } + + case PyTreeKind::kDict: { + nb::dict dict; + for (int i = 0; i < node.arity; ++i) { + dict[node.sorted_dict_keys[i]] = std::move(children[i]); + } + return std::move(dict); + break; + } + case PyTreeKind::kCustom: { + nb::object tuple = nb::steal(PyTuple_New(node.arity)); + for (int i = 0; i < node.arity; ++i) { + PyTuple_SET_ITEM(tuple.ptr(), i, children[i].release().ptr()); + } + return node.custom->from_iterable(node.node_data, tuple); + } + + case PyTreeKind::kDataclass: { + nb::kwargs kwargs; + auto meta_size = node.custom->meta_fields.size(); + for (int i = 0; i < meta_size; ++i) { + kwargs[node.custom->meta_fields[i]] = + nb::borrow(nb::tuple(node.node_data)[i]); + } + auto data_size = node.custom->data_fields.size(); + for (int i = 0; i < data_size; ++i) { + kwargs[node.custom->data_fields[i]] = std::move(children[i]); + } + return node.custom->type(**kwargs); + } + } + throw std::logic_error("Unreachable code."); +} + +nb::list PyTreeDef::FlattenUpTo(nb::handle xs) const { + nb::list leaves = nb::steal(PyList_New(num_leaves())); + std::vector agenda; + agenda.push_back(nb::borrow(xs)); + auto it = traversal_.rbegin(); + int leaf = num_leaves() - 1; + while (!agenda.empty()) { + if (it == traversal_.rend()) { + throw std::invalid_argument(absl::StrFormat( + "Tree structures did not match: %s vs %s", + nb::cast(nb::repr(xs)), ToString())); + } + const Node& node = *it; + nb::object object = agenda.back(); + agenda.pop_back(); + ++it; + + switch (node.kind) { + case PyTreeKind::kLeaf: + if (leaf < 0) { + throw std::logic_error("Leaf count mismatch."); + } + PyList_SET_ITEM(leaves.ptr(), leaf, object.release().ptr()); + --leaf; + break; + + case PyTreeKind::kNone: + if (!object.is_none()) { + throw std::invalid_argument(absl::StrFormat( + "Expected None, got %s.\n\n" + "In previous releases of JAX, flatten-up-to used to " + "consider None to be a tree-prefix of non-None values. To obtain " + "the previous behavior, you can usually write:\n" + " jax.tree.map(lambda x, y: None if x is None else f(x, y), a, " + "b, is_leaf=lambda x: x is None)", + nb::cast(nb::repr(object)))); + } + break; + + case PyTreeKind::kTuple: { + if (!PyTuple_CheckExact(object.ptr())) { + throw std::invalid_argument( + absl::StrFormat("Expected tuple, got %s.", + nb::cast(nb::repr(object)))); + } + nb::tuple tuple = nb::borrow(object); + if (tuple.size() != node.arity) { + throw std::invalid_argument(absl::StrFormat( + "Tuple arity mismatch: %d != %d; tuple: %s.", tuple.size(), + node.arity, nb::cast(nb::repr(object)))); + } + for (nb::handle entry : tuple) { + agenda.push_back(nb::borrow(entry)); + } + break; + } + + case PyTreeKind::kList: { + if (!PyList_CheckExact(object.ptr())) { + throw std::invalid_argument( + absl::StrFormat("Expected list, got %s.", + nb::cast(nb::repr(object)))); + } + nb::list list = nb::borrow(object); + if (list.size() != node.arity) { + throw std::invalid_argument(absl::StrFormat( + "List arity mismatch: %d != %d; list: %s.", list.size(), + node.arity, nb::cast(nb::repr(object)))); + } + for (nb::handle entry : list) { + agenda.push_back(nb::borrow(entry)); + } + break; + } + + case PyTreeKind::kDict: { + if (!PyDict_CheckExact(object.ptr())) { + throw std::invalid_argument( + absl::StrFormat("Expected dict, got %s.", + nb::cast(nb::repr(object)))); + } + nb::dict dict = nb::borrow(object); + std::vector keys = GetSortedPyDictKeys(dict.ptr()); + if (!IsSortedPyDictKeysEqual(keys, node.sorted_dict_keys)) { + // Convert to a nb::list for nb::repr to avoid having to stringify a + // vector. This is error path so it is fine to pay conversion cost. + throw std::invalid_argument( + absl::StrFormat("Dict key mismatch; expected keys: %s; dict: %s.", + nb::cast( + nb::repr(nb::cast(node.sorted_dict_keys))), + nb::cast(nb::repr(object)))); + } + for (nb::handle key : keys) { + agenda.push_back(dict[key]); + } + break; + } + + case PyTreeKind::kNamedTuple: { + if (!nb::isinstance(object) || + !nb::hasattr(object, "_fields")) { + throw std::invalid_argument( + absl::StrFormat("Expected named tuple, got %s.", + nb::cast(nb::repr(object)))); + } + nb::tuple tuple = nb::borrow(object); + if (tuple.size() != node.arity) { + throw std::invalid_argument(absl::StrFormat( + "Named tuple arity mismatch: %d != %d; tuple: %s.", tuple.size(), + node.arity, nb::cast(nb::repr(object)))); + } + if (tuple.type().not_equal(node.node_data)) { + throw std::invalid_argument(absl::StrFormat( + "Named tuple type mismatch: expected type: %s, tuple: %s.", + nb::cast(nb::repr(node.node_data)), + nb::cast(nb::repr(object)))); + } + for (nb::handle entry : tuple) { + agenda.push_back(nb::borrow(entry)); + } + break; + } + + case PyTreeKind::kCustom: { + auto* registration = registry_->Lookup(object.type()); + if (registration != node.custom) { + throw std::invalid_argument(absl::StrFormat( + "Custom node type mismatch: expected type: %s, value: %s.", + nb::cast(nb::repr(node.custom->type)), + nb::cast(nb::repr(object)))); + } + auto [leaves, aux_data] = node.custom->ToIterable(object); + if (node.node_data.not_equal(aux_data)) { + throw std::invalid_argument(absl::StrFormat( + "Mismatch custom node data: %s != %s; value: %s.", + nb::cast(nb::repr(node.node_data)), + nb::cast(nb::repr(aux_data)), + nb::cast(nb::repr(object)))); + } + int arity = 0; + for (nb::handle entry : leaves) { + ++arity; + agenda.push_back(nb::borrow(entry)); + } + if (arity != node.arity) { + throw std::invalid_argument(absl::StrFormat( + "Custom type arity mismatch: %d != %d; value: %s.", arity, + node.arity, nb::cast(nb::repr(object)))); + } + break; + } + + case PyTreeKind::kDataclass: { + auto* registration = registry_->Lookup(object.type()); + if (registration != node.custom) { + throw std::invalid_argument(absl::StrFormat( + "Custom dataclasss node type mismatch: expected type: %s, value: " + "%s.", + nb::cast(nb::repr(node.custom->type)), + nb::cast(nb::repr(std::move(object))))); + } + auto meta_size = node.custom->meta_fields.size(); + nb::object aux_data = nb::steal(PyTuple_New(meta_size)); + for (int meta_leaf = 0; meta_leaf < meta_size; ++meta_leaf) { + PyTuple_SET_ITEM( + aux_data.ptr(), meta_leaf, + nb::getattr(object, node.custom->meta_fields[meta_leaf]) + .release() + .ptr()); + } + if (node.node_data.not_equal(aux_data)) { + throw std::invalid_argument(absl::StrFormat( + "Mismatch custom dataclass node data: %s != %s; value: %s.", + nb::cast(nb::repr(node.node_data)), + nb::cast(nb::repr(aux_data)), + nb::cast(nb::repr(object)))); + } + auto data_size = node.custom->data_fields.size(); + if (data_size != node.arity) { + throw std::invalid_argument(absl::StrFormat( + "Custom type arity mismatch: %d != %d; value: %s.", data_size, + node.arity, nb::cast(nb::repr(object)))); + } + for (int leaf = 0; leaf < data_size; ++leaf) { + agenda.push_back(nb::borrow( + nb::getattr(object, node.custom->data_fields[leaf]))); + } + break; + } + } + } + if (it != traversal_.rend() || leaf != -1) { + throw std::invalid_argument( + absl::StrFormat("Tree structures did not match: %s vs %s", + nb::cast(nb::repr(xs)), ToString())); + } + return leaves; +} + +nb::object PyTreeDef::Walk(const nb::callable& f_node, nb::handle f_leaf, + nb::iterable leaves) const { + std::vector agenda; + auto it = leaves.begin(); + for (const Node& node : traversal_) { + switch (node.kind) { + case PyTreeKind::kLeaf: { + if (it == leaves.end()) { + throw std::invalid_argument("Too few leaves for PyTreeDef"); + } + + nb::object leaf = nb::borrow(*it); + agenda.push_back(f_leaf.is_none() ? std::move(leaf) + : f_leaf(std::move(leaf))); + ++it; + break; + } + + case PyTreeKind::kNone: + case PyTreeKind::kTuple: + case PyTreeKind::kNamedTuple: + case PyTreeKind::kList: + case PyTreeKind::kDict: + case PyTreeKind::kCustom: + case PyTreeKind::kDataclass: { + if (agenda.size() < node.arity) { + throw std::logic_error("Too few elements for custom type."); + } + nb::object tuple = nb::steal(PyTuple_New(node.arity)); + for (int i = node.arity - 1; i >= 0; --i) { + PyTuple_SET_ITEM(tuple.ptr(), i, agenda.back().release().ptr()); + agenda.pop_back(); + } + nb::object node_data = node.node_data; + if (node.kind == PyTreeKind::kDict) { + // Convert to a nb::list for f_node invocation. + node_data = nb::cast(node.sorted_dict_keys); + } + agenda.push_back(f_node(tuple, node_data ? node_data : nb::none())); + } + } + } + if (it != leaves.end()) { + throw std::invalid_argument("Too many leaves for PyTreeDef"); + } + if (agenda.size() != 1) { + throw std::logic_error("PyTreeDef traversal did not yield a singleton."); + } + return std::move(agenda.back()); +} + +nb::object PyTreeDef::FromIterableTreeHelper( + nb::handle xs, + absl::InlinedVector::const_reverse_iterator* it) const { + if (*it == traversal_.rend()) { + throw std::invalid_argument("Tree structures did not match."); + } + const Node& node = **it; + ++*it; + if (node.kind == PyTreeKind::kLeaf) { + return nb::borrow(xs); + } + nb::iterable iterable = nb::borrow(xs); + std::vector ys; + ys.reserve(node.arity); + for (nb::handle x : iterable) { + ys.push_back(nb::borrow(x)); + } + if (ys.size() != node.arity) { + throw std::invalid_argument("Arity mismatch between trees"); + } + for (int j = node.arity - 1; j >= 0; --j) { + ys[j] = FromIterableTreeHelper(ys[j], it); + } + + return MakeNode(node, absl::MakeSpan(ys)); +} + +nb::object PyTreeDef::FromIterableTree(nb::handle xs) const { + auto it = traversal_.rbegin(); + nb::object out = FromIterableTreeHelper(xs, &it); + if (it != traversal_.rend()) { + throw std::invalid_argument("Tree structures did not match."); + } + return out; +} + +nb_class_ptr PyTreeDef::Compose(const PyTreeDef& inner) const { + if (inner.registry_ != registry_) { + throw std::invalid_argument( + "PyTree registries of PyTreeDefs passed to Compose() must match."); + } + auto out = make_nb_class(registry_ref_); + out->traversal_.reserve(static_cast(num_leaves()) * + inner.num_nodes() + + num_nodes() - num_leaves()); + for (const Node& n : traversal_) { + if (n.kind == PyTreeKind::kLeaf) { + absl::c_copy(inner.traversal_, std::back_inserter(out->traversal_)); + } else { + out->traversal_.push_back(n); + } + } + out->SetNumLeavesAndNumNodes(); + return out; +} + +/*static*/ nb_class_ptr PyTreeDef::Tuple( + nb_class_ptr registry, nb::list defs) { + auto out = make_nb_class(std::move(registry)); + int num_leaves = 0; + for (nb::handle def_handle : defs) { + const PyTreeDef* def = nb::cast(def_handle); + if (def->registry() != out->registry()) { + throw std::invalid_argument( + "PyTree registries of PyTreeDefs passed to Tuple() must match."); + } + absl::c_copy(def->traversal_, std::back_inserter(out->traversal_)); + num_leaves += def->num_leaves(); + } + Node node; + node.kind = PyTreeKind::kTuple; + node.arity = defs.size(); + node.num_leaves = num_leaves; + node.num_nodes = out->traversal_.size() + 1; + out->traversal_.push_back(node); + return out; +} + +std::vector> PyTreeDef::Children() const { + std::vector> children; + if (traversal_.empty()) { + return children; + } + Node const& root = traversal_.back(); + children.resize(root.arity); + int pos = traversal_.size() - 1; + for (int i = root.arity - 1; i >= 0; --i) { + children[i] = make_nb_class(registry_ref_); + const Node& node = traversal_.at(pos - 1); + if (pos < node.num_nodes) { + throw std::logic_error("children() walked off start of array"); + } + std::copy(traversal_.begin() + pos - node.num_nodes, + traversal_.begin() + pos, + std::back_inserter(children[i]->traversal_)); + pos -= node.num_nodes; + } + if (pos != 0) { + throw std::logic_error("pos != 0 at end of PyTreeDef::Children"); + } + return children; +} + +std::string PyTreeDef::ToString() const { + std::vector agenda; + for (const Node& node : traversal_) { + if (agenda.size() < node.arity) { + throw std::logic_error("Too few elements for container."); + } + + std::string children = + absl::StrJoin(agenda.end() - node.arity, agenda.end(), ", "); + std::string representation; + switch (node.kind) { + case PyTreeKind::kLeaf: + agenda.push_back("*"); + continue; + case PyTreeKind::kNone: + representation = "None"; + break; + case PyTreeKind::kTuple: + // Tuples with only one element must have a trailing comma. + if (node.arity == 1) children += ","; + representation = absl::StrCat("(", children, ")"); + break; + case PyTreeKind::kList: + representation = absl::StrCat("[", children, "]"); + break; + case PyTreeKind::kDict: { + if (node.sorted_dict_keys.size() != node.arity) { + throw std::logic_error("Number of keys and entries does not match."); + } + representation = "{"; + std::string separator; + auto child_iter = agenda.end() - node.arity; + for (const nb::handle& key : node.sorted_dict_keys) { + absl::StrAppendFormat(&representation, "%s%s: %s", separator, + nb::cast(nb::repr(key)), + *child_iter); + child_iter++; + separator = ", "; + } + representation += "}"; + break; + } + + case PyTreeKind::kNamedTuple: + case PyTreeKind::kCustom: + case PyTreeKind::kDataclass: { + std::string kind; + std::string data; + if (node.kind == PyTreeKind::kNamedTuple) { + kind = "namedtuple"; + if (node.node_data) { + // Node data for named tuples is the type. + data = absl::StrFormat( + "[%s]", nb::cast( + nb::str(nb::getattr(node.node_data, "__name__")))); + } + } else { + kind = nb::cast( + nb::str(nb::getattr(node.custom->type, "__name__"))); + if (node.node_data) { + data = absl::StrFormat( + "[%s]", nb::cast(nb::str(node.node_data))); + } + } + + representation = + absl::StrFormat("CustomNode(%s%s, [%s])", kind, data, children); + break; + } + } + agenda.erase(agenda.end() - node.arity, agenda.end()); + agenda.push_back(std::move(representation)); + } + if (agenda.size() != 1) { + throw std::logic_error("PyTreeDef traversal did not yield a singleton."); + } + return absl::StrCat("PyTreeDef(", agenda.back(), ")"); +} + +nb::object PyTreeDef::ToPickle() const { + nb::list traversal; + for (const auto& node : traversal_) { + nb::object node_data = node.node_data; + if (node.kind == PyTreeKind::kDict) { + // Convert to a nb::list for pickling to avoid having to pickle a vector. + // Pickle should be a rare operation so this conversion cost is hopefully + // on non-critical path. + node_data = nb::cast(node.sorted_dict_keys); + } + traversal.append( + nb::make_tuple(static_cast(node.kind), node.arity, + node_data ? node_data : nb::none(), + node.custom != nullptr ? node.custom->type : nb::none(), + node.num_leaves, node.num_nodes)); + } + return nb::make_tuple(nb::cast(registry_ref_), traversal); +} + +void PyTreeDef::FromPickle(nb::object pickle) { + for (const auto& item : nb::cast(pickle)) { + auto t = nb::cast(item); + if (t.size() != 6) { + throw xla::XlaRuntimeError("Malformed pickled PyTreeDef"); + } + Node& node = traversal_.emplace_back(); + node.kind = static_cast(nb::cast(t[0])); + node.arity = nb::cast(t[1]); + switch (node.kind) { + case PyTreeKind::kNamedTuple: + node.node_data = t[2]; + break; + case PyTreeKind::kDict: + node.sorted_dict_keys = nb::cast>(t[2]); + break; + case PyTreeKind::kCustom: + case PyTreeKind::kDataclass: + node.node_data = t[2]; + break; + default: + if (!t[2].is_none()) { + throw xla::XlaRuntimeError("Malformed pickled PyTreeDef"); + } + break; + } + if (node.kind == PyTreeKind::kCustom || + node.kind == PyTreeKind::kDataclass) { + node.custom = t[3].is_none() ? nullptr : registry()->Lookup(t[3]); + if (node.custom == nullptr) { + throw xla::XlaRuntimeError( + absl::StrCat("Unknown custom type in pickled PyTreeDef: ", + nb::cast(nb::repr(t[3])))); + } + } else { + if (!t[3].is_none()) { + throw xla::XlaRuntimeError("Malformed pickled PyTreeDef"); + } + } + node.num_leaves = nb::cast(t[4]); + node.num_nodes = nb::cast(t[5]); + } +} + +void PyTreeDef::SetNumLeavesAndNumNodes() { + // num_leaves and num_nodes are fully determined by arity. + std::vector> starts; + int num_leaves = 0; + for (int i = 0; i < traversal_.size(); ++i) { + std::pair start = {num_leaves, i}; + if (traversal_[i].kind == PyTreeKind::kLeaf) { + num_leaves += 1; + } + if (traversal_[i].arity == 0) { + starts.push_back(start); + } else { + starts.resize(starts.size() - (traversal_[i].arity - 1)); + } + traversal_[i].num_leaves = num_leaves - starts.back().first; + traversal_[i].num_nodes = i + 1 - starts.back().second; + } +} + +void PyTreeDef::SerializeTo(jax::PyTreeDefProto& result) const { + absl::flat_hash_map interned_strings; + auto intern_str = [&](const std::string& key) { + auto [it, added] = + interned_strings.emplace(key, result.interned_strings_size()); + if (added) { + result.add_interned_strings(key); + } + return it->second; + }; + for (const auto& node : traversal_) { + auto* node_data = result.add_nodes(); + node_data->set_arity(node.arity); + switch (node.kind) { + case PyTreeKind::kLeaf: + node_data->set_type(jax::PyTreeNodeType::PY_TREE_KIND_LEAF); + break; + case PyTreeKind::kList: + node_data->set_type(jax::PyTreeNodeType::PY_TREE_KIND_LIST); + break; + case PyTreeKind::kNone: + node_data->set_type(jax::PyTreeNodeType::PY_TREE_KIND_NONE); + break; + case PyTreeKind::kTuple: + node_data->set_type(jax::PyTreeNodeType::PY_TREE_KIND_TUPLE); + break; + case PyTreeKind::kDict: + node_data->set_type(jax::PyTreeNodeType::PY_TREE_KIND_DICT); + for (auto& key : node.sorted_dict_keys) { + if (!nb::isinstance(key)) { + throw std::invalid_argument( + "Only string keys are supported in proto pytree " + "serialization."); + } + node_data->mutable_dict_keys()->add_str_id( + intern_str(nb::cast(key))); + } + break; + default: + throw std::invalid_argument( + "User-defined nodes are not supported when serializing pytrees as " + "protocol buffers. You should either convert the user-defined " + "nodes to another type or use pickle instead."); + break; + } + } +} + +nb_class_ptr PyTreeDef::DeserializeFrom( + nb_class_ptr registry, const jax::PyTreeDefProto& input) { + std::vector interned_strings; + interned_strings.reserve(input.interned_strings().size()); + for (auto& s : input.interned_strings()) { + interned_strings.push_back(nb::cast(s)); + } + nb_class_ptr result = + make_nb_class(std::move(registry)); + for (auto& node_proto : input.nodes()) { + result->traversal_.emplace_back(); + auto& node = result->traversal_.back(); + node.arity = node_proto.arity(); + node.custom = nullptr; + switch (node_proto.type()) { + case jax::PyTreeNodeType::PY_TREE_KIND_LEAF: + node.kind = PyTreeKind::kLeaf; + break; + case jax::PyTreeNodeType::PY_TREE_KIND_LIST: + node.kind = PyTreeKind::kList; + break; + case jax::PyTreeNodeType::PY_TREE_KIND_NONE: + node.kind = PyTreeKind::kNone; + break; + case jax::PyTreeNodeType::PY_TREE_KIND_TUPLE: + node.kind = PyTreeKind::kTuple; + break; + case jax::PyTreeNodeType::PY_TREE_KIND_DICT: + node.kind = PyTreeKind::kDict; + for (uint32_t str_id : node_proto.dict_keys().str_id()) { + if (str_id >= interned_strings.size()) { + throw std::invalid_argument( + "Malformed pytree proto (dict_key out of range)."); + } + node.sorted_dict_keys.push_back(interned_strings.at(str_id)); + } + break; + default: + throw std::invalid_argument( + "Malformed pytree proto (invalid node type)"); + break; + } + } + result->SetNumLeavesAndNumNodes(); + return result; +} + +std::optional> PyTreeDef::GetNodeData() + const { + if (traversal_.empty()) { + throw std::logic_error("empty PyTreeDef traversal."); + } + auto builtin_type = [](PyTypeObject* type_obj) { + return nb::borrow(reinterpret_cast(type_obj)); + }; + const auto& node = traversal_.back(); + switch (node.kind) { + case PyTreeKind::kLeaf: + return std::nullopt; + case PyTreeKind::kNone: + return std::make_pair(builtin_type(Py_TYPE(Py_None)), nb::none()); + case PyTreeKind::kTuple: + return std::make_pair(builtin_type(&PyTuple_Type), nb::none()); + case PyTreeKind::kList: + return std::make_pair(builtin_type(&PyList_Type), nb::none()); + case PyTreeKind::kDict: + return std::make_pair(builtin_type(&PyDict_Type), + nb::cast(node.sorted_dict_keys)); + case PyTreeKind::kNamedTuple: + return std::make_pair(node.node_data, nb::none()); + case PyTreeKind::kCustom: + case PyTreeKind::kDataclass: + return std::make_pair(node.custom->type, node.node_data); + } +} + +nb_class_ptr PyTreeDef::MakeFromNodeDataAndChildren( + nb_class_ptr registry, + std::optional> node_data, + nb::iterable children) { + nb_class_ptr result = + make_nb_class(std::move(registry)); + int num_leaves = 0; + int arity = 0; + for (nb::handle pchild : children) { + const PyTreeDef& child = nb::cast(pchild); + absl::c_copy(child.traversal_, std::back_inserter(result->traversal_)); + num_leaves += child.num_leaves(); + ++arity; + } + result->traversal_.emplace_back(); + auto& node = result->traversal_.back(); + node.arity = arity; + node.custom = nullptr; + node.num_leaves = num_leaves; + node.num_nodes = result->traversal_.size(); + if (node_data == std::nullopt) { + node.kind = PyTreeKind::kLeaf; + ++node.num_leaves; + return result; + } + int is_nt = PyObject_IsSubclass(node_data->first.ptr(), + reinterpret_cast(&PyTuple_Type)); + if (is_nt == -1) { + throw nb::python_error(); + } + if (is_nt != 0 && nb::hasattr(node_data->first, "_fields")) { + node.kind = PyTreeKind::kNamedTuple; + node.node_data = node_data->first; + return result; + } + auto* registration = result->registry()->Lookup(node_data->first); + if (registration == nullptr) { + throw std::logic_error(absl::StrFormat( + "Could not find type: %s.", + nb::cast(nb::repr(node_data->first)))); + } + node.kind = registration->kind; + if (node.kind == PyTreeKind::kCustom || node.kind == PyTreeKind::kDataclass) { + node.custom = registration; + node.node_data = node_data->second; + } else if (node.kind == PyTreeKind::kNamedTuple) { + node.node_data = node_data->first; + } else if (node.kind == PyTreeKind::kDict) { + node.sorted_dict_keys = + nb::cast>(node_data->second); + } + return result; +} + +int PyTreeDef::Node::tp_traverse(visitproc visit, void* arg) const { + Py_VISIT(node_data.ptr()); + for (const auto& key : sorted_dict_keys) { + Py_VISIT(key.ptr()); + } + return 0; +} + +/* static */ int PyTreeDef::tp_traverse(PyObject* self, visitproc visit, + void* arg) { + PyTreeDef* treedef = nb::inst_ptr(self); + Py_VISIT(Py_TYPE(self)); + Py_VISIT(treedef->registry_ref_.ptr()); + for (const auto& node : treedef->traversal_) { + node.tp_traverse(visit, arg); + } + return 0; +} + +/* static */ int PyTreeDef::tp_clear(PyObject* self) { + PyTreeDef* treedef = nb::inst_ptr(self); + treedef->registry_ref_.reset(); + treedef->traversal_.clear(); + return 0; +} + +/* static */ PyType_Slot PyTreeDef::slots_[] = { + {Py_tp_traverse, (void*)PyTreeDef::tp_traverse}, + {Py_tp_clear, (void*)PyTreeDef::tp_clear}, + {0, nullptr}, +}; + +void BuildPytreeSubmodule(nb::module_& m) { + nb::module_ pytree = m.def_submodule("pytree", "Python tree library"); + pytree.attr("version") = nb::int_(3); + + nb::class_ treedef(pytree, "PyTreeDef", + nb::type_slots(PyTreeDef::slots_)); + + nb::class_ registry(m, "PyTreeRegistry", nb::dynamic_attr(), + nb::type_slots(PyTreeRegistry::slots_)); + + registry.def(nb::init(), + nb::arg("enable_none") = true, nb::arg("enable_tuple") = true, + nb::arg("enable_namedtuple") = true, + nb::arg("enable_list") = true, nb::arg("enable_dict") = true); + registry.def( + "flatten", + [](nb_class_ptr registry, nb::object x, + std::optional leaf_predicate) { + nb::list leaves; + nb_class_ptr def = + make_nb_class(std::move(registry)); + def->Flatten(x, leaves, leaf_predicate); + return nb::make_tuple(std::move(leaves), std::move(def)); + }, + nb::arg("tree").none(), nb::arg("leaf_predicate").none() = std::nullopt); + registry.def("flatten_one_level", &PyTreeRegistry::FlattenOneLevel, + nb::arg("tree").none()); + registry.def("flatten_one_level_with_keys", + &PyTreeRegistry::FlattenOneLevelWithKeys, + nb::arg("tree").none()); + registry.def( + "flatten_with_path", + [](nb_class_ptr registry, nb::object x, + std::optional leaf_predicate) { + nb::list leaves; + nb_class_ptr def = + make_nb_class(std::move(registry)); + def->FlattenWithPath(x, leaves, leaf_predicate); + return nb::make_tuple(std::move(leaves), std::move(def)); + }, + nb::arg("tree").none(), nb::arg("leaf_predicate").none() = std::nullopt); + registry.def("register_node", &PyTreeRegistry::Register, + nb::arg("type").none(), nb::arg("to_iterable").none(), + nb::arg("from_iterable").none(), + nb::arg("to_iterable_with_keys").none() = std::nullopt); + registry.def("register_dataclass_node", &PyTreeRegistry::RegisterDataclass); + registry.def("__reduce__", + [](nb::object self) { return self.attr("__name__"); }); + + pytree.attr("_default_registry") = make_nb_class( + /*enable_none=*/true, /*enable_tuple=*/true, /*enable_namedtuple=*/true, + /*enable_list=*/true, /*enable_dict*/ true); + pytree.def("default_registry", + [registry = nb::cast>( + pytree.attr("_default_registry"))]() { return registry; }); + + pytree.attr("PyTreeRegistry") = m.attr("PyTreeRegistry"); + pytree.def("tuple", &PyTreeDef::Tuple); + pytree.def("all_leaves", &PyTreeDef::AllLeaves); + + treedef.def("unflatten", + static_cast( + &PyTreeDef::Unflatten)); + treedef.def("flatten_up_to", &PyTreeDef::FlattenUpTo, nb::arg("tree").none()); + treedef.def("compose", &PyTreeDef::Compose); + treedef.def( + "walk", &PyTreeDef::Walk, + "Walk pytree, calling f_node(node, node_data) at nodes, and f_leaf " + "at leaves", + nb::arg("f_node"), nb::arg("f_leaf"), nb::arg("leaves")); + treedef.def("from_iterable_tree", &PyTreeDef::FromIterableTree); + treedef.def("children", &PyTreeDef::Children); + treedef.def_prop_ro("num_leaves", &PyTreeDef::num_leaves); + treedef.def_prop_ro("num_nodes", &PyTreeDef::num_nodes); + treedef.def("__repr__", &PyTreeDef::ToString); + treedef.def("__eq__", + [](const PyTreeDef& a, const PyTreeDef& b) { return a == b; }); + treedef.def("__ne__", + [](const PyTreeDef& a, const PyTreeDef& b) { return a != b; }); + treedef.def("__hash__", [](const PyTreeDef& t) { return absl::HashOf(t); }); + treedef.def("serialize_using_proto", [](const PyTreeDef& a) { + jax::PyTreeDefProto result; + a.SerializeTo(result); + std::string serialized = result.SerializeAsString(); + return nb::bytes(serialized.data(), serialized.size()); + }); + treedef.def_static( + "deserialize_using_proto", + [](nb_class_ptr registry, nb::bytes data) { + jax::PyTreeDefProto input; + absl::string_view serialized(data.c_str(), data.size()); + if (serialized.size() > std::numeric_limits::max()) { + throw xla::XlaRuntimeError( + "Pytree serialization too large to deserialize."); + } + if (!input.ParseFromArray(serialized.data(), serialized.size())) { + throw xla::XlaRuntimeError("Could not deserialize PyTreeDefProto."); + } + return PyTreeDef::DeserializeFrom(std::move(registry), input); + }, + nb::arg("registry"), nb::arg("data")); + treedef.def("node_data", &PyTreeDef::GetNodeData, + "Returns None if a leaf-pytree, else (type, node_data)"); + treedef.def_static( + "make_from_node_data_and_children", + &PyTreeDef::MakeFromNodeDataAndChildren, nb::arg("registry"), + nb::arg("node_data").none(), nb::arg("children"), + "Reconstructs a pytree from `node_data()` and `children()`."); + treedef.def("__getstate__", &PyTreeDef::ToPickle); + treedef.def("__setstate__", [](PyTreeDef& t, nb::object o) { + nb::tuple pickle = nb::cast(o); + if (pickle.size() != 2) { + throw xla::XlaRuntimeError( + "Malformed pickled PyTreeDef, expected 2-tuple"); + } + auto registry = nb::cast>(pickle[0]); + new (&t) PyTreeDef(registry); + t.FromPickle(pickle[1]); + }); + + nb::class_ sequence_key(pytree, "SequenceKey"); + sequence_key.def(nb::init(), nb::arg("idx")); + sequence_key.def("__str__", &SequenceKey::ToString); + sequence_key.def("__repr__", &SequenceKey::ToReprString); + sequence_key.def("__eq__", &SequenceKey::Equals); + sequence_key.def("__hash__", [](const SequenceKey& key) { + return key.idx() + kSequenceKeyHashSalt; + }); + sequence_key.def_prop_ro("idx", &SequenceKey::idx); + sequence_key.def_prop_ro_static("__match_args__", &SequenceKey::MatchArgs); + sequence_key.def("__getstate__", + [](SequenceKey& key) { return nb::make_tuple(key.idx()); }); + sequence_key.def("__setstate__", + [](SequenceKey& key, const nb::tuple& state) { + if (state.size() != 1) { + throw xla::XlaRuntimeError( + "Malformed pickled SequenceKey, expected 1-tuple"); + } + new (&key) SequenceKey(nb::cast(state[0])); + }); + + nb::class_ dict_key(pytree, "DictKey", + nb::type_slots(DictKey::slots_)); + dict_key.def(nb::init(), nb::arg("key")); + dict_key.def("__str__", &DictKey::ToString); + dict_key.def("__repr__", &DictKey::ToReprString); + dict_key.def("__eq__", &DictKey::Equals); + dict_key.def("__hash__", + [](const DictKey& key) { return nanobind::hash(key.key()); }); + dict_key.def_prop_ro("key", &DictKey::key); + dict_key.def_prop_ro_static("__match_args__", &DictKey::MatchArgs); + dict_key.def("__getstate__", + [](DictKey& key) { return nb::make_tuple(key.key()); }); + dict_key.def("__setstate__", [](DictKey& key, const nb::tuple& state) { + if (state.size() != 1) { + throw xla::XlaRuntimeError("Malformed pickled DictKey, expected 1-tuple"); + } + new (&key) DictKey(nb::cast(state[0])); + }); + + nb::class_ get_attr_key(pytree, "GetAttrKey"); + get_attr_key.def(nb::init(), nb::arg("name")); + get_attr_key.def("__str__", &GetAttrKey::ToString); + get_attr_key.def("__repr__", &GetAttrKey::ToReprString); + get_attr_key.def("__eq__", &GetAttrKey::Equals); + get_attr_key.def("__hash__", + [](const GetAttrKey& key) { return nb::hash(key.name()); }); + get_attr_key.def_prop_ro("name", &GetAttrKey::name); + get_attr_key.def_prop_ro_static("__match_args__", &GetAttrKey::MatchArgs); + get_attr_key.def("__getstate__", + [](GetAttrKey& key) { return nb::make_tuple(key.name()); }); + get_attr_key.def("__setstate__", [](GetAttrKey& key, const nb::tuple& state) { + if (state.size() != 1) { + throw xla::XlaRuntimeError( + "Malformed pickled GetAttrKey, expected 1-tuple"); + } + new (&key) GetAttrKey(nb::str(state[0])); + }); + + nb::class_ flattened_index_key(pytree, + "FlattenedIndexKey"); + flattened_index_key.def(nb::init(), nb::arg("key")); + flattened_index_key.def("__str__", &FlattenedIndexKey::ToString); + flattened_index_key.def("__repr__", &FlattenedIndexKey::ToReprString); + flattened_index_key.def("__eq__", &FlattenedIndexKey::Equals); + flattened_index_key.def("__hash__", [](const FlattenedIndexKey& key) { + return key.key() + kFlattenedIndexKeyHashSalt; + }); + flattened_index_key.def_prop_ro("key", &FlattenedIndexKey::key); + flattened_index_key.def_prop_ro_static("__match_args__", + &FlattenedIndexKey::MatchArgs); + flattened_index_key.def("__getstate__", [](FlattenedIndexKey& key) { + return nb::make_tuple(key.key()); + }); + flattened_index_key.def( + "__setstate__", [](FlattenedIndexKey& key, const nb::tuple& state) { + if (state.size() != 1) { + throw xla::XlaRuntimeError( + "Malformed pickled FlattenedIndexKey, expected 1-tuple"); + } + new (&key) FlattenedIndexKey(nb::cast(state[0])); + }); +} + +} // namespace xla diff --git a/jaxlib/xla/pytree.h b/jaxlib/xla/pytree.h new file mode 100644 index 000000000000..722fe41169a0 --- /dev/null +++ b/jaxlib/xla/pytree.h @@ -0,0 +1,408 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_PYTREE_H_ +#define JAXLIB_XLA_PYTREE_H_ + +// See https://jax.readthedocs.io/en/latest/pytrees.html for the documentation +// about pytree. + +#include + +#include +#include +#include +#include +#include +#include + +// placeholder for index annotation headers +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "jaxlib/xla/pytree.pb.h" +#include "xla/python/nb_class_ptr.h" + +namespace xla { + +enum class PyTreeKind { + kLeaf, // An opaque leaf node + kNone, // None. + kTuple, // A tuple + kNamedTuple, // A collections.namedtuple + kList, // A list + kDict, // A dict + kCustom, // A custom type. + kDataclass, // A dataclass. +}; + +// Registry of custom node types. +class PyTreeRegistry { + public: + PyTreeRegistry(bool enable_none, bool enable_tuple, bool enable_namedtuple, + bool enable_list, bool enable_dict); + + PyTreeRegistry(const PyTreeRegistry&) = delete; + PyTreeRegistry(PyTreeRegistry&&) = delete; + PyTreeRegistry& operator=(const PyTreeRegistry&) = delete; + PyTreeRegistry& operator=(PyTreeRegistry&&) = delete; + + struct Registration { + PyTreeKind kind; + + // The following values are populated for custom types. + // The Python type object, used to identify the type. + nanobind::object type; + // A function with signature: object -> (iterable, aux_data) + nanobind::callable to_iterable; + // A function with signature: (aux_data, iterable) -> object + nanobind::callable from_iterable; + // A function with signature: (aux_data, iterable(keypath, leaf)) -> object + std::optional to_iterable_with_keys; + + // Helper that calls to_iterable and validates that it returns a pair + // of an iterable and an aux_data object + std::pair ToIterable( + nanobind::handle o) const; + // Helper that calls to_iterable_with_keys and validates that it returns a + // pair of an iterable of key-leaf pairs and an aux_data object. If + // to_iterable_with_keys is not available, return a dummy key for each leaf, + // similar to the current jax.tree_util.FlattenedIndexKey. + std::pair>, + nanobind::object> + ToIterableWithKeys(nanobind::handle o) const; + + // For dataclasses. + std::vector data_fields; + std::vector meta_fields; + + int tp_traverse(visitproc visit, void* arg); + }; + + // Registers a new custom type. Objects of `type` will be treated as container + // node types in PyTrees. + void Register( + nanobind::object type, nanobind::callable to_iterable, + nanobind::callable from_iterable, + std::optional to_iterable_with_keys = std::nullopt); + // Same, but for dataclasses. + void RegisterDataclass(nanobind::object type, + std::vector data_fields, + std::vector meta_fields); + + // Finds the custom type registration for `type`. Returns nullptr if none + // exists. + const Registration* Lookup(nanobind::handle type) const; + + PyTreeKind KindOfObject(nanobind::handle obj, + PyTreeRegistry::Registration const** custom) const; + + // Flattens a pytree one level, returning either a tuple of the leaves and + // the node data, or None, if the entry is a leaf. + nanobind::object FlattenOneLevel(nanobind::handle x) const; + // Similar to above but returns a key-leaf pair for each leaf. + nanobind::object FlattenOneLevelWithKeys(nanobind::handle x) const; + // Underlying implementation of FlattenOneLevel and FlattenOneLevelWithKeys. + nanobind::object FlattenOneLevelImpl(nanobind::handle x, + bool with_keys) const; + + static PyType_Slot slots_[]; + + private: + struct TypeHash { + using is_transparent = void; + size_t operator()(const nanobind::object& t) const { + return absl::HashOf(t.ptr()); + } + size_t operator()(const nanobind::handle& t) const { + return absl::HashOf(t.ptr()); + } + }; + struct TypeEq { + using is_transparent = void; + bool operator()(const nanobind::object& a, + const nanobind::object& b) const { + return a.ptr() == b.ptr(); + } + bool operator()(const nanobind::object& a, + const nanobind::handle& b) const { + return a.ptr() == b.ptr(); + } + }; + mutable nanobind::ft_mutex mu_; + absl::flat_hash_map, TypeHash, + TypeEq> + registrations_; // Guarded by mu_ + bool enable_namedtuple_; + + static int tp_traverse(PyObject* self, visitproc visit, void* arg); + static int tp_clear(PyObject* self); +}; + +class SequenceKey { + public: + explicit SequenceKey(int idx) : idx_(idx) {}; + std::string ToReprString() const; + std::string ToString() const; + bool Equals(const nanobind::object& other); + int idx() const { return idx_; } + static nanobind::tuple MatchArgs(nanobind::handle unused); + + private: + int idx_; +}; + +class DictKey { + public: + explicit DictKey(nanobind::object key) : key_(key) {}; + std::string ToReprString() const; + std::string ToString() const; + bool Equals(const nanobind::object& other); + nanobind::object key() const { return key_; } + static nanobind::tuple MatchArgs(nanobind::handle unused); + static PyType_Slot slots_[]; + + private: + nanobind::object key_; + static int tp_traverse(PyObject* self, visitproc visit, void* arg); + static int tp_clear(PyObject* self); +}; + +class GetAttrKey { + public: + explicit GetAttrKey(nanobind::str name) : name_(name) {}; + std::string ToReprString() const; + std::string ToString() const; + bool Equals(const nanobind::object& other); + nanobind::str name() const { return name_; } + static nanobind::tuple MatchArgs(nanobind::handle unused); + + private: + nanobind::str name_; +}; + +class FlattenedIndexKey { + public: + explicit FlattenedIndexKey(int key) : key_(key) {}; + std::string ToReprString() const; + std::string ToString() const; + bool Equals(const nanobind::object& other); + int key() const { return key_; } + static nanobind::tuple MatchArgs(nanobind::handle unused); + + private: + int key_; +}; + +// A PyTreeDef describes the tree structure of a PyTree. A PyTree is a tree of +// Python values, where the interior nodes are tuples, lists, dictionaries, or +// user-defined containers, and the leaves are other objects. +class PyTreeDef { + public: + // Unowned registry: the registry must remain live at least as long as the + // PyTreeDef. It is the caller's responsibility to enforce this. + explicit PyTreeDef(PyTreeRegistry* registry) : registry_(registry) {} + + explicit PyTreeDef(nb_class_ptr registry) + : registry_(registry.get()), registry_ref_(std::move(registry)) {} + + // Flattens a Pytree into a list of leaves and a PyTreeDef. + // Returns references to the flattened objects, which might be temporary + // objects in the case of custom pytype handlers. + static std::pair, nb_class_ptr> + Flatten(nanobind::handle x, nb_class_ptr registry, + std::optional leaf_predicate = std::nullopt); + + // Flattens a Pytree into a list of `leaves` and a PyTreeDef (this). + // `leaves` owns references to the flattened objects, which might be + // temporary objects in the case of custom pytype handlers. + void Flatten(nanobind::handle handle, std::vector& leaves, + std::optional leaf_predicate = std::nullopt); + void Flatten(nanobind::handle handle, + absl::InlinedVector& leaves, + std::optional leaf_predicate = std::nullopt); + void Flatten(nanobind::handle handle, nanobind::list& leaves, + std::optional leaf_predicate = std::nullopt); + + void FlattenWithPath( + nanobind::handle handle, nanobind::list& leaves, + std::optional leaf_predicate = std::nullopt); + + // Tests whether the given list is a flat list of leaves. + static bool AllLeaves(PyTreeRegistry* registry, const nanobind::iterable& x); + + // Flattens a Pytree up to this PyTreeDef. 'this' must be a tree prefix of + // the tree-structure of 'x'. For example, if we flatten a value + // [(1, (2, 3)), {"foo": 4}] with a treedef [(*, *), *], the result is the + // list of leaves [1, (2, 3), {"foo": 4}]. + nanobind::list FlattenUpTo(nanobind::handle x) const; + + // Returns an unflattened PyTree given an iterable of leaves and a PyTreeDef. + nanobind::object Unflatten(nanobind::iterable leaves) const; + nanobind::object Unflatten(absl::Span leaves) const; + + // Composes two PyTreeDefs, replacing the leaves of this tree with copies of + // `inner`. The returned PyTreeDef holds a reference to its registry. + nb_class_ptr Compose(const PyTreeDef& inner) const; + + // Makes a Tuple PyTreeDef out of a vector of PyTreeDefs. + static nb_class_ptr Tuple(nb_class_ptr registry, + nanobind::list defs); + + // The returned PyTreeDefs hold a reference to the registry. + std::vector> Children() const; + + // Maps a function over a PyTree structure, applying f_leaf to each leaf, and + // f_node(node, node_data) to each container node. + nanobind::object Walk(const nanobind::callable& f_node, + nanobind::handle f_leaf, + nanobind::iterable leaves) const; + + // Given a tree of iterables with the same node/leaf structure as this PyTree, + // build the corresponding PyTree. + // TODO(phawkins): use flattening everywhere instead and delete this method. + nanobind::object FromIterableTree(nanobind::handle xs) const; + + int num_leaves() const { + if (traversal_.empty()) { + return 0; + } + return traversal_.back().num_leaves; + } + + int num_nodes() const { return traversal_.size(); } + + PyTreeRegistry* registry() const { return registry_; } + + size_t Hash() const; + + bool operator==(const PyTreeDef& other) const; + bool operator!=(const PyTreeDef& other) const { return !(*this == other); } + + std::string ToString() const; + + // Transforms the PyTreeDef into a pickleable object. Used to implement + // `PyTreeDef.__getstate__`. + nanobind::object ToPickle() const; + + // Transforms the object returned by `ToPickleable()` back to PyTreeDef. Used + // to implement `PyTreeDef.__setstate__`. + void FromPickle(nanobind::object pickleable); + + void SerializeTo(jax::PyTreeDefProto& result) const; + + static nb_class_ptr DeserializeFrom( + nb_class_ptr registry, const jax::PyTreeDefProto& input); + + std::optional> GetNodeData() + const; + + static nb_class_ptr MakeFromNodeDataAndChildren( + nb_class_ptr registry, + std::optional> node_data, + nanobind::iterable children); + + static PyType_Slot slots_[]; + + private: + void SetNumLeavesAndNumNodes(); + + struct Node { + PyTreeKind kind = PyTreeKind::kLeaf; + + // Arity for non-kLeaf types. + int arity = 0; + + // Kind-specific auxiliary data. For a kNamedTuple, contains the tuple type + // object. For a kDict, use `sorted_dict_keys` field below. For a kCustom + // type, contains the auxiliary data returned by the `to_iterable` function. + nanobind::object node_data; + + // Kind-specific auxiliary data specialized for kDict. Use a c++ vector + // to hold the sorted dict keys instead of a py::list to avoid creating + // a new python list object when flattening kDict. For deeply nested dict, + // using c++ vector instead of py::list avoids creating too many python + // objects that make python gc sweep slow. + std::vector sorted_dict_keys; + + // Custom type registration. Must be null for non-custom types. + const PyTreeRegistry::Registration* custom = nullptr; + + // Number of leaf nodes in the subtree rooted at this node. + int num_leaves = 0; + + // Number of leaf and interior nodes in the subtree rooted at this node. + int num_nodes = 0; + + int tp_traverse(visitproc visit, void* arg) const; + }; + template + friend H AbslHashValue(H h, const Node& n); + + template + friend H AbslHashValue(H h, const PyTreeDef& t); + + // Helper that manufactures an instance of a node given its children. + static nanobind::object MakeNode(const Node& node, + absl::Span children); + + // Recursive helper used to implement FromIterableTree() + nanobind::object FromIterableTreeHelper( + nanobind::handle xs, + absl::InlinedVector::const_reverse_iterator* it) + const; + + template + void FlattenImpl(nanobind::handle handle, T& leaves, + const std::optional& leaf_predicate, + std::optional>& keypath); + + template + nanobind::object UnflattenImpl(T leaves) const; + + static int tp_traverse(PyObject* self, visitproc visit, void* arg); + static int tp_clear(PyObject* self); + + // Pytree registry. Not owned. + PyTreeRegistry* registry_; + // If this class holds a reference to `registry`, it is held by + // `registry_ref_`. + nb_class_ptr registry_ref_; + + // Nodes, in a post-order traversal. We use an ordered traversal to minimize + // allocations, and post-order corresponds to the order we need to rebuild the + // tree structure. + absl::InlinedVector traversal_; +}; + +template +H AbslHashValue(H h, const PyTreeDef::Node& n) { + h = H::combine(std::move(h), n.kind, n.arity, n.custom); + return h; +} + +template +H AbslHashValue(H h, const PyTreeDef& t) { + h = H::combine(std::move(h), t.traversal_); + return h; +} + +void BuildPytreeSubmodule(nanobind::module_& m); + +} // namespace xla + +#endif // JAXLIB_XLA_PYTREE_H_ diff --git a/jaxlib/xla/pytree.proto b/jaxlib/xla/pytree.proto new file mode 100644 index 000000000000..73c087ef55ab --- /dev/null +++ b/jaxlib/xla/pytree.proto @@ -0,0 +1,32 @@ +syntax = "proto3"; + +package jax; + +enum PyTreeNodeType { + PY_TREE_KIND_INVALID = 0; + PY_TREE_KIND_LEAF = 1; + PY_TREE_KIND_LIST = 2; + PY_TREE_KIND_NONE = 3; + PY_TREE_KIND_TUPLE = 4; + PY_TREE_KIND_DICT = 5; +} + +message DictKeysProto { + repeated uint32 str_id = 1; +} + +message PyTreeNodeDefProto { + // Recovers the tree structure. + uint32 arity = 1; + // Node type. + PyTreeNodeType type = 2; + // Only set when type == DICT. + DictKeysProto dict_keys = 3; +} + +// A Pytree. +message PyTreeDefProto { + repeated PyTreeNodeDefProto nodes = 1; + // Extra strings. + repeated string interned_strings = 2; +} diff --git a/jaxlib/xla/xla.cc b/jaxlib/xla/xla.cc index fdd4456b238c..bd3ed3205fb2 100644 --- a/jaxlib/xla/xla.cc +++ b/jaxlib/xla/xla.cc @@ -91,6 +91,7 @@ limitations under the License. #include "jaxlib/xla/mlir.h" #include "jaxlib/xla/pjit.h" #include "jaxlib/xla/pmap_lib.h" +#include "jaxlib/xla/pytree.h" #include "jaxlib/xla/weakref_lru_cache.h" #include "jaxlib/xla/xla_compiler.h" #include "xla/pjrt/distributed/key_value_store_interface.h" @@ -119,7 +120,6 @@ limitations under the License. #include "xla/python/py_executable.h" #include "xla/python/py_memory_space.h" #include "xla/python/python_ref_manager.h" -#include "xla/python/pytree.h" #include "xla/python/sharding.h" #include "xla/python/traceback.h" #include "xla/tsl/distributed_runtime/preemption/preemption_sync_manager.h" From b4922df2206707cdf94aec466708e4de2bc52d7c Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Tue, 25 Mar 2025 01:23:47 +0000 Subject: [PATCH 137/483] [attrs] allow setattr on a previously non-existant attr Before this change, we handled attrs for initial-style primitives like jit/scan like this: 1. the traceable would form a jaxpr and see what attrs were touched (by jax_getattr or jax_setattr), 2. for each such attr, the traceable would do jax_getattr to get the current value, tree-flatten, pass the flat valuesinto the (pure) bind, get the new values out, tree-unflatten, then jax_setattr the result. That approach would error if the function called `jax_setattr` to set a previously non-existant attr. That is, this would work: ```python from jax.experimental.attrs import jax_setattr class Thing: ... thing = Thing() jax_setattr(thing, 'x', 1.0) ``` but it wouldn't work under a `jax.jit`. This commit makes the same code work under a jit. We just 1. in partial_eval.py's `to_jaxpr`, ensure attrs added during jaxpr formation are deleted, using a special sentinel value `dne_sentinel` to indicate the attribute initially did not exist before tracing; 2. in pjit.py's `_get_states`, when reading initial attr values before the pjit_p bind, if the attribute does not exist we don't try to read it and instead just use `dne_sentinel` as the value, which is a convenient empty pytree; 3. in pjit.py's `_attr_token` for jit caching, when forming the cache key based on the current attr states, we map attrs that don't exist to `dne_sentinel` (rather than just erroring when the attr doesn't exist, as before). In short, we use a special value to indicate "does not exist". If `jax_getattr` supported the 'default' argument, the code would be a little cleaner since we could avoid the `if hasattr` stuff. And that's probably a useful feature to have anyway. We can add that in a follow-up. This PR only makes setattr-to-nonexistant-attr work with jit. We'll add scan etc in follow-ups. --- jax/_src/interpreters/partial_eval.py | 12 +++++--- jax/_src/pjit.py | 12 ++++---- jax/experimental/attrs.py | 3 +- tests/attrs_test.py | 43 +++++++++++++++++++++++++++ 4 files changed, 59 insertions(+), 11 deletions(-) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 07c516fd95c7..58b97ce2f3da 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -42,8 +42,8 @@ mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx, InputType, OutputType, get_referent, JaxprEqnContext) from jax._src.state.types import AbstractRef, ReadEffect -from jax._src.tree_util import (PyTreeDef, treedef_tuple, - tree_flatten, tree_structure) +from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_flatten, + tree_structure, register_static) from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list, merge_lists, partition_list, OrderedSet, as_hashable_function, weakref_lru_cache, subs_list, @@ -1699,7 +1699,7 @@ def to_jaxpr(self, trace: DynamicJaxprTrace, jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals) jaxpr, constvals = _inline_literals(jaxpr, constvals) init_trees = [tree_structure(init_val) for init_val in self.attrs_inits] - set_states(self.attrs_tracked, self.attrs_inits) + set_states(self.attrs_tracked, self.attrs_inits) # reset to initial values return jaxpr, list(constvals), zip(init_trees, end_trees, self.attrs_tracked) def to_jaxpr2(self, out_tracers: Sequence[core.Tracer], @@ -2246,11 +2246,15 @@ def trace_to_jaxpr_dynamic2( AttrStates = list def set_states(attrs_tracked: AttrsTracked, vals: AttrStates): for ((obj, attr), val) in zip(attrs_tracked, vals): - setattr(obj, attr, val) + setattr(obj, attr, val) if val is not dne_sentinel else delattr(obj, attr) def get_states(attrs_tracked: AttrsTracked): return [getattr(obj, attr) for (obj, attr) in attrs_tracked] +@register_static +class DoesNotExist: ... +dne_sentinel = DoesNotExist() + def infer_lambda_input_type( axes_specs: Sequence[AbstractedAxesSpec] | None, diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index d690cd6e9c67..bcdbe6b1bdb7 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -240,10 +240,10 @@ def _set_states(attrs_tracked, vals): jax_setattr(obj, attr, val) def _get_states(attrs_tracked): - from jax.experimental.attrs import jax_getattr + from jax.experimental.attrs import jax_getattr, dne_sentinel vals = [] for treedef, _, (obj, attr) in attrs_tracked: - tree = jax_getattr(obj, attr) + tree = jax_getattr(obj, attr) if hasattr(obj, attr) else dne_sentinel leaves, treedef_ = tree_flatten(tree) assert treedef == treedef_ vals.extend(leaves) @@ -1354,11 +1354,11 @@ def _attr_token( fun: lu.WrappedFun, in_type: core.InputType | tuple[core.AbstractValue, ...] ) -> int: - from jax.experimental.attrs import jax_getattr + from jax.experimental.attrs import jax_getattr, dne_sentinel cases = seen_attrs_get(fun, in_type) for i, records in enumerate(cases): for obj, attr, treedef, avals in records: - val = jax_getattr(obj, attr) + val = jax_getattr(obj, attr) if hasattr(obj, attr) else dne_sentinel vals, treedef_ = tree_flatten(val) avals_ = map(core.shaped_abstractify, vals) if treedef != treedef_ or avals != avals_: break @@ -1367,8 +1367,8 @@ def _attr_token( return len(cases) def _attr_update(fun, in_type, i, attrs_tracked): - from jax.experimental.attrs import jax_getattr - leaves = lambda obj, attr: tree_leaves(jax_getattr(obj, attr)) + from jax.experimental.attrs import jax_getattr, dne_sentinel + leaves = lambda obj, attr: tree_leaves(jax_getattr(obj, attr) if hasattr(obj, attr) else dne_sentinel) records = [(obj, attr, init_tree, map(core.shaped_abstractify, leaves(obj, attr))) for init_tree, _, (obj, attr) in attrs_tracked] cases = seen_attrs_get(fun, in_type) diff --git a/jax/experimental/attrs.py b/jax/experimental/attrs.py index 4e1dc4b8f493..bb4c7bf83b3f 100644 --- a/jax/experimental/attrs.py +++ b/jax/experimental/attrs.py @@ -36,6 +36,7 @@ Pytree = Any register = api_util.register_class_with_attrs +dne_sentinel = pe.dne_sentinel def jax_getattr(obj: Any, attr: str): with core.take_current_trace() as t: @@ -65,7 +66,7 @@ def new_tracer(x): return tracer if (obj, attr) not in frame.attrs_tracked: - init_val = getattr(obj, attr) + init_val = getattr(obj, attr, dne_sentinel) frame.attrs_inits.append(init_val) init_vals, init_tree = tree_flatten(init_val) tracers = map(new_tracer, init_vals) diff --git a/tests/attrs_test.py b/tests/attrs_test.py index 2334a7b98f91..169df3712899 100644 --- a/tests/attrs_test.py +++ b/tests/attrs_test.py @@ -360,6 +360,49 @@ def body(i, _): return i + 1, None _, _ = jax.lax.scan(body, 0, None, length=3) # don't crash + @parameterized.parameters([True, False]) + def test_setattr_doesnt_exist(self, jit): + class Thing: + ... + thing = Thing() + + def f(x): + assert (not jit) or tracing_is_ok + jax_setattr(thing, 'x', x) + + if jit: + f = jax.jit(f) + + tracing_is_ok = True + self.assertFalse(hasattr(thing, 'x')) + f(1.0) + self.assertEqual(thing.x, 1.0) + f(2.0) + self.assertEqual(thing.x, 2.0) + + tracing_is_ok = False + f(3.0) + self.assertEqual(thing.x, 3.0) + + del thing.x + f(4.0) + self.assertEqual(thing.x, 4.0) + + tracing_is_ok = True + f(5) + self.assertEqual(thing.x, 5) + + def test_setattr_doesnt_exist_doesnt_leave_sentinel_around(self): + class Thing: + ... + thing = Thing() + + def f(x): + jax_setattr(thing, 'x', x) + + jax.make_jaxpr(f)(3.) + self.assertFalse(hasattr(thing, 'x')) + class AttrsJVPTest(jtu.JaxTestCase): From ca30ce69197f90a12003b654b7697272a7a44c88 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Tue, 25 Mar 2025 03:40:28 -0700 Subject: [PATCH 138/483] [Mosaic GPU] Add warpgroup lowering for `AxisIndex` in Pallas. PiperOrigin-RevId: 740280136 --- jax/_src/pallas/mosaic_gpu/lowering.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 607b3028f93b..677f63c6674a 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1738,6 +1738,7 @@ def _block_id(ctx: LoweringRuleContext, dim: gpu_dialect.Dimension) -> ir.Value: @register_lowering_rule(lax.axis_index_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.axis_index_p, mgpu.ThreadSemantics.Warpgroup) def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable): axis_names = ctx.module_ctx.axis_names if not axis_names or axis_name not in axis_names: From fce11d0e472c2479cba6869262bc117cc20b95e7 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Tue, 25 Mar 2025 03:53:27 -0700 Subject: [PATCH 139/483] [Mosaic GPU] Use `math.inf` instead of `None` when short-cutting default layout inference. default_vector_size is initialized with `math.inf` and is never `None`. PiperOrigin-RevId: 740283678 --- jax/experimental/mosaic/gpu/layout_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/experimental/mosaic/gpu/layout_inference.py b/jax/experimental/mosaic/gpu/layout_inference.py index dec75e4db1a0..402a8c08a4ef 100644 --- a/jax/experimental/mosaic/gpu/layout_inference.py +++ b/jax/experimental/mosaic/gpu/layout_inference.py @@ -564,7 +564,7 @@ def update_default_vector_size(op: ir.OpView): for op in module.body: traverse_op(op, update_default_vector_size) - if default_vector_size is None: # Nothing to annotate. + if default_vector_size == math.inf: # Nothing to annotate. return def to_default_layout(ty: ir.Type) -> ir.Attribute | None: From 9bbff1e4469bc0078edc6176e006d861411c8c00 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 25 Mar 2025 04:27:39 -0700 Subject: [PATCH 140/483] Update XLA dependency to use revision http://github.com/openxla/xla/commit/d505fef9c5eb6cc1bf282fdf62139783d7fe4ec5. PiperOrigin-RevId: 740293121 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 996ee511f835..8fcda2281ea7 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "9a8dd0796bcfeb00e4e6d09d74726db5c7d4a003" -XLA_SHA256 = "4e3248d37a1b0598de3e93e8e46ede060578bc45bfbdfaf24d91ab598543b770" +XLA_COMMIT = "d505fef9c5eb6cc1bf282fdf62139783d7fe4ec5" +XLA_SHA256 = "4fe51bd389428ce65415b08693f966b142fe8218ced771becab9033503a70a3d" def repo(): tf_http_archive( From 4ed257065a5d4de6e24826cf6546bedada78985f Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 25 Mar 2025 05:45:27 -0700 Subject: [PATCH 141/483] Fix ODR problem in jax_jit.h. We need to include the type caster for std::string_view if we use nb::cast. PiperOrigin-RevId: 740311318 --- jaxlib/xla/jax_jit.h | 1 + 1 file changed, 1 insertion(+) diff --git a/jaxlib/xla/jax_jit.h b/jaxlib/xla/jax_jit.h index a000ef6773b2..254ed11ba78c 100644 --- a/jaxlib/xla/jax_jit.h +++ b/jaxlib/xla/jax_jit.h @@ -34,6 +34,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "nanobind/nanobind.h" +#include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "jaxlib/xla/pytree.h" #include "xla/pjrt/pjrt_client.h" #include "xla/python/nb_helpers.h" From ad7550de6de003f88d38417be466411c620ce5c4 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Tue, 25 Mar 2025 06:08:15 -0700 Subject: [PATCH 142/483] [Mosaic GPU] Add warpgroup lowering for `SetMaxRegisters` in Pallas. PiperOrigin-RevId: 740318556 --- jax/_src/pallas/mosaic_gpu/primitives.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 9665f14254f8..fe28766cfb96 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -952,6 +952,7 @@ def _set_max_registers_abstract_eval(n, *, action): @lowering.register_lowering_rule(set_max_registers_p, mgpu.ThreadSemantics.Lane) +@lowering.register_lowering_rule(set_max_registers_p, mgpu.ThreadSemantics.Warpgroup) def _set_max_registers_lowering( ctx: lowering.LoweringRuleContext, n, *, action ): From 411450b8b896f374758127ba109f93db5b75e742 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Tue, 25 Mar 2025 06:41:04 -0700 Subject: [PATCH 143/483] Fix Jax XLA FFI callback handlers for OSS GPU. OSS Jax builds for GPU backends split `jaxlib` into three wheels and since we cannot expect a stable C++ ABI among the shared libraries, we refactor to ensure: 1. C++ objects are not created/consumed by different shared libraries. 2. Static objects are declared and defined appropriately. This PR: 1. Migrates Jax XLA FFI callback handlers from XLA's Internal FFI API to the [External FFI API](https://github.com/openxla/xla/tree/main/xla/ffi#xla-ffi-external-vs-internal-apis). Note that we update both CPU and GPU handlers because we cannot mix Internal and External APIs. 2. Updates how FFI GPU handlers are registered, now analogous to how the original GPU custom call was registered. 3. Adds an `xla::ffi::ExecutionContext` member to `ifrt::PjRtLoadedExectuable` holding opaque pointers to callbacks. 4. Updates Jax `callback.py` to call the new FFI callback handlers. PiperOrigin-RevId: 740327296 --- jax/_src/callback.py | 146 +++++++++++++++-------- jax_plugins/cuda/__init__.py | 4 + jax_plugins/rocm/__init__.py | 4 + jaxlib/cuda/BUILD | 9 ++ jaxlib/cuda/cuda_plugin_extension.cc | 11 ++ jaxlib/gpu/py_client_gpu.cc | 168 +++++++++++++++++++++++++++ jaxlib/gpu/py_client_gpu.h | 3 + jaxlib/rocm/BUILD | 9 ++ jaxlib/rocm/rocm_plugin_extension.cc | 12 ++ jaxlib/xla/xla_client.py | 2 +- 10 files changed, 318 insertions(+), 50 deletions(-) diff --git a/jax/_src/callback.py b/jax/_src/callback.py index bdceb98d92b7..92c275e7e924 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -33,6 +33,7 @@ from jax._src import tree_util from jax._src import util from jax._src import xla_bridge as xb +from jax._src.lib import xla_extension_version from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -200,7 +201,11 @@ def _callback_op_sharding( # program has bulk array semantics, so we run the callback with a MAXIMAL # sharding and hence execute it only once on the full logical value). if config.use_shardy_partitioner.value: - op_sharding = sharding_impls.SdyArrayShardingList([ + # For shardy, we need to have the same number of shardy annotations as the + # number of result ops. If there are no result ops, we need 1 shardy + # annotation. + num_sdy_shardings = max(1, len(avals_out)) + op_sharding = sharding_impls.SdyArrayShardingList(num_sdy_shardings * [ sharding_impls.SdyArraySharding( mesh_shape=(), dimension_shardings=[], @@ -592,7 +597,6 @@ def io_callback( return tree_util.tree_unflatten(out_tree, out_flat) - def is_empty_shape(s: core.Shape) -> bool: return any(d == 0 for d in s) @@ -822,55 +826,99 @@ def _wrapped_callback(*args): for result_aval in result_avals] return outputs, token, None - result_types = mlir.flatten_ir_types([mlir.aval_to_ir_type(aval) for aval in result_avals]) - if token: + if xla_extension_version <= 320: + result_types = mlir.flatten_ir_types([mlir.aval_to_ir_type(aval) for aval in result_avals]) + if token: + + callback_without_token = _wrapped_callback + def _wrapped_callback(token, *args): # type: ignore # pylint: disable=function-redefined + return (token, *callback_without_token(*args)) + + operand_shapes = [ + xla.aval_to_xla_shapes(core.abstract_token)[0], *operand_shapes + ] + result_shapes = [ + xla.aval_to_xla_shapes(core.abstract_token)[0], *result_shapes + ] + operands = [token, *operands] + result_types = [mlir.token_type(), *result_types] + operand_mlir_layouts = [_layout_to_mlir_layout(None), *operand_mlir_layouts] + result_mlir_layouts = [_layout_to_mlir_layout(None), *result_mlir_layouts] + callback_descriptor, ifrt_callback = ( + backend.get_emit_python_callback_descriptor(_wrapped_callback, + operand_shapes, + result_shapes)) + ctx.module_context.add_host_callback(ifrt_callback) + descriptor_operand = mlir.ir_constant(callback_descriptor) + callback_operands = [descriptor_operand, *operands] + if operand_mlir_layouts is not None: + operand_mlir_layouts = [_layout_to_mlir_layout([]), *operand_mlir_layouts] + result_type = ir.TupleType.get_tuple(result_types) + call_target_name = ("xla_python_gpu_callback" + if platform in {"cuda", "rocm"} else "xla_python_cpu_callback") + result = hlo.CustomCallOp( + [result_type], + callback_operands, + call_target_name=ir.StringAttr.get(call_target_name), + has_side_effect=ir.BoolAttr.get(has_side_effect), + api_version=mlir.i32_attr(2), + called_computations=ir.ArrayAttr.get([]), + backend_config=ir.StringAttr.get(str(callback_descriptor)), + operand_layouts=( + None if operand_mlir_layouts is None + else ir.ArrayAttr.get(operand_mlir_layouts)), + result_layouts=( + None if result_mlir_layouts is None + else ir.ArrayAttr.get(result_mlir_layouts))) + if sharding is not None: + mlir.set_sharding(result, sharding) + results = [ + hlo.get_tuple_element(result, mlir.i32_attr(i)) + for i in range(len(result_types)) + ] + else: + call_target_name = ( + "xla_ffi_python_gpu_callback" + if platform in {"cuda", "rocm"} + else "xla_ffi_python_cpu_callback" + ) + if token: + callback_without_token = _wrapped_callback + def _wrapped_callback(token, *args): # type: ignore # pylint: disable=function-redefined + return (token, *callback_without_token(*args)) + operands = [token, *operands] + if ( + config.use_shardy_partitioner.value + and sharding is not None + and len(ctx.avals_out) > 0 + and isinstance(sharding, sharding_impls.SdyArrayShardingList) + ): + # Add a sharding annotation for the token if we have at least one + # output. Otherwise, the single shardy annotation required of all ops + # (even those without any results) can annotate the token. + sharding = sharding_impls.SdyArrayShardingList( + [*sharding.shardings, sharding.shardings[-1]] + ) + ctx = dataclasses.replace( + ctx, + avals_in=[core.abstract_token, *ctx.avals_in], + avals_out=[core.abstract_token, *ctx.avals_out], + ) - callback_without_token = _wrapped_callback - def _wrapped_callback(token, *args): # type: ignore # pylint: disable=function-redefined - return (token, *callback_without_token(*args)) + # TODO(dsuo): Remove this line once we deprecate the XLA custom call + # handler. + ifrt_callback = _wrapped_callback + ctx.module_context.add_host_callback(ifrt_callback) + index = np.uint64(len(ctx.module_context.host_callbacks) - 1) + result = ffi.build_ffi_lowering_function( # type: ignore + call_target_name, + has_side_effect=has_side_effect, + )(ctx, *operands, index=np.uint64(index)) - operand_shapes = [ - xla.aval_to_xla_shapes(core.abstract_token)[0], *operand_shapes - ] - result_shapes = [ - xla.aval_to_xla_shapes(core.abstract_token)[0], *result_shapes - ] - operands = [token, *operands] - result_types = [mlir.token_type(), *result_types] - operand_mlir_layouts = [_layout_to_mlir_layout(None), *operand_mlir_layouts] - result_mlir_layouts = [_layout_to_mlir_layout(None), *result_mlir_layouts] - callback_descriptor, ifrt_callback = ( - backend.get_emit_python_callback_descriptor(_wrapped_callback, - operand_shapes, - result_shapes)) - ctx.module_context.add_host_callback(ifrt_callback) - descriptor_operand = mlir.ir_constant(callback_descriptor) - callback_operands = [descriptor_operand, *operands] - if operand_mlir_layouts is not None: - operand_mlir_layouts = [_layout_to_mlir_layout([]), *operand_mlir_layouts] - result_type = ir.TupleType.get_tuple(result_types) - call_target_name = ("xla_python_gpu_callback" - if platform in {"cuda", "rocm"} else "xla_python_cpu_callback") - result = hlo.CustomCallOp( - [result_type], - callback_operands, - call_target_name=ir.StringAttr.get(call_target_name), - has_side_effect=ir.BoolAttr.get(has_side_effect), - api_version=mlir.i32_attr(2), - called_computations=ir.ArrayAttr.get([]), - backend_config=ir.StringAttr.get(str(callback_descriptor)), - operand_layouts=( - None if operand_mlir_layouts is None - else ir.ArrayAttr.get(operand_mlir_layouts)), - result_layouts=( - None if result_mlir_layouts is None - else ir.ArrayAttr.get(result_mlir_layouts))) - if sharding is not None: - mlir.set_sharding(result, sharding) - results = [ - hlo.get_tuple_element(result, mlir.i32_attr(i)) - for i in range(len(result_types)) - ] + if sharding is not None: + mlir.set_sharding(result, sharding) + + results = result.results # type: ignore if token: token, *results = results return results, token, ifrt_callback diff --git a/jax_plugins/cuda/__init__.py b/jax_plugins/cuda/__init__.py index f6540e986024..2b02621c89f5 100644 --- a/jax_plugins/cuda/__init__.py +++ b/jax_plugins/cuda/__init__.py @@ -94,6 +94,10 @@ def initialize(): ) for _name, _value in cuda_plugin_extension.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="CUDA") + for _name, _value in cuda_plugin_extension.ffi_registrations().items(): + xla_client.register_custom_call_target( + _name, _value, platform='CUDA', api_version=1 + ) xla_client.register_custom_type_id_handler( "CUDA", functools.partial( diff --git a/jax_plugins/rocm/__init__.py b/jax_plugins/rocm/__init__.py index c48a681bf337..0699ae1e34a1 100644 --- a/jax_plugins/rocm/__init__.py +++ b/jax_plugins/rocm/__init__.py @@ -94,6 +94,10 @@ def initialize(): ) for _name, _value in rocm_plugin_extension.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="ROCM") + for _name, _value in rocm_plugin_extension.ffi_registrations().items(): + xla_client.register_custom_call_target( + _name, _value, platform='ROCM', api_version=1 + ) xla_client.register_custom_type_id_handler( "ROCM", functools.partial( diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index ee32888864dd..5cd7283ea3fc 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -668,19 +668,28 @@ cc_library( features = ["-use_header_modules"], deps = [ ":cuda_vendor", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@nanobind", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep "@xla//xla:comparison_util", + "@xla//xla:shape_util", + "@xla//xla/ffi:ffi_api", + "@xla//xla/ffi/api:ffi", "@xla//xla/pjrt:exceptions", "@xla//xla/pjrt:host_callback", "@xla//xla/pjrt:transpose", "@xla//xla/python:callback", "@xla//xla/python:nb_numpy", + "@xla//xla/python:types", "@xla//xla/service:custom_call_status", "@xla//xla/service:custom_call_target_registry", "@xla//xla/service:platform_util", diff --git a/jaxlib/cuda/cuda_plugin_extension.cc b/jaxlib/cuda/cuda_plugin_extension.cc index 6655128b9842..63375921e3be 100644 --- a/jaxlib/cuda/cuda_plugin_extension.cc +++ b/jaxlib/cuda/cuda_plugin_extension.cc @@ -48,12 +48,23 @@ nb::dict Registrations() { jax::EncapsulateFunction(jax::cuda::XlaPythonGpuCallback); return dict; } +nb::dict FfiRegistrations() { + nb::dict dict; + nb::dict gpu_callback_dict; + gpu_callback_dict["instantiate"] = + jax::EncapsulateFfiHandler(jax::cuda::kGpuTransposePlanCacheInstantiate); + gpu_callback_dict["execute"] = + jax::EncapsulateFfiHandler(jax::cuda::kXlaFfiPythonGpuCallback); + dict["xla_ffi_python_gpu_callback"] = gpu_callback_dict; + return dict; +} } // namespace NB_MODULE(cuda_plugin_extension, m) { BuildGpuPluginExtension(m); m.def("registrations", &Registrations); + m.def("ffi_registrations", &FfiRegistrations); m.def( "get_device_ordinal", diff --git a/jaxlib/gpu/py_client_gpu.cc b/jaxlib/gpu/py_client_gpu.cc index 3e140411770d..71d327ffdb28 100644 --- a/jaxlib/gpu/py_client_gpu.cc +++ b/jaxlib/gpu/py_client_gpu.cc @@ -13,6 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "jaxlib/gpu/py_client_gpu.h" + +#include + #include #include #include @@ -20,24 +24,31 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/base/casts.h" +#include "absl/container/inlined_vector.h" #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/ascii.h" #include "absl/strings/numbers.h" +#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "nanobind/nanobind.h" #include "jaxlib/gpu/vendor.h" +#include "xla/ffi/api/ffi.h" +#include "xla/ffi/ffi_api.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/host_callback.h" #include "xla/pjrt/transpose.h" #include "xla/primitive_util.h" #include "xla/python/callback.h" #include "xla/python/nb_numpy.h" +#include "xla/python/types.h" #include "xla/service/custom_call_status.h" #include "xla/service/custom_call_target_registry.h" +#include "xla/shape_util.h" namespace nb = nanobind; @@ -155,5 +166,162 @@ void XlaPythonGpuCallback(gpuStream_t stream, void** buffers, XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( "xla_python_gpu_callback", &XlaPythonGpuCallback, absl::AsciiStrToUpper(JAX_GPU_PLUGIN_NAME)); + +struct GpuTransposePlanCache { + static xla::ffi::TypeId id; + explicit GpuTransposePlanCache(int capacity) : cache(capacity) {} + xla::TransposePlanCache cache; +}; +xla::ffi::TypeId GpuTransposePlanCache::id = {}; + +XLA_FFI_REGISTER_TYPE(xla::ffi::GetXlaFfiApi(), "GpuTransposePlanCache", + &GpuTransposePlanCache::id); + +static xla::ffi::ErrorOr> +GpuTransposePlanCacheInstantiate(uint64_t index) { + return std::make_unique(16); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + kGpuTransposePlanCacheInstantiate, GpuTransposePlanCacheInstantiate, + xla::ffi::Ffi::BindInstantiate().Attr("index")); +xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, + xla::FfiLoadedHostCallbacks* callbacks, + GpuTransposePlanCache* transpose_cache, + uint64_t index, + xla::ffi::RemainingArgs args, + xla::ffi::RemainingRets rets) { + size_t arity = args.size(); + std::vector host_input_buffers(arity); + // Copy input GPU buffers to host + for (size_t i = 0; i < arity; ++i) { + auto arg = args.get(i); + auto ptype = static_cast(arg->element_type()); + if (ptype == xla::TOKEN) { + host_input_buffers[i] = nullptr; + continue; + } + void* buf = new char[arg->size_bytes()]; + host_input_buffers[i] = buf; + // TODO(b/238441608): Use pinned memory here to speed up the transfer. + auto gpu_res = + gpuMemcpyAsync(buf, arg.value().untyped_data(), arg->size_bytes(), + gpuMemcpyDeviceToHost, stream); + CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; + } + CHECK_EQ(gpuStreamSynchronize(stream), gpuSuccess) + << "Failed to gpuStreamSynchronize"; + nb::gil_scoped_acquire gil; + auto callback = nb::borrow( + static_cast(callbacks->callbacks[index])); + nb::tuple host_input_arrays = nb::steal(PyTuple_New(arity)); + for (size_t i = 0; i < arity; ++i) { + auto arg = args.get(i); + auto ptype = static_cast(arg->element_type()); + if (ptype == xla::TOKEN) { + PyTuple_SET_ITEM(host_input_arrays.ptr(), i, nb::none().inc_ref().ptr()); + continue; + } + nb::capsule base(host_input_buffers[i], [](void* ptr) noexcept { + delete[] static_cast(ptr); + }); + auto maybe_dtype = PrimitiveTypeToNbDtype(ptype); + if (!maybe_dtype.ok()) { + return xla::ffi::Error::Internal(maybe_dtype.status().ToString()); + } + auto dtype = maybe_dtype.value(); + auto dims = absl::Span(arg->dimensions().begin(), + arg->dimensions().size()); + auto array = xla::nb_numpy_ndarray(dtype, dims, std::nullopt, + host_input_buffers[i], base); + array.attr("flags").attr("writeable") = nb::bool_(false); + PyTuple_SET_ITEM(host_input_arrays.ptr(), i, array.inc_ref().ptr()); + } + + xla::EnterHostCallback(); + // TODO(dsuo): Change this to use the Python vectorcall protocol, which allows + // you to avoid constructing a tuple for the arguments. + nb::tuple result_tuple; + try { + auto result_object = callback(*nb::borrow(host_input_arrays)); + result_tuple = nb::cast(result_object); + } catch (nb::python_error& e) { + return xla::ffi::Error::Internal( + absl::StrFormat("CpuCallback error calling callback: %s", e.what())); + } + xla::LeaveHostCallback(); + + std::vector temp_buffers; + for (size_t i = 0; i < rets.size(); ++i) { + auto ret = rets.get(i).value(); + auto ptype = static_cast(ret->element_type()); + if (ptype == xla::TOKEN) continue; + nb::object output = + nb::borrow(PyTuple_GetItem(result_tuple.ptr(), i)); + auto array = xla::nb_numpy_ndarray::ensure(std::move(output)); + absl::Span strides( + reinterpret_cast(array.strides()), array.ndim()); + // We expect the output to be in default numpy layout. + auto dims = absl::Span(ret->dimensions().begin(), + ret->dimensions().size()); + auto maybe_expected_shape = xla::ShapeUtil::MakeValidatedShape(ptype, dims); + if (!maybe_expected_shape.ok()) { + return xla::ffi::Error::Internal( + maybe_expected_shape.status().ToString()); + } + auto expected_shape = maybe_expected_shape.value(); + auto expected_strides = xla::ByteStridesForShape(expected_shape); + if (strides == expected_strides) { + auto gpu_res = + gpuMemcpyAsync(ret->untyped_data(), array.data(), ret->size_bytes(), + gpuMemcpyHostToDevice, stream); + CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; + continue; + } + void* temp = new char[ret->size_bytes()]; + temp_buffers.push_back(temp); + xla::TransposePlan::Options options; + options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); + options.dims = absl::Span( + reinterpret_cast(array.shape()), array.ndim()); + absl::InlinedVector reversed_layout; + reversed_layout.resize(expected_shape.rank()); + absl::c_reverse_copy(expected_shape.layout().minor_to_major(), + reversed_layout.begin()); + options.permutation = reversed_layout; + options.input_layout = xla::TransposePlan::Striding{strides}; + auto maybe_plan = transpose_cache->cache.GetOrCreate(options); + if (!maybe_plan.ok()) { + return xla::ffi::Error::Internal(maybe_plan.status().ToString()); + } + auto plan = maybe_plan.value(); + plan->Execute(array.data(), temp); + auto gpu_res = gpuMemcpyAsync(ret->untyped_data(), temp, ret->size_bytes(), + gpuMemcpyHostToDevice, stream); + CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; + } + nb::gil_scoped_release release; + CHECK_EQ(gpuStreamSynchronize(stream), gpuSuccess) + << "Failed to gpuStreamSynchronize"; + for (int i = 0; i < temp_buffers.size(); ++i) { + delete[] static_cast(temp_buffers[i]); + } + return xla::ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + kXlaFfiPythonGpuCallback, XlaFfiPythonGpuCallback, + xla::ffi::Ffi::Bind() + .Ctx>() + .Ctx>() + .Ctx>() + .Attr("index") + .RemainingArgs() + .RemainingRets()); +XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), + "xla_ffi_python_gpu_callback", + absl::AsciiStrToUpper(JAX_GPU_PLUGIN_NAME), + {kGpuTransposePlanCacheInstantiate, nullptr, nullptr, + kXlaFfiPythonGpuCallback}); } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/gpu/py_client_gpu.h b/jaxlib/gpu/py_client_gpu.h index e9454504f5d9..06a955365c0b 100644 --- a/jaxlib/gpu/py_client_gpu.h +++ b/jaxlib/gpu/py_client_gpu.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "jaxlib/gpu/vendor.h" +#include "xla/ffi/api/ffi.h" #include "xla/service/custom_call_status.h" namespace jax { @@ -28,6 +29,8 @@ void XlaPythonGpuCallback(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status); +XLA_FFI_DECLARE_HANDLER_SYMBOL(kGpuTransposePlanCacheInstantiate); +XLA_FFI_DECLARE_HANDLER_SYMBOL(kXlaFfiPythonGpuCallback); } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index 99df757018f3..522aa8da0145 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -566,19 +566,28 @@ cc_library( features = ["-use_header_modules"], deps = [ ":hip_vendor", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@nanobind", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep "@xla//xla:comparison_util", + "@xla//xla:shape_util", + "@xla//xla/ffi:ffi_api", + "@xla//xla/ffi/api:ffi", "@xla//xla/pjrt:exceptions", "@xla//xla/pjrt:host_callback", "@xla//xla/pjrt:transpose", "@xla//xla/python:callback", "@xla//xla/python:nb_numpy", + "@xla//xla/python:types", "@xla//xla/service:custom_call_status", "@xla//xla/service:custom_call_target_registry", "@xla//xla/service:platform_util", diff --git a/jaxlib/rocm/rocm_plugin_extension.cc b/jaxlib/rocm/rocm_plugin_extension.cc index 454f4741d667..642467a9afef 100644 --- a/jaxlib/rocm/rocm_plugin_extension.cc +++ b/jaxlib/rocm/rocm_plugin_extension.cc @@ -72,12 +72,24 @@ nb::dict Registrations() { jax::EncapsulateFunction(jax::hip::XlaPythonGpuCallback); return dict; } +nb::dict FfiRegistrations() { + nb::dict dict; + nb::dict gpu_callback_dict; + gpu_callback_dict["instantiate"] = + jax::EncapsulateFfiHandler(jax::hip::kGpuTransposePlanCacheInstantiate); + gpu_callback_dict["execute"] = + jax::EncapsulateFfiHandler(jax::hip::kXlaFfiPythonGpuCallback); + dict["xla_ffi_python_gpu_callback"] = gpu_callback_dict; + return dict; +} } // namespace NB_MODULE(rocm_plugin_extension, m) { BuildGpuPluginExtension(m); m.def("registrations", &Registrations); + m.def("ffi_registrations", &FfiRegistrations); + m.def( "get_device_ordinal", [](std::intptr_t data_value) { diff --git a/jaxlib/xla/xla_client.py b/jaxlib/xla/xla_client.py index a111c14232de..0e4eebdfb26f 100644 --- a/jaxlib/xla/xla_client.py +++ b/jaxlib/xla/xla_client.py @@ -50,7 +50,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 320 +_version = 321 # Version number for MLIR:Python components. mlir_api_version = 58 From a58592ebb0c217cee4279c2afcec0f2d73688f0b Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 25 Mar 2025 06:46:19 -0700 Subject: [PATCH 144/483] Finalize some deprecations from jax.lib.xla_client --- CHANGELOG.md | 3 +++ jax/lib/xla_client.py | 32 +++++++++++++------------------- 2 files changed, 16 insertions(+), 19 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 17fb421fcc06..1acb2b48eab6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * {func}`jax.tree_util.build_tree` is deprecated. Use {func}`jax.tree.unflatten` instead. + * Several previously-deprecated APIs have been removed, including: + * From `jax.lib.xla_client`: `FftType`, `PaddingType`, `dtype_to_etype`, + and `shape_from_pyval`. ## jax 0.5.3 (Mar 19, 2025) diff --git a/jax/lib/xla_client.py b/jax/lib/xla_client.py index 86e7307c804b..07c6914a1f59 100644 --- a/jax/lib/xla_client.py +++ b/jax/lib/xla_client.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jax._src.lax.fft import FftType as _FftType from jax._src.lib import xla_client as _xc get_topology_for_devices = _xc.get_topology_for_devices @@ -48,23 +47,27 @@ ), None, ), - # Added Oct 10 2024 + # Finalized 2025-03-25; remove after 2025-06-25 "FftType": ( - "jax.lib.xla_client.FftType is deprecated; use jax.lax.FftType.", - _FftType, + "jax.lib.xla_client.FftType was removed in JAX v0.6.0; use jax.lax.FftType.", + None, ), "PaddingType": ( ( - "jax.lib.xla_client.PaddingType is deprecated; this type is unused" - " by JAX so there is no replacement." + "jax.lib.xla_client.PaddingType was removed in JAX v0.6.0;" + " this type is unused by JAX so there is no replacement." ), - _xc.PaddingType, + None, ), - # Added Oct 11 2024 "dtype_to_etype": ( - "dtype_to_etype is deprecated; use StableHLO instead.", - _xc.dtype_to_etype, + "dtype_to_etype was removed in JAX v0.6.0; use StableHLO instead.", + None, + ), + "shape_from_pyval": ( + "shape_from_pyval was removed in JAX v0.6.0; use StableHLO instead.", + None, ), + # Added Oct 11 2024 "ops": ( "ops is deprecated; use StableHLO instead.", _xc.ops, @@ -74,10 +77,6 @@ "(https://jax.readthedocs.io/en/latest/ffi.html)", _xc.register_custom_call_target, ), - "shape_from_pyval": ( - "shape_from_pyval is deprecated; use StableHLO instead.", - _xc.shape_from_pyval, - ), "PrimitiveType": ( "PrimitiveType is deprecated; use StableHLO instead.", _xc.PrimitiveType, @@ -104,14 +103,10 @@ import typing as _typing if _typing.TYPE_CHECKING: - dtype_to_etype = _xc.dtype_to_etype ops = _xc.ops register_custom_call_target = _xc.register_custom_call_target - shape_from_pyval = _xc.shape_from_pyval ArrayImpl = _xc.ArrayImpl Device = _xc.Device - FftType = _FftType - PaddingType = _xc.PaddingType PrimitiveType = _xc.PrimitiveType Shape = _xc.Shape XlaBuilder = _xc.XlaBuilder @@ -123,5 +118,4 @@ __getattr__ = _deprecation_getattr(__name__, _deprecations) del _deprecation_getattr del _typing -del _FftType del _xc From 3c63f600000423df181f114345d1fe56821dcd94 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 25 Mar 2025 07:17:05 -0700 Subject: [PATCH 145/483] [JAX] [XLA:Python] Migrate py_socket_transfer to JAX. Also in passing fix up some header guards and authorship comments. PiperOrigin-RevId: 740337166 --- jaxlib/xla/BUILD | 39 ++- jaxlib/xla/config.h | 6 +- jaxlib/xla/custom_call_sharding.h | 6 +- jaxlib/xla/dlpack.h | 6 +- jaxlib/xla/jax_jit.h | 6 +- jaxlib/xla/mlir.h | 6 +- jaxlib/xla/pjit.h | 6 +- jaxlib/xla/pmap_lib.h | 6 +- jaxlib/xla/py_socket_transfer.cc | 409 ++++++++++++++++++++++++++++++ jaxlib/xla/py_socket_transfer.h | 26 ++ jaxlib/xla/pytree.cc | 2 +- jaxlib/xla/pytree.h | 2 +- jaxlib/xla/sdy.h | 6 +- jaxlib/xla/weakref_lru_cache.h | 6 +- jaxlib/xla/xla.cc | 2 +- jaxlib/xla/xla_compiler.h | 6 +- 16 files changed, 506 insertions(+), 34 deletions(-) create mode 100644 jaxlib/xla/py_socket_transfer.cc create mode 100644 jaxlib/xla/py_socket_transfer.h diff --git a/jaxlib/xla/BUILD b/jaxlib/xla/BUILD index 2edc183bc49b..e562cb7e84ea 100644 --- a/jaxlib/xla/BUILD +++ b/jaxlib/xla/BUILD @@ -134,10 +134,10 @@ nanobind_extension( ], "@xla//xla/tsl:windows": [], "//conditions:default": [ + ":py_socket_transfer", "@gloo//:transport_tcp", "@xla//xla/backends/cpu/collectives:gloo_collectives", "@xla//xla/backends/cpu/collectives:gloo_kv_store", - "@xla//xla/python/transfer:py_socket_transfer", ], }) + select({ # mpitrampoline does not build on windows @@ -414,6 +414,43 @@ cc_library( ], ) +cc_library( + name = "py_socket_transfer", + srcs = ["py_socket_transfer.cc"], + hdrs = ["py_socket_transfer.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@nanobind", + "@xla//xla:util", + "@xla//xla/pjrt:pjrt_client", + "@xla//xla/pjrt:status_casters", + "@xla//xla/python:nb_class_ptr", + "@xla//xla/python:nb_numpy", + "@xla//xla/python:py_client", + "@xla//xla/python:traceback", + "@xla//xla/python:types", + "@xla//xla/python/ifrt", + "@xla//xla/python/pjrt_ifrt", + "@xla//xla/python/pjrt_ifrt:pjrt_dtype", + "@xla//xla/python/transfer:event_loop", + "@xla//xla/python/transfer:socket-server", + "@xla//xla/python/transfer:socket_bulk_transport", + "@xla//xla/python/transfer:streaming", + "@xla//xla/python/transfer:streaming_ifrt", + "@xla//xla/tsl/concurrency:ref_count", + "@xla//xla/tsl/platform:statusor", + ], +) + proto_library( name = "pytree_proto", srcs = ["pytree.proto"], diff --git a/jaxlib/xla/config.h b/jaxlib/xla/config.h index 40847bf4a370..2a9281f498b4 100644 --- a/jaxlib/xla/config.h +++ b/jaxlib/xla/config.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_CONFIG_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_CONFIG_H_ +#ifndef JAXLIB_XLA_CONFIG_H_ +#define JAXLIB_XLA_CONFIG_H_ #include @@ -31,4 +31,4 @@ void BuildConfigSubmodule(nanobind::module_& m); } // namespace jax -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_CONFIG_H_ +#endif // JAXLIB_XLA_CONFIG_H_ diff --git a/jaxlib/xla/custom_call_sharding.h b/jaxlib/xla/custom_call_sharding.h index c3470901f53e..5a5f3776cc30 100644 --- a/jaxlib/xla/custom_call_sharding.h +++ b/jaxlib/xla/custom_call_sharding.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_CUSTOM_CALL_SHARDING_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_CUSTOM_CALL_SHARDING_H_ +#ifndef JAXLIB_XLA_CUSTOM_CALL_SHARDING_H_ +#define JAXLIB_XLA_CUSTOM_CALL_SHARDING_H_ // placeholder for index annotation headers #include "nanobind/nanobind.h" @@ -25,4 +25,4 @@ void BuildCustomCallShardingPybindAPI(nanobind::module_& m); } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_CUSTOM_CALL_SHARDING_H_ +#endif // JAXLIB_XLA_CUSTOM_CALL_SHARDING_H_ diff --git a/jaxlib/xla/dlpack.h b/jaxlib/xla/dlpack.h index 5d7fd7c10bf8..d0079b1d4914 100644 --- a/jaxlib/xla/dlpack.h +++ b/jaxlib/xla/dlpack.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DLPACK_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_DLPACK_H_ +#ifndef JAXLIB_XLA_DLPACK_H_ +#define JAXLIB_XLA_DLPACK_H_ #include #include @@ -54,4 +54,4 @@ absl::StatusOr PrimitiveTypeToNbDLDataType( } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DLPACK_H_ +#endif // JAXLIB_XLA_DLPACK_H_ diff --git a/jaxlib/xla/jax_jit.h b/jaxlib/xla/jax_jit.h index 254ed11ba78c..9e6f8e34f1e9 100644 --- a/jaxlib/xla/jax_jit.h +++ b/jaxlib/xla/jax_jit.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_JAX_JIT_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_JAX_JIT_H_ +#ifndef JAXLIB_XLA_JAX_JIT_H_ +#define JAXLIB_XLA_JAX_JIT_H_ #include @@ -263,4 +263,4 @@ void BuildJaxjitSubmodule(nanobind::module_& m); } // namespace jax -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_JAX_JIT_H_ +#endif // JAXLIB_XLA_JAX_JIT_H_ diff --git a/jaxlib/xla/mlir.h b/jaxlib/xla/mlir.h index f0bfd69bca6b..ee95f5f95921 100644 --- a/jaxlib/xla/mlir.h +++ b/jaxlib/xla/mlir.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_MLIR_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_MLIR_H_ +#ifndef JAXLIB_XLA_MLIR_H_ +#define JAXLIB_XLA_MLIR_H_ // placeholder for index annotation headers #include "nanobind/nanobind.h" @@ -25,4 +25,4 @@ void BuildMlirSubmodule(nanobind::module_& m); } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_MLIR_H_ +#endif // JAXLIB_XLA_MLIR_H_ diff --git a/jaxlib/xla/pjit.h b/jaxlib/xla/pjit.h index 545fb2307783..8d47347ab9a2 100644 --- a/jaxlib/xla/pjit.h +++ b/jaxlib/xla/pjit.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_PJIT_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_PJIT_H_ +#ifndef JAXLIB_XLA_PJIT_H_ +#define JAXLIB_XLA_PJIT_H_ // placeholder for index annotation headers #include "nanobind/nanobind.h" @@ -24,4 +24,4 @@ namespace jax { void BuildPjitSubmodule(nanobind::module_& m); } -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_PJIT_H_ +#endif // JAXLIB_XLA_PJIT_H_ diff --git a/jaxlib/xla/pmap_lib.h b/jaxlib/xla/pmap_lib.h index 9ad60a03daf6..e02311e03c73 100644 --- a/jaxlib/xla/pmap_lib.h +++ b/jaxlib/xla/pmap_lib.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_PMAP_LIB_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_PMAP_LIB_H_ +#ifndef JAXLIB_XLA_PMAP_LIB_H_ +#define JAXLIB_XLA_PMAP_LIB_H_ #include #include @@ -34,4 +34,4 @@ void BuildPmapSubmodule(nanobind::module_& m); } // namespace jax -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_PMAP_LIB_H_ +#endif // JAXLIB_XLA_PMAP_LIB_H_ diff --git a/jaxlib/xla/py_socket_transfer.cc b/jaxlib/xla/py_socket_transfer.cc new file mode 100644 index 000000000000..dd2c02898e18 --- /dev/null +++ b/jaxlib/xla/py_socket_transfer.cc @@ -0,0 +1,409 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "jaxlib/xla/py_socket_transfer.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/synchronization/mutex.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/array.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/array_spec.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "xla/python/pjrt_ifrt/pjrt_device.h" +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" +#include "xla/python/pjrt_ifrt/pjrt_memory.h" +#include "xla/python/py_array.h" +#include "xla/python/py_client.h" +#include "xla/python/to_ifrt_sharding.h" +#include "xla/python/traceback.h" +#include "xla/python/transfer/event_loop.h" +#include "xla/python/transfer/socket-server.h" +#include "xla/python/transfer/socket_bulk_transport.h" +#include "xla/python/transfer/streaming.h" +#include "xla/python/transfer/streaming_ifrt.h" +#include "xla/python/types.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" + +namespace aux { + +namespace nb = nanobind; + +absl::StatusOr MemorySpaceFromSharding( + const xla::ifrt::Sharding& sharding) { + if (sharding.devices()->devices().size() != 1) { + return xla::InvalidArgument( + "Can only convert SingleDeviceSharding to MemorySpace not %s", + sharding.DebugString()); + } + auto* device = sharding.devices()->devices()[0]; + if (sharding.memory_kind().memory_kind().has_value()) { + // Find `PjRtMemorySpace` that is associated with the sharding's device + // and matches the sharding's memory_kind. + xla::ifrt::Memory* memory = nullptr; + for (xla::ifrt::Memory* ms : device->Memories()) { + if (ms->Kind() == sharding.memory_kind()) { + memory = ms; + break; + } + } + if (memory == nullptr) { + return xla::InvalidArgument( + "Invalid memory kind: %s; available memory kinds: %s", + *sharding.memory_kind().memory_kind(), + absl::StrJoin(sharding.devices()->devices().front()->Memories(), ", ", + [](std::string* out, xla::ifrt::Memory* ms) { + absl::StrAppend(out, *ms->Kind().memory_kind()); + })); + } + return tensorflow::down_cast(memory)->pjrt_memory(); + } else { + if (!device->IsAddressable()) { + return xla::InvalidArgument( + "Cannot copy array to non-addressable device %s", + device->DebugString()); + } + return tensorflow::down_cast(device) + ->pjrt_device() + ->default_memory_space(); + } +} + +class IfrtArrayEntry : public PullTable::Entry { + public: + struct BufferRef { + tsl::RCReference arr; + xla::PjRtBuffer* buffer; + size_t buf_size; + }; + explicit IfrtArrayEntry(std::vector arrs, + std::shared_ptr state, + size_t xfer_size) + : arrs_(std::move(arrs)), state_(state), xfer_size_(xfer_size) {} + bool Handle(tsl::RCReference state, + const SocketTransferPullRequest& req, + size_t base_req_id) override { + for (uint64_t bid : req.buffer_ids()) { + auto req_id = base_req_id; + ++base_req_id; + for (size_t i = 0; i * xfer_size_ < arrs_[bid].buf_size; ++i) { + DmaCopyChunk blob; + blob.arr = std::move(arrs_[bid].arr); + blob.buffer = arrs_[bid].buffer; + blob.buffer_id = bid; + blob.offset = i * xfer_size_; + blob.size = std::min(xfer_size_, arrs_[bid].buf_size - blob.offset); + bool is_largest = blob.size + blob.offset == arrs_[bid].buf_size; + state_->ScheduleCopy( + blob, [req_id, state, copier_state = state_, is_largest]( + PremappedCopierState* copier_state_ptr, void* buf, + const DmaCopyChunk& chunk) { + state->Send( + req_id, buf, chunk.offset, chunk.size, is_largest, + [copier_state, buf]() { copier_state->ReturnBuffer(buf); }); + }); + } + } + + num_consumed_bufs_ += req.buffer_ids().size(); + return num_consumed_bufs_ == arrs_.size(); + } + + private: + absl::Mutex mu_; + size_t num_consumed_bufs_ = 0; + std::vector arrs_; + std::shared_ptr state_; + size_t xfer_size_; +}; + +absl::StatusOr> CreatePullEntry( + const std::vector>& arrs, + std::shared_ptr state, size_t xfer_size) { + std::vector refs; + for (auto& arr : arrs) { + auto* pjrt_arr = llvm::dyn_cast_or_null(arr.get()); + if (pjrt_arr == nullptr) { + return absl::InvalidArgumentError( + "Cannot remote transfer non-pjrt arrays."); + } + for (auto& pjrt_buf : pjrt_arr->pjrt_buffers()) { + TF_ASSIGN_OR_RETURN(size_t buf_size, pjrt_buf->GetOnDeviceSizeInBytes()); + refs.push_back({arr, pjrt_buf.get(), buf_size}); + } + } + return tsl::MakeRef(std::move(refs), state, xfer_size); +} + +class PyTransferServerConnection { + public: + explicit PyTransferServerConnection( + tsl::RCReference conn) + : conn_(std::move(conn)) {} + + void Pull(uint64_t uuid, std::vector buffer_ids, + std::vector> pull_dests) { + for (size_t i = 0; i < buffer_ids.size(); ++i) { + conn_->Pull(uuid, buffer_ids[i], std::move(pull_dests[i])); + } + } + + private: + tsl::RCReference conn_; +}; + +class PyTransferServer { + public: + PyTransferServer() = default; + absl::Status Start(xla::ifrt::Client* client, size_t max_num_parallel_copies, + size_t xfer_size, const SocketAddress& addr, + const std::vector& transport_addresses) { + std::shared_ptr factory; + if (transport_addresses.empty()) { + factory = BulkTransportFactory::CreateLocal(); + } else { + auto tmp = xla::ValueOrThrow( + AllocateAlignedMemory(xfer_size * max_num_parallel_copies)); + SlabAllocator uallocator(xla::ValueOrThrow(MapPjrtMemory( + client, tmp->data(), tmp->size(), tmp)), + xfer_size); + factory = xla::ValueOrThrow(CreateSocketBulkTransportFactory( + transport_addresses, std::nullopt, uallocator)); + } + + server_ = std::make_shared(); + + TF_ASSIGN_OR_RETURN(auto mem, + AllocateAndMapPjrtMemory( + client, max_num_parallel_copies * xfer_size * 2)); + premapped_copier_ = std::make_shared( + mem, max_num_parallel_copies, xfer_size); + xfer_size_ = xfer_size; + return server_->Start(addr, factory); + } + std::string address() { return server_->addr().ToString(); } + + PyTransferServerConnection Connect(const std::string& saddr) { + return PyTransferServerConnection( + server_->Connect(xla::ValueOrThrow(SocketAddress::Parse(saddr)))); + } + + void AwaitPull(uint64_t uuid, + const std::vector>& arrs) { + server_->AwaitPull(uuid, xla::ValueOrThrow(CreatePullEntry( + arrs, premapped_copier_, xfer_size_))); + } + + size_t xfer_size() { return xfer_size_; } + + std::shared_ptr premapped_copier() { + return premapped_copier_; + } + + private: + std::shared_ptr server_; + std::shared_ptr premapped_copier_; + size_t xfer_size_; +}; + +absl::StatusOr ArraySpecFromShapeDtypeStruct( + nb::handle aval) { + TF_ASSIGN_OR_RETURN(xla::ifrt::DType dtype, + xla::DtypeToIfRtDType( + nb::borrow(aval.attr("dtype").ptr()))); + auto shape_dims = nb::cast>(aval.attr("shape")); + auto shape = xla::ifrt::Shape( + xla::ifrt::Shape::Dimensions(shape_dims.begin(), shape_dims.end())); + TF_ASSIGN_OR_RETURN(auto sharding, + xla::GetIfrtHloSharding(aval.attr("sharding"), shape)); + return xla::ifrt::ArraySpec{dtype, std::move(shape), std::move(sharding)}; +} + +struct BufferSource { + tsl::RCReference arr; + xla::PjRtBuffer* buffer; +}; + +struct CopyDests { + std::vector shape_specs; + xla::PjRtMemorySpace* memory_space; +}; + +void RegisterTransferServerTypes(nanobind::module_& m) { + nb::class_(m, "TransferConnection") + .def("_pull_flat", [](PyTransferServerConnection& self, uint64_t uuid, + xla::nb_class_ptr py_client, + std::vector py_avals) { + auto* ifrt_client = llvm::dyn_cast_or_null( + py_client->ifrt_client()); + if (ifrt_client == nullptr) { + xla::ThrowIfError(absl::InvalidArgumentError( + "_pull_flat only supported on pjrt-ifrt clients.")); + } + + std::vector avals; + std::vector shardings; + shardings.reserve(py_avals.size()); + avals.reserve(py_avals.size()); + for (const auto& py_aval : py_avals) { + avals.push_back( + xla::ValueOrThrow(ArraySpecFromShapeDtypeStruct(py_aval))); + shardings.push_back(py_aval.attr("sharding")); + } + + std::vector dests; + std::vector> fetch_idxs; + absl::flat_hash_map mapping; + std::vector>> buffer_list; + + for (auto& aval : avals) { + std::vector> buf_list; + auto prim_type = + xla::ValueOrThrow(xla::ifrt::ToPrimitiveType(aval.dtype)); + auto shards = xla::ValueOrThrow(aval.sharding->Disassemble( + aval.shape, + xla::ifrt::SingleDeviceShardSemantics::kAddressableShards)); + buf_list.reserve(shards.size()); + for (auto& shard : shards) { + auto* mem_space = + xla::ValueOrThrow(MemorySpaceFromSharding(*shard.second)); + int dest_idx = + mapping.emplace(mem_space, static_cast(dests.size())) + .first->second; + if (dest_idx == dests.size()) { + dests.emplace_back(); + dests.back().memory_space = mem_space; + } + fetch_idxs.push_back( + {dest_idx, + static_cast(dests[dest_idx].shape_specs.size())}); + buf_list.push_back(fetch_idxs.back()); + dests[dest_idx].shape_specs.push_back( + {prim_type, xla::DimensionVector(shard.first.dims().begin(), + shard.first.dims().end())}); + } + buffer_list.push_back(std::move(buf_list)); + } + + std::vector< + std::shared_ptr> + atms; + atms.reserve(dests.size()); + + for (auto& dest : dests) { + atms.push_back(xla::ValueOrThrow( + py_client->pjrt_client()->CreateBuffersForAsyncHostToDevice( + dest.shape_specs, std::nullopt, dest.memory_space))); + } + + std::vector> pull_dests; + std::vector buffer_ids; + pull_dests.reserve(fetch_idxs.size()); + buffer_ids.reserve(fetch_idxs.size()); + for (auto& fetch_idx : fetch_idxs) { + auto& atm = atms[fetch_idx.first]; + pull_dests.push_back(MakeDmaDestination( + atm, fetch_idx.second, atm->buffer_size(fetch_idx.second))); + buffer_ids.push_back(static_cast(buffer_ids.size())); + } + + self.Pull(uuid, buffer_ids, std::move(pull_dests)); + + std::vector out; + auto traceback = xla::Traceback::Get(); + for (size_t i = 0; i < buffer_list.size(); ++i) { + xla::ifrt::PjRtArray::PjRtBuffers buffers; + buffers.reserve(buffer_list[i].size()); + for (auto& v : buffer_list[i]) { + buffers.push_back(atms[v.first]->RetrieveBuffer(v.second)); + } + auto arr = xla::ValueOrThrow(xla::ifrt::PjRtArray::Create( + ifrt_client, avals[i].dtype, avals[i].shape, avals[i].sharding, + std::move(buffers), avals[i].layout)); + out.push_back(xla::PyArray::MakeFromIfrtArrayAndSharding( + py_client, traceback, std::move(arr), shardings[i], false, true, + /*skip_checks=*/false)); + } + + return out; + }); + + nb::class_(m, "TransferServer") + .def("address", [](PyTransferServer& self) { return self.address(); }) + .def("_await_pull_flat", + [](PyTransferServer& self, uint64_t uuid, + std::vector inputs) { + std::vector> arrs; + arrs.reserve(inputs.size()); + for (const xla::PyArray& input : inputs) { + arrs.push_back(tsl::FormRef(input.ifrt_array())); + } + self.AwaitPull(uuid, arrs); + }) + .def("connect", [](PyTransferServer& self, const std::string& address) { + return self.Connect(address); + }); + + m.def( + "start_transfer_server", + [](xla::nb_class_ptr py_client, std::string address, + std::vector transport_addresses_str, + size_t max_num_parallel_copies, + size_t transfer_size) -> PyTransferServer { + PyTransferServer result; + std::vector transport_addresses; + transport_addresses.reserve(transport_addresses_str.size()); + for (const std::string& addr : transport_addresses_str) { + transport_addresses.push_back( + xla::ValueOrThrow(SocketAddress::Parse(addr))); + } + xla::ThrowIfError(result.Start( + py_client->ifrt_client(), max_num_parallel_copies, transfer_size, + xla::ValueOrThrow(SocketAddress::Parse(address)), + transport_addresses)); + return result; + }, + nb::arg("client"), nb::arg("address") = SocketAddress().ToString(), + nb::arg("transport_addresses") = std::vector(), + nb::arg("max_num_parallel_copies") = 8, + nb::arg("transfer_size") = 256 * 1024 * 1024); +} + +} // namespace aux diff --git a/jaxlib/xla/py_socket_transfer.h b/jaxlib/xla/py_socket_transfer.h new file mode 100644 index 000000000000..fa477f24e3e5 --- /dev/null +++ b/jaxlib/xla/py_socket_transfer.h @@ -0,0 +1,26 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef JAXLIB_XLA_TRANSFER_PY_SOCKET_TRANSFER_H_ +#define JAXLIB_XLA_TRANSFER_PY_SOCKET_TRANSFER_H_ + +#include "nanobind/nanobind.h" + +namespace aux { + +void RegisterTransferServerTypes(nanobind::module_& m); + +} // namespace aux + +#endif // JAXLIB_XLA_TRANSFER_PY_SOCKET_TRANSFER_H_ diff --git a/jaxlib/xla/pytree.cc b/jaxlib/xla/pytree.cc index dd5a0bd9cf69..7d1f7676bada 100644 --- a/jaxlib/xla/pytree.cc +++ b/jaxlib/xla/pytree.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The OpenXLA Authors. +/* Copyright 2019 The JAX Authors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/jaxlib/xla/pytree.h b/jaxlib/xla/pytree.h index 722fe41169a0..471d25af89bc 100644 --- a/jaxlib/xla/pytree.h +++ b/jaxlib/xla/pytree.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The OpenXLA Authors. +/* Copyright 2019 The JAX Authors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/jaxlib/xla/sdy.h b/jaxlib/xla/sdy.h index 5d8c8c2eb7dd..ef075855decd 100644 --- a/jaxlib/xla/sdy.h +++ b/jaxlib/xla/sdy.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_SDY_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_SDY_H_ +#ifndef JAXLIB_XLA_SDY_H_ +#define JAXLIB_XLA_SDY_H_ // placeholder for index annotation headers #include "nanobind/nanobind.h" @@ -25,4 +25,4 @@ void BuildSdySubmodule(nanobind::module_& m); } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_SDY_H_ +#endif // JAXLIB_XLA_SDY_H_ diff --git a/jaxlib/xla/weakref_lru_cache.h b/jaxlib/xla/weakref_lru_cache.h index 444e01cef575..7c75974d3d23 100644 --- a/jaxlib/xla/weakref_lru_cache.h +++ b/jaxlib/xla/weakref_lru_cache.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_WEAKREF_LRU_CACHE_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_WEAKREF_LRU_CACHE_H_ +#ifndef JAXLIB_XLA_WEAKREF_LRU_CACHE_H_ +#define JAXLIB_XLA_WEAKREF_LRU_CACHE_H_ // placeholder for index annotation headers #include "nanobind/nanobind.h" @@ -25,4 +25,4 @@ void BuildWeakrefLRUCacheAPI(nanobind::module_& m); } // namespace jax -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_WEAKREF_LRU_CACHE_H_ +#endif // JAXLIB_XLA_WEAKREF_LRU_CACHE_H_ diff --git a/jaxlib/xla/xla.cc b/jaxlib/xla/xla.cc index bd3ed3205fb2..54c94c57a734 100644 --- a/jaxlib/xla/xla.cc +++ b/jaxlib/xla/xla.cc @@ -71,9 +71,9 @@ limitations under the License. #if defined(__linux__) #include "gloo/transport/tcp/attr.h" #include "gloo/transport/tcp/device.h" +#include "jaxlib/xla/py_socket_transfer.h" #include "xla/backends/cpu/collectives/gloo_collectives.h" #include "xla/backends/cpu/collectives/gloo_kv_store.h" -#include "xla/python/transfer/py_socket_transfer.h" #elif defined(__APPLE__) #include "gloo/transport/uv/device.h" #include "xla/backends/cpu/collectives/gloo_collectives.h" // NOLINT diff --git a/jaxlib/xla/xla_compiler.h b/jaxlib/xla/xla_compiler.h index f3ffe5fe9440..ca5bc762a7d8 100644 --- a/jaxlib/xla/xla_compiler.h +++ b/jaxlib/xla/xla_compiler.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_XLA_COMPILER_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_XLA_COMPILER_H_ +#ifndef JAXLIB_XLA_XLA_COMPILER_H_ +#define JAXLIB_XLA_XLA_COMPILER_H_ // placeholder for index annotation headers #include "nanobind/nanobind.h" @@ -25,4 +25,4 @@ void BuildXlaCompilerSubmodule(nanobind::module_& m); } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_XLA_COMPILER_H_ +#endif // JAXLIB_XLA_XLA_COMPILER_H_ From 4f9571eb2bd72ab893e0ec3df1bf08777a0cc7c1 Mon Sep 17 00:00:00 2001 From: Charles Hofer Date: Mon, 24 Mar 2025 19:32:47 +0000 Subject: [PATCH 146/483] Fix auditwheel --- build/rocm/tools/build_wheels.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/build/rocm/tools/build_wheels.py b/build/rocm/tools/build_wheels.py index fd98bbb8ec04..a7ebdf86f916 100644 --- a/build/rocm/tools/build_wheels.py +++ b/build/rocm/tools/build_wheels.py @@ -226,7 +226,10 @@ def fix_wheel(path, jax_path): py_bin = "/opt/python/cp310-cp310/bin" env["PATH"] = "%s:%s" % (py_bin, env["PATH"]) - cmd = ["pip", "install", "auditwheel>=6"] + # NOTE(mrodden): auditwheel 6.0 added lddtree module, but 6.3.0 changed + # the fuction to ldd and also changed its behavior + # constrain range to 6.0 to 6.2.x + cmd = ["pip", "install", "auditwheel>=6,<6.3"] subprocess.run(cmd, check=True, env=env) fixwheel_path = os.path.join(jax_path, "build/rocm/tools/fixwheel.py") From a7d46e6acc4aecee92dcfe68a8d3d86d21b3db3c Mon Sep 17 00:00:00 2001 From: Tori Baker Date: Tue, 25 Mar 2025 07:48:50 -0700 Subject: [PATCH 147/483] Integrate Triton up to [cdb53266](https://github.com/openai/triton/commits/cdb53266e6c251d91a2c321d64e8466caff129a9) PiperOrigin-RevId: 740345806 --- jax/_src/pallas/triton/lowering.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index a0883ea589b0..c85c5f0a39c0 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -1779,6 +1779,12 @@ def _reshape(a: ir.Value, shape: Sequence[int]) -> ir.Value: ) +def get_join_type(old_type: ir.RankedTensorType): + shape = old_type.shape + shape.append(2) + return ir.RankedTensorType.get(shape, old_type.element_type, old_type.encoding) + + @register_lowering(lax.concatenate_p) def _concatenate_lowering_rule(ctx: LoweringRuleContext, *args, dimension): if len(args) != 2: @@ -1793,9 +1799,10 @@ def _concatenate_lowering_rule(ctx: LoweringRuleContext, *args, dimension): raise NotImplementedError( "Only arguments with shape [..., 1] are supported." ) - return tt_dialect.join( - _reshape(x, x_aval.shape[:-1]), _reshape(y, y_aval.shape[:-1]) - ) + lhs = _reshape(x, x_aval.shape[:-1]) + rhs = _reshape(y, y_aval.shape[:-1]) + ret_type = get_join_type(ir.RankedTensorType(rhs.type)) + return tt_dialect.join(ret_type, lhs, rhs) @register_lowering(lax.split_p) @@ -2102,10 +2109,11 @@ def _masked_load_lowering_rule( # most significant. Before jaxlib 0.5.2, the order was reversed. if is_contiguous_int4: msb_values = arith_dialect.shrui(values, _full(values.type, 4)) + join_type = get_join_type(ir.RankedTensorType(values.type)) if jaxlib_version < (0, 5, 2): - values = tt_dialect.join(msb_values, values) + values = tt_dialect.join(join_type, msb_values, values) else: - values = tt_dialect.join(values, msb_values) + values = tt_dialect.join(join_type, values, msb_values) shape = ir.RankedTensorType(values.type).shape values = _reshape(values, (*shape[:-2], shape[-2] * shape[-1])) else: From 8260ab329145155da697a276591e73100993635f Mon Sep 17 00:00:00 2001 From: Nicolas Castet Date: Tue, 25 Mar 2025 09:53:29 -0500 Subject: [PATCH 148/483] Address review comments --- jax/_src/pallas/core.py | 1 - jax/_src/pallas/mosaic/core.py | 4 +++- jax/_src/pallas/mosaic/lowering.py | 15 ++++++++++++-- jax/_src/pallas/primitives.py | 32 +++++++++-------------------- jax/experimental/pallas/__init__.py | 1 + jax/experimental/pallas/tpu.py | 15 ++++++++------ 6 files changed, 36 insertions(+), 32 deletions(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 389bbd3b0733..78b815820609 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -67,7 +67,6 @@ def __repr__(self): class semaphore_dtype(dtypes.extended): pass class semaphore(semaphore_dtype): pass -class dma_semaphore(semaphore_dtype): pass class barrier_semaphore(semaphore_dtype): pass @runtime_checkable diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index 5d503779f092..fc4ecbedaca5 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -112,6 +112,8 @@ def __call__(self, shape: tuple[int, ...], dtype: jnp.dtype): # A convenience function for constructing MemoryRef types. return pallas_core.MemoryRef(shape, dtype, self) +class dma_semaphore(pallas_core.semaphore_dtype): pass + class AbstractSemaphoreTyRules: @staticmethod def pallas_interpret_element_aval(_) -> jax_core.ShapedArray: @@ -141,7 +143,7 @@ class SemaphoreTy(AbstractSemaphoreTy): name = "sem" class DmaSemaphoreTy(AbstractSemaphoreTy): - type = pallas_core.dma_semaphore + type = dma_semaphore name = "dma_sem" class BarrierSemaphoreTy(AbstractSemaphoreTy): diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 3469ef4de952..00302494f67d 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -193,7 +193,7 @@ def _memory_space_to_mosaic_attribute(memory_space: MemorySpace | None def _dtype_to_ir_type(dtype: jnp.dtype, is_kernel_boundary: bool = False) -> ir.Type: if jnp.issubdtype(dtype, pallas_core.semaphore_dtype): - if jnp.issubdtype(dtype, pallas_core.dma_semaphore): + if jnp.issubdtype(dtype, tpu_core.dma_semaphore): return ir.Type.parse("!tpu.dma_semaphore") elif jnp.issubdtype(dtype, pallas_core.semaphore): return ir.Type.parse("!tpu.semaphore") @@ -3367,7 +3367,18 @@ def _semaphore_read_lowering_rule( *args, args_tree, ): - sem_aval, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in) + sem_aval, sem_transforms_avals = tree_util.tree_unflatten(args_tree, ctx.avals_in) + primitives.check_sem_avals( + sem_aval, + sem_transforms_avals, + "read", + allowed_semaphore_types={ + tpu_core.dma_semaphore, + pallas_core.semaphore, + pallas_core.barrier_semaphore, + pallas_core.SEMAPHORE_INTERPRET_DTYPE, + }, + ) sem, transforms = tree_util.tree_unflatten(args_tree, args) sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, transforms) return tpu.sem_read(sem) diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index 5d3444ef719f..4971b83a9ba2 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -1023,16 +1023,15 @@ def check_sem_avals( sem_shape = sem_transforms_avals[-1].get_indexer_shape() if sem_shape: raise ValueError(f"Cannot {name} on a non-()-shaped semaphore: {sem_shape}") - # Uncomment when semaphore type works for Mosaic-GPU lowering - # sem_dtype = sem_aval.dtype - # if not any( - # jnp.issubdtype(sem_dtype, sem_type) - # for sem_type in allowed_semaphore_types - # ): - # raise ValueError( - # f"Must {name} semaphores of the following types:" - # f" {allowed_semaphore_types}." - # ) + sem_dtype = sem_aval.dtype + if not any( + jnp.issubdtype(sem_dtype, sem_type) + for sem_type in allowed_semaphore_types + ): + raise ValueError( + f"Must {name} semaphores of the following types:" + f" {allowed_semaphore_types}." + ) def _transform_semaphore(ref_value, transforms, ref_aval): @@ -1063,18 +1062,7 @@ def _semaphore_read_abstract_eval( *avals, args_tree, ): - sem_aval, sem_transforms_avals = tree_util.tree_unflatten(args_tree, avals) - check_sem_avals( - sem_aval, - sem_transforms_avals, - "read", - allowed_semaphore_types={ - pallas_core.dma_semaphore, - pallas_core.semaphore, - pallas_core.barrier_semaphore, - pallas_core.SEMAPHORE_INTERPRET_DTYPE, - }, - ) + del avals, args_tree return jax_core.ShapedArray((), jnp.dtype("int32")) def _semaphore_read_discharge_rule(in_avals, diff --git a/jax/experimental/pallas/__init__.py b/jax/experimental/pallas/__init__.py index ea58fae3d283..fd523712fa9c 100644 --- a/jax/experimental/pallas/__init__.py +++ b/jax/experimental/pallas/__init__.py @@ -30,6 +30,7 @@ from jax._src.pallas.core import MemorySpace as MemorySpace from jax._src.pallas.core import Buffered as Buffered from jax._src.pallas.core import no_block_spec as no_block_spec +from jax._src.pallas.core import semaphore as semaphore from jax._src.pallas.core import Unblocked as Unblocked from jax._src.pallas.core import unblocked as unblocked from jax._src.pallas.cost_estimate import estimate_cost as estimate_cost diff --git a/jax/experimental/pallas/tpu.py b/jax/experimental/pallas/tpu.py index c81edaf76fa3..da054bf18309 100644 --- a/jax/experimental/pallas/tpu.py +++ b/jax/experimental/pallas/tpu.py @@ -17,11 +17,10 @@ from jax._src.pallas.mosaic import core as core from jax._src.pallas.mosaic.core import ARBITRARY as ARBITRARY from jax._src.pallas.mosaic.core import create_tensorcore_mesh as create_tensorcore_mesh -from jax._src.pallas.core import dma_semaphore as dma_semaphore +from jax._src.pallas.mosaic.core import dma_semaphore as dma_semaphore from jax._src.pallas.mosaic.core import GridDimensionSemantics as GridDimensionSemantics from jax._src.pallas.mosaic.core import PARALLEL as PARALLEL from jax._src.pallas.mosaic.core import PrefetchScalarGridSpec as PrefetchScalarGridSpec -from jax._src.pallas.core import semaphore as semaphore from jax._src.pallas.mosaic.core import SemaphoreType as SemaphoreType from jax._src.pallas.mosaic.core import TPUMemorySpace as TPUMemorySpace from jax._src.pallas.mosaic.core import TPUCompilerParams as TPUCompilerParams @@ -40,8 +39,6 @@ from jax._src.pallas.mosaic.primitives import async_remote_copy as async_remote_copy from jax._src.pallas.mosaic.primitives import bitcast as bitcast from jax._src.pallas.mosaic.primitives import delay as delay -from jax._src.pallas.primitives import device_id as device_id -from jax._src.pallas.primitives import DeviceIdType as DeviceIdType from jax._src.pallas.mosaic.primitives import get_barrier_semaphore as get_barrier_semaphore from jax._src.pallas.mosaic.primitives import make_async_copy as make_async_copy from jax._src.pallas.mosaic.primitives import make_async_remote_copy as make_async_remote_copy @@ -49,11 +46,17 @@ from jax._src.pallas.mosaic.primitives import prng_seed as prng_seed from jax._src.pallas.mosaic.primitives import repeat as repeat from jax._src.pallas.mosaic.primitives import roll as roll +from jax._src.pallas.mosaic.random import sample_block as sample_block +from jax._src.pallas.mosaic.random import to_pallas_key as to_pallas_key + +# Those primitives got moved to Pallas core. Keeping the updated imports +# here for backward compatibility. +from jax._src.pallas.core import semaphore as semaphore +from jax._src.pallas.primitives import device_id as device_id +from jax._src.pallas.primitives import DeviceIdType as DeviceIdType from jax._src.pallas.primitives import semaphore_read as semaphore_read from jax._src.pallas.primitives import semaphore_signal as semaphore_signal from jax._src.pallas.primitives import semaphore_wait as semaphore_wait -from jax._src.pallas.mosaic.random import sample_block as sample_block -from jax._src.pallas.mosaic.random import to_pallas_key as to_pallas_key import types from jax._src.pallas.mosaic.verification import assume From a9266a1521ade250f99114227f281b30235cef9c Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 25 Mar 2025 09:09:53 -0700 Subject: [PATCH 149/483] [pallas:mosaic_gpu] `PallasCallTest` now runs all tests under both Lane and WG thread semantics PiperOrigin-RevId: 740371195 --- jax/_src/pallas/mosaic_gpu/lowering.py | 4 +- .../mosaic_gpu/pallas_call_registration.py | 10 +- tests/pallas/mosaic_gpu_test.py | 414 +++++++++--------- 3 files changed, 208 insertions(+), 220 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 677f63c6674a..493d8c07b941 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1404,7 +1404,7 @@ def convert(ty, x): mosaic_lowering_rules[mgpu.ThreadSemantics.Warpgroup].update({ lax.neg_p: _lower_fun(lambda x: jnp.subtract(0, x), multiple_results=False), lax.not_p: _lower_fun( - lambda x: jnp.bitwise_xor(x, -1), multiple_results=False + lambda x: jnp.astype(jnp.bitwise_xor(jnp.astype(x, int), -1), jnp.dtype(x)), multiple_results=False, ), }) @@ -1821,7 +1821,7 @@ def _debug_print_lowering_rule( return () @register_lowering_rule(primitives.debug_print_p, mgpu.ThreadSemantics.Warpgroup) -def _debug_print_lowering_rule( +def _debug_print_lowering_rule_wg( ctx: LoweringRuleContext, *args, fmt, diff --git a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py index d506349fe101..5399727878a6 100644 --- a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py @@ -27,7 +27,7 @@ from jax._src.interpreters import mlir from jax._src.pallas import core as pallas_core from jax._src.pallas.mosaic_gpu import lowering -import jax.experimental.mosaic.gpu.core as mosaic_core +from jax.experimental.mosaic import gpu as mgpu def pallas_call_lowering( @@ -57,10 +57,10 @@ def pallas_call_lowering( print(grid_mapping) thread_semantics = compiler_params.get("mosaic_gpu", {}).get( - "thread_semantics", mosaic_core.ThreadSemantics.Lane + "thread_semantics", mgpu.ThreadSemantics.Warpgroup ) - if thread_semantics == mosaic_core.ThreadSemantics.Warpgroup: - mosaic_core.dialect.register_dialect(ctx.module_context.context) # pytype: disable=attribute-error + if thread_semantics == mgpu.ThreadSemantics.Warpgroup: + mgpu.dialect.register_dialect(ctx.module_context.context) # pytype: disable=attribute-error lowering_result = lowering.lower_pipelined_jaxpr_to_module( grid_mapping, @@ -77,7 +77,7 @@ def pallas_call_lowering( new_avals_out = [ jax_core.ShapedArray(t.shape, t.dtype) for t in lowering_result.out_structs ] - outs = mosaic_core._mosaic_gpu_lowering_rule( + outs = mgpu.core._mosaic_gpu_lowering_rule( ctx.replace(avals_out=new_avals_out), *args, module=module, diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index b33857df40b6..b39288252e08 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -13,24 +13,27 @@ # limitations under the License. import contextlib +import dataclasses import functools import math import operator import os import re import tempfile +from typing import ClassVar from absl.testing import absltest from absl.testing import parameterized import jax from jax import lax from jax._src import test_util as jtu -from jax._src.pallas.mosaic_gpu import pipeline as mgpu_pipeline from jax._src.pallas import pallas_call +from jax._src.pallas.mosaic_gpu import pipeline as mgpu_pipeline from jax.experimental import pallas as pl from jax.experimental.pallas import mosaic_gpu as plgpu import jax.numpy as jnp import numpy as np + try: from jax._src.lib import mosaic_gpu as mosaic_gpu_lib except ImportError: @@ -55,7 +58,16 @@ def _sum_same_dtype(x): return jnp.sum(x, dtype=x.dtype) -class PallasTest(jtu.JaxTestCase): +class PallasTestMetaclass(parameterized.TestGeneratorMetaclass): + + def __new__(mcs, *args, thread_semantics=plgpu.ThreadSemantics.Lane): + cls = super().__new__(mcs, *args) + cls.THREAD_SEMANTICS = thread_semantics + return cls + + +class PallasTest(jtu.JaxTestCase, metaclass=PallasTestMetaclass): + THREAD_SEMANTICS: ClassVar[plgpu.ThreadSemantics] def setUp(self): if not jtu.is_cuda_compute_capability_at_least("9.0"): @@ -66,6 +78,17 @@ def setUp(self): super().setUp() + def skip_if_wg_semantics(self): + if self.THREAD_SEMANTICS == plgpu.ThreadSemantics.Warpgroup: + self.skipTest("Not supported under WG semantics") + + def pallas_call(self, *args, **kwargs): + compiler_params = dataclasses.replace( + kwargs.pop("compiler_params", plgpu.GPUCompilerParams()), + thread_semantics=self.THREAD_SEMANTICS, + ) + return pl.pallas_call(*args, compiler_params=compiler_params, **kwargs) + @contextlib.contextmanager def capture_stdout(self): if mosaic_gpu_lib is None: @@ -104,17 +127,14 @@ class PallasCallTest(PallasTest): lax.log, ], approx_math=[True, False], - thread_semantics=[*plgpu.ThreadSemantics], ) - def test_unary_op(self, op, approx_math, thread_semantics): + def test_unary_op(self, op, approx_math): dtype = jnp.int32 if op is lax.bitwise_not else jnp.float32 @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], dtype), - compiler_params=plgpu.GPUCompilerParams( - approx_math=approx_math, thread_semantics=thread_semantics - ), + compiler_params=plgpu.GPUCompilerParams(approx_math=approx_math), ) def kernel(x_ref, o_ref): o_ref[...] = op(x_ref[...]) @@ -135,16 +155,10 @@ def kernel(x_ref, o_ref): jnp.maximum, ], dtype=[jnp.float32, jnp.int32, jnp.uint32], - thread_semantics=[*plgpu.ThreadSemantics], ) - def test_binary_op(self, op, dtype, thread_semantics): - + def test_binary_op(self, op, dtype): @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([256], dtype), - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], dtype) ) def kernel(x_ref, y_ref, o_ref): o_ref[...] = op(x_ref[...], y_ref[...]) @@ -165,16 +179,10 @@ def kernel(x_ref, y_ref, o_ref): ], # TODO(slebedev): Support integral types. dtype=[jnp.float32, jnp.int32, jnp.uint32], - thread_semantics=[*plgpu.ThreadSemantics], ) - def test_comparison_op(self, op, dtype, thread_semantics): - + def test_comparison_op(self, op, dtype): @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([256], dtype), - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], dtype) ) def kernel(o_ref): o_ref[...] = jnp.broadcast_to( @@ -184,8 +192,9 @@ def kernel(o_ref): np.testing.assert_array_equal(kernel(), jnp.full([256], op(42, 24), dtype)) def test_add_first(self): + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.float32), ) def kernel(x_ref, y_ref, o_ref): @@ -195,16 +204,10 @@ def kernel(x_ref, y_ref, o_ref): y = jnp.flip(x).reshape(1, 256) np.testing.assert_array_equal(kernel(x, y), x + y[0]) - @parameterized.product( - shape=[(128,), (128, 128)], thread_semantics=[*plgpu.ThreadSemantics] - ) - def test_reduce_sum(self, shape, thread_semantics): + @parameterized.product(shape=[(128,), (128, 128)]) + def test_reduce_sum(self, shape): @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct(shape, jnp.float32), - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), + self.pallas_call, out_shape=jax.ShapeDtypeStruct(shape, jnp.float32) ) def kernel(x_ref, o_ref): o_ref[...] = jnp.broadcast_to(_sum_same_dtype(x_ref[...]), o_ref.shape) @@ -213,11 +216,12 @@ def kernel(x_ref, o_ref): np.testing.assert_array_equal(kernel(x), jnp.sum(x)) def test_reshape(self): + self.skip_if_wg_semantics() + shape1, shape2 = (128,), (2, 16, 4) @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct(shape2, jnp.float32), + self.pallas_call, out_shape=jax.ShapeDtypeStruct(shape2, jnp.float32) ) def kernel(x_ref, out_ref): x_ref_reshaped = x_ref.reshape(shape2) @@ -228,14 +232,9 @@ def kernel(x_ref, out_ref): x = jnp.arange(math.prod(shape1)).astype(jnp.float32) np.testing.assert_array_equal(kernel(x), x.reshape(shape2)) - @parameterized.product(thread_semantics=[*plgpu.ThreadSemantics]) - def test_add_xy_indexed(self, thread_semantics): + def test_add_xy_indexed(self): @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([128], jnp.float32), - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), + self.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.float32) ) def kernel(x_ref, y_ref, o_ref): idx = _sum_same_dtype(y_ref[...]) @@ -246,8 +245,9 @@ def kernel(x_ref, y_ref, o_ref): np.testing.assert_array_equal(kernel(x, y), x[jnp.sum(y)]) def test_add_one_grid(self): + @functools.partial( - pl.pallas_call, + self.pallas_call, in_specs=[pl.BlockSpec((128,), lambda *i: i)], out_specs=pl.BlockSpec((128,), lambda *i: i), out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.float32), @@ -260,9 +260,8 @@ def kernel(x_ref, o_ref): np.testing.assert_array_equal(kernel(x), x + 1.0) def test_add_one_grid_with_scratch(self): - @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.float32), in_specs=[pl.BlockSpec((128,), lambda *i: i)], out_specs=pl.BlockSpec((128,), lambda *i: i), @@ -278,9 +277,8 @@ def kernel(x_ref, o_ref, scratch_ref): @parameterized.product(max_concurrent_steps=[1, 2, 3, 4, 16]) def test_add_one_grid_pipelined(self, max_concurrent_steps): - @functools.partial( - pl.pallas_call, + self.pallas_call, in_specs=[pl.BlockSpec((128, 16), lambda i, j: (i, j))], out_specs=pl.BlockSpec((128, 16), lambda i, j: (i, j)), out_shape=jax.ShapeDtypeStruct([128 * 2, 64], jnp.float32), @@ -297,9 +295,8 @@ def kernel(x_ref, o_ref): np.testing.assert_array_equal(kernel(x), x + 1.0) def test_add_one_grid_pipelined_program_id(self): - @functools.partial( - pl.pallas_call, + self.pallas_call, out_specs=pl.BlockSpec((16, 16), lambda i, j: (i, j)), out_shape=jax.ShapeDtypeStruct([16, 64], jnp.int32), compiler_params=plgpu.GPUCompilerParams( @@ -317,8 +314,9 @@ def kernel(o_ref): ) def test_add_one_grid_pipelined_sequential_invariant_output(self): + @functools.partial( - pl.pallas_call, + self.pallas_call, in_specs=[pl.BlockSpec((32, 16), lambda i, j: (i, j))], out_specs=pl.BlockSpec((32, 16), lambda i, j: (i, 0)), out_shape=jax.ShapeDtypeStruct([32 * 2, 64], jnp.float32), @@ -345,30 +343,29 @@ def kernel(x_ref, o_ref): @parameterized.parameters(jnp.float32, jnp.int32, jnp.uint32) def test_iota(self, dtype): + self.skip_if_wg_semantics() + dimension = 1 + @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct((128, 128), dtype), + self.pallas_call, out_shape=jax.ShapeDtypeStruct((128, 128), dtype) ) def kernel(o_ref): - o_ref[...] = plgpu.broadcasted_iota(dtype, (128, 128), dimension, layout=plgpu.Layout.WGMMA) - - np.testing.assert_array_equal(kernel(), jax.lax.broadcasted_iota(dtype, (128, 128), dimension)) + o_ref[...] = plgpu.broadcasted_iota( + dtype, o_ref.shape, dimension, layout=plgpu.Layout.WGMMA + ) - @parameterized.product( - indexer=[..., slice(128), slice(None, 128)], - thread_semantics=[*plgpu.ThreadSemantics], - ) - def test_copy_smem_to_gmem(self, indexer, thread_semantics): + np.testing.assert_array_equal( + kernel(), jax.lax.broadcasted_iota(dtype, (128, 128), dimension) + ) + @parameterized.product(indexer=[..., slice(128), slice(None, 128)]) + def test_copy_smem_to_gmem(self, indexer): @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.float32), out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), scratch_shapes=[plgpu.SMEM((256,), jnp.float32)], - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), ) def kernel(x_ref, o_ref_gmem, scratch_ref): scratch_ref[...] = x_ref[...] + 1 @@ -388,8 +385,9 @@ def kernel(x_ref, o_ref_gmem, scratch_ref): "shape": (64, 64), "indexers": (4, slice(0, 64))}, ) def test_copy_smem_to_gmem_with_multiple_gmem_indexers(self, shape, indexers): + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct(shape, jnp.float32), out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), scratch_shapes=[plgpu.SMEM(shape, jnp.float32)], @@ -413,8 +411,9 @@ def kernel(x_ref, o_ref_gmem, scratch_ref): @parameterized.product(indexer=[..., slice(128), slice(None, 128)]) def test_copy_gmem_to_smem(self, indexer): + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.float32), in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), scratch_shapes=[ @@ -458,13 +457,15 @@ def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref): }, ) def test_copy_gmem_to_smem_with_multiple_gmem_indexers(self, shape, indexers): + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct(shape, jnp.float32), in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), - scratch_shapes=[plgpu.SMEM(shape, jnp.float32), - plgpu.Barrier(num_arrivals=1), - ], + scratch_shapes=[ + plgpu.SMEM(shape, jnp.float32), + plgpu.Barrier(num_arrivals=1), + ], grid=(1,), ) def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref): @@ -489,7 +490,7 @@ def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref): def test_gmem_to_smem_with_multiple_smem_indexers(self): x = jax.random.uniform(jax.random.key(0), (2, 64, 64), dtype=jnp.float32) @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([64, 64], jnp.float32), in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), scratch_shapes=[ @@ -506,21 +507,31 @@ def extract_x0(x_ref_gmem, o_ref, scratch_ref, barrier_ref): np.testing.assert_array_equal(extract_x0(x), x[0]) def test_gmem_to_smem_with_multiple_smem_indexers_and_transforms(self): + self.skip_if_wg_semantics() + x = jnp.arange(512 * 512, dtype=jnp.int32).reshape(512, 512) @functools.partial( - pl.pallas_call, + self.pallas_call, grid=(4, 4), out_shape=jax.ShapeDtypeStruct((256, 128), jnp.int32), - in_specs=(plgpu.GPUBlockSpec( - block_shape=(128, 128), - index_map=lambda i, j: (i, j), - memory_space=plgpu.SMEM, - transforms=(plgpu.TilingTransform((64, 32)), - plgpu.SwizzleTransform(128))),), - out_specs=(plgpu.GPUBlockSpec( - block_shape=(64, 32), - index_map=lambda i, j: (i, j), - memory_space=plgpu.SMEM,)), + in_specs=( + plgpu.GPUBlockSpec( + block_shape=(128, 128), + index_map=lambda i, j: (i, j), + memory_space=plgpu.SMEM, + transforms=( + plgpu.TilingTransform((64, 32)), + plgpu.SwizzleTransform(128), + ), + ), + ), + out_specs=( + plgpu.GPUBlockSpec( + block_shape=(64, 32), + index_map=lambda i, j: (i, j), + memory_space=plgpu.SMEM, + ) + ), ) def kernel(x_ref, o_ref): x_sliced = x_ref.at[0:64, 32:96].at[:, 0:32] # get x_ref[0:64, 32:64] @@ -532,8 +543,9 @@ def kernel(x_ref, o_ref): @parameterized.product(indexer=[0, 1, 2, 3]) def test_copy_gmem_to_smem_with_indexed_barrier(self, indexer): + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.float32), in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), scratch_shapes=[ @@ -553,6 +565,8 @@ def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref): @parameterized.named_parameters(("_g2s", False), ("_s2g", True)) def test_copy_with_transforms(self, to_smem): + self.skip_if_wg_semantics() + def kernel(x_ref, o_ref, barrier_ref): if to_smem: plgpu.copy_gmem_to_smem(x_ref, o_ref, barrier_ref) @@ -574,7 +588,7 @@ def kernel(x_ref, o_ref, barrier_ref): ) if not to_smem: in_spec, out_spec = out_spec, in_spec - f = pl.pallas_call( + f = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct([128, 128], jnp.float32), in_specs=(in_spec,), @@ -585,6 +599,8 @@ def kernel(x_ref, o_ref, barrier_ref): np.testing.assert_array_equal(f(x), x) def test_scoped_copy_with_transforms(self): + self.skip_if_wg_semantics() + ts = (plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(128)) def kernel(x_ref, o_ref, barrier_ref): def body(tmp_ref): @@ -597,7 +613,7 @@ def body(tmp_ref): out_spec = plgpu.GPUBlockSpec( (128, 128), lambda: (0, 0), transforms=ts, memory_space=plgpu.SMEM, ) - f = pl.pallas_call( + f = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct([128, 128], jnp.float32), in_specs=(in_spec,), @@ -608,6 +624,8 @@ def body(tmp_ref): np.testing.assert_array_equal(f(x), x * 2) def test_copy_with_transforms_and_indexing(self): + self.skip_if_wg_semantics() + def kernel(x_ref, o_ref, barrier_ref): for i in range(2): plgpu.copy_gmem_to_smem(x_ref, o_ref.at[i], barrier_ref) @@ -624,7 +642,7 @@ def kernel(x_ref, o_ref, barrier_ref): ), memory_space=plgpu.SMEM, ) - f = pl.pallas_call( + f = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct([2, 128, 128], jnp.float32), in_specs=(in_spec,), @@ -644,30 +662,33 @@ def kernel(x_ref, o_ref, barrier_ref): ], ) def test_load_to_layout_with_indexing(self, src_memory_space, layout): + self.skip_if_wg_semantics() + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct([2, 128], jnp.float32), + in_specs=[pl.BlockSpec(memory_space=src_memory_space)], + out_specs=plgpu.GPUBlockSpec( + (2, 128), + lambda: (0, 0), + memory_space=plgpu.SMEM, + ), + ) def kernel(x_ref, o_ref): for i in range(2): x = plgpu.load(x_ref, (i,), layout=layout) o_ref[i, ...] = x - in_spec = pl.BlockSpec(memory_space=src_memory_space) - out_spec = plgpu.GPUBlockSpec( - (2, 128), lambda: (0, 0), memory_space=plgpu.SMEM, - ) - f = pl.pallas_call( - kernel, - out_shape=jax.ShapeDtypeStruct([2, 128], jnp.float32), - in_specs=(in_spec,), - out_specs=out_spec, - ) x = jnp.arange(2 * 128, dtype=jnp.float32).reshape(2, 128) - np.testing.assert_array_equal(f(x), x) + np.testing.assert_array_equal(kernel(x), x) - @parameterized.product(src_memory_space=[plgpu.SMEM], - layout=[ - plgpu.Layout.WGMMA_ROW, - plgpu.Layout.WGMMA_COL, - ],) + @parameterized.product( + src_memory_space=[plgpu.SMEM], + layout=[plgpu.Layout.WGMMA_ROW, plgpu.Layout.WGMMA_COL], + ) def test_load_row_input_to_wgmma_with_transforms(self, src_memory_space, layout): + self.skip_if_wg_semantics() + m, k, n = 64, 128, 192 key1, key2 = jax.random.split(jax.random.key(42), 2) if layout == plgpu.Layout.WGMMA_ROW: @@ -694,7 +715,7 @@ def compute(acc_ref): out_spec = plgpu.GPUBlockSpec( (m, n), lambda: (0, 0), memory_space=plgpu.SMEM, ) - f = pl.pallas_call( + f = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct([m, n], jnp.float32), in_specs=( @@ -705,8 +726,9 @@ def compute(acc_ref): transforms=( plgpu.TilingTransform((64, 64)), plgpu.SwizzleTransform(128), - ) - )), + ), + ), + ), out_specs=out_spec, ) @@ -716,6 +738,8 @@ def compute(acc_ref): np.testing.assert_allclose(f(a, b), out_ref, rtol=1e-3) def test_indexing_before_transpose(self): + self.skip_if_wg_semantics() + def kernel(x_ref, o_ref, barrier_ref): for i in range(2): plgpu.copy_gmem_to_smem( @@ -727,7 +751,7 @@ def kernel(x_ref, o_ref, barrier_ref): out_spec = plgpu.GPUBlockSpec( (2, 64, 2, 128), lambda: (0, 0, 0, 0), memory_space=plgpu.SMEM, ) - f = pl.pallas_call( + f = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct([2, 64, 2, 128], jnp.float32), in_specs=(in_spec,), @@ -739,8 +763,9 @@ def kernel(x_ref, o_ref, barrier_ref): np.testing.assert_array_equal(f(x), np.stack([xt, xt], axis=0)) def test_copy_gmem_to_smem_in_run_scoped(self): + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.float32), in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), ) @@ -757,8 +782,9 @@ def inner_body(scratch_ref): np.testing.assert_array_equal(kernel(x), x + 1.0) def test_add_doubled_sum(self): + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.float32), ) def kernel(x_ref, o_ref): @@ -767,26 +793,6 @@ def kernel(x_ref, o_ref): x = jnp.arange(128).astype(jnp.float32) np.testing.assert_array_equal(kernel(x), x + x.sum()*2) - @parameterized.named_parameters( - ("rsqrt", jax.lax.rsqrt, ), - ("log", jax.lax.log, 5e-7), - ("exp", jax.lax.exp, ), - ("exp2", jax.lax.exp2, 5e-7), - ("logistic", jax.lax.logistic, ), - ("tanh", jax.lax.tanh, 5e-7), - ) - def test_approx_math_unary_op(self, unary_op, rtol=1e-7): - @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([128], jnp.float32), - compiler_params=plgpu.GPUCompilerParams(approx_math=True), - ) - def kernel(x_ref, o_ref): - o_ref[...] = unary_op(x_ref[...]) - - x = jnp.arange(128).astype(jnp.float32) / 128 - np.testing.assert_allclose(kernel(x), unary_op(x), rtol=rtol, atol=1e-5) - @parameterized.product(input_factor=[0.001, 1, 10, 100, 100]) def test_layer_norm(self, input_factor): eps = 1e-5 @@ -794,7 +800,7 @@ def test_layer_norm(self, input_factor): beta = 1.0 @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.float32), ) def layer_norm(x_ref, o_ref): @@ -822,8 +828,9 @@ def layer_norm_np(x): np.testing.assert_allclose(layer_norm(x), layer_norm_np(x), rtol=5e-5) def test_print(self): + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.float32), ) def kernel(x_ref, o_ref): @@ -836,16 +843,32 @@ def kernel(x_ref, o_ref): self.assertEqual(output(), "It works!\n") def test_print_wgmma_tiled_layout(self): + self.skip_if_wg_semantics() + shape = (128, 64) size = math.prod(shape) + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct(shape, jnp.float32), + in_specs=[ + plgpu.GPUBlockSpec( + shape, + lambda: (0, 0), + transforms=( + plgpu.TilingTransform((64, 32)), + plgpu.SwizzleTransform(128), + ), + ) + ], + ) def kernel(x_ref, o_ref): + del o_ref # Unused. pl.debug_print("prefix {}", x_ref[...]) - spec = plgpu.GPUBlockSpec(shape, lambda: (0, 0), transforms=(plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(128))) - x = jnp.arange(size, dtype=jnp.float32).reshape(shape) - f = pl.pallas_call(kernel, out_shape=x, in_specs=[spec], out_specs=spec) + x = jnp.arange(size, dtype=jnp.float32).reshape(shape) with self.capture_stdout() as get_output: - jax.block_until_ready(f(x)) + jax.block_until_ready(kernel(x)) output = get_output() results = re.findall(r"prefix \[(\d+), (\d+)\]: (\d+).?\d*", output) @@ -855,8 +878,10 @@ def kernel(x_ref, o_ref): self.assertEqual(v, i * shape[1] + j) def test_print_scalar(self): + self.skip_if_wg_semantics() + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32), ) def kernel(x_ref, o_ref): @@ -870,8 +895,10 @@ def kernel(x_ref, o_ref): self.assertIn(f"x.sum() = {x.sum()}", output()) def test_print_scalar_array(self): + self.skip_if_wg_semantics() + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32), ) def kernel(x_ref, o_ref): @@ -885,10 +912,12 @@ def kernel(x_ref, o_ref): self.assertIn(f"x.sum() = {x.sum() + 1}", output()) def test_print_array(self): + self.skip_if_wg_semantics() + in_shape = [2, 1, 64, 64] @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct(in_shape, jnp.int32), ) def kernel(x_ref, o_ref): @@ -913,9 +942,11 @@ def kernel(x_ref, o_ref): np.testing.assert_array_equal(kernel(jnp.arange(11, dtype=jnp.int32)), jnp.full((128,), 10, dtype=jnp.int32)) - @parameterized.product(thread_semantics=[*plgpu.ThreadSemantics]) - def test_run_scoped(self, thread_semantics): - + def test_run_scoped(self): + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), + ) def kernel(x_ref, o_ref): def body(tmp_ref): self.assertEqual(tmp_ref.shape, (8, 128)) @@ -926,16 +957,8 @@ def body(tmp_ref): self.assertEqual(tmp.shape, (8, 128)) o_ref[...] = tmp - inp = np.ones((8, 128), jnp.float32) - f = pl.pallas_call( - kernel, - out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), - ) - o = f(inp) - np.testing.assert_array_equal(o, inp + 1.0) + x = np.ones((8, 128), jnp.float32) + np.testing.assert_array_equal(kernel(x), x + 1.0) def test_program_id(self): @functools.partial( @@ -1031,14 +1054,10 @@ def kernel(x_ref, o_ref): x = jnp.arange(128 * 128).astype(jnp.float16).reshape(128, 128) np.testing.assert_array_equal(kernel(x), x) - @parameterized.product( - force_while=[False, True], thread_semantics=[*plgpu.ThreadSemantics] - ) - def test_fori_loop_array(self, force_while, thread_semantics): + @parameterized.product(force_while=[False, True]) + def test_fori_loop_array(self, force_while): @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([256], jnp.int32), - compiler_params=plgpu.GPUCompilerParams(thread_semantics=thread_semantics), + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32) ) def kernel(x_ref, o_ref): # Equivalent to x_ref[...] + 2 + 3. @@ -1047,14 +1066,10 @@ def kernel(x_ref, o_ref): x = jnp.arange(256, dtype=jnp.int32) np.testing.assert_array_equal(kernel(x), x + 2 + 3) - @parameterized.product( - force_while=[False, True], thread_semantics=[*plgpu.ThreadSemantics] - ) - def test_fori_loop_scalar(self, force_while, thread_semantics): + @parameterized.product(force_while=[False, True]) + def test_fori_loop_scalar(self, force_while): @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([256], jnp.int32), - compiler_params=plgpu.GPUCompilerParams(thread_semantics=thread_semantics), + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32) ) def kernel(o_ref): # Equivalent to 2 + 3. @@ -1066,7 +1081,6 @@ def kernel(o_ref): np.testing.assert_array_equal(kernel(), jnp.full([256], 5, jnp.int32)) def test_fori_loop_dynamic_bounds(self): - @functools.partial( pl.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32), @@ -1081,16 +1095,10 @@ def kernel(o_ref): np.testing.assert_array_equal(kernel(), jnp.full([256], 5, dtype=jnp.int32)) - @parameterized.product( - force_while=[False, True], thread_semantics=[*plgpu.ThreadSemantics] - ) - def test_fori_loop_tuple(self, force_while, thread_semantics): + @parameterized.product(force_while=[False, True]) + def test_fori_loop_tuple(self, force_while): @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([256], jnp.int32), - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32) ) def kernel(o_ref): def body(step, xs): @@ -1109,16 +1117,11 @@ def body(step, xs): kernel(), jnp.full([256], 3 * (0 + 1), jnp.int32) ) - @parameterized.product( - force_while=[False, True], thread_semantics=[*plgpu.ThreadSemantics] - ) - def test_fori_loop_indexed_store(self, force_while, thread_semantics): + @parameterized.product(force_while=[False, True]) + def test_fori_loop_indexed_store(self, force_while): @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([4, 128], jnp.float32), - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), ) def kernel(x_ref, y_ref, o_ref): def body(idx, _): @@ -1131,17 +1134,11 @@ def body(idx, _): y = x + 1 np.testing.assert_array_equal(kernel(x, y), x + y) - @parameterized.product(thread_semantics=[*plgpu.ThreadSemantics]) - def test_while_loop(self, thread_semantics): - if thread_semantics == plgpu.ThreadSemantics.Warpgroup: - self.skipTest("WG lowering does not support reduce_sum_p needed for this test") + def test_while_loop(self): + self.skip_if_wg_semantics() @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([128], jnp.int32), - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), + self.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.int32) ) def kernel(x_ref, o_ref): o_ref[...] = jnp.zeros(o_ref.shape, dtype=jnp.int32) @@ -1182,12 +1179,9 @@ def body(acc): with self.assertRaisesRegex(ValueError, "has layout .*, when it should be"): kernel() - @parameterized.parameters([*plgpu.ThreadSemantics]) - def test_cond(self, thread_semantics): + def test_cond(self): @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([256], jnp.int32), - compiler_params=plgpu.GPUCompilerParams(thread_semantics=thread_semantics), + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32) ) def kernel(x_ref, o_ref): jax.lax.cond( @@ -1203,14 +1197,9 @@ def kernel(x_ref, o_ref): self.assertIn("acc % 2", output()) - @parameterized.parameters([*plgpu.ThreadSemantics]) - def test_cond_returning_array(self, thread_semantics): + def test_cond_returning_array(self): @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([256], jnp.int32), - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32) ) def kernel(x_ref, o_ref): acc_sum = _sum_same_dtype(x_ref[...]) @@ -1341,20 +1330,13 @@ def kernel(x_ref, o_ref): (jnp.uint32, jnp.int32), (jnp.int32, jnp.uint32), ], - thread_semantics=[*plgpu.ThreadSemantics], ) - def test_bitcast_convert_type(self, dtypes, thread_semantics): + def test_bitcast_convert_type(self, dtypes): in_dtype, out_dtype = dtypes m, n = 16, 8 out_shape = jax.ShapeDtypeStruct((m, n), out_dtype) - @functools.partial( - pl.pallas_call, - out_shape=out_shape, - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), - ) + @functools.partial(self.pallas_call, out_shape=out_shape) def convert(x_ref, y_ref): y_ref[...] = jax.lax.bitcast_convert_type(x_ref[...], out_shape) @@ -1364,6 +1346,12 @@ def convert(x_ref, y_ref): ) +class PallasCallWGTest( + PallasCallTest, thread_semantics=plgpu.ThreadSemantics.Warpgroup +): + ... + + class PallasCallSm90ATest(PallasSm90ATest): @parameterized.parameters(False, True) From 336852c57bcccef7cc22db0ac6b499f5c277e78b Mon Sep 17 00:00:00 2001 From: Seunghoon Park Date: Tue, 25 Mar 2025 09:11:35 -0700 Subject: [PATCH 150/483] Expose jax.lax.shape_as_value(). PiperOrigin-RevId: 740371651 --- jax/lax/__init__.py | 1 + tests/lax_test.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index 4e376fb666d1..6f2163c424a6 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -198,6 +198,7 @@ select as select, select_n as select_n, select_n_p as select_n_p, + shape_as_value as shape_as_value, shift_left as shift_left, shift_left_p as shift_left_p, shift_right_arithmetic as shift_right_arithmetic, diff --git a/tests/lax_test.py b/tests/lax_test.py index 40f2eb8f3588..14b6c852c61f 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -29,6 +29,7 @@ import jax from jax._src import core +from jax import export from jax import jvp, grad from jax import lax import jax.numpy as jnp @@ -3621,6 +3622,37 @@ def f(x): g = jax.grad(f)(5.) # doesn't crash self.assertAllClose(g, 3., check_dtypes=False) + def test_shape_as_value_handles_static_shapes(self): + result = lax.shape_as_value(()) + self.assertArraysEqual(result, lax.full((0,), np.array(0, np.int64))) + + result = lax.shape_as_value((2,)) + self.assertArraysEqual(result, np.asarray((2,), np.int64)) + + result = lax.shape_as_value((2, 3)) + self.assertArraysEqual(result, np.asarray((2, 3), np.int64)) + + def test_shape_as_value_handles_polymorphic_shapes(self): + @jax.jit + def f(x): + return lax.shape_as_value(x.shape) + + exported = export.export(f)( + jax.ShapeDtypeStruct(export.symbolic_shape("a"), jnp.float32) + ) + result = exported.call(np.ones((1), dtype=np.float32)) + self.assertArraysEqual(result, np.asarray((1,), np.int64)) + result = exported.call(np.ones((2), dtype=np.float32)) + self.assertArraysEqual(result, np.asarray((2,), np.int64)) + + exported = export.export(f)( + jax.ShapeDtypeStruct(export.symbolic_shape("a, b"), jnp.float32) + ) + result = exported.call(np.ones((1, 2), dtype=np.float32)) + self.assertArraysEqual(result, np.asarray((1, 2), np.int64)) + result = exported.call(np.ones((3, 4), dtype=np.float32)) + self.assertArraysEqual(result, np.asarray((3, 4), np.int64)) + class LazyConstantTest(jtu.JaxTestCase): def _Check(self, make_const, expected): From b088b3aef83d1254c82bc0d1780a305588868e59 Mon Sep 17 00:00:00 2001 From: Michael Whittaker Date: Tue, 25 Mar 2025 09:37:52 -0700 Subject: [PATCH 151/483] Fixed broken JAX distributed tests. PiperOrigin-RevId: 740379562 --- tests/BUILD | 11 +++++++ tests/distributed_initialize_test.py | 44 ++++++++++++++++++++++++++++ tests/distributed_test.py | 18 ------------ 3 files changed, 55 insertions(+), 18 deletions(-) create mode 100644 tests/distributed_initialize_test.py diff --git a/tests/BUILD b/tests/BUILD index 0a8fb9459044..d706f08b8092 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -137,9 +137,20 @@ jax_multiplatform_test( srcs = ["debug_nans_test.py"], ) +jax_py_test( + name = "distributed_initialize_test", + srcs = ["distributed_initialize_test.py"], + deps = [ + "//jax", + "//jax:test_util", + ] + py_deps("portpicker"), +) + jax_multiplatform_test( name = "distributed_test", srcs = ["distributed_test.py"], + enable_backends = ["gpu"], + deps = py_deps("portpicker"), ) jax_py_test( diff --git a/tests/distributed_initialize_test.py b/tests/distributed_initialize_test.py new file mode 100644 index 000000000000..33242a41a68e --- /dev/null +++ b/tests/distributed_initialize_test.py @@ -0,0 +1,44 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from absl.testing import absltest +import jax +from jax._src import test_util as jtu + +try: + import portpicker +except ImportError: + portpicker = None + +jax.config.parse_flags_with_absl() + + +@unittest.skipIf(not portpicker, "Test requires portpicker") +class DistributedInitializeTest(jtu.JaxTestCase): + + @jtu.skip_under_pytest( + """Side effects from jax.distributed.initialize conflict with other tests + in the same process. pytest runs multiple tests in the same process.""" + ) + def test_is_distributed_initialized(self): + port = portpicker.pick_unused_port() # type: ignore + self.assertFalse(jax.distributed.is_initialized()) + jax.distributed.initialize(f"localhost:{port}", 1, 0) + self.assertTrue(jax.distributed.is_initialized()) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/distributed_test.py b/tests/distributed_test.py index 3961932dfad0..5e47228c1719 100644 --- a/tests/distributed_test.py +++ b/tests/distributed_test.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import subprocess -import sys import threading import unittest @@ -67,22 +65,6 @@ def task(i): for thread in threads: thread.join() - def test_is_distributed_initialized(self): - # Run in subprocess to isolate side effects from jax.distributed.initialize which conflict with other - # tests. Unfortunately this can't be avoided by calling jax.distributed.shutdown, as the XLA backend - # will be warmed up, which yields a RuntimeError on subsequent calls to initialize. - port = portpicker.pick_unused_port() # type: ignore - cmd = f"""import jax; - assert not jax.distributed.is_initialized(); - jax.distributed.initialize('localhost:{port}', 1, 0); - assert jax.distributed.is_initialized(); - """.replace("\n", ' ') - - result = subprocess.run([sys.executable, "-c", cmd], capture_output=True) - self.assertEqual( - result.returncode, 0, msg=f"Test failed with:\n{result.stdout}\n{result.stderr}" - ) - if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From d8f38ff857726e4fb57bccb169cfcfdb1eb68656 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Tue, 25 Mar 2025 11:42:24 -0700 Subject: [PATCH 152/483] [jaxlib:gpu] Clean up custom call GPU callback handling code. PiperOrigin-RevId: 740428623 --- jax_plugins/cuda/__init__.py | 2 - jax_plugins/rocm/__init__.py | 2 - jaxlib/cuda/BUILD | 4 - jaxlib/cuda/cuda_plugin_extension.cc | 7 -- jaxlib/gpu/py_client_gpu.cc | 118 --------------------------- jaxlib/gpu/py_client_gpu.h | 6 -- jaxlib/rocm/BUILD | 4 - jaxlib/rocm/rocm_plugin_extension.cc | 7 -- 8 files changed, 150 deletions(-) diff --git a/jax_plugins/cuda/__init__.py b/jax_plugins/cuda/__init__.py index 2b02621c89f5..13293de7181d 100644 --- a/jax_plugins/cuda/__init__.py +++ b/jax_plugins/cuda/__init__.py @@ -92,8 +92,6 @@ def initialize(): cuda_plugin_extension.register_custom_call_target, c_api ), ) - for _name, _value in cuda_plugin_extension.registrations().items(): - xla_client.register_custom_call_target(_name, _value, platform="CUDA") for _name, _value in cuda_plugin_extension.ffi_registrations().items(): xla_client.register_custom_call_target( _name, _value, platform='CUDA', api_version=1 diff --git a/jax_plugins/rocm/__init__.py b/jax_plugins/rocm/__init__.py index 0699ae1e34a1..0b1b077acfcd 100644 --- a/jax_plugins/rocm/__init__.py +++ b/jax_plugins/rocm/__init__.py @@ -92,8 +92,6 @@ def initialize(): rocm_plugin_extension.register_custom_call_target, c_api ), ) - for _name, _value in rocm_plugin_extension.registrations().items(): - xla_client.register_custom_call_target(_name, _value, platform="ROCM") for _name, _value in rocm_plugin_extension.ffi_registrations().items(): xla_client.register_custom_call_target( _name, _value, platform='ROCM', api_version=1 diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index 5cd7283ea3fc..48441632fba9 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -684,14 +684,10 @@ cc_library( "@xla//xla:shape_util", "@xla//xla/ffi:ffi_api", "@xla//xla/ffi/api:ffi", - "@xla//xla/pjrt:exceptions", "@xla//xla/pjrt:host_callback", "@xla//xla/pjrt:transpose", - "@xla//xla/python:callback", "@xla//xla/python:nb_numpy", "@xla//xla/python:types", - "@xla//xla/service:custom_call_status", - "@xla//xla/service:custom_call_target_registry", "@xla//xla/service:platform_util", ], ) diff --git a/jaxlib/cuda/cuda_plugin_extension.cc b/jaxlib/cuda/cuda_plugin_extension.cc index 63375921e3be..68230a332d95 100644 --- a/jaxlib/cuda/cuda_plugin_extension.cc +++ b/jaxlib/cuda/cuda_plugin_extension.cc @@ -42,12 +42,6 @@ static std::string ToString(CUresult result) { return absl::StrCat(error_name, ": ", error_string); } -nb::dict Registrations() { - nb::dict dict; - dict["xla_python_gpu_callback"] = - jax::EncapsulateFunction(jax::cuda::XlaPythonGpuCallback); - return dict; -} nb::dict FfiRegistrations() { nb::dict dict; nb::dict gpu_callback_dict; @@ -63,7 +57,6 @@ nb::dict FfiRegistrations() { NB_MODULE(cuda_plugin_extension, m) { BuildGpuPluginExtension(m); - m.def("registrations", &Registrations); m.def("ffi_registrations", &FfiRegistrations); m.def( diff --git a/jaxlib/gpu/py_client_gpu.cc b/jaxlib/gpu/py_client_gpu.cc index 71d327ffdb28..c39d5201f223 100644 --- a/jaxlib/gpu/py_client_gpu.cc +++ b/jaxlib/gpu/py_client_gpu.cc @@ -25,13 +25,11 @@ limitations under the License. #include #include "absl/algorithm/container.h" -#include "absl/base/casts.h" #include "absl/container/inlined_vector.h" #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/ascii.h" -#include "absl/strings/numbers.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -39,15 +37,11 @@ limitations under the License. #include "jaxlib/gpu/vendor.h" #include "xla/ffi/api/ffi.h" #include "xla/ffi/ffi_api.h" -#include "xla/pjrt/exceptions.h" #include "xla/pjrt/host_callback.h" #include "xla/pjrt/transpose.h" #include "xla/primitive_util.h" -#include "xla/python/callback.h" #include "xla/python/nb_numpy.h" #include "xla/python/types.h" -#include "xla/service/custom_call_status.h" -#include "xla/service/custom_call_target_registry.h" #include "xla/shape_util.h" namespace nb = nanobind; @@ -55,118 +49,6 @@ namespace nb = nanobind; namespace jax { namespace JAX_GPU_NAMESPACE { -void XlaPythonGpuCallback(gpuStream_t stream, void** buffers, - const char* opaque, size_t opaque_len, - XlaCustomCallStatus* status) { - // Ignore `descriptor` arg to callback - buffers += 1; - uint64_t descriptor; - if (!absl::SimpleAtoi(opaque, &descriptor)) { - throw xla::XlaRuntimeError("Invalid callback descriptor"); - return; - } - xla::CpuCallback* callback = - absl::bit_cast(static_cast(descriptor)); - size_t arity = callback->num_args(); - std::vector host_input_buffers(arity); - // Copy input GPU buffers to host - for (size_t i = 0; i < arity; ++i) { - const xla::CpuCallback::Arg& arg = callback->args()[i]; - if (arg.type == xla::TOKEN) { - host_input_buffers[i] = nullptr; - continue; - } - void* buf = new char[arg.size_in_bytes]; - host_input_buffers[i] = buf; - // TODO(b/238441608): Use pinned memory here to speed up the transfer. - auto gpu_res = gpuMemcpyAsync(buf, buffers[i], arg.size_in_bytes, - gpuMemcpyDeviceToHost, stream); - CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; - } - CHECK_EQ(gpuStreamSynchronize(stream), gpuSuccess) - << "Failed to gpuStreamSynchronize"; - nb::gil_scoped_acquire gil; - nb::tuple host_input_arrays = nb::steal(PyTuple_New(arity)); - for (size_t i = 0; i < arity; ++i) { - xla::CpuCallback::Arg arg = callback->args()[i]; - if (arg.type == xla::TOKEN) { - PyTuple_SET_ITEM(host_input_arrays.ptr(), i, nb::none().inc_ref().ptr()); - continue; - } - nb::capsule base(host_input_buffers[i], [](void* ptr) noexcept { - delete[] static_cast(ptr); - }); - auto array = xla::nb_numpy_ndarray(arg.dtype, arg.dims, arg.strides, - const_cast(host_input_buffers[i]), - /*base=*/base); - array.attr("flags").attr("writeable") = nb::bool_(false); - PyTuple_SET_ITEM(host_input_arrays.ptr(), i, array.inc_ref().ptr()); - } - xla::EnterHostCallback(); - absl::StatusOr maybe_result_tuple = - callback->Call(host_input_arrays); - xla::LeaveHostCallback(); - if (!maybe_result_tuple.ok()) { - absl::string_view msg = maybe_result_tuple.status().message(); - XlaCustomCallStatusSetFailure(status, msg.data(), msg.length()); - return; - } - nb::tuple result_tuple = maybe_result_tuple.value(); - std::vector temp_buffers; - for (size_t i = 0; i < callback->results().size(); ++i) { - xla::CpuCallback::Result result = callback->results()[i]; - if (result.type == xla::TOKEN) { - continue; - } - nb::object output = - nb::borrow(PyTuple_GetItem(result_tuple.ptr(), i)); - xla::nb_numpy_ndarray array = - xla::nb_numpy_ndarray::ensure(std::move(output)); - absl::Span dims( - reinterpret_cast(array.shape()), array.ndim()); - absl::Span strides( - reinterpret_cast(array.strides()), array.ndim()); - if (strides == result.expected_strides) { - auto gpu_res = - gpuMemcpyAsync(buffers[arity + i], array.data(), result.size_in_bytes, - gpuMemcpyHostToDevice, stream); - CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; - } else { - void* temp = new char[result.size_in_bytes]; - temp_buffers.push_back(temp); - xla::TransposePlan::Options options; - options.elem_size_in_bytes = xla::primitive_util::ByteWidth(result.type); - options.dims = dims; - options.permutation = result.reversed_layout; - options.input_layout = xla::TransposePlan::Striding{strides}; - absl::StatusOr> plan = - callback->transpose_cache().GetOrCreate(options); - if (!plan.ok()) { - throw xla::XlaRuntimeError(plan.status().ToString()); - } - plan.value()->Execute(array.data(), temp); - auto gpu_res = - gpuMemcpyAsync(buffers[arity + i], temp, result.size_in_bytes, - gpuMemcpyHostToDevice, stream); - CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; - } - } - nb::gil_scoped_release release; - CHECK_EQ(gpuStreamSynchronize(stream), gpuSuccess) - << "Failed to gpuStreamSynchronize"; - for (int i = 0; i < temp_buffers.size(); ++i) { - delete[] static_cast(temp_buffers[i]); - } -} - -// TODO(danfm): When compiled as part of a jaxlib plugin, this will register -// the custom call target in the plugin's registry. This won't affect -// registration via the Python API, but we should remove this once we have -// fully migrated to the plugin interface. -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( - "xla_python_gpu_callback", &XlaPythonGpuCallback, - absl::AsciiStrToUpper(JAX_GPU_PLUGIN_NAME)); - struct GpuTransposePlanCache { static xla::ffi::TypeId id; explicit GpuTransposePlanCache(int capacity) : cache(capacity) {} diff --git a/jaxlib/gpu/py_client_gpu.h b/jaxlib/gpu/py_client_gpu.h index 06a955365c0b..8c5404570919 100644 --- a/jaxlib/gpu/py_client_gpu.h +++ b/jaxlib/gpu/py_client_gpu.h @@ -20,15 +20,9 @@ limitations under the License. #include "jaxlib/gpu/vendor.h" #include "xla/ffi/api/ffi.h" -#include "xla/service/custom_call_status.h" namespace jax { namespace JAX_GPU_NAMESPACE { - -void XlaPythonGpuCallback(gpuStream_t stream, void** buffers, - const char* opaque, size_t opaque_len, - XlaCustomCallStatus* status); - XLA_FFI_DECLARE_HANDLER_SYMBOL(kGpuTransposePlanCacheInstantiate); XLA_FFI_DECLARE_HANDLER_SYMBOL(kXlaFfiPythonGpuCallback); } // namespace JAX_GPU_NAMESPACE diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index 522aa8da0145..258556be8b1e 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -582,14 +582,10 @@ cc_library( "@xla//xla:shape_util", "@xla//xla/ffi:ffi_api", "@xla//xla/ffi/api:ffi", - "@xla//xla/pjrt:exceptions", "@xla//xla/pjrt:host_callback", "@xla//xla/pjrt:transpose", - "@xla//xla/python:callback", "@xla//xla/python:nb_numpy", "@xla//xla/python:types", - "@xla//xla/service:custom_call_status", - "@xla//xla/service:custom_call_target_registry", "@xla//xla/service:platform_util", ], ) diff --git a/jaxlib/rocm/rocm_plugin_extension.cc b/jaxlib/rocm/rocm_plugin_extension.cc index 642467a9afef..2ba5d98ae668 100644 --- a/jaxlib/rocm/rocm_plugin_extension.cc +++ b/jaxlib/rocm/rocm_plugin_extension.cc @@ -66,12 +66,6 @@ std::string ToString(hipError_t result) { } } -nb::dict Registrations() { - nb::dict dict; - dict["xla_python_gpu_callback"] = - jax::EncapsulateFunction(jax::hip::XlaPythonGpuCallback); - return dict; -} nb::dict FfiRegistrations() { nb::dict dict; nb::dict gpu_callback_dict; @@ -87,7 +81,6 @@ nb::dict FfiRegistrations() { NB_MODULE(rocm_plugin_extension, m) { BuildGpuPluginExtension(m); - m.def("registrations", &Registrations); m.def("ffi_registrations", &FfiRegistrations); m.def( From c8ccd7570aa50dd67c80350f29477c4a44992897 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 25 Mar 2025 11:47:35 -0700 Subject: [PATCH 153/483] Add functionality that let us do a "jax" only release Introduces a new `download-jax-only-from-gcs` variable to the workflow configs. When set to 1, the test workflows will only download and install the `jax` wheel. Other artifacts such as the latest releases of `jaxlib` and the CUDA plugin dependencies will be downloaded and installed from PyPI. PiperOrigin-RevId: 740430538 --- .github/workflows/pytest_cpu.yml | 19 ++++++++++++-- .github/workflows/pytest_cuda.yml | 26 +++++++++++++++---- .github/workflows/pytest_tpu.yml | 11 +++++++- .../workflows/wheel_tests_nightly_release.yml | 8 ++++++ ci/utilities/install_wheels_locally.sh | 5 ++++ 5 files changed, 61 insertions(+), 8 deletions(-) diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index 137f49c6d8c7..c952ef9ee1a6 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -29,6 +29,11 @@ on: type: string required: true default: "0" + download-jax-only-from-gcs: + description: "Whether to download only the jax wheel from GCS (e.g for testing a jax only release)" + required: false + default: '0' + type: string gcs_download_uri: description: "GCS location prefix from where the artifacts should be downloaded" required: true @@ -92,7 +97,12 @@ jobs: run: | mkdir -p $(pwd)/dist gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ - gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ + + if [[ "${{ inputs.download-jax-only-from-gcs }}" == "1" ]]; then + echo "JAX only release. Only downloading the jax wheel from the release bucket." + else + gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ + fi - name: Download wheels from GCS (Windows runs) id: download-wheel-artifacts-w # Set continue-on-error to true to prevent actions from failing the workflow if this step @@ -106,7 +116,12 @@ jobs: @REM Use `call` so that we can run sequential gsutil commands on Windows @REM See https://github.com/GoogleCloudPlatform/gsutil/issues/233#issuecomment-196150652 call gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl dist/ - call gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*%PYTHON_MAJOR_MINOR%*%OS%*%ARCH%*.whl" dist/ + + if "${{ inputs.download-jax-only-from-gcs }}"=="1" ( + echo "JAX only release. Only downloading the jax wheel from the release bucket." + ) else ( + call gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*%PYTHON_MAJOR_MINOR%*%OS%*%ARCH%*.whl" dist/ + ) - name: Skip the test run if the wheel artifacts were not downloaded successfully if: steps.download-wheel-artifacts-nw.outcome == 'failure' || steps.download-wheel-artifacts-w.outcome == 'failure' run: | diff --git a/.github/workflows/pytest_cuda.yml b/.github/workflows/pytest_cuda.yml index ae74da53edcb..b3d1b15a0052 100644 --- a/.github/workflows/pytest_cuda.yml +++ b/.github/workflows/pytest_cuda.yml @@ -34,6 +34,11 @@ on: type: string required: true default: "0" + download-jax-only-from-gcs: + description: "Whether to download only the jax wheel from GCS (e.g for testing a jax only release)" + required: false + default: '0' + type: string gcs_download_uri: description: "GCS location prefix from where the artifacts should be downloaded" required: true @@ -88,11 +93,22 @@ jobs: # informative error message. continue-on-error: true run: | - mkdir -p $(pwd)/dist && - gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ && - gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && - gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && - gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/ + mkdir -p $(pwd)/dist + gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ + + # Do not download the jaxlib and CUDA plugin artifacts if we are testing a jax only + # release. + if [[ "${{ inputs.download-jax-only-from-gcs }}" == "1" ]]; then + echo "JAX only release. Only downloading the jax wheel from the release bucket." + + # Set the env var to install the CUDA plugin and PJRT packages from PyPI. jaxlib is + # required dependency of jax so that gets installed automatically. + echo "JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI=jax_cuda_pypi">> $GITHUB_ENV + else + gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && + gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && + gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/ + fi - name: Skip the test run if the wheel artifacts were not downloaded successfully if: steps.download-wheel-artifacts.outcome == 'failure' run: | diff --git a/.github/workflows/pytest_tpu.yml b/.github/workflows/pytest_tpu.yml index a105a2feb347..0b56635a8aac 100644 --- a/.github/workflows/pytest_tpu.yml +++ b/.github/workflows/pytest_tpu.yml @@ -54,6 +54,11 @@ on: # - "pypi_latest": Use the latest libtpu wheel from PyPI. # - "oldest_supported_libtpu": Use the oldest supported libtpu wheel. default: "nightly" + download-jax-only-from-gcs: + description: "Whether to download only the jax wheel from GCS (e.g for testing a jax only release)" + required: false + default: '0' + type: string gcs_download_uri: description: "GCS location prefix from where the artifacts should be downloaded" required: true @@ -110,7 +115,11 @@ jobs: run: | mkdir -p $(pwd)/dist gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ - gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ + if [[ "${{ inputs.download-jax-only-from-gcs }}" == "1" ]]; then + echo "JAX only release. Only downloading the jax wheel from the release bucket." + else + gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ + fi - name: Skip the test run if the wheel artifacts were not downloaded successfully if: steps.download-wheel-artifacts.outcome == 'failure' run: | diff --git a/.github/workflows/wheel_tests_nightly_release.yml b/.github/workflows/wheel_tests_nightly_release.yml index adb678be9d9d..9cd48c925cf3 100644 --- a/.github/workflows/wheel_tests_nightly_release.yml +++ b/.github/workflows/wheel_tests_nightly_release.yml @@ -17,6 +17,11 @@ on: required: true default: 'gs://jax-nightly-release-transient/nightly/latest' type: string + download-jax-only-from-gcs: + description: "Whether to download only the jax wheel from GCS (e.g for testing a jax only release)" + required: true + default: '0' + type: string concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} @@ -41,6 +46,7 @@ jobs: runner: ${{ matrix.runner }} python: ${{ matrix.python }} enable-x64: ${{ matrix.enable-x64 }} + download-jax-only-from-gcs: ${{inputs.download-jax-only-from-gcs}} gcs_download_uri: ${{inputs.gcs_download_uri}} run-pytest-cuda: @@ -60,6 +66,7 @@ jobs: python: ${{ matrix.python }} cuda: ${{ matrix.cuda }} enable-x64: ${{ matrix.enable-x64 }} + download-jax-only-from-gcs: ${{inputs.download-jax-only-from-gcs}} gcs_download_uri: ${{inputs.gcs_download_uri}} run-pytest-tpu: @@ -98,4 +105,5 @@ jobs: python: ${{ matrix.python }} run-full-tpu-test-suite: "1" libtpu-version-type: ${{ matrix.libtpu-version-type }} + download-jax-only-from-gcs: ${{inputs.download-jax-only-from-gcs}} gcs_download_uri: ${{inputs.gcs_download_uri}} \ No newline at end of file diff --git a/ci/utilities/install_wheels_locally.sh b/ci/utilities/install_wheels_locally.sh index 64f88765bb75..53f070d1e0e6 100644 --- a/ci/utilities/install_wheels_locally.sh +++ b/ci/utilities/install_wheels_locally.sh @@ -26,6 +26,11 @@ for i in "${!WHEELS[@]}"; do # Append [tpu] to the jax wheel name to download the latest libtpu wheel # from PyPI. WHEELS[$i]="${WHEELS[$i]}[tpu]" + elif [[ "$JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI" == "jax_cuda_pypi" ]]; then + # Append [cuda12-local] to the jax wheel name to download the latest + # release of JAX's CUDA plugin and PJRT packages from PyPI. This is used + # when running CUDA tests for a "jax" only release. + WHEELS[$i]="${WHEELS[$i]}[cuda12-local]" fi fi done From bda37e322289c4a903e2aa490d4f376a2d370fcf Mon Sep 17 00:00:00 2001 From: Michael Whittaker Date: Tue, 25 Mar 2025 13:28:12 -0700 Subject: [PATCH 154/483] Increased sharding for `lax_scipy_spectral_dac_test_cpu_shardy`. PiperOrigin-RevId: 740464973 --- tests/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/BUILD b/tests/BUILD index d706f08b8092..2e03f331744c 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -607,7 +607,7 @@ jax_multiplatform_test( name = "lax_scipy_spectral_dac_test", srcs = ["lax_scipy_spectral_dac_test.py"], shard_count = { - "cpu": 10, + "cpu": 20, "gpu": 10, "tpu": 10, }, From 8c44b277bebf1cf801e8cc3d91890ac0f270a880 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Tue, 25 Mar 2025 13:33:13 -0700 Subject: [PATCH 155/483] [Mosaic GPU] Add warpgroup lowering for `BarrierArrive` in Pallas. PiperOrigin-RevId: 740466565 --- jax/_src/pallas/mosaic_gpu/primitives.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index fe28766cfb96..8eafa0ac8e6d 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -489,6 +489,7 @@ def _barrier_arrive_abstract_eval(barrier, *args, **params): @lowering.register_lowering_rule(barrier_arrive_p, mgpu.ThreadSemantics.Lane) +@lowering.register_lowering_rule(barrier_arrive_p, mgpu.ThreadSemantics.Warpgroup) def _barrier_arrive_lowering( ctx: lowering.LoweringRuleContext, barrier, From 85150471e283cd9b5f167aa2345a1796ec1ae0d5 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 25 Mar 2025 13:45:54 -0700 Subject: [PATCH 156/483] Support __jax_array__ in jnp.full_like & co --- jax/_src/numpy/array_creation.py | 16 +++++++++--- tests/array_extensibility_test.py | 41 +++++++++++++++++++++++++++---- 2 files changed, 49 insertions(+), 8 deletions(-) diff --git a/jax/_src/numpy/array_creation.py b/jax/_src/numpy/array_creation.py index 67418e7322c9..a0495986fcd1 100644 --- a/jax/_src/numpy/array_creation.py +++ b/jax/_src/numpy/array_creation.py @@ -244,6 +244,8 @@ def zeros_like(a: ArrayLike | DuckTypedArray, [0, 0, 0]], dtype=int32) """ if not (hasattr(a, 'dtype') and hasattr(a, 'shape')): # support duck typing + if hasattr(a, '__jax_array__'): + a = a.__jax_array__() util.check_arraylike("zeros_like", a) dtypes.check_user_dtype_supported(dtype, "zeros_like") if shape is not None: @@ -287,6 +289,8 @@ def ones_like(a: ArrayLike | DuckTypedArray, [1, 1, 1]], dtype=int32) """ if not (hasattr(a, 'dtype') and hasattr(a, 'shape')): # support duck typing + if hasattr(a, '__jax_array__'): + a = a.__jax_array__() util.check_arraylike("ones_like", a) dtypes.check_user_dtype_supported(dtype, "ones_like") if shape is not None: @@ -332,9 +336,13 @@ def empty_like(prototype: ArrayLike | DuckTypedArray, [0, 0, 0]], dtype=int32) """ if not (hasattr(prototype, 'dtype') and hasattr(prototype, 'shape')): # support duck typing - util.check_arraylike("empty_like", prototype) - dtypes.check_user_dtype_supported(dtype, "empty_like") - return zeros_like(prototype, dtype=dtype, shape=shape, device=device) + if hasattr(prototype, '__jax_array__'): + prototype = prototype.__jax_array__() + util.check_arraylike("ones_like", prototype) + dtypes.check_user_dtype_supported(dtype, "ones_like") + if shape is not None: + shape = canonicalize_shape(shape) + return lax.full_like(prototype, 0, dtype, shape, sharding=util.normalize_device_to_sharding(device)) @export @@ -382,6 +390,8 @@ def full_like(a: ArrayLike | DuckTypedArray, util.check_arraylike("full_like", 0, fill_value) else: util.check_arraylike("full_like", a, fill_value) + if hasattr(a, '__jax_array__'): + a = a.__jax_array__() dtypes.check_user_dtype_supported(dtype, "full_like") if shape is not None: shape = canonicalize_shape(shape) diff --git a/tests/array_extensibility_test.py b/tests/array_extensibility_test.py index 45c83f7967ce..3e84f6668b8d 100644 --- a/tests/array_extensibility_test.py +++ b/tests/array_extensibility_test.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import functools +from typing import Any, Callable, NamedTuple from absl.testing import absltest from absl.testing import parameterized -from typing import Any, Callable, NamedTuple +import numpy as np import jax import jax.numpy as jnp @@ -38,6 +40,15 @@ def __jax_array__(self) -> jax.Array: return jnp.asarray(self.x) +class DuckTypedArrayWithErroringJaxArray: + """Duck-typed array that provides a __jax_array__ method which fails.""" + shape = (2, 3) + dtype = np.dtype('float32') + + def __jax_array__(self): + raise ValueError("jax array was called.") + + class NumPyAPI(NamedTuple): fun: Callable[..., Any] args: list[jax.ShapeDtypeStruct] @@ -287,7 +298,7 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: NumPyAPI.sig(jnp.dsplit, Float[3, 5, 6], indices_or_sections=2), # NumPyAPI.sig(jnp.dstack, Float[3, 5]), NumPyAPI.sig(jnp.ediff1d, Float[5]), - # NumPyAPI.sig(jnp.empty_like, Float[5]), + NumPyAPI.sig(jnp.empty_like, Float[5]), NumPyAPI.sig(jnp.equal, Float[5], Float[5]), NumPyAPI.sig(jnp.exp, Float[5]), NumPyAPI.sig(jnp.exp2, Float[5]), @@ -312,7 +323,7 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: NumPyAPI.sig(jnp.fmin, Float[5], Float[5]), NumPyAPI.sig(jnp.fmod, Float[5], Float[5]), NumPyAPI.sig(jnp.frexp, Float[5]), - # NumPyAPI.sig(jnp.full_like, Float[5], Float[()]), + NumPyAPI.sig(jnp.full_like, Float[5], Float[()]), NumPyAPI.sig(jnp.gcd, Int[5], Int[5]), NumPyAPI.sig(jnp.greater, Float[5], Float[5]), NumPyAPI.sig(jnp.greater_equal, Float[5], Float[5]), @@ -393,7 +404,7 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: NumPyAPI.sig(jnp.nextafter, Float[5], Float[5]), NumPyAPI.sig(jnp.nonzero, Float[5]), NumPyAPI.sig(jnp.not_equal, Float[5], Float[5]), - # NumPyAPI.sig(jnp.ones_like, Float[5]), + NumPyAPI.sig(jnp.ones_like, Float[5]), NumPyAPI.sig(jnp.outer, Float[5], Float[5]), NumPyAPI.sig(jnp.packbits, Int[5]), # NumPyAPI.sig(jnp.pad, Float[5], pad_width=2), @@ -493,7 +504,7 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: NumPyAPI.sig(jnp.vsplit, Float[6], indices_or_sections=2), NumPyAPI.sig(jnp.vstack, [Float[5], Float[2, 5]]), NumPyAPI.sig(jnp.where, Bool[5], Float[5], Float[5]), - # NumPyAPI.sig(jnp.zeros_like, Float[5]), + NumPyAPI.sig(jnp.zeros_like, Float[5]), ] @@ -511,6 +522,26 @@ def test_numpy_api_supports_jax_array(self, api): self.assertAllClose(wrapped, expected, atol=0, rtol=0) + @parameterized.named_parameters( + {'testcase_name': func.__name__, 'func': func} + for func in [jnp.zeros_like, jnp.ones_like, jnp.empty_like, jnp.full_like] + ) + def test_array_creation_from_duck_typed_array(self, func): + # Ensure that jnp.*_like prefers shape/dtype over __jax_array__ when + # both methods are available. + if func is jnp.full_like: + func = functools.partial(func, fill_value=2.0) + obj = DuckTypedArrayWithErroringJaxArray() + + # The test relies on this failing + with self.assertRaises(ValueError): + jnp.asarray(obj) + + result = func(obj) + self.assertIsInstance(result, jax.Array) + self.assertEqual(result.shape, obj.shape) + self.assertEqual(result.dtype, obj.dtype) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 679ea6370b54818cb1ec3449924addc02a413a0b Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 25 Mar 2025 14:07:30 -0700 Subject: [PATCH 157/483] [JAX] [XLA:Python] Migrate py_client to JAX. PiperOrigin-RevId: 740478728 --- jaxlib/xla/BUILD | 190 ++- jaxlib/xla/dlpack.cc | 4 +- jaxlib/xla/dlpack.h | 2 +- jaxlib/xla/ifrt_proxy.cc | 162 ++ jaxlib/xla/ifrt_proxy.h | 31 + jaxlib/xla/jax_jit.cc | 4 +- jaxlib/xla/jax_jit.h | 4 +- jaxlib/xla/pjit.cc | 8 +- jaxlib/xla/pmap_lib.cc | 16 +- jaxlib/xla/py_array.cc | 2063 ++++++++++++++++++++++++++ jaxlib/xla/py_array.h | 360 +++++ jaxlib/xla/py_client.cc | 851 +++++++++++ jaxlib/xla/py_client.h | 270 ++++ jaxlib/xla/py_compile_only_client.cc | 131 ++ jaxlib/xla/py_compile_only_client.h | 45 + jaxlib/xla/py_device.cc | 350 +++++ jaxlib/xla/py_device.h | 82 + jaxlib/xla/py_device_list.cc | 472 ++++++ jaxlib/xla/py_device_list.h | 137 ++ jaxlib/xla/py_executable.cc | 463 ++++++ jaxlib/xla/py_executable.h | 263 ++++ jaxlib/xla/py_memory_space.cc | 102 ++ jaxlib/xla/py_memory_space.h | 64 + jaxlib/xla/py_program.cc | 291 ++++ jaxlib/xla/py_program.h | 27 + jaxlib/xla/py_socket_transfer.cc | 6 +- jaxlib/xla/py_values.cc | 745 ++++++++++ jaxlib/xla/py_values.h | 127 ++ jaxlib/xla/sharded_device_array.h | 217 +++ jaxlib/xla/sharding.cc | 346 +++++ jaxlib/xla/sharding.h | 242 +++ jaxlib/xla/to_ifrt_sharding.cc | 141 ++ jaxlib/xla/to_ifrt_sharding.h | 56 + jaxlib/xla/xla.cc | 20 +- jaxlib/xla/xla_compiler.cc | 2 +- 35 files changed, 8253 insertions(+), 41 deletions(-) create mode 100644 jaxlib/xla/ifrt_proxy.cc create mode 100644 jaxlib/xla/ifrt_proxy.h create mode 100644 jaxlib/xla/py_array.cc create mode 100644 jaxlib/xla/py_array.h create mode 100644 jaxlib/xla/py_client.cc create mode 100644 jaxlib/xla/py_client.h create mode 100644 jaxlib/xla/py_compile_only_client.cc create mode 100644 jaxlib/xla/py_compile_only_client.h create mode 100644 jaxlib/xla/py_device.cc create mode 100644 jaxlib/xla/py_device.h create mode 100644 jaxlib/xla/py_device_list.cc create mode 100644 jaxlib/xla/py_device_list.h create mode 100644 jaxlib/xla/py_executable.cc create mode 100644 jaxlib/xla/py_executable.h create mode 100644 jaxlib/xla/py_memory_space.cc create mode 100644 jaxlib/xla/py_memory_space.h create mode 100644 jaxlib/xla/py_program.cc create mode 100644 jaxlib/xla/py_program.h create mode 100644 jaxlib/xla/py_values.cc create mode 100644 jaxlib/xla/py_values.h create mode 100644 jaxlib/xla/sharded_device_array.h create mode 100644 jaxlib/xla/sharding.cc create mode 100644 jaxlib/xla/sharding.h create mode 100644 jaxlib/xla/to_ifrt_sharding.cc create mode 100644 jaxlib/xla/to_ifrt_sharding.h diff --git a/jaxlib/xla/BUILD b/jaxlib/xla/BUILD index e562cb7e84ea..979e659a309f 100644 --- a/jaxlib/xla/BUILD +++ b/jaxlib/xla/BUILD @@ -49,10 +49,12 @@ nanobind_extension( ":config", ":custom_call_sharding", ":dlpack", + ":ifrt_proxy", ":jax_jit", ":mlir", ":pjit", ":pmap_lib", + ":py_client", ":pytree", ":sdy", ":weakref_lru_cache", @@ -105,7 +107,6 @@ nanobind_extension( "@xla//xla/python:ops", "@xla//xla/python:pprof_profile_builder", "@xla//xla/python:profiler", - "@xla//xla/python:py_client", "@xla//xla/python:python_ref_manager", "@xla//xla/python:refine_polymorphic_shapes", "@xla//xla/python:traceback", @@ -114,7 +115,6 @@ nanobind_extension( "@xla//xla/python/ifrt", "@xla//xla/python/ifrt:plugin_program", "@xla//xla/python/ifrt:plugin_program_serdes", - "@xla//xla/python/ifrt_proxy/client:py_module", "@xla//xla/python/pjrt_ifrt", "@xla//xla/python/pjrt_ifrt:pjrt_attribute_map_util", "@xla//xla/python/pjrt_ifrt:xla_ifrt", @@ -211,6 +211,7 @@ cc_library( ], features = ["-use_header_modules"], deps = [ + ":py_client", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -229,7 +230,6 @@ cc_library( "@xla//xla/pjrt:pjrt_compiler", "@xla//xla/pjrt:pjrt_layout", "@xla//xla/python:nb_class_ptr", - "@xla//xla/python:py_client", "@xla//xla/python:python_ref_manager", "@xla//xla/python:traceback", "@xla//xla/python:types", @@ -242,6 +242,37 @@ cc_library( ], ) +cc_library( + name = "ifrt_proxy", + srcs = ["ifrt_proxy.cc"], + hdrs = ["ifrt_proxy.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":py_client", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/log:log_entry", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/time", + "@nanobind", + "@tsl//tsl/platform:env", + "@tsl//tsl/platform:statusor", + "@xla//xla/pjrt:status_casters", + "@xla//xla/python:nb_class_ptr", + "@xla//xla/python/ifrt", + "@xla//xla/python/ifrt:attribute_map", + "@xla//xla/python/ifrt_proxy/client:grpc_client", + "@xla//xla/python/ifrt_proxy/client:registry", + ], +) + cc_library( name = "jax_jit", srcs = ["jax_jit.cc"], @@ -253,6 +284,7 @@ cc_library( ], features = ["-use_header_modules"], deps = [ + ":py_client", ":pytree", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", @@ -272,7 +304,6 @@ cc_library( "@xla//xla/python:nb_absl_inlined_vector", "@xla//xla/python:nb_absl_span", "@xla//xla/python:nb_helpers", - "@xla//xla/python:py_client", "@xla//xla/python:python_ref_manager", "@xla//xla/python:types", "@xla//xla/tsl/platform:logging", @@ -331,6 +362,7 @@ cc_library( deps = [ ":config", ":jax_jit", + ":py_client", ":pytree", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", @@ -354,7 +386,6 @@ cc_library( "@xla//xla/python:nb_class_ptr", "@xla//xla/python:nb_helpers", "@xla//xla/python:nb_numpy", - "@xla//xla/python:py_client", "@xla//xla/python:python_ref_manager", "@xla//xla/python:traceback", "@xla//xla/python/ifrt", @@ -379,6 +410,7 @@ cc_library( deps = [ ":config", ":jax_jit", + ":py_client", ":pytree", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", @@ -401,7 +433,6 @@ cc_library( "@xla//xla/python:nb_class_ptr", "@xla//xla/python:nb_helpers", "@xla//xla/python:nb_numpy", - "@xla//xla/python:py_client", "@xla//xla/python:python_ref_manager", "@xla//xla/python:traceback", "@xla//xla/python:types", @@ -414,6 +445,149 @@ cc_library( ], ) +cc_library( + name = "py_client", + srcs = [ + "py_array.cc", + "py_client.cc", + "py_compile_only_client.cc", + "py_device.cc", + "py_device_list.cc", + "py_executable.cc", + "py_memory_space.cc", + "py_program.cc", + "py_values.cc", + "sharding.cc", + "to_ifrt_sharding.cc", + ], + hdrs = [ + "py_array.h", + "py_client.h", + "py_compile_only_client.h", + "py_device.h", + "py_device_list.h", + "py_executable.h", + "py_memory_space.h", + "py_program.h", + "py_values.h", + "sharded_device_array.h", + "sharding.h", + "to_ifrt_sharding.h", + ], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + visibility = jax_visibility("jaxlib/xla/py_client"), + deps = [ + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@com_google_absl//absl/types:variant", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@nanobind", + "@shardy//shardy/dialect/sdy/ir:dialect", + "@tsl//tsl/platform:casts", + "@tsl//tsl/platform:fingerprint", + "@tsl//tsl/platform:ml_dtypes", + "@tsl//tsl/profiler/lib:profiler_session", + "@tsl//tsl/profiler/lib:traceme", + "@tsl//tsl/profiler/protobuf:profiled_instructions_proto_cc", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla:comparison_util", + "@xla//xla:literal", + "@xla//xla:shape_util", + "@xla//xla:status_macros", + "@xla//xla:types", + "@xla//xla:util", + "@xla//xla:xla_data_proto_cc", + "@xla//xla/hlo/builder:xla_builder", + "@xla//xla/hlo/builder:xla_computation", + "@xla//xla/hlo/builder/lib:arithmetic", + "@xla//xla/hlo/ir:hlo", + "@xla//xla/pjrt:exceptions", + "@xla//xla/pjrt:host_callback", + "@xla//xla/pjrt:host_memory_spaces", + "@xla//xla/pjrt:lru_cache", + "@xla//xla/pjrt:mlir_to_hlo", + "@xla//xla/pjrt:pjrt_client", + "@xla//xla/pjrt:pjrt_common", + "@xla//xla/pjrt:pjrt_compiler", + "@xla//xla/pjrt:pjrt_device_description", + "@xla//xla/pjrt:pjrt_executable", + "@xla//xla/pjrt:pjrt_future", + "@xla//xla/pjrt:pjrt_layout", + "@xla//xla/pjrt:status_casters", + "@xla//xla/pjrt:transpose", + "@xla//xla/pjrt/distributed", + "@xla//xla/pjrt/distributed:client", + "@xla//xla/python:aggregate_profile", + "@xla//xla/python:callback", + "@xla//xla/python:guard_lib", + "@xla//xla/python:nb_absl_span", + "@xla//xla/python:nb_class_ptr", + "@xla//xla/python:nb_helpers", + "@xla//xla/python:nb_numpy", + "@xla//xla/python:pprof_profile_builder", + "@xla//xla/python:py_client_cpu", + "@xla//xla/python:py_host_callback", + "@xla//xla/python:py_host_callback_proto_cc", + "@xla//xla/python:python_ref_manager", + "@xla//xla/python:traceback", + "@xla//xla/python:types", + "@xla//xla/python:util", + "@xla//xla/python:xplane_to_profile_instructions", + "@xla//xla/python/compile_only_ifrt:client", + "@xla//xla/python/ifrt", + "@xla//xla/python/ifrt:attribute_map", + "@xla//xla/python/ifrt:custom_call_program", + "@xla//xla/python/ifrt:plugin_program", + "@xla//xla/python/ifrt:plugin_program_serdes", + "@xla//xla/python/ifrt:user_context", + "@xla//xla/python/ifrt/hlo:hlo_program", + "@xla//xla/python/pjrt_ifrt", + "@xla//xla/python/pjrt_ifrt:pjrt_attribute_map_util", + "@xla//xla/python/pjrt_ifrt:pjrt_dtype", + "@xla//xla/python/pjrt_ifrt:xla_host_callback_proto_cc", + "@xla//xla/python/pjrt_ifrt:xla_ifrt", + "@xla//xla/service:computation_placer_hdr", + "@xla//xla/service:custom_call_status", + "@xla//xla/service:custom_call_target_registry", + "@xla//xla/service:platform_util", + "@xla//xla/service/spmd/shardy:constants", + "@xla//xla/service/spmd/shardy:utils", + "@xla//xla/service/spmd/shardy/sdy_round_trip:pipelines", + "@xla//xla/tsl/concurrency:ref_count", + "@xla//xla/tsl/framework:allocator", + "@xla//xla/tsl/framework/mlir:status_scoped_diagnostic_handler", + "@xla//xla/tsl/platform:env", + "@xla//xla/tsl/platform:errors", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/platform:status", + "@xla//xla/tsl/platform:statusor", + "@xla//xla/tsl/python/lib/core:numpy", + ], +) + cc_library( name = "py_socket_transfer", srcs = ["py_socket_transfer.cc"], @@ -424,6 +598,7 @@ cc_library( ], features = ["-use_header_modules"], deps = [ + ":py_client", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -435,7 +610,6 @@ cc_library( "@xla//xla/pjrt:status_casters", "@xla//xla/python:nb_class_ptr", "@xla//xla/python:nb_numpy", - "@xla//xla/python:py_client", "@xla//xla/python:traceback", "@xla//xla/python:types", "@xla//xla/python/ifrt", @@ -558,6 +732,7 @@ cc_library( features = ["-use_header_modules"], deps = [ ":dlpack", + ":py_client", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/hash", @@ -594,7 +769,6 @@ cc_library( "@xla//xla/python:nb_absl_span", "@xla//xla/python:nb_helpers", "@xla//xla/python:nb_numpy", - "@xla//xla/python:py_client", "@xla//xla/python:types", "@xla//xla/service:call_inliner", "@xla//xla/service:computation_placer", diff --git a/jaxlib/xla/dlpack.cc b/jaxlib/xla/dlpack.cc index f6605a36f02b..8b29e136f296 100644 --- a/jaxlib/xla/dlpack.cc +++ b/jaxlib/xla/dlpack.cc @@ -34,6 +34,8 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "nanobind/nanobind.h" #include "nanobind/ndarray.h" +#include "jaxlib/xla/py_array.h" +#include "jaxlib/xla/py_client.h" #include "xla/layout.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/pjrt_client.h" @@ -46,8 +48,6 @@ limitations under the License. #include "xla/python/pjrt_ifrt/pjrt_array.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" #include "xla/python/pjrt_ifrt/pjrt_device.h" -#include "xla/python/py_array.h" -#include "xla/python/py_client.h" #include "xla/python/python_ref_manager.h" #include "xla/python/traceback.h" #include "xla/python/types.h" diff --git a/jaxlib/xla/dlpack.h b/jaxlib/xla/dlpack.h index d0079b1d4914..e73c477b1495 100644 --- a/jaxlib/xla/dlpack.h +++ b/jaxlib/xla/dlpack.h @@ -22,9 +22,9 @@ limitations under the License. #include "absl/status/statusor.h" #include "nanobind/nanobind.h" #include "nanobind/ndarray.h" +#include "jaxlib/xla/py_client.h" #include "xla/python/ifrt/device.h" #include "xla/python/nb_class_ptr.h" -#include "xla/python/py_client.h" namespace xla { diff --git a/jaxlib/xla/ifrt_proxy.cc b/jaxlib/xla/ifrt_proxy.cc new file mode 100644 index 000000000000..e03fde194d49 --- /dev/null +++ b/jaxlib/xla/ifrt_proxy.cc @@ -0,0 +1,162 @@ +// Copyright 2023 The JAX Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "jaxlib/xla/ifrt_proxy.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/log/log_entry.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/function.h" // IWYU pragma: keep +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/unordered_map.h" // IWYU pragma: keep +#include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "jaxlib/xla/py_client.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/attribute_map.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt_proxy/client/registry.h" +#include "xla/python/nb_class_ptr.h" +#include "tsl/platform/env.h" +#include "tsl/platform/statusor.h" + +namespace nb = ::nanobind; + +namespace xla { +namespace ifrt { +namespace proxy { +namespace { + +struct PyClientConnectionOptions { + std::optional> on_disconnect; + std::optional> on_connection_update; + std::optional connection_timeout_in_seconds; + std::optional< + std::unordered_map>> + initialization_data; +}; + +absl::StatusOr> GetClient( + std::string proxy_server_address, + const PyClientConnectionOptions& py_options) { + DCHECK(PyGILState_Check()); + std::unique_ptr client; + + ClientConnectionOptions options; + if (py_options.on_disconnect) { + // While it is possible to pass around `py_options.on_disconnect` without + // wrapping it via a shared_ptr, copying the `py_options.on_disconnect` + // object can internally attempt to acquire the GIL [1], and can thus block + // or even deadlock. A unique_ptr or `absl::AnyInvocable` is not sufficient + // because downstream code can make copies. Reference: + // https://pybind11.readthedocs.io/en/stable/advanced/misc.html#common-sources-of-global-interpreter-lock-errors + auto py_on_disconnect = std::make_shared>( + std::move(*py_options.on_disconnect)); + + options.on_disconnect = + [on_disconnect = std::move(py_on_disconnect)](absl::Status s) mutable { + LOG(WARNING) << "Connection to server failed, calling supplied " + << "`on_disconnect` function: " << s; + tsl::Env::Default()->SchedClosure([s, on_disconnect]() mutable { + nb::gil_scoped_acquire gil_acquire; + (*on_disconnect)(s.ToString()); + on_disconnect = nullptr; + }); + }; + } + + if (py_options.on_connection_update) { + auto fn = std::make_shared>( + std::move(*py_options.on_connection_update)); + options.on_connection_update = [fn](absl::string_view log_line) -> void { + tsl::Env::Default()->SchedClosure([fn, str = std::string(log_line)] { + nb::gil_scoped_acquire gil_acquire; + (*fn)(std::string(str)); + }); + }; + } + + if (py_options.connection_timeout_in_seconds.has_value()) { + options.connection_timeout = + absl::Seconds(*py_options.connection_timeout_in_seconds); + } + + if (py_options.initialization_data.has_value()) { + AttributeMap::Map attribute_map; + for (const auto& [key, py_value] : *py_options.initialization_data) { + if (std::holds_alternative(py_value)) { + nb::bytes value = std::get(py_value); + attribute_map.insert({key, AttributeMap::StringValue(std::string( + value.c_str(), value.size()))}); + } else if (std::holds_alternative(py_value)) { + attribute_map.insert( + {key, AttributeMap::BoolValue(std::get(py_value))}); + } else { + CHECK(std::holds_alternative(py_value)); + attribute_map.insert( + {key, AttributeMap::Int64Value(std::get(py_value))}); + } + } + options.initialization_data = AttributeMap(std::move(attribute_map)); + } + + { + nb::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN(client, CreateClient(proxy_server_address, options)); + } + + // Constructing `xla::PyClient` requires GIL as it may dec-ref Python objects. + return xla::PyClient::Make(std::move(client)); +} + +} // namespace + +void BuildIfrtProxySubmodule(nb::module_& m) { + nb::module_ sub_module = m.def_submodule("ifrt_proxy", "IFRT proxy"); + + nb::class_(sub_module, "ClientConnectionOptions") + .def(nb::init<>()) + .def_rw("on_disconnect", &PyClientConnectionOptions::on_disconnect, + nb::arg().none()) + .def_rw("on_connection_update", + &PyClientConnectionOptions::on_connection_update, + nb::arg().none()) + .def_rw("connection_timeout_in_seconds", + &PyClientConnectionOptions::connection_timeout_in_seconds, + nb::arg().none()) + .def_rw("initialization_data", + &PyClientConnectionOptions::initialization_data, + nb::arg().none()); + + sub_module.def("get_client", xla::ValueOrThrowWrapper(GetClient), + nb::arg("proxy_server_address"), nb::arg("options")); +} + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/jaxlib/xla/ifrt_proxy.h b/jaxlib/xla/ifrt_proxy.h new file mode 100644 index 000000000000..a8fcb9e676ff --- /dev/null +++ b/jaxlib/xla/ifrt_proxy.h @@ -0,0 +1,31 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_IFRT_PROXY_CLIENT_PY_MODULE_H_ +#define JAXLIB_XLA_IFRT_PROXY_CLIENT_PY_MODULE_H_ + +#include "nanobind/nanobind.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +void BuildIfrtProxySubmodule(nanobind::module_& m); + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // JAXLIB_XLA_IFRT_PROXY_CLIENT_PY_MODULE_H_ diff --git a/jaxlib/xla/jax_jit.cc b/jaxlib/xla/jax_jit.cc index 23abe9a8404a..4645c59c7147 100644 --- a/jaxlib/xla/jax_jit.cc +++ b/jaxlib/xla/jax_jit.cc @@ -53,14 +53,14 @@ limitations under the License. #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/py_values.h" #include "jaxlib/xla/pytree.h" +#include "jaxlib/xla/sharding.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_layout.h" #include "xla/pjrt/status_casters.h" #include "xla/python/nb_absl_inlined_vector.h" // IWYU pragma: keep #include "xla/python/nb_absl_span.h" // IWYU pragma: keep -#include "xla/python/py_values.h" -#include "xla/python/sharding.h" #include "xla/python/types.h" #include "xla/tsl/platform/logging.h" #include "tsl/profiler/lib/traceme.h" diff --git a/jaxlib/xla/jax_jit.h b/jaxlib/xla/jax_jit.h index 9e6f8e34f1e9..e2c186c5d3ff 100644 --- a/jaxlib/xla/jax_jit.h +++ b/jaxlib/xla/jax_jit.h @@ -35,12 +35,12 @@ limitations under the License. #include "absl/types/span.h" #include "nanobind/nanobind.h" #include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "jaxlib/xla/py_values.h" #include "jaxlib/xla/pytree.h" +#include "jaxlib/xla/sharding.h" #include "xla/pjrt/pjrt_client.h" #include "xla/python/nb_helpers.h" -#include "xla/python/py_values.h" #include "xla/python/python_ref_manager.h" -#include "xla/python/sharding.h" #include "xla/tsl/platform/logging.h" namespace jax { diff --git a/jaxlib/xla/pjit.cc b/jaxlib/xla/pjit.cc index a13d3f0b52e3..6681f72b7b49 100644 --- a/jaxlib/xla/pjit.cc +++ b/jaxlib/xla/pjit.cc @@ -52,7 +52,11 @@ limitations under the License. #include "nanobind/stl/vector.h" // IWYU pragma: keep #include "jaxlib/xla/config.h" #include "jaxlib/xla/jax_jit.h" +#include "jaxlib/xla/py_array.h" +#include "jaxlib/xla/py_executable.h" +#include "jaxlib/xla/py_values.h" #include "jaxlib/xla/pytree.h" +#include "jaxlib/xla/sharding.h" #include "xla/layout.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/lru_cache.h" @@ -66,11 +70,7 @@ limitations under the License. #include "xla/python/nb_class_ptr.h" #include "xla/python/nb_helpers.h" #include "xla/python/nb_numpy.h" -#include "xla/python/py_array.h" -#include "xla/python/py_executable.h" -#include "xla/python/py_values.h" #include "xla/python/python_ref_manager.h" -#include "xla/python/sharding.h" #include "xla/python/traceback.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/tsl/platform/env.h" diff --git a/jaxlib/xla/pmap_lib.cc b/jaxlib/xla/pmap_lib.cc index c6849f8c25fd..3dbd736076db 100644 --- a/jaxlib/xla/pmap_lib.cc +++ b/jaxlib/xla/pmap_lib.cc @@ -46,7 +46,15 @@ limitations under the License. #include "nanobind/stl/vector.h" // IWYU pragma: keep #include "jaxlib/xla/config.h" #include "jaxlib/xla/jax_jit.h" +#include "jaxlib/xla/py_array.h" +#include "jaxlib/xla/py_client.h" +#include "jaxlib/xla/py_device.h" +#include "jaxlib/xla/py_executable.h" +#include "jaxlib/xla/py_values.h" #include "jaxlib/xla/pytree.h" +#include "jaxlib/xla/sharded_device_array.h" +#include "jaxlib/xla/sharding.h" +#include "jaxlib/xla/to_ifrt_sharding.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/status_casters.h" #include "xla/python/ifrt/array.h" @@ -59,15 +67,7 @@ limitations under the License. #include "xla/python/nb_class_ptr.h" #include "xla/python/nb_helpers.h" #include "xla/python/nb_numpy.h" -#include "xla/python/py_array.h" -#include "xla/python/py_client.h" -#include "xla/python/py_device.h" -#include "xla/python/py_executable.h" -#include "xla/python/py_values.h" #include "xla/python/python_ref_manager.h" -#include "xla/python/sharded_device_array.h" -#include "xla/python/sharding.h" -#include "xla/python/to_ifrt_sharding.h" #include "xla/python/traceback.h" #include "xla/python/types.h" #include "xla/status_macros.h" diff --git a/jaxlib/xla/py_array.cc b/jaxlib/xla/py_array.cc new file mode 100644 index 000000000000..305582a987f7 --- /dev/null +++ b/jaxlib/xla/py_array.cc @@ -0,0 +1,2063 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/py_array.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // NOLINT +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/casts.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/Support/Casting.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/pair.h" // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/py_client.h" +#include "jaxlib/xla/py_device.h" +#include "jaxlib/xla/py_device_list.h" +#include "jaxlib/xla/py_values.h" +#include "jaxlib/xla/sharding.h" +#include "jaxlib/xla/to_ifrt_sharding.h" +#include "xla/layout.h" +#include "xla/layout_util.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/lru_cache.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_compiler.h" +#include "xla/pjrt/pjrt_future.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/pjrt/status_casters.h" +#include "xla/primitive_util.h" +#include "xla/python/guard_lib.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/array_spec.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/remap_plan.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/nb_absl_span.h" // IWYU pragma: keep +#include "xla/python/nb_class_ptr.h" +#include "xla/python/nb_helpers.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "xla/python/pjrt_ifrt/pjrt_device.h" +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" +#include "xla/python/python_ref_manager.h" +#include "xla/python/traceback.h" +#include "xla/python/types.h" +#include "xla/python/util.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status_macros.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/python/lib/core/numpy.h" // IWYU pragma: keep +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace { + +namespace nb = nanobind; + +PjRtBuffer* GetPjrtBuffer(ifrt::Array* ifrt_array) { + auto* arr = llvm::dyn_cast_or_null(ifrt_array); + if (arr == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + return arr->pjrt_buffers().front().get(); +} + +absl::StatusOr XlaDynamicShape(ifrt::Array* ifrt_array, + std::optional& scratch) { + auto* pjrt_buffer = GetPjrtBuffer(ifrt_array); + + if (!scratch) { + absl::Span dims; + std::optional> logical_dims_storage; + if (pjrt_buffer->has_dynamic_dimensions()) { + { + nb::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN(std::vector logical_dims, + pjrt_buffer->logical_dimensions()); + logical_dims_storage.emplace(std::move(logical_dims)); + } + dims = *logical_dims_storage; + } else { + dims = pjrt_buffer->dimensions(); + } + Shape shape = ShapeUtil::MakeShape(pjrt_buffer->element_type(), dims); + // TODO(b/327524065): fix this + *shape.mutable_layout() = pjrt_buffer->layout()->xla_layout(); + scratch = std::move(shape); + } + return &scratch.value(); +} + +tsl::RCReference CreateIfRtArrayFromSingleDeviceShardedPyArrays( + nb_dtype dtype, absl::Span shape, + absl::Span py_arrays, const nb::object& sharding) { + const ifrt::MemoryKind dst_memory_kind = xla::GetMemoryKind(sharding); + + std::vector> ifrt_arrays; + ifrt_arrays.reserve(py_arrays.size()); + absl::InlinedVector devices; + devices.reserve(py_arrays.size()); + absl::flat_hash_set device_set; + device_set.reserve(py_arrays.size()); + std::vector shapes; + shapes.reserve(py_arrays.size()); + + auto sharding_device_list = xla::GetIfrtDeviceList(sharding); + if (!sharding_device_list.ok()) { + // TODO(hyeontaek): Return a absl::Status. + throw nb::value_error(sharding_device_list.status().ToString().c_str()); + } + ifrt::Device* device = sharding_device_list.value()->devices().front(); + + // TODO(hyeontaek): Canonicalize every `ifrt::MemoryKind` at creation time to + // skip canonicalization here once JAX begins to do it for JAX shardings. + const ifrt::MemoryKind canonical_dst_memory_kind = + ifrt::CanonicalizeMemoryKind(dst_memory_kind, device); + for (const auto& py_array : py_arrays) { + if (py_array.num_shards() != 1) { + throw nb::value_error( + absl::StrFormat( + "When making an array from single-device arrays the input arrays " + "must have one shard each. An argument array had %d shard(s).", + py_array.num_shards()) + .c_str()); + } + ifrt_arrays.push_back(tsl::FormRef(py_array.ifrt_array())); + ifrt::Device* const device = + ifrt_arrays.back()->sharding().devices()->devices().front(); + devices.push_back(device); + device_set.insert(device); + shapes.push_back(ifrt_arrays.back()->shape()); + if (canonical_dst_memory_kind != + ifrt::CanonicalizeMemoryKind( + ifrt_arrays.back()->sharding().memory_kind(), device)) { + throw nb::value_error( + absl::StrFormat( + "Memory kind mismatch with PjRtBuffers. Got sharding with " + "memory kind '%v' and a buffer with memory_kind '%v'", + dst_memory_kind, ifrt_arrays.back()->sharding().memory_kind()) + .c_str()); + } + } + ifrt::DeviceListRef device_list = device->client()->MakeDeviceList(devices); + if (device_set.size() != device_list->size()) { + throw nb::value_error( + absl::StrFormat( + "When making an array from single-device arrays, the input arrays " + "must be from distinct devices, but got %v", + *device_list) + .c_str()); + } + + auto ifrt_dtype = DtypeToIfRtDType(dtype); + if (!ifrt_dtype.ok()) { + // TODO(hyeontaek): Return a absl::Status. + throw nb::value_error(ifrt_dtype.status().ToString().c_str()); + } + + absl::StatusOr> ifrt_sharding = + sharding.type().is(jax::PmapSharding::type()) + ? xla::GetIfrtConcreteSharding(sharding, ifrt::Shape(shape), + std::move(shapes)) + : xla::GetIfrtHloSharding(sharding, ifrt::Shape(shape)); + if (!ifrt_sharding.ok()) { + // TODO(hyeontaek): Return a absl::Status. + throw nb::value_error(ifrt_sharding.status().ToString().c_str()); + } + // TODO(emilyaf): Always use `ifrt_dtype` once tokens are handled correctly. + ifrt::DType array_dtype = + ifrt_arrays.empty() ? ifrt_dtype.value() : ifrt_arrays[0]->dtype(); + absl::StatusOr> ifrt_array = + device->client()->AssembleArrayFromSingleDeviceArrays( + array_dtype, ifrt::Shape(shape), *std::move(ifrt_sharding), + absl::MakeSpan(ifrt_arrays), ifrt::ArrayCopySemantics::kReuseInput, + ifrt::SingleDeviceShardSemantics::kAddressableShards); + if (!ifrt_array.ok()) { + // TODO(hyeontaek): Return a absl::Status. + throw nb::value_error(ifrt_array.status().ToString().c_str()); + } + return *std::move(ifrt_array); +} + +struct PyArrayObject { + PyObject_HEAD; +#if PY_VERSION_HEX < 0x030C0000 + PyObject* weakrefs; + PyObject* dict; +#endif // PY_VERSION_HEX < 0x030B0000 + bool initialized; + alignas(PyArray::Storage) char array_storage[sizeof(PyArray::Storage)]; +}; +static_assert(std::is_standard_layout::value); + +PyArray::Storage* GetPyArrayStorageFromObject(PyArrayObject* py_array_object) { + return std::launder( + reinterpret_cast(py_array_object->array_storage)); +} + +extern "C" PyObject* PyArray_tp_new(PyTypeObject* type, PyObject*, PyObject*) { + PyObject* self = type->tp_alloc(type, 0); + auto* obj = reinterpret_cast(self); + obj->initialized = false; + return self; +} + +extern "C" void PyArray_tp_dealloc(PyObject* self) { + PyObject_GC_UnTrack(self); + PyTypeObject* tp = Py_TYPE(self); + auto* obj = reinterpret_cast(self); + + if (obj->initialized) { + GetPyArrayStorageFromObject(obj)->~PyArray_Storage(); + } + + PyObject_ClearWeakRefs(self); +#if PY_VERSION_HEX < 0x030C0000 + PyObject*& dict = *_PyObject_GetDictPtr(self); + Py_CLEAR(dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_ClearManagedDict(self); +#else + PyObject_ClearManagedDict(self); +#endif // PY_VERSION_HEX < 0x030C0000 + + tp->tp_free(self); + Py_DECREF(tp); +} + +// dynamic_attr: Allow the garbage collector to traverse the internal instance +// `__dict__`. +extern "C" int PyArray_tp_traverse(PyObject* self, visitproc visit, void* arg) { +#if PY_VERSION_HEX < 0x030C0000 + PyObject*& dict = *_PyObject_GetDictPtr(self); + Py_VISIT(dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_VisitManagedDict(self, visit, arg); +#else + PyObject_VisitManagedDict(self, visit, arg); +#endif // PY_VERSION_HEX < 0x030C0000 + // https://docs.python.org/3/c-api/typeobj.html#c.PyTypeObject.tp_traverse + Py_VISIT(Py_TYPE(self)); + return 0; +} + +// dynamic_attr: Allow the GC to clear the dictionary. +extern "C" int PyArray_tp_clear(PyObject* self) { + switch (auto guard_level = jax::GetGarbageCollectArrayGuard(); guard_level) { + case jax::GarbageCollectionGuardLevel::kAllow: + break; + case jax::GarbageCollectionGuardLevel::kLog: + case jax::GarbageCollectionGuardLevel::kFatal: { + auto* obj = reinterpret_cast(self); + std::string traceback_str; + if (obj->initialized) { + auto traceback = GetPyArrayStorageFromObject(obj)->traceback; + if (traceback.has_value()) { + traceback_str = traceback.value()->ToString(); + } + } + auto error_msg = absl::StrCat( + "`jax.Array` was deleted by the Python garbage collector " + "instead of reference counting. Break the reference cycle " + "that delays the deletion of this `jax.Array` to avoid hogging " + "memory. Traceback: \n", + traceback_str.empty() ? "not available" : traceback_str); + if (guard_level == jax::GarbageCollectionGuardLevel::kFatal) { + Py_FatalError(error_msg.c_str()); + } else { + PyErr_SetString(PyExc_RuntimeError, error_msg.c_str()); + PyErr_Print(); + PyErr_Clear(); + } + break; + } + } +#if PY_VERSION_HEX < 0x030C0000 + PyObject*& dict = *_PyObject_GetDictPtr(self); + Py_CLEAR(dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_ClearManagedDict(self); +#else + PyObject_ClearManagedDict(self); +#endif // PY_VERSION_HEX < 0x030C0000 + return 0; +} + +template +PyArray::Storage* Construct(PyArrayObject* self, Args&&... args) { + PyArray::Storage* out = + new (self->array_storage) PyArray::Storage(std::forward(args)...); + self->initialized = true; + return out; +} + +struct ShapedArrayCacheKey { + std::vector dims; + ifrt::DType dtype{ifrt::DType::kInvalid}; + bool weak_type; + + template + friend H AbslHashValue(H h, const ShapedArrayCacheKey& value) { + return H::combine(std::move(h), value.dims, value.dtype, value.weak_type); + } + bool operator==(const ShapedArrayCacheKey& other) const { + return dims == other.dims && dtype == other.dtype && + weak_type == other.weak_type; + } +}; + +// Constructing ShapedArrays has gotten slow. Cache it. +nb::object MakeShapedArrayCached(const ShapedArrayCacheKey& key) { + using CacheT = + LRUCache>>; + static nb::ft_mutex mu; + static auto* lru_list = new CacheT::LRUList(4096); + static auto* cache = new CacheT(lru_list); + + static const nb::object* shaped_array = []() -> nb::object* { + nb::object jax_core; + try { + jax_core = nb::module_::import_("jax.core"); + } catch (nb::python_error& e) { + return nullptr; + } + return new nb::object(jax_core.attr("ShapedArray")); + }(); + if (!shaped_array) { + return nb::none(); + } + + nb::ft_lock_guard lock(mu); + auto value = + cache->GetOrCreateIfAbsent(key, [](const ShapedArrayCacheKey& key) { + return std::make_shared>(); + }); + + if (!value->has_value()) { + nb_dtype dtype = + IfrtDtypeToDtypeWithTokenCanonicalization(key.dtype).value(); + nb::object aval = (*shaped_array)( + SpanToNbTuple(absl::Span( + key.dtype.kind() == ifrt::DType::kToken ? std::vector{0} + : key.dims)), + dtype, key.weak_type); + *value = aval; + return aval; + } + return **value; +} + +// Grouping key used by BatchedCopyToDeviceWithSharding. +// Defined outside of the function as required by templatized function +// `AbslHashValue`. +struct BatchedCopyToDeviceWithShardingKey { + ifrt::DeviceListRef src_devices; + ifrt::MemoryKind src_memory_kind; + ifrt::DeviceListRef dst_devices; + ifrt::MemoryKind dst_memory_kind; + ifrt::ArrayCopySemantics array_copy_semantics; + + bool operator==(const BatchedCopyToDeviceWithShardingKey& other) const { + return *src_devices == *other.src_devices && + src_memory_kind == other.src_memory_kind && + *dst_devices == *other.dst_devices && + dst_memory_kind == other.dst_memory_kind && + array_copy_semantics == other.array_copy_semantics; + } + + template + friend H AbslHashValue(H h, const BatchedCopyToDeviceWithShardingKey& key) { + return H::combine(std::move(h), key.src_devices, key.src_memory_kind, + key.dst_devices, key.dst_memory_kind, + key.array_copy_semantics); + } +}; + +} // namespace + +PyArray_Storage::PyArray_Storage( + nb::object aval, bool weak_type, xla::nb_dtype dtype, + std::vector shape, nb::object sharding, bool committed, + nb_class_ptr py_client, std::optional traceback, + tsl::RCReference ifrt_array, xla::PjRtFuture<> result_status) + : aval(std::move(aval)), + weak_type(weak_type), + dtype(std::move(dtype)), + shape(std::move(shape)), + sharding(std::move(sharding)), + committed(committed), + py_client(std::move(py_client)), + traceback(std::move(traceback)), + ifrt_array(std::move(ifrt_array)), + result_status(std::move(result_status)) { + static_assert(PyClient::kNumArraysShards < + std::numeric_limits::max()); + thread_id_bucket = std::hash()(std::this_thread::get_id()) % + PyClient::kNumArraysShards; + + PyClient::ArraysShard& shard = this->py_client->arrays_[thread_id_bucket]; + nanobind::ft_lock_guard lock(shard.mutex); + next = shard.arrays; + shard.arrays = this; + if (next) { + next->prev = this; + } + prev = nullptr; +} + +void PyInit_helper(PyArray self, nb::object aval, nb::object sharding, + absl::Span py_arrays, bool committed) { + auto dtype = nb::cast(aval.attr("dtype")); + auto shape = nb::cast>(aval.attr("shape")); + auto py_device_list = nb::cast( + sharding.attr("_internal_device_list")); + nb_class_ptr py_client = py_device_list->py_client(); + auto ifrt_array = CreateIfRtArrayFromSingleDeviceShardedPyArrays( + dtype, shape, py_arrays, sharding); + Construct(reinterpret_cast(self.ptr()), aval, + nb::cast(aval.attr("weak_type")), std::move(dtype), + std::move(shape), std::move(sharding), committed, py_client, + Traceback::Get(), std::move(ifrt_array), xla::PjRtFuture<>()); +} + +void PyArray::PyInit(PyArray self, nb::object aval, nb::object sharding, + absl::Span py_arrays, bool committed, + bool skip_checks) { + if (skip_checks) { + PyInit_helper(self, aval, sharding, py_arrays, committed); + } else { + nb::object rearranged_arrays = + self.CheckAndRearrange(py_arrays, sharding, aval); + auto rearranged_py_arrays = + nb::cast>(rearranged_arrays); + PyInit_helper(self, aval, sharding, rearranged_py_arrays, committed); + } +} + +PyArray PyArray::MakeFromSingleDeviceArray( + nb_class_ptr py_client, std::optional traceback, + tsl::RCReference ifrt_array, bool weak_type, bool committed, + xla::PjRtFuture<> result_status) { + if (!llvm::isa(ifrt_array->sharding())) { + throw XlaRuntimeError( + InvalidArgument("Constructing single device jax.Array from non-single " + "device ifrt array.")); + } + auto shape_span = ifrt_array->shape().dims(); + ShapedArrayCacheKey key; + key.dtype = ifrt_array->dtype(); + key.dims = key.dtype.kind() == ifrt::DType::kToken + ? std::vector{0} + : std::vector(shape_span.begin(), shape_span.end()); + key.weak_type = weak_type; + auto aval = MakeShapedArrayCached(key); + auto dtype = IfrtDtypeToDtypeWithTokenCanonicalization(key.dtype).value(); + const ifrt::MemoryKind memory_kind = ifrt_array->sharding().memory_kind(); + nb::object py_memory_kind = + (memory_kind.memory_kind().has_value()) + ? nb::object(nb::str(memory_kind.memory_kind()->data(), + memory_kind.memory_kind()->size())) + : nb::none(); + nb::object sharding = make_nb_class( + py_client, ifrt_array->sharding().devices(), std::move(py_memory_kind)); + return PyArray(std::move(aval), weak_type, dtype, std::move(key.dims), + std::move(sharding), std::move(py_client), + std::move(traceback), std::move(ifrt_array), committed, + /*skip_checks=*/true, std::move(result_status)); +} + +PyArray PyArray::MakeFromIfrtArrayAndSharding( + nb_class_ptr py_client, std::optional traceback, + tsl::RCReference ifrt_array, nb::object sharding, + bool weak_type, bool committed, bool skip_checks) { + auto shape_span = ifrt_array->shape().dims(); + ShapedArrayCacheKey key; + key.dtype = ifrt_array->dtype(); + key.dims = key.dtype.kind() == ifrt::DType::kToken + ? std::vector{0} + : std::vector(shape_span.begin(), shape_span.end()); + key.weak_type = weak_type; + auto aval = MakeShapedArrayCached(key); + auto dtype = IfrtDtypeToDtypeWithTokenCanonicalization(key.dtype).value(); + return PyArray(std::move(aval), weak_type, dtype, std::move(key.dims), + std::move(sharding), std::move(py_client), + std::move(traceback), std::move(ifrt_array), committed, + skip_checks); +} + +PyArrayResultHandler::PyArrayResultHandler(nb::object aval, nb::object sharding, + bool committed, bool skip_checks) + : aval_(std::move(aval)), + sharding_(std::move(sharding)), + committed_(committed), + skip_checks_(skip_checks) { + weak_type_ = nb::cast(aval_.attr("weak_type")); + dtype_ = nb::cast(aval_.attr("dtype")); + shape_ = nb::cast>(aval_.attr("shape")); +} + +PyArray PyArrayResultHandler::Call(absl::Span py_arrays) const { + auto py_device_list = jax::GetPyDeviceList(sharding_); + if (!py_device_list.ok()) { + throw nb::value_error( + absl::StrCat("Failed to get py device list from sharding: ", + py_device_list.status().ToString()) + .c_str()); + } + return Call(py_device_list.value()->py_client(), + CreateIfRtArrayFromSingleDeviceShardedPyArrays( + dtype_, shape_, py_arrays, sharding_), + xla::PjRtFuture<>()); +} + +PyArray PyArrayResultHandler::Call(nb_class_ptr py_client, + tsl::RCReference ifrt_array, + xla::PjRtFuture<> result_status) const { + return PyArray(aval_, weak_type_, dtype_, shape_, sharding_, + std::move(py_client), Traceback::Get(), std::move(ifrt_array), + committed_, skip_checks_, std::move(result_status)); +} + +PyArray PyArrayResultHandler::Call(PyArray py_array) const { + return Call(py_array.py_client(), tsl::FormRef(py_array.ifrt_array()), + xla::PjRtFuture<>()); +} + +PyArray::PyArray(nb::object aval, bool weak_type, nb_dtype dtype, + std::vector shape, nb::object sharding, + nb_class_ptr py_client, + std::optional traceback, + tsl::RCReference ifrt_array, bool committed, + bool skip_checks, xla::PjRtFuture<> result_status) { + auto* self = + PyArray_tp_new(reinterpret_cast(type_), nullptr, nullptr); + m_ptr = self; + Construct(reinterpret_cast(self), std::move(aval), weak_type, + std::move(dtype), std::move(shape), std::move(sharding), committed, + std::move(py_client), std::move(traceback), std::move(ifrt_array), + std::move(result_status)); + + if (!skip_checks) { + this->attr("_arrays") = this->attr("_check_and_rearrange")( + this->attr("_arrays"), this->attr("_sharding"), this->attr("aval")); + } +} + +PyArray::Storage& PyArray::GetStorage() { + return *GetPyArrayStorageFromObject(reinterpret_cast(ptr())); +} + +const PyArray::Storage& PyArray::GetStorage() const { + return *GetPyArrayStorageFromObject(reinterpret_cast(ptr())); +} + +nb::object PyArray::CheckAndRearrange(const absl::Span py_arrays, + const nb::object sharding, + const nb::object aval) { + return this->attr("_check_and_rearrange")(py_arrays, sharding, aval); +} + +void PyArray::SetIfrtArray(tsl::RCReference ifrt_array) { + GetStorage().ifrt_array = std::move(ifrt_array); +} + +const std::vector& PyArray::py_arrays_cached() { + auto& py_arrays = this->py_arrays(); + + if (py_arrays.empty()) { + auto ifrt_arrays = ifrt_array()->DisassembleIntoSingleDeviceArrays( + ifrt::ArrayCopySemantics::kReuseInput, + ifrt::SingleDeviceShardSemantics::kAddressableShards); + if (!ifrt_arrays.ok()) { + throw nb::value_error( + absl::StrCat("Failed to disassemble into single-device arrays: ", + ifrt_arrays.status().ToString()) + .c_str()); + } + py_arrays.reserve(ifrt_arrays->size()); + for (auto& ifrt_array : *ifrt_arrays) { + py_arrays.push_back(PyArray::MakeFromSingleDeviceArray( + py_client(), traceback(), std::move(ifrt_array), weak_type(), + committed(), result_status())); + } + } + + return py_arrays; +} + +nb::object PyArray::arrays() { + // For performance, we only keep pjrt buffers by default. But on python side + // "_arrays" returns PyArrays instead, and subsequent calls to "_arrays" + // should return the same PyArrays (to avoid duplicate device to host + // transfers). So we create PyArrays the first time it is called and reuse + // them later. + if (ifrt_array() == nullptr || ifrt_array()->IsDeleted()) return nb::none(); + + if (llvm::isa(&ifrt_array()->sharding())) { + std::vector py_arrays; + py_arrays.push_back(*this); + return nb::cast(py_arrays); + } + + return nb::cast(py_arrays_cached()); +} + +absl::Status PyArray::set_arrays(nb::object obj) { + if (obj.is_none()) { + SetIfrtArray(tsl::RCReference()); + py_arrays().clear(); + return absl::OkStatus(); + } + + if (!nb::isinstance(obj)) { + return InvalidArgument("Unsupported arg when setting Array._arrays: %s", + nb::cast(nb::str(obj.type()))); + } + + nb::list list(obj); + + if (list.size() == 0) return absl::OkStatus(); + + SetIfrtArray(tsl::RCReference()); + py_arrays().clear(); + std::vector> ifrt_arrays; + ifrt_arrays.reserve(list.size()); + absl::InlinedVector devices; + devices.reserve(list.size()); + std::vector shapes; + shapes.reserve(list.size()); + for (nb::handle obj : list) { + if (obj.type().is(PyArray::type())) { + auto py_array = nb::borrow(obj); + if (py_array.py_client().get() != py_client().get()) { + return InvalidArgument("Client mismatch when assigning to _arrays."); + } + if (py_array.num_shards() != 1) { + return InvalidArgument("Wrong number of shards: %d", + py_array.num_shards()); + } + ifrt_arrays.push_back(tsl::FormRef(py_array.ifrt_array())); + devices.push_back( + ifrt_arrays.back()->sharding().devices()->devices().front()); + shapes.push_back(ifrt_arrays.back()->shape()); + } else { + return InvalidArgument("Unsupported arg when setting Array._arrays: %s", + nb::cast(nb::str(obj.type()))); + } + } + const ifrt::MemoryKind first_memory_kind = + ifrt_arrays.front()->sharding().memory_kind(); + // TODO(hyeontaek): Canonicalize every `ifrt::MemoryKind` at creation time to + // skip canonicalization here once JAX begins to do it for JAX shardings. + const ifrt::MemoryKind canonical_first_memory_kind = + ifrt::CanonicalizeMemoryKind( + first_memory_kind, + ifrt_arrays.front()->sharding().devices()->devices().front()); + for (const auto& ifrt_array : ifrt_arrays) { + if (canonical_first_memory_kind != + ifrt::CanonicalizeMemoryKind( + ifrt_array->sharding().memory_kind(), + ifrt_array->sharding().devices()->devices().front())) { + throw nb::value_error( + absl::StrFormat( + "Memory kind mismatch between single-device arrays. Got one " + "array with memory kind '%v' and another with memory_kind '%v'", + first_memory_kind, ifrt_array->sharding().memory_kind()) + .c_str()); + } + } + + TF_ASSIGN_OR_RETURN( + auto ifrt_sharding, + sharding().type().is(jax::PmapSharding::type()) + ? xla::GetIfrtConcreteSharding(sharding(), ifrt::Shape(shape()), + std::move(shapes)) + : xla::GetIfrtHloSharding(sharding(), ifrt::Shape(shape()))); + TF_ASSIGN_OR_RETURN( + auto array, + py_client()->ifrt_client()->AssembleArrayFromSingleDeviceArrays( + ifrt::Shape(shape()), std::move(ifrt_sharding), + absl::MakeSpan(ifrt_arrays), ifrt::ArrayCopySemantics::kReuseInput, + ifrt::SingleDeviceShardSemantics::kAddressableShards)); + SetIfrtArray(std::move(array)); + return absl::OkStatus(); +} + +absl::StatusOr PyArray::FullyReplicatedShard() { + auto& cached = GetStorage().fully_replicated_array; + if (!cached.is_none()) { + return nb::cast(cached); + } + + if (ifrt_array() == nullptr) { + return InvalidArgument( + "FullyReplicatedShard() called on deleted or donated buffer"); + } + + TF_ASSIGN_OR_RETURN(auto fully_replicated_ifrt_shard, + ifrt_array()->FullyReplicatedShard( + ifrt::ArrayCopySemantics::kReuseInput)); + auto array = MakeFromSingleDeviceArray( + py_client(), traceback(), std::move(fully_replicated_ifrt_shard), + weak_type(), committed(), result_status()); + cached = array; + return nb::cast(cached); +} + +absl::Status PyArray::BlockUntilReady() const { + nb::gil_scoped_release gil_release; + if (ifrt_array() == nullptr) { + return InvalidArgument( + "BlockHostUntilReady() called on deleted or donated buffer"); + } + ifrt::Array* ifrt_array = this->ifrt_array(); + return AwaitBuffersReady(absl::MakeConstSpan(&ifrt_array, 1)); +} + +absl::StatusOr PyArray::GetOnDeviceSizeInBytes() { + if (ifrt_array() == nullptr) { + return InvalidArgument( + "GetOnDeviceSizeInBytes() called on deleted or donated buffer"); + } + + TF_ASSIGN_OR_RETURN(size_t shard_size, + GetPjrtBuffer(ifrt_array())->GetOnDeviceSizeInBytes()); + return shard_size * nb::len(nb::object(sharding().attr("device_set"))); +} + +absl::Status PyArray::BlockUntilResultStatusIsReady() { + auto& result_status = GetStorage().result_status; + // If the result_status future is not valid, this result did not come directly + // from a computation that returns tokens, so we don't wait for the status. + if (!result_status.IsValid()) { + return absl::OkStatus(); + } + if (!result_status.IsReady()) { + // Only release the gil if we need to Await(). + nb::gil_scoped_release release_gil; + return result_status.Await(); + } + return result_status.Await(); +} + +absl::StatusOr> +PyArray::SingleDeviceArrayToNumpyArrayDidCopy() { + TF_ASSIGN_OR_RETURN(auto arr, FullyReplicatedShard()); + auto result = arr.GetStorage().host_value.AsNumPyArray( + arr.GetStorage().dynamic_shape, arr.ifrt_array()); + TF_RETURN_IF_ERROR(arr.BlockUntilResultStatusIsReady()); + return result; +} + +absl::StatusOr PyArray::SingleDeviceArrayToNumpyArray() { + TF_ASSIGN_OR_RETURN(auto result, SingleDeviceArrayToNumpyArrayDidCopy()); + return result.first; +} + +absl::Status PyArray::CopySingleDeviceArrayToHostAsync() { + TF_ASSIGN_OR_RETURN(auto arr, FullyReplicatedShard()); + return arr.GetStorage().host_value.CopyToHostAsync( + arr.GetStorage().dynamic_shape, arr.ifrt_array()); +} + +absl::StatusOr PyArray::AssertUnsharded(absl::string_view api) { + if (ifrt_array() == nullptr) { + return InvalidArgument("%s( called on deleted or donated buffer", api); + } + + if (llvm::isa(&ifrt_array()->sharding())) { + return *this; + } + + auto& py_arrays = py_arrays_cached(); + if (py_arrays.size() != 1) { + return InvalidArgument("%s() is supported only for unsharded arrays.", api); + } + return py_arrays[0]; +} + +absl::StatusOr PyArray::UnsafeBufferPointer() { + TF_ASSIGN_OR_RETURN(auto arr, AssertUnsharded("UnsafeBufferPointer")); + + return py_client()->pjrt_client()->UnsafeBufferPointer( + GetPjrtBuffer(arr.ifrt_array())); +} + +nb::dict PyArray::CudaArrayInterface() { + auto arr_or_error = AssertUnsharded("UnsafeBufferPointer"); + if (!arr_or_error.ok()) { + throw nb::attribute_error( + "__cuda_array_interface__ is only supported for unsharded arrays."); + } + auto arr = *arr_or_error; + + ifrt::Array* ifrt_array = arr.ifrt_array(); + std::optional& scratch = arr.GetStorage().dynamic_shape; + auto* pjrt_buffer = GetPjrtBuffer(ifrt_array); + if (pjrt_buffer->client()->platform_id() != CudaId()) { + throw nb::attribute_error( + "__cuda_array_interface__ is only defined for NVidia GPU buffers."); + } + if (pjrt_buffer->IsTuple()) { + throw nb::attribute_error( + "__cuda_array_interface__ is only defined for array buffers."); + } + + switch (pjrt_buffer->element_type()) { + case PrimitiveType::PRED: + case PrimitiveType::S8: + case PrimitiveType::S16: + case PrimitiveType::S32: + case PrimitiveType::S64: + case PrimitiveType::U8: + case PrimitiveType::U16: + case PrimitiveType::U32: + case PrimitiveType::U64: + case PrimitiveType::F16: + case PrimitiveType::F32: + case PrimitiveType::F64: + case PrimitiveType::C64: + case PrimitiveType::C128: + break; + + default: + throw nb::attribute_error( + absl::StrFormat( + "__cuda_array_interface__ is not supported for %s buffers.", + PrimitiveType_Name(pjrt_buffer->element_type())) + .c_str()); + } + + nb::str typestr = + ValueOrThrow(TypeDescriptorForPrimitiveType(pjrt_buffer->element_type())); + + // TODO(b/327524065): use PjRtLayout directly instead of xla::Layout + Layout xla_layout = pjrt_buffer->layout()->xla_layout(); + if (!LayoutUtil::IsMonotonicWithDim0Major(xla_layout)) { + throw nb::attribute_error( + "__cuda_array_interface__ is only currently supported for " + "buffers in row-major order."); + } + + nb::dict result; + const auto* dynamic_shape = + ValueOrThrow(XlaDynamicShape(ifrt_array, scratch)); + result["shape"] = SpanToNbTuple(dynamic_shape->dimensions()); + result["typestr"] = std::move(typestr); + std::unique_ptr external_reference_hold = + ValueOrThrow(pjrt_buffer->AcquireExternalReference()); + const void* root_ptr = + external_reference_hold->OpaqueDeviceMemoryDataPointer(); + nb::tuple data = + nb::make_tuple(nb::int_(absl::bit_cast(root_ptr)), + nb::bool_(true) /* read-only */ + ); + result["data"] = std::move(data); + result["version"] = nb::int_(2); + return result; +} + +absl::StatusOr CudaArrayInterfaceToBuffer( + const nb::dict& cai, nb_class_ptr client, + std::optional device_id) { + if (!cai.contains("data")) { + return absl::InvalidArgumentError( + "CUDA Array Interface does not define `data`"); + } + if (!cai.contains("shape")) { + return absl::InvalidArgumentError( + "CUDA Array Interface does not define `shape`"); + } + if (!cai.contains("typestr")) { + return absl::InvalidArgumentError( + "CUDA Array Interface does not define `typestr`"); + } + if (!cai.contains("version")) { + return absl::InvalidArgumentError( + "CUDA Array Interface does not define `version`"); + } + auto version = nb::cast(cai["version"]); + if (version < 2 || version > 3) { + LOG(WARNING) << "CUDA Array Interface version " << version + << " support is undefined"; + } + auto data = nb::cast(cai["data"]); + auto data_value = nb::cast(data[0]); + void* data_ptr = reinterpret_cast(data_value); + auto dimensions = nb::cast>(cai["shape"]); + if (data_value == 0 && absl::c_find(dimensions, 0) == dimensions.end()) { + return absl::InvalidArgumentError( + "CUDA Array Interface `data`(=NULL) and `shape`(no zero-valued " + "dimensions) are inconsistent"); + } + auto ndim = dimensions.size(); + TF_ASSIGN_OR_RETURN( + PrimitiveType element_type, + DtypeToPrimitiveType(nb_dtype::from_args(cai["typestr"]))); + + if (!device_id.has_value()) { + throw XlaRuntimeError( + "This operation requires CUDA support from jaxlib or jax cuda plugin."); + } + TF_ASSIGN_OR_RETURN(auto device, + client->DeviceFromLocalHardwareId(*device_id)); + bool is_default_stream = + data_value == 0 || version == 2 || + (version == 3 && (!cai.contains("stream") || cai["stream"].is_none())); + TF_ASSIGN_OR_RETURN( + std::intptr_t stream, + ([is_default_stream, cai, device]() -> absl::StatusOr { + if (is_default_stream) { + return device->GetStreamForExternalReadyEvents(); + } else { + auto stream_ = nb::cast(cai["stream"]); + if (stream_ == 0) { + return absl::InvalidArgumentError( + "CUDA Array Interface does not allow zero stream value"); + } + return stream_; + } + }())); + + std::vector minor_to_major(ndim); + if (cai.contains("strides") && !cai["strides"].is_none() && data_value != 0) { + std::iota(minor_to_major.begin(), minor_to_major.end(), 0); + auto strides = nb::cast>(cai["strides"]); + if (strides.size() != ndim) { + return absl::InvalidArgumentError( + "CUDA Array Interface `shape` and `strides` dimensionalities are " + "inconsistent"); + } + absl::c_sort(minor_to_major, [&](int a, int b) { + // If two dimensions have the same stride, prefer the major-to-minor + // interpretation of the ordering, since that's what JAX wants. + return (strides[a] == strides[b] ? b < a : strides[a] < strides[b]); + }); + int64_t stride = ShapeUtil::ByteSizeOfPrimitiveType(element_type); + for (int64_t d : minor_to_major) { + if (dimensions[d] > 1 && strides[d] != stride) { + return absl::UnimplementedError(absl::StrCat( + "Only arrays with trivial (compact) striding are supported; " + "i.e., arrays whose striding represents a transposition of the " + "underlying buffer but not broadcasting. Dimensions were: [%s], " + "strides were [%s].", + absl::StrJoin(dimensions, ","), absl::StrJoin(strides, ","))); + } + stride *= dimensions[d]; + } + } else { + std::iota(minor_to_major.rbegin(), minor_to_major.rend(), 0); + } + Shape shape = ShapeUtil::MakeShapeWithDenseLayout(element_type, dimensions, + minor_to_major); + std::function on_delete_callback = []() {}; + auto* pjrt_device = + llvm::dyn_cast_or_null(device->device()); + if (pjrt_device == nullptr) { + return InvalidArgument( + "This operation is implemented for a PjRt-compatible backend only."); + } + TF_RET_CHECK(pjrt_device->IsAddressable()); + TF_ASSIGN_OR_RETURN( + auto pjrt_buffer, + device->client()->pjrt_client()->CreateViewOfDeviceBuffer( + static_cast(data_ptr), shape, + *pjrt_device->pjrt_device()->default_memory_space(), + on_delete_callback, + stream <= 2 ? std::nullopt : std::make_optional(stream))); + auto* ifrt_client = + llvm::dyn_cast_or_null(client->ifrt_client()); + if (ifrt_client == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + TF_ASSIGN_OR_RETURN(auto ifrt_array, + ifrt_client->CreatePjRtArray(std::move(pjrt_buffer))); + return PyArray::MakeFromSingleDeviceArray(std::move(client), Traceback::Get(), + std::move(ifrt_array), false, true); +} + +absl::Status PyArray::Delete() { + for (auto& arr : py_arrays()) { + TF_RETURN_IF_ERROR(arr.Delete()); + } + py_arrays().clear(); + if (ifrt_array() != nullptr) { + // We do not wait for the deletion to complete here. + // + // (1) Skipping blocking does not affect the correctness of deletion as long + // as the runtime preserves dispatch ordering of deletion w.r.t. other + // operations. + // + // (2) Synchronously waiting for the deletion to complete is very expensive + // when the deletion can return a status only after the underlying physical + // buffer has been deleted or a request must be processed via RPC, + // especially as this deletion is done per array. + ifrt_array()->Delete(); + SetIfrtArray(tsl::RCReference()); + } + return absl::OkStatus(); +} + +bool PyArray::IsDeleted() const { + if (ifrt_array() == nullptr) { + return true; + } + + return ifrt_array()->IsDeleted(); +} + +PyArray PyArray::Clone() const { + auto array = tsl::FormRef(ifrt_array()); + auto* ifrt_client = py_client()->ifrt_client(); + tsl::RCReference out = + ifrt_client + ->CopyArrays(absl::MakeSpan(&array, 1), /*devices=*/std::nullopt, + /*memory_kind=*/std::nullopt, + ifrt::ArrayCopySemantics::kReuseInput) + .value() + .front(); + return PyArray(aval(), weak_type(), dtype(), + std::vector(shape().begin(), shape().end()), + sharding(), py_client(), traceback(), std::move(out), + committed(), /*skip_checks=*/true, result_status()); +} + +nb::handle PyArray::Storage::AsHandle() { + return reinterpret_cast(reinterpret_cast(this) - + offsetof(PyArrayObject, array_storage)); +} + +PyArray::Storage::~PyArray_Storage() { + CHECK(PyGILState_Check()); + if (py_client) { + PyClient::ArraysShard& shard = py_client->arrays_[thread_id_bucket]; + nanobind::ft_lock_guard lock(shard.mutex); + if (shard.arrays == this) { + shard.arrays = next; + } + if (prev) { + prev->next = next; + } + if (next) { + next->prev = prev; + } + } + // Release GIL and then explicitly destroy `ifrt_array` to prevent deadlock on + // CPU backend caused by interactions between argument donations and host + // callbacks. + nb::gil_scoped_release gil_release; + ifrt_array.reset(); +} + +absl::StatusOr> PyArray::BatchedCopyToDeviceWithSharding( + absl::Span py_arrays, + absl::Span dst_device_lists, + absl::Span dst_shardings, + absl::Span array_copy_semantics) { + if (py_arrays.empty()) { + return std::vector(); + } + + TF_RET_CHECK(py_arrays.size() == dst_device_lists.size()); + TF_RET_CHECK(py_arrays.size() == dst_shardings.size()); + + ifrt::Client* const client = py_arrays.front().ifrt_array()->client(); + std::vector results(py_arrays.size()); + + // Arrays to be copied, grouped by source/destination devices and memory + // kinds. The grouping is enforced by `ifrt::Client::CopyArrays()`. + struct Batch { + std::vector indexes; + std::vector> ifrt_arrays; + }; + absl::flat_hash_map batches; + + for (int i = 0; i < py_arrays.size(); ++i) { + const auto& py_array = py_arrays[i]; + const auto& dst_sharding = dst_shardings[i]; + const auto& array_cs = array_copy_semantics[i]; + + auto* ifrt_array_ptr = py_array.ifrt_array(); + const ifrt::DeviceListRef& src_devices = + ifrt_array_ptr->sharding().devices(); + const ifrt::DeviceListRef& dst_devices = dst_device_lists[i]; + + ifrt::MemoryKind src_memory_kind = + ifrt::CanonicalizeMemoryKind(ifrt_array_ptr->sharding().memory_kind(), + src_devices->devices().front()); + ifrt::MemoryKind dst_memory_kind = ifrt::CanonicalizeMemoryKind( + xla::GetMemoryKind(dst_sharding), dst_devices->devices().front()); + + if (*src_devices == *dst_devices && src_memory_kind == dst_memory_kind && + array_cs == ifrt::ArrayCopySemantics::kReuseInput) { + results[i] = py_arrays[i]; + continue; + } + + auto transfer_guard_formatter = [&py_array, &dst_sharding] { + return absl::StrCat( + "aval=", nb::cast(nb::repr(py_array.aval())), + ", sharding=", + nb::cast(nb::repr(py_array.sharding())), + ", dst_sharding=", + nb::cast(nb::repr(dst_sharding))); + }; + TF_RETURN_IF_ERROR( + jax::ApplyTransferGuardToDeviceToDevice(transfer_guard_formatter)); + + Batch& batch = batches[BatchedCopyToDeviceWithShardingKey{ + src_devices, src_memory_kind, dst_devices, dst_memory_kind, array_cs}]; + batch.indexes.push_back(i); + batch.ifrt_arrays.push_back(tsl::FormRef(ifrt_array_ptr)); + } + + std::vector>> ifrt_arrays; + { + GlobalPyRefManager()->CollectGarbage(); + nb::gil_scoped_release gil_release; + + for (auto& [key, batch] : batches) { + TF_ASSIGN_OR_RETURN( + auto copied, + client->CopyArrays( + absl::MakeSpan(batch.ifrt_arrays), + // All arrays in `batch` have the same `key.dst_devices` and + // `key.dst_memory_kind` due to the grouping above. + key.dst_devices, key.dst_memory_kind, key.array_copy_semantics)); + for (int i = 0; i < batch.indexes.size(); ++i) { + ifrt_arrays.push_back( + std::make_pair(batch.indexes[i], std::move(copied[i]))); + } + } + } + + auto traceback = Traceback::Get(); + for (auto& [i, ifrt_array] : ifrt_arrays) { + const auto& py_array = py_arrays[i]; + absl::Span shape_span = py_array.shape(); + results[i] = + PyArray(py_array.aval(), py_array.weak_type(), py_array.dtype(), + std::vector(shape_span.begin(), shape_span.end()), + dst_shardings[i], py_array.py_client(), traceback, + std::move(ifrt_array), py_array.committed(), + /*skip_checks=*/true, py_array.result_status()); + } + return results; +} + +absl::StatusOr PyArray::BatchedDevicePut( + nb::object aval, nb::object sharding, std::vector xs, + absl::Span dst_devices, bool committed, + bool force_copy, PjRtClient::HostBufferSemantics host_buffer_semantics, + bool jax_enable_x64) { + if (dst_devices.size() != xs.size()) { + throw nb::value_error( + absl::StrCat("Argument sizes (xs and devices) must match %zu vs %zu", + dst_devices.size(), xs.size()) + .c_str()); + } + for (const PyDevice* device : dst_devices) { + if (device->client().get() == nullptr) { + return InvalidArgument("Cannot copy to unattached devices."); + } + } + auto transfer_guard_formatter = [&aval, &sharding] { + return absl::StrCat( + "aval=", nb::cast(nb::repr(aval)), + ", dst_sharding=", nb::cast(nb::repr(sharding))); + }; + + GlobalPyRefManager()->CollectGarbage(); + + auto n_devices = dst_devices.size(); + + DevicePutOptions options; + options.squash_64bit_types = !jax_enable_x64; + options.allow_zero_copy = + (!force_copy && (host_buffer_semantics == + ifrt::Client::HostBufferSemantics::kImmutableZeroCopy)); + if (!dst_devices.empty()) { + options.ifrt_user_context = + dst_devices.front()->client()->ifrt_client()->CreateUserContext(); + } + + nb::list owning_pylist; + std::vector> ifrt_arrays; + + absl::InlinedVector devices; + devices.reserve(n_devices); + std::vector shapes; + shapes.reserve(n_devices); + + ifrt::MemoryKind dst_memory_kind = xla::GetMemoryKind(sharding); + + std::vector device_put_fns; + device_put_fns.reserve(xs.size()); + size_t i = 0; + for (auto& x : xs) { + if (PyArray::IsPyArray(x)) { + TF_RETURN_IF_ERROR( + jax::ApplyTransferGuardToDeviceToDevice(transfer_guard_formatter)); + } else { + TF_RETURN_IF_ERROR( + jax::ApplyTransferGuardToHostToDevice(transfer_guard_formatter)); + } + TF_ASSIGN_OR_RETURN( + device_put_fns.emplace_back(), + DevicePut(x, dst_devices[i]->client()->ifrt_client(), + dst_devices[i]->device(), options, dst_memory_kind)); + ++i; + } + std::vector device_puts; + device_puts.reserve(device_put_fns.size()); + { + nb::gil_scoped_release gil_release; + for (auto& device_put_fn : device_put_fns) { + TF_ASSIGN_OR_RETURN(auto device_put, std::move(device_put_fn)()); + device_puts.push_back(std::move(device_put)); + } + } + for (auto& device_put : device_puts) { + ifrt_arrays.push_back(std::move(device_put.ifrt_array)); + devices.push_back( + ifrt_arrays.back()->sharding().devices()->devices().front()); + shapes.push_back(ifrt_arrays.back()->shape()); + if (device_put.owning_pybuffer) { + owning_pylist.append(device_put.owning_pybuffer); + } + } + + // TODO(phawkins): it's highly suspicious to me that owning_pylist isn't + // consumed here. Look into this. + + auto weak_type = nb::cast(aval.attr("weak_type")); + auto dtype = aval.attr("dtype"); + auto shape = nb::cast>(aval.attr("shape")); + + TF_ASSIGN_OR_RETURN( + auto ifrt_sharding, + sharding.type().is(jax::PmapSharding::type()) + ? xla::GetIfrtConcreteSharding(sharding, ifrt::Shape(shape), + std::move(shapes)) + : xla::GetIfrtHloSharding(sharding, ifrt::Shape(shape))); + TF_ASSIGN_OR_RETURN(auto ifrt_dtype, DtypeToIfRtDType(dtype)); + // TODO(emilyaf): Remove the following and just use ifrt_dtype when tokens are + // supported. + ifrt::DType array_dtype = + ifrt_arrays.empty() ? ifrt_dtype : ifrt_arrays.front()->dtype(); + TF_ASSIGN_OR_RETURN(auto py_device_list, jax::GetPyDeviceList(sharding)); + TF_ASSIGN_OR_RETURN( + auto ifrt_array, + py_device_list->py_client() + ->ifrt_client() + ->AssembleArrayFromSingleDeviceArrays( + array_dtype, ifrt::Shape(shape), std::move(ifrt_sharding), + absl::MakeSpan(ifrt_arrays), + xla::ifrt::ArrayCopySemantics::kReuseInput, + xla::ifrt::SingleDeviceShardSemantics::kAddressableShards)); + + return PyArray(aval, weak_type, dtype, std::move(shape), sharding, + py_device_list->py_client(), Traceback::Get(), + std::move(ifrt_array), committed, /*skip_checks=*/true); +} + +absl::StatusOr PyArray::ReorderShards( + PyArray x, nanobind::object dst_sharding, + ifrt::ArrayCopySemantics array_copy_semantics) { + xla::ifrt::Array* ifrt_array_ptr = x.ifrt_array(); + if (ifrt_array_ptr == nullptr) { + return absl::InvalidArgumentError( + "Reorder() called on deleted or donated buffer"); + } + + ifrt::Client* const client = ifrt_array_ptr->client(); + + const auto& device_list = ifrt_array_ptr->sharding().devices(); + TF_ASSIGN_OR_RETURN(auto dst_device_list, GetIfrtDeviceList(dst_sharding)); + if (device_list->AddressableDeviceList()->size() != + dst_device_list->AddressableDeviceList()->size()) { + return absl::InvalidArgumentError(absl::StrCat( + "Array is expected to have ", + dst_device_list->AddressableDeviceList()->size(), + " addressable shards, but has ", + device_list->AddressableDeviceList()->size(), " addressable shards")); + } + + TF_ASSIGN_OR_RETURN( + std::shared_ptr dst_ifrt_sharding, + GetIfrtConcreteEvenSharding(dst_sharding, ifrt_array_ptr->dtype(), + ifrt_array_ptr->shape())); + + tsl::RCReference new_ifrt_array; + { + nb::gil_scoped_release gil_release; + + const absl::Span addressable_devices = + device_list->AddressableDeviceList()->devices(); + const absl::Span dst_addressable_devices = + dst_device_list->AddressableDeviceList()->devices(); + + absl::flat_hash_map device_id_to_array_shard_index; + device_id_to_array_shard_index.reserve(dst_addressable_devices.size()); + for (int i = 0; i < dst_addressable_devices.size(); ++i) { + const int device_id = dst_addressable_devices[i]->Id().value(); + const bool inserted = + device_id_to_array_shard_index.insert({device_id, i}).second; + if (!inserted) { + return absl::InvalidArgumentError( + absl::StrCat("Sharding contains duplicate device id=", device_id)); + } + } + + std::vector from_shard_indices; + from_shard_indices.reserve(addressable_devices.size()); + std::vector to_shard_indices; + to_shard_indices.reserve(dst_addressable_devices.size()); + for (int i = 0; i < dst_addressable_devices.size(); ++i) { + from_shard_indices.push_back(i); + const int shard_device_id = addressable_devices[i]->Id().value(); + const auto it = device_id_to_array_shard_index.find(shard_device_id); + if (it == device_id_to_array_shard_index.end()) { + return absl::InvalidArgumentError(absl::StrCat( + "Array shard ", i, " is on device id=", shard_device_id, + ", but sharding does not have a shard on that device.")); + } + to_shard_indices.push_back(it->second); + } + + auto mappings = + std::make_shared>(); + { + auto& mapping = mappings->emplace_back(); + mapping.in_array = 0; + mapping.out_array = 0; + mapping.from.reserve(dst_addressable_devices.size()); + mapping.to.reserve(dst_addressable_devices.size()); + for (int64_t i = 0; i < dst_addressable_devices.size(); ++i) { + mapping.from.push_back(xla::ifrt::RemapPlan::Interval{ + from_shard_indices[i], from_shard_indices[i] + 1, 1}); + mapping.to.push_back(xla::ifrt::RemapPlan::Interval{ + to_shard_indices[i], to_shard_indices[i] + 1, 1}); + } + } + + xla::ifrt::RemapPlan plan = { + /*input_specs=*/{xla::ifrt::ArraySpec{ + /*dtype=*/ifrt_array_ptr->dtype(), + /*shape=*/ifrt_array_ptr->shape(), + /*sharding=*/ifrt_array_ptr->shared_ptr_sharding()}}, + /*output_specs=*/ + {xla::ifrt::ArraySpec{/*dtype=*/ifrt_array_ptr->dtype(), + /*shape=*/ifrt_array_ptr->shape(), + /*sharding=*/std::move(dst_ifrt_sharding)}}, + /*mappings=*/std::move(mappings), + }; + DCHECK_OK(plan.Validate()); + std::vector> input; + input.push_back(tsl::FormRef(ifrt_array_ptr)); + TF_ASSIGN_OR_RETURN( + auto remapped, + client->RemapArrays(plan, absl::MakeSpan(input), array_copy_semantics)); + + TF_RET_CHECK(remapped.size() == 1); + new_ifrt_array = std::move(remapped.front()); + } + + return xla::PyArray(nb::borrow(x.aval().ptr()), x.weak_type(), + nb::borrow(x.dtype().ptr()), + std::vector(x.shape().begin(), x.shape().end()), + std::move(dst_sharding), x.py_client(), x.traceback(), + std::move(new_ifrt_array), + /*committed=*/true, + /*skip_checks=*/true); +} + +absl::Status PyArray::BatchedBlockUntilReady(std::vector objs) { + // Create ready futures for all arrays before blocking on their readiness. + // This helps reduce the latency in some backend implementations where + // querying readiness of an array is not free. + + std::vector ifrt_arrays; + ifrt_arrays.reserve(objs.size()); + for (nb::handle obj : objs) { + if (obj.type().is(PyArray::type())) { + auto py_array = nb::borrow(obj); + ifrt::Array* const ifrt_array = py_array.ifrt_array(); + if (ifrt_array == nullptr) { + return absl::InvalidArgumentError( + "BlockHostUntilReady() called on deleted or donated buffer"); + } + ifrt_arrays.push_back(ifrt_array); + } else { + return absl::InvalidArgumentError( + "PyArray::BatchedBlockUntilReady can take PyArray only"); + } + } + + GlobalPyRefManager()->CollectGarbage(); + nb::gil_scoped_release gil_release; + return AwaitBuffersReady(absl::MakeConstSpan(ifrt_arrays)); +} + +std::vector PyClient::LiveArrays() const { + std::vector result; + for (auto& shard : arrays_) { + nb::ft_lock_guard lock(shard.mutex); + for (PyArray::Storage* array = shard.arrays; array; array = array->next) { + bool all_deleted = + (array->ifrt_array == nullptr || array->ifrt_array->IsDeleted()); + if (!all_deleted) { + result.push_back(nb::borrow(array->AsHandle())); + } + } + } + return result; +} + +// PEP 3118 buffer protocol implementation. + +namespace { + +// Extra data to be kept alive by the consumer of the buffer protocol. +struct ExtraBufferInfo { + explicit ExtraBufferInfo( + std::shared_ptr buffer, + std::unique_ptr external_reference_hold) + : buffer(std::move(buffer)), + external_reference_hold(std::move(external_reference_hold)) {} + + std::vector strides; + // We keep an external reference hold to the PjRtBuffer. This prevents a + // use-after-free in the event that Delete() is called on a buffer with an + // live buffer protocol view. It does however mean that Delete() sometimes + // won't actually delete immediately. + std::shared_ptr buffer; + std::unique_ptr external_reference_hold; +}; + +// The default layout of a non-tuple array should have major-to-minor layout +// and no tiles. +bool HasDefaultLayout(const Layout& layout) { + return LayoutUtil::IsMonotonicWithDim0Major(layout) && layout.tiles().empty(); +} + +int PyArray_bf_getbuffer(PyObject* exporter, Py_buffer* view, int flags) { + absl::Status status = [&]() -> absl::Status { + PyArray py_array = nb::borrow(exporter); + if (py_array.ifrt_array() == nullptr) { + // TODO(phawkins): why is this happening? + return InvalidArgument("Array is null"); + } + if (!llvm::isa(py_array.ifrt_array())) { + return InvalidArgument("Only local arrays are supported, got %s", + py_array.ifrt_array()->DebugString()); + } + auto* array = + static_cast(py_array.ifrt_array()); + absl::Span> buffers = + array->pjrt_buffers(); + + PjRtBuffer& buffer = *buffers.front(); + if (!buffer.IsOnCpu()) { + return InvalidArgument( + "Python buffer protocol is only defined for CPU buffers."); + } + + if (buffers.size() != 1) { + return InvalidArgument( + "Python buffer protocol is only defined for buffers with a single " + "shard."); + } + if (!py_array.sharding().type().is(jax::SingleDeviceSharding::type())) { + return InvalidArgument( + "Python buffer protocol is only defined for single-device sharded " + "buffers."); + } + + const char* format = + PEP3118FormatDescriptorForPrimitiveType(buffer.element_type()); + // It isn't an option for us to export unknown types as, say, bytes. When + // converting an object to an ndarray, NumPy tries the buffer protocol + // first. We very much want NumPy to fail and fall back to using + // __array__, which allows us to handle custom dtypes correctly. + if (!format) { + return InvalidArgument( + "Buffers of type %s are not supported by the Python buffer protocol.", + PrimitiveType_Name(buffer.element_type())); + } + + std::unique_ptr external_reference_hold; + { + // We call BlockHostUntilReady() below, which may block. + nb::gil_scoped_release gil_release; + + if (buffer.IsTuple()) { + return InvalidArgument( + "Python buffer protocol is only defined for array buffers."); + } + if ((flags & PyBUF_WRITEABLE) == PyBUF_WRITEABLE) { + return InvalidArgument("XLA buffers are read-only."); + } + TF_ASSIGN_OR_RETURN(external_reference_hold, + buffer.AcquireExternalReference()); + if (buffer.IsDeleted()) { + return InvalidArgument("Deleted buffer used in buffer protocol."); + } + + // TODO(b/327524065): use PjRtLayout directly instead of xla::Layout + Layout xla_layout = buffer.layout()->xla_layout(); + + if (((flags & PyBUF_C_CONTIGUOUS) == PyBUF_C_CONTIGUOUS || + (flags & PyBUF_STRIDES) == PyBUF_ND) && + !LayoutUtil::IsMonotonicWithDim0Major(xla_layout)) { + return InvalidArgument("Buffer is not in C-contiguous layout."); + } else if ((flags & PyBUF_F_CONTIGUOUS) == PyBUF_F_CONTIGUOUS && + !LayoutUtil::IsMonotonicWithDim0Minor(xla_layout)) { + return InvalidArgument("Buffer is not in F-contiguous layout."); + } else if ((flags & PyBUF_ANY_CONTIGUOUS) == PyBUF_ANY_CONTIGUOUS && + !LayoutUtil::IsMonotonicWithDim0Major(xla_layout) && + !LayoutUtil::IsMonotonicWithDim0Minor(xla_layout)) { + return InvalidArgument("Buffer is not in contiguous layout."); + } else if (!HasDefaultLayout(xla_layout)) { + // Fail and fall back to using __array__ if the CPU buffer has a device + // specific layout. For instance, this happens for host buffers in + // pinned memories of the TPU device. + return InvalidArgument( + "Buffer is potentially a device buffer with non default layout."); + } + TF_RETURN_IF_ERROR(buffer.GetReadyFuture().Await()); + } + + // We must hold the GIL (or at least prevent Python GC) while writing to the + // view object, see https://github.com/python/cpython/issues/130409. + std::memset(view, 0, sizeof(Py_buffer)); + const void* root_ptr = + external_reference_hold->OpaqueDeviceMemoryDataPointer(); + view->buf = const_cast(root_ptr); + auto extra = std::make_unique( + buffers.front(), std::move(external_reference_hold)); + view->itemsize = ShapeUtil::ByteSizeOfPrimitiveType(buffer.element_type()); + TF_ASSIGN_OR_RETURN(view->len, buffer.GetOnDeviceSizeInBytes()); + view->readonly = 1; + if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) { + view->format = const_cast(format); + } + if ((flags & PyBUF_ND) == PyBUF_ND) { + view->ndim = buffer.dimensions().size(); + static_assert(sizeof(int64_t) == sizeof(Py_ssize_t), + "Py_ssize_t must be 64 bits"); + if (view->ndim != 0) { + view->shape = reinterpret_cast( + const_cast(buffer.dimensions().data())); + if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) { + extra->strides = + ByteStridesForShape(buffer.element_type(), buffer.dimensions(), + buffer.layout()->xla_layout()); + view->strides = reinterpret_cast( + const_cast(extra->strides.data())); + } + } + } + view->internal = extra.release(); + return absl::OkStatus(); + }(); + if (!status.ok()) { + // numpy.asarray(...) eats the PyExc_BufferError. Adding a log here helps + // debugging when the error really occurs. + VLOG(1) << "Buffer Protocol Error: " << status; + PyErr_SetString(PyExc_BufferError, status.ToString().c_str()); + return -1; + } + view->obj = exporter; + Py_INCREF(view->obj); + return 0; +} + +void PyArray_bf_releasebuffer(PyObject*, Py_buffer* buffer) { + auto extra = static_cast(buffer->internal); + delete extra; +} + +// Returns if shape has a major-to-minor layout. +bool HasMajorToMinorLayout(const xla::Shape& shape) { + if (shape.has_layout()) { + for (int i = 0; i < shape.layout().minor_to_major_size(); ++i) { + if (shape.layout().minor_to_major(i) != + shape.layout().minor_to_major_size() - 1 - i) { + return false; + } + } + } + return true; +} + +// Returns byte_strides if shape has a non-major-to-minor layout. +std::optional> ByteStridesOrDefaultForShapeInt64( + const Shape& shape) { + if (!shape.has_layout() || HasMajorToMinorLayout(shape)) { + return std::nullopt; + } + return ByteStridesForShape(shape); +} + +bool IsZeroCopyableCpuBuffer(const PjRtBuffer* buf) { + // For CPU buffers with device-specific layouts, we must delinearize + // to unpack the array. This could happen for the host buffer + // pre-mapped to the TPU device, a.k.a., pinned host buffers for the + // device. + bool has_default_layout = + buf->layout() == nullptr || HasDefaultLayout(buf->layout()->xla_layout()); + // On CPU for values >= 8 bits, we can return the value in a zero-copy way. + // For sub-byte values, we must copy in order to unpack the array. + return buf->IsOnCpu() && + !primitive_util::IsSubByteNonPredType(buf->element_type()) && + has_default_layout; +} +} // namespace + +PyHostValue::PyHostValue() = default; +PyHostValue::~PyHostValue() = default; + +absl::StatusOr> PyHostValue::AsNumPyArray( + std::optional& dynamic_shape_holder, ifrt::Array* ifrt_array) { + if (ifrt_array->IsDeleted()) { + return InvalidArgument("DeviceArray has been deleted."); + } + // The only `jax.Array` with token-shape buffer is the one wrapped by + // `jax.core.Token`. Since it is an internal implementation detail, we + // don't support converting it to a numpy array. + if (ifrt_array->dtype().kind() == ifrt::DType::kToken) { + return InvalidArgument( + "Cannot convert a token-shape buffer to a numpy array."); + } + auto* arr = llvm::dyn_cast_or_null(ifrt_array); + if (arr != nullptr) { + auto* pjrt_buffer = arr->pjrt_buffers().front().get(); + TF_RET_CHECK(!pjrt_buffer->IsTuple()); + // On CPU for values >= 8 bits, we can return the value in a zero-copy way. + // For sub-byte values, we must copy in order to unpack the array. + if (IsZeroCopyableCpuBuffer(pjrt_buffer)) { + TF_ASSIGN_OR_RETURN(const auto* shape, + XlaDynamicShape(ifrt_array, dynamic_shape_holder)); + TF_ASSIGN_OR_RETURN(nb_dtype dtype, + PrimitiveTypeToNbDtype(shape->element_type())); + // Objects that must be kept alive while the array is alive. + struct Hold { + tsl::RCReference buffer; + std::unique_ptr external_reference_hold; + }; + auto hold = std::make_unique(); + hold->buffer = tsl::FormRef(ifrt_array); + auto* hold_ptr = hold.release(); + nb::capsule hold_capsule( + hold_ptr, [](void* h) noexcept { delete static_cast(h); }); + { + // Release the GIL as `AcquireExternalReference` may block. + nb::gil_scoped_release gil; + TF_ASSIGN_OR_RETURN(hold_ptr->external_reference_hold, + pjrt_buffer->AcquireExternalReference()); + TF_RETURN_IF_ERROR(ifrt_array->GetReadyFuture().Await()); + } + void* data = + hold_ptr->external_reference_hold->OpaqueDeviceMemoryDataPointer(); + nb_numpy_ndarray array(dtype, shape->dimensions(), + ByteStridesForShape(*shape), data, hold_capsule); + array.attr("flags").attr("writeable") = nb::bool_(false); + return std::make_pair(array, false); + } + } + + TF_RETURN_IF_ERROR(CopyToHostAsync(dynamic_shape_holder, ifrt_array)); + if (!ready_.IsReady()) { + nb::gil_scoped_release gil; + TF_RETURN_IF_ERROR(ready_.Await()); + } else { + TF_RETURN_IF_ERROR(ready_.Await()); + } + if (string_array_contents_ != nullptr) { + TF_RETURN_IF_ERROR(ConvertStringArrayContentsToNumpyArray(ifrt_array)); + } + return std::make_pair(value_, true); +} + +absl::Status PyHostValue::ConvertStringArrayContentsToNumpyArray( + ifrt::Array* ifrt_array) { +#ifdef NPY_2_0_API_VERSION + if (PyArray_RUNTIME_VERSION < NPY_2_0_API_VERSION) { + return absl::FailedPreconditionError( + absl::StrCat("String arrays are not supported in NumPy version: ", + PyArray_RUNTIME_VERSION)); + } + auto numpy_dtype = nb::steal( + reinterpret_cast(PyArray_DescrFromType(NPY_VSTRING))); + value_ = nb_numpy_ndarray(numpy_dtype, ifrt_array->shape().dims(), + /*strides=*/std::nullopt); + + auto dst_py_array_obj = reinterpret_cast<::PyArrayObject*>(value_.ptr()); + auto iter = + nb::steal(PyArray_IterNew(reinterpret_cast(dst_py_array_obj))); + for (auto& cord : *string_array_contents_) { + absl::string_view input_str_view = cord.Flatten(); + auto py_unicode = nb::steal(PyUnicode_FromStringAndSize( + input_str_view.data(), input_str_view.size())); + if (py_unicode.ptr() == nullptr) { + return absl::InternalError("PyUnicode_FromStringAndSize failed"); + } + if (PyArray_SETITEM(dst_py_array_obj, + static_cast(PyArray_ITER_DATA(iter.ptr())), + py_unicode.ptr()) != 0) { + return absl::InternalError("PyArray_SETITEM failed"); + } + PyArray_ITER_NEXT(iter.ptr()); + } + + value_.attr("flags").attr("writeable") = nb::bool_(false); + + string_array_contents_.reset(); + + return absl::OkStatus(); +#else + return absl::FailedPreconditionError( + "String arrays are not supported in this NumPy version."); +#endif +} + +absl::Status PyHostValue::CopyStringArrayToHostAsync( + std::optional& dynamic_shape_holder, ifrt::Array* ifrt_array) { + auto transfer_guard_formatter = [ifrt_array] { + return absl::StrCat( + "shape=(", absl::StrJoin(ifrt_array->shape().dims(), ","), + "), dtype=", ifrt_array->dtype().DebugString(), ", device=", + ifrt_array->sharding().devices()->devices().front()->DebugString()); + }; + TF_RETURN_IF_ERROR( + jax::ApplyTransferGuardToDeviceToHost(transfer_guard_formatter)); + + TF_ASSIGN_OR_RETURN(nb_dtype dtype, IfrtDtypeToNbDtype(ifrt_array->dtype())); + auto shape = ifrt_array->shape(); + + // Allocate a vector of cords to hold the contents of the array until + // they are until they are ultimately converted to a numpy array as part + // of the `AsNumPyArray` call. + string_array_contents_ = + std::make_shared>(shape.num_elements()); + ready_ = ifrt_array->CopyToHostBuffer(string_array_contents_->data(), + /*byte_strides=*/std::nullopt, + ifrt::ArrayCopySemantics::kAlwaysCopy); + + ready_.OnReady( + [string_array_contents = string_array_contents_](absl::Status) { + }); // Keeps the cords alive until the copy is done. + + return absl::OkStatus(); +} + +absl::Status PyHostValue::CopyToHostAsync( + std::optional& dynamic_shape_holder, ifrt::Array* ifrt_array) { + if (ready_.IsValid()) { + // The array value has been populated, so CopyToHostAsync has been called. + return absl::OkStatus(); + } + + // Copying in Arrays of type kString requires some special handling + if (ifrt_array->dtype().kind() == ifrt::DType::kString) { + return CopyStringArrayToHostAsync(dynamic_shape_holder, ifrt_array); + } + + auto* arr = llvm::dyn_cast_or_null(ifrt_array); + if (arr != nullptr && !arr->pjrt_buffers().front()->IsTuple() && + IsZeroCopyableCpuBuffer(arr->pjrt_buffers().front().get())) { + return absl::OkStatus(); + } + auto transfer_guard_formatter = [ifrt_array] { + return absl::StrCat( + "shape=(", absl::StrJoin(ifrt_array->shape().dims(), ","), + "), dtype=", ifrt_array->dtype().DebugString(), ", device=", + ifrt_array->sharding().devices()->devices().front()->DebugString()); + }; + TF_RETURN_IF_ERROR( + jax::ApplyTransferGuardToDeviceToHost(transfer_guard_formatter)); + + // TODO(b/182461453): This is a blocking call. If we further implemented + // populating dynamic shape metadata while fetching the literal, we wouldn't + // need this static approach. + const xla::Shape* dynamic_shape; + std::optional shape_holder; + if (llvm::isa(ifrt_array)) { + TF_ASSIGN_OR_RETURN(dynamic_shape, + XlaDynamicShape(ifrt_array, dynamic_shape_holder)); + } else { + // Skip querying the dynamic shape for a non-PjRt Array. + TF_ASSIGN_OR_RETURN(xla::PrimitiveType type, + ifrt::ToPrimitiveType(ifrt_array->dtype())); + shape_holder = ShapeUtil::MakeShapeWithDescendingLayout( + type, ifrt_array->shape().dims()); + dynamic_shape = &*shape_holder; + } + + xla::Shape host_shape = ShapeUtil::DeviceShapeToHostShape(*dynamic_shape); + + auto strides = ByteStridesOrDefaultForShapeInt64(host_shape); + TF_ASSIGN_OR_RETURN(nb_dtype dtype, + PrimitiveTypeToNbDtype(host_shape.element_type())); + value_ = nb_numpy_ndarray(dtype, host_shape.dimensions(), strides); + // TODO(hyeontaek): Several PjRt runtimes assume that the host buffer uses + // the same transposition as the device buffer. This is different from + // PjRtBuffer::ToLiteral()'s semantics that the runtime respects the layout + // of the host buffer literal. On the other hand, the runtime often knows + // better about an efficient layout for the host buffer. It will be useful + // to revisit the semantics of PjRtBuffer::ToLiteral() to see if it is + // desirable for the runtime to choose the layout. + ready_ = ifrt_array->CopyToHostBuffer(value_.mutable_data(), strides, + ifrt::ArrayCopySemantics::kReuseInput); + // Make sure the destination of the copy remains alive until the copy is done. + value_.inc_ref(); + ready_.OnReady([array{value_.ptr()}](absl::Status status) { + GlobalPyRefManager()->AddGarbage(nb::steal(array)); + }); + value_.attr("flags").attr("writeable") = nb::bool_(false); + return absl::OkStatus(); +} + +namespace { +PyGetSetDef PyArray_tp_getset[] = { + {"__dict__", PyObject_GenericGetDict, PyObject_GenericSetDict, nullptr, + nullptr}, + {nullptr, nullptr, nullptr, nullptr, nullptr}, +}; + +PyMemberDef PyArray_members[] = { +#if PY_VERSION_HEX < 0x030C0000 + {"__weaklistoffset__", T_PYSSIZET, + static_cast(offsetof(PyArrayObject, weakrefs)), READONLY, + nullptr}, + {"__dictoffset__", T_PYSSIZET, + static_cast(offsetof(PyArrayObject, dict)), READONLY, nullptr}, +#endif // PY_VERSION_HEX < 0x030C0000 + {nullptr, 0, 0, 0, nullptr}, +}; // namespace xla + +PyType_Slot PyArray_slots[] = { + {Py_tp_new, reinterpret_cast(PyArray_tp_new)}, + {Py_tp_dealloc, reinterpret_cast(PyArray_tp_dealloc)}, + {Py_tp_members, reinterpret_cast(PyArray_members)}, + {Py_tp_traverse, reinterpret_cast(PyArray_tp_traverse)}, + {Py_tp_clear, reinterpret_cast(PyArray_tp_clear)}, + {Py_tp_getset, reinterpret_cast(PyArray_tp_getset)}, + {Py_bf_getbuffer, reinterpret_cast(PyArray_bf_getbuffer)}, + {Py_bf_releasebuffer, reinterpret_cast(PyArray_bf_releasebuffer)}, + {0, nullptr}, +}; + +} // namespace + +absl::Status PyArray::RegisterTypes(nb::module_& m) { + std::string name = + absl::StrCat(nb::cast(m.attr("__name__")), ".ArrayImpl"); + + PyType_Spec PyArray_spec = { +#if PY_VERSION_HEX < 0x030B0000 + // Work around for https://github.com/python/cpython/issues/89478 + // CPython 3.10 and earlier assume that the .name value remains alive + // forever. + /*.name=*/strdup(name.c_str()), +#else + /*.name=*/name.c_str(), +#endif // PY_VERSION_HEX < 0x030B0000 + /*.basicsize=*/static_cast(sizeof(PyArrayObject)), + /*.itemsize=*/0, +#if PY_VERSION_HEX < 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, +#else // PY_VERSION_HEX >= 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | + Py_TPFLAGS_MANAGED_DICT | Py_TPFLAGS_MANAGED_WEAKREF, +#endif // PY_VERSION_HEX >= 0x030C0000 + /*.slots=*/PyArray_slots, + }; + + type_ = PyType_FromSpec(&PyArray_spec); + if (!type_) { + throw nb::python_error(); + } + auto type = nb::borrow(type_); + m.attr("ArrayImpl") = type; + + type.attr("__init__") = nb::cpp_function( + [](PyArray self, nb::object aval, nb::object sharding, nb::list arrays, + bool committed, bool skip_checks) { + if (!(arrays.size() == 0 || arrays[0].type().is(PyArray::type()))) { + throw nb::type_error( + absl::StrCat( + "Unsupported type for elements in `arrays`: ", + nb::cast(nb::str(arrays[0].type()))) + .c_str()); + } + auto py_arrays = nb::cast>(arrays); + PyArray::PyInit(self, std::move(aval), std::move(sharding), py_arrays, + committed, skip_checks); + }, + nb::is_method(), nb::arg("aval"), nb::arg("sharding"), nb::arg("arrays"), + nb::arg("committed"), nb::arg("_skip_checks") = false); + type.attr("delete") = nb::cpp_function( + [](PyArray& self) { xla::ThrowIfError(self.Delete()); }, nb::is_method()); + type.attr("_sharding") = nb_property_readonly(&PyArray::sharding); + type.attr("aval") = nb_property(&PyArray::aval, &PyArray::set_aval); + type.attr("_arrays") = + nb_property(&PyArray::arrays, [](PyArray& self, nb::object obj) { + xla::ThrowIfError(self.set_arrays(obj)); + }); + type.attr("_fully_replicated_shard") = nb::cpp_function( + [](PyArray self) { + return xla::ValueOrThrow(self.FullyReplicatedShard()); + }, + nb::is_method()); + type.attr("_npy_value") = + nb_property(&PyArray::npy_value, &PyArray::set_npy_value); + type.attr("_committed") = nb_property_readonly(&PyArray::committed); + type.attr("unsafe_buffer_pointer") = nb::cpp_function( + [](PyArray self) { + return xla::ValueOrThrow(self.UnsafeBufferPointer()); + }, + nb::is_method()); + type.attr("__cuda_array_interface__") = nb_property_readonly( + [](PyArray self) { return self.CudaArrayInterface(); }); + type.attr("_pjrt_layout") = + nb_property_readonly(xla::ValueOrThrowWrapper(&PyArray::layout)); + type.attr("on_device_size_in_bytes") = nb::cpp_function( + xla::ValueOrThrowWrapper(&PyArray::GetOnDeviceSizeInBytes), + nb::is_method()); + type.attr("_single_device_array_to_np_array_did_copy") = nb::cpp_function( + xla::ValueOrThrowWrapper(&PyArray::SingleDeviceArrayToNumpyArrayDidCopy), + nb::is_method()); + type.attr("_copy_single_device_array_to_host_async") = nb::cpp_function( + [](PyArray& self) { + xla::ThrowIfError(self.CopySingleDeviceArrayToHostAsync()); + }, + nb::is_method()); + type.attr("block_until_ready") = nb::cpp_function( + [](PyArray self) -> nb::object { + xla::ThrowIfError(self.BlockUntilReady()); + return self; + }, + nb::is_method()); + type.attr("platform") = nb::cpp_function( + [](PyArray self) { + if (self.ifrt_array()->client()->platform_name() == "cuda" || + self.ifrt_array()->client()->platform_name() == "rocm") { + return absl::string_view("gpu"); + } else { + return self.ifrt_array()->client()->platform_name(); + } + }, + nb::is_method()); + type.attr("is_ready") = nb::cpp_function( + [](PyArray self) { return xla::ValueOrThrow(self.IsReady()); }, + nb::is_method()); + type.attr("is_deleted") = + nb::cpp_function(&PyArray::IsDeleted, nb::is_method()); + type.attr("traceback") = nb_property_readonly(&PyArray::traceback); + type.attr("clone") = nb::cpp_function(&PyArray::Clone, nb::is_method()); + type.attr("__module__") = m.attr("__name__"); + + m.attr("batched_copy_array_to_devices_with_sharding") = nb::cpp_function( + [](absl::Span arrays, + absl::Span> dst_device_lists, + absl::Span shardings, + absl::Span array_copy_semantics) { + if (arrays.empty()) { + return std::vector(); + } + auto* client = arrays[0].ifrt_array()->client(); + std::vector device_lists; + device_lists.reserve(dst_device_lists.size()); + for (const auto& dst_devices : dst_device_lists) { + absl::InlinedVector devices; + devices.reserve(dst_devices.size()); + for (auto& d : dst_devices) { + devices.push_back(d->device()); + } + device_lists.push_back(client->MakeDeviceList(devices)); + } + return xla::ValueOrThrow(PyArray::BatchedCopyToDeviceWithSharding( + arrays, device_lists, shardings, array_copy_semantics)); + }); + m.attr("array_result_handler") = nb::cpp_function( + [](nb::object aval, nb::object sharding, bool committed, + bool skip_checks) -> nb_class_ptr { + return make_nb_class( + std::move(aval), std::move(sharding), committed, skip_checks); + }, + nb::arg("aval"), nb::arg("sharding"), nb::arg("committed"), + nb::arg("_skip_checks") = false); + + nb::class_(m, "ResultHandler") + .def("__call__", [](const PyArrayResultHandler& self, + PyArray arg) { return self.Call(arg); }) + .def("__call__", + [](const PyArrayResultHandler& self, + std::vector py_arrays) { return self.Call(py_arrays); }); + + return absl::OkStatus(); +} + +} // namespace xla diff --git a/jaxlib/xla/py_array.h b/jaxlib/xla/py_array.h new file mode 100644 index 000000000000..f914639e383f --- /dev/null +++ b/jaxlib/xla/py_array.h @@ -0,0 +1,360 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_PY_ARRAY_H_ +#define JAXLIB_XLA_PY_ARRAY_H_ + +#include + +#include +#include +#include +#include +#include +#include + +// placeholder for index annotation headers +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/types/span.h" +#include "llvm/Support/Casting.h" +#include "nanobind/nanobind.h" +#include "jaxlib/xla/py_client.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_future.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "xla/python/traceback.h" +#include "xla/shape.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/util.h" + +namespace xla { + +// Private to PyArray, but you cannot forward declare member classes. +// Not thread safe; assumes the GIL is held. +class PyHostValue { + public: + PyHostValue(); + ~PyHostValue(); + + PyHostValue(const PyHostValue&) = delete; + PyHostValue(PyHostValue&&) = delete; + PyHostValue& operator=(const PyHostValue&) = delete; + PyHostValue& operator=(PyHostValue&&) = delete; + + absl::Status CopyToHostAsync(std::optional& dynamic_shape_holder, + ifrt::Array* ifrt_array); + + absl::StatusOr> AsNumPyArray( + std::optional& dynamic_shape_holder, ifrt::Array* ifrt_array); + + private: + absl::Status CopyStringArrayToHostAsync( + std::optional& dynamic_shape_holder, ifrt::Array* ifrt_array); + + absl::Status ConvertStringArrayContentsToNumpyArray(ifrt::Array* ifrt_array); + + ifrt::Future<> ready_; + nb_numpy_ndarray value_; + + // Optional field, only used for arrays of type kString. This vector of cords + // serves as input buffer for the CopyToHostBuffer call. It holds these + // contents until it is lazily converted it to a numpy array when the user + // calls `AsNumPyArray`. + std::shared_ptr> string_array_contents_; +}; + +// Private to PyArray, but you cannot forward declare member classes. +struct PyArray_Storage { + PyArray_Storage(nanobind::object aval, bool weak_type, nb_dtype dtype, + std::vector shape, nanobind::object sharding, + bool committed, nb_class_ptr py_client, + std::optional traceback, + tsl::RCReference ifrt_array, + xla::PjRtFuture<> result_status); + + ~PyArray_Storage(); + nanobind::handle AsHandle(); + + nanobind::object aval; + bool weak_type = false; + nb_dtype dtype; + std::vector shape; + + nanobind::object sharding; + nanobind::object npy_value = nanobind::none(); + bool committed = false; + + nb_class_ptr py_client; + std::optional traceback; + tsl::RCReference ifrt_array; + nanobind::object fully_replicated_array = nanobind::none(); + + // optional field, used only in python + std::vector py_arrays; + PyHostValue host_value; // Protected by the GIL. + std::optional dynamic_shape = std::nullopt; + // Only set if this Array was generated by a computation that has effects. + // This is the result status of the XLA computation that generated this + // array. + xla::PjRtFuture<> result_status; + + // Doubly-linked list of all PyArrays known to the client. Protected by the + // GIL. Since multiple PyArrays may share the same PjRtBuffer, there may be + // duplicate PjRtBuffers in this list. + PyArray_Storage* next; + PyArray_Storage* prev; + + uint8_t thread_id_bucket; +}; + +// The C++ implementation of jax.Array. A few key methods and data members are +// implemented in C++ for performance, while most of the functionalities are +// still implemented in python. +class PyArray : public nanobind::object { + public: + NB_OBJECT(PyArray, nanobind::object, "Array", PyArray::IsPyArray); + PyArray() = default; + + // "__init__" methods. Only used in python + static void PyInit(PyArray self, nanobind::object aval, + nanobind::object sharding, + absl::Span py_arrays, bool committed, + bool skip_checks); + + // Only used in C++. `skip_checks` should only be set for Arrays created by + // jax that cannot possibly have consistency issues (e.g. `sharding` devices + // different than `ifrt_array` devices). Arrays created by users should be + // checked. + PyArray(nanobind::object aval, bool weak_type, nb_dtype dtype, + std::vector shape, nanobind::object sharding, + nb_class_ptr py_client, + std::optional traceback, + tsl::RCReference ifrt_array, bool committed, + bool skip_checks, + xla::PjRtFuture<> result_status = xla::PjRtFuture<>()); + + static PyArray MakeFromSingleDeviceArray( + nb_class_ptr py_client, std::optional traceback, + tsl::RCReference ifrt_array, bool weak_type, bool committed, + xla::PjRtFuture<> result_status = xla::PjRtFuture<>()); + + static PyArray MakeFromIfrtArrayAndSharding( + nb_class_ptr py_client, std::optional traceback, + tsl::RCReference ifrt_array, nanobind::object sharding, + bool weak_type, bool committed, bool skip_checks); + + static absl::Status RegisterTypes(nanobind::module_& m); + + static PyArray borrow(PyObject* ptr) { + return nanobind::borrow(ptr); + } + + using Storage = PyArray_Storage; + + const nanobind::object& aval() const { return GetStorage().aval; } + void set_aval(nanobind::object aval) { GetStorage().aval = std::move(aval); } + + bool weak_type() const { return GetStorage().weak_type; } + + const nb_dtype& dtype() const { return GetStorage().dtype; } + absl::Span shape() const { return GetStorage().shape; } + + const nanobind::object& sharding() const { return GetStorage().sharding; } + + absl::StatusOr> layout() { + return ifrt_array()->layout(); + } + + bool committed() const { return GetStorage().committed; } + + const nanobind::object& npy_value() const { return GetStorage().npy_value; } + void set_npy_value(nanobind::object v) { + GetStorage().npy_value = std::move(v); + } + + const nb_class_ptr& py_client() const { + return GetStorage().py_client; + } + + const std::optional& traceback() const { + return GetStorage().traceback; + } + + // Returns xla::InvalidArgument if the buffer has been deleted. + // See `PjRtFuture` for the semantics of `IsReady` and `IsKnownReady`. + absl::StatusOr IsReady() { + ifrt::Array* ifrt_array_ptr = ifrt_array(); + if (ifrt_array_ptr->IsDeleted()) { + return InvalidArgument("Array has been deleted."); + } + return ifrt_array_ptr->GetReadyFuture().IsReady(); + } + + const xla::PjRtFuture<>& result_status() const { + return GetStorage().result_status; + } + + ifrt::Array* ifrt_array() const { return GetStorage().ifrt_array.get(); } + + // Short-term escape hatch to get PjRtBuffers from PyArray. + // TODO(hyeontaek): Migrate all users of this method to be agnostic of PjRt. + absl::Span> pjrt_buffers() const { + ifrt::Array* ifrt_array_ptr = ifrt_array(); + if (ifrt_array_ptr == nullptr) { + return {}; + } + auto* arr = + llvm::dyn_cast_or_null(ifrt_array_ptr); + if (arr == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + return arr->pjrt_buffers(); + } + + int num_addressable_shards() const { + ifrt::Array* ifrt_array_ptr = ifrt_array(); + if (ifrt_array_ptr == nullptr) { + return 0; + } + auto* arr = + llvm::dyn_cast_or_null(ifrt_array_ptr); + if (arr == nullptr) { + // TODO(hyeontaek): Add num_addressable_shards to ifrt. + return num_shards(); + } + return arr->pjrt_buffers().size(); + } + + std::vector& py_arrays() { return GetStorage().py_arrays; } + const std::vector& py_arrays() const { + return GetStorage().py_arrays; + } + const std::vector& py_arrays_cached(); + + nanobind::object arrays(); + absl::Status set_arrays(nanobind::object obj); + absl::StatusOr FullyReplicatedShard(); + + int num_shards() const { + ifrt::Array* ifrt_array_ptr = ifrt_array(); + if (ifrt_array_ptr == nullptr) { + return 0; + } + return ifrt_array_ptr->sharding().devices()->size(); + } + + static nanobind::handle type() { + DCHECK(type_); + return nanobind::handle(type_); + } + + static bool IsPyArray(nanobind::handle arg) { + return arg.type().is(PyArray::type()); + } + + absl::Status BlockUntilReady() const; + + absl::Status BlockUntilResultStatusIsReady(); + + absl::StatusOr GetOnDeviceSizeInBytes(); + absl::StatusOr> + SingleDeviceArrayToNumpyArrayDidCopy(); + absl::StatusOr SingleDeviceArrayToNumpyArray(); + absl::Status CopySingleDeviceArrayToHostAsync(); + nanobind::dict CudaArrayInterface(); + absl::StatusOr UnsafeBufferPointer(); + + absl::Status Delete(); + + bool IsDeleted() const; + + PyArray Clone() const; + + static absl::StatusOr> BatchedCopyToDeviceWithSharding( + absl::Span py_arrays, + absl::Span dst_device_lists, + absl::Span dst_shardings, + absl::Span array_copy_semantics); + + static absl::StatusOr BatchedDevicePut( + nanobind::object aval, nanobind::object sharding, + std::vector xs, + absl::Span dst_devices, bool committed, + bool force_copy, PjRtClient::HostBufferSemantics host_buffer_semantics, + bool jax_enable_x64); + + static absl::StatusOr ReorderShards( + PyArray x, nanobind::object dst_sharding, + ifrt::ArrayCopySemantics array_copy_semantics); + + static absl::Status BatchedBlockUntilReady( + std::vector objs); + + private: + absl::StatusOr AssertUnsharded(absl::string_view api); + + nanobind::object CheckAndRearrange(absl::Span py_arrays, + nanobind::object sharding, + nanobind::object aval); + + void SetIfrtArray(tsl::RCReference ifrt_array); + + Storage& GetStorage(); + const Storage& GetStorage() const; + + inline static PyObject* type_ = nullptr; +}; + +class PyArrayResultHandler { + public: + PyArrayResultHandler(nanobind::object aval, nanobind::object sharding, + bool committed, bool skip_checks); + + PyArray Call(absl::Span py_arrays) const; + PyArray Call(PyArray py_array) const; + + PyArray Call(nb_class_ptr py_client, + tsl::RCReference ifrt_array, + xla::PjRtFuture<> result_status = xla::PjRtFuture<>()) const; + + private: + nanobind::object aval_; + nanobind::object sharding_; + bool weak_type_; + bool committed_; + bool skip_checks_; + + nb_dtype dtype_; + std::vector shape_; +}; + +absl::StatusOr CudaArrayInterfaceToBuffer( + const nanobind::dict& cai, nb_class_ptr cuda_client, + std::optional device_id); + +} // namespace xla + +#endif // JAXLIB_XLA_PY_ARRAY_H_ diff --git a/jaxlib/xla/py_client.cc b/jaxlib/xla/py_client.cc new file mode 100644 index 000000000000..434077b0824f --- /dev/null +++ b/jaxlib/xla/py_client.cc @@ -0,0 +1,851 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/py_client.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "mlir/Pass/PassManager.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/pair.h" // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/py_array.h" +#include "jaxlib/xla/py_device.h" +#include "jaxlib/xla/py_executable.h" +#include "jaxlib/xla/py_memory_space.h" +#include "jaxlib/xla/py_values.h" +#include "xla/literal.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/mlir_to_hlo.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_compiler.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/callback.h" +#include "xla/python/guard_lib.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/compiler.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/hlo/hlo_program.h" +#include "xla/python/ifrt/host_callback.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/program.h" +#include "xla/python/nb_absl_span.h" // IWYU pragma: keep +#include "xla/python/nb_class_ptr.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "xla/python/pjrt_ifrt/xla_compiler.h" +#include "xla/python/pprof_profile_builder.h" +#include "xla/python/py_host_callback.h" +#include "xla/python/python_ref_manager.h" +#include "xla/python/traceback.h" +#include "xla/python/types.h" +#include "xla/service/custom_call_target_registry.h" +#include "xla/service/platform_util.h" // IWYU pragma: keep +#include "xla/shape.h" +#include "xla/status_macros.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" + +namespace xla { + +namespace nb = nanobind; + +/*static*/ nb_class_ptr PyClient::Make( + std::shared_ptr ifrt_client) { + auto client = make_nb_class(std::move(ifrt_client)); + Initialize(client); + return client; +} + +PyClient::PyClient(std::shared_ptr ifrt_client) + : ifrt_client_(std::move(ifrt_client)), + client_attributes_(ifrt_client_->Attributes()) { + CHECK(ifrt_client_); +} + +/* static */ void PyClient::Initialize(nb_class_ptr client) { + for (ifrt::Device* device : client->ifrt_client()->devices()) { + client->devices_[device] = make_nb_class(client, device); + + for (ifrt::Memory* memory : device->Memories()) { + auto& py_memory = client->memory_spaces_[memory]; + if (py_memory.get() == nullptr) { + py_memory = make_nb_class(client, memory); + } + } + } +} + +PyClient::~PyClient() { + nb::gil_scoped_release gil; + ifrt_client_ = nullptr; +} + +nb_class_ptr PyClient::GetPyDevice(ifrt::Device* device) { + auto& py_device = devices_[device]; + if (py_device.get() == nullptr) { + py_device = make_nb_class( + nb::borrow>(nb::find(this)), device); + } + return py_device; +} + +nb_class_ptr PyClient::GetPyMemorySpace( + ifrt::Memory* memory_space) { + auto& py_memory = memory_spaces_[memory_space]; + if (py_memory.get() == nullptr) { + py_memory = make_nb_class( + nb::borrow>(nb::find(this)), memory_space); + } + return py_memory; +} + +std::vector> PyClient::Devices() { + std::vector> devices; + auto span = ifrt_client_->devices(); + devices.reserve(span.size()); + for (ifrt::Device* device : span) { + devices.push_back(GetPyDevice(device)); + } + return devices; +} + +std::vector> PyClient::LocalDevices() { + std::vector> devices; + devices.reserve(ifrt_client_->addressable_devices().size()); + for (ifrt::Device* device : ifrt_client_->addressable_devices()) { + devices.push_back(GetPyDevice(device)); + } + return devices; +} + +std::vector> PyClient::GetAllDevices() { + std::vector> devices; + devices.reserve(ifrt_client_->GetAllDevices().size()); + for (ifrt::Device* device : ifrt_client_->GetAllDevices()) { + devices.push_back(GetPyDevice(device)); + } + return devices; +} + +absl::StatusOr> PyClient::DeviceFromLocalHardwareId( + int local_hardware_id) { + TF_ASSIGN_OR_RETURN(ifrt::Device * device, + ifrt_client_->LookupAddressableDevice(local_hardware_id)); + return GetPyDevice(device); +} + +nb::list PyClient::LiveExecutables() { + CHECK(PyGILState_Check()); + nb::ft_lock_guard lock(executables_mutex_); + nb::list executables; + for (PyLoadedExecutable* exec = executables_; exec; exec = exec->next_) { + if (!exec->is_deleted()) { + executables.append(nb::find(exec)); + } + } + return executables; +} + +absl::Status PyClient::Defragment() { + CHECK(PyGILState_Check()); + if (!llvm::isa(ifrt_client_.get())) { + return absl::UnimplementedError( + "Defragmentation is not supported on this runtime."); + } + ifrt::PlatformId platform_id = ifrt_client_->platform_id(); + bool is_gpu_client = platform_id == CudaId() || platform_id == RocmId() || + platform_id == SyclId(); + + if (!is_gpu_client) { + return pjrt_client()->Defragment(); + } + + struct TmpBuffer { + // Non-empty for buffers found in a PyArray_Storage. Multiple Arrays + // can reference the same PjRtBuffer. + std::vector*> pjrt_buffer_ptrs; + // TODO(skyewm): maybe use py_buffer's HostValue + std::shared_ptr host_copy; + }; + + // Synchronously copy all buffers to host + absl::flat_hash_map pjrt_buf_to_tmp_buffer; + + std::vector arrays = LiveArrays(); + for (const PyArray& array : arrays) { + // TODO(hyeontaek): Support non-PjRt Arrays. + // TODO(hyeontaek): Re-construct ifrt::Array with new PjRtBuffer so that + // std::shared_ptr does not need to be updated in-place. + if (array.ifrt_array() == nullptr) { + continue; + } + auto* arr = + llvm::dyn_cast_or_null(array.ifrt_array()); + if (arr == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend " + "only."); + } + TF_ASSIGN_OR_RETURN(absl::Span> pjrt_buffers, + arr->mutable_pjrt_buffers()); + for (int i = 0; i < pjrt_buffers.size(); ++i) { + std::shared_ptr& pjrt_buf_ptr = pjrt_buffers[i]; + if (pjrt_buf_ptr->IsDeleted()) { + continue; + } + auto [iter, inserted] = + pjrt_buf_to_tmp_buffer.insert({pjrt_buf_ptr.get(), TmpBuffer()}); + if (inserted) { + TF_ASSIGN_OR_RETURN(iter->second.host_copy, + pjrt_buf_ptr->ToLiteralSync()); + } + iter->second.pjrt_buffer_ptrs.push_back(&pjrt_buf_ptr); + } + } + + // All buffers successfully copied to host, delete on-device copies. + // + // Use blocking delete operation to ensure all memory is actually cleared + // before we start rewriting buffers. + // + // Die instead of returning a bad status because program presumably can't + // continue if we fail to reconstitute device buffers. + for (const auto& it : pjrt_buf_to_tmp_buffer) { + PjRtBuffer* pjrt_buf = it.first; + TF_CHECK_OK(pjrt_buf + ->ReleaseDeviceMemoryOwnership( + /*wait_for_operations_to_complete=*/true) + .status()); + } + + // Copy host copies back to device and update PyArrays in-place. + for (auto& it : pjrt_buf_to_tmp_buffer) { + PjRtBuffer* pjrt_buf = it.first; + TmpBuffer& tmp_buffer = it.second; + std::unique_ptr new_copy = + pjrt_client() + ->BufferFromHostLiteral(*tmp_buffer.host_copy, + pjrt_buf->memory_space()) + .value(); + TF_CHECK_OK(new_copy->GetReadyFuture().Await()); + + std::shared_ptr new_pjrt_buf_ptr(new_copy.release()); + for (std::shared_ptr* pjrt_buffer_ptr : + tmp_buffer.pjrt_buffer_ptrs) { + *pjrt_buffer_ptr = new_pjrt_buf_ptr; + } + } + + // TODO(skyewm): delete executables? + return absl::OkStatus(); +} + +/* static */ absl::StatusOr PyClient::BufferFromPyval( + nb_class_ptr client, nb::handle argument, ifrt::Device* device, + bool force_copy, ifrt::Client::HostBufferSemantics host_buffer_semantics) { + if (device == nullptr) { + TF_RET_CHECK(!client->ifrt_client_->addressable_devices().empty()); + device = client->ifrt_client_->addressable_devices().front(); + } + CHECK(device != nullptr); + + auto transfer_guard_formatter = [&argument, dst_device = device] { + auto type = nb::cast(nb::str(argument.type())); + // Catch exceptions because shape and dtype properties convertible to str + // are not guaranteed to present in an arbitrary argument. + std::string shape; + std::string dtype; + try { + shape = + nb::cast(nb::str(nb::object(argument.attr("shape")))); + } catch (const std::exception& e) { + shape = ""; + } + try { + dtype = + nb::cast(nb::str(nb::object(argument.attr("dtype")))); + } catch (const std::exception& e) { + dtype = ""; + } + return absl::StrCat("type=", type, ", shape=", shape, ", dtype=", dtype, + ", dst_device=", dst_device->DebugString()); + }; + TF_RETURN_IF_ERROR( + jax::ApplyTransferGuardToHostToDevice(transfer_guard_formatter)); + + TF_ASSIGN_OR_RETURN(ifrt::Device * found_device, + client->ifrt_client_->LookupDevice(device->Id())); + if (found_device != device) { + return InvalidArgument("Cannot copy value to device '%s' with '%s' backend", + device->DebugString(), + client->ifrt_client_->platform_name()); + } + GlobalPyRefManager()->CollectGarbage(); + + DevicePutOptions options; + options.squash_64bit_types = false; + options.allow_zero_copy = + (!force_copy && (host_buffer_semantics == + ifrt::Client::HostBufferSemantics::kImmutableZeroCopy)); + TF_ASSIGN_OR_RETURN(auto put_fn, + DevicePut(argument, client->ifrt_client_.get(), device, + options, ifrt::MemoryKind())); + TF_ASSIGN_OR_RETURN(auto put, [&]() { + // Must release the GIL before calling IFRT because backends may + // decide to block/sleep for device buffer allocation. + nb::gil_scoped_release gil_release; + return std::move(put_fn)(); + }()); + + if (put.ifrt_array) { + auto traceback = Traceback::Get(); + return PyArray::MakeFromSingleDeviceArray( + std::move(client), std::move(traceback), std::move(put.ifrt_array), + /*weak_type=*/false, + /*committed=*/false); + } else { + return put.owning_pybuffer; + } +} + +namespace { + +// Makes IFRT `CompileOptions` from XLA `CompileOptions` and optional host +// callbacks. +std::unique_ptr MakeIfrtCompileOptions( + CompileOptions options, std::vector host_callbacks) { + std::vector> + ifrt_loaded_host_callbacks; + ifrt_loaded_host_callbacks.reserve(host_callbacks.size()); + // Extract `ifrt::LoadedHostCallback`s from host callback capsules that were + // created by `PyClient::MakePythonCallbackUsingHostSendAndRecv()` or + // `PyClient::GetEmitPythonCallbackDescriptor()`. + for (auto& host_callback : host_callbacks) { + ifrt_loaded_host_callbacks.push_back(tsl::FormRef( + static_cast(host_callback.data()))); + } + return std::make_unique( + std::move(options), std::move(ifrt_loaded_host_callbacks)); +} + +// Makes IFRT `DeserializeExecutableOptions` from XLA `CompileOptions` and +// optional host callbacks. +std::unique_ptr +MakeIfrtDeserializeExecutableOptions(std::optional options, + std::vector host_callbacks) { + std::vector> + ifrt_loaded_host_callbacks; + ifrt_loaded_host_callbacks.reserve(host_callbacks.size()); + // Extract `ifrt::LoadedHostCallback`s from host callback capsules that were + // created by `PyClient::MakePythonCallbackUsingHostSendAndRecv()` or + // `PyClient::GetEmitPythonCallbackDescriptor()`. + for (auto& host_callback : host_callbacks) { + ifrt_loaded_host_callbacks.push_back(tsl::FormRef( + static_cast(host_callback.data()))); + } + return std::make_unique( + std::move(options), std::move(ifrt_loaded_host_callbacks)); +} + +} // namespace + +/* static */ absl::StatusOr> +PyClient::CompileIfrtProgram( + nb_class_ptr client, std::unique_ptr ifrt_program, + std::unique_ptr ifrt_options) { + auto* pjrt_compatible_client = + llvm::dyn_cast_or_null( + client->ifrt_client_.get()); + auto* ifrt_xla_options = + llvm::dyn_cast_or_null(ifrt_options.get()); + // For XLA programs, pass allocated device memory size to compile options for + // pjrt compatible backends. + if (pjrt_compatible_client != nullptr && ifrt_xla_options != nullptr) { + xla::CompileOptions& options = ifrt_xla_options->compile_options; + auto addressable_devices = + pjrt_compatible_client->pjrt_client()->addressable_devices(); + if (!addressable_devices.empty()) { + int device_ordinal = options.executable_build_options.device_ordinal(); + if (device_ordinal < 0) { + device_ordinal = 0; + } + CHECK_LT(device_ordinal, addressable_devices.size()); + auto stats = addressable_devices[device_ordinal]->GetAllocatorStats(); + if (stats.ok() && stats->bytes_limit) { + options.executable_build_options.set_device_memory_size( + *stats->bytes_limit); + } + } + + if (pjrt_compatible_client->pjrt_client()->key_value_store().has_value()) { + options.executable_build_options.set_key_value_store( + *pjrt_compatible_client->pjrt_client()->key_value_store()); + } + } + + std::unique_ptr ifrt_loaded_executable; + std::optional fingerprint; + { + nb::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN(ifrt_loaded_executable, + client->ifrt_client_->GetDefaultCompiler()->Compile( + std::move(ifrt_program), std::move(ifrt_options))); + TF_RETURN_IF_ERROR(ifrt_loaded_executable->GetReadyFuture().Await()); + TF_ASSIGN_OR_RETURN(fingerprint, ifrt_loaded_executable->Fingerprint()); + } + auto traceback = Traceback::Get(); + return make_nb_class( + std::move(client), std::move(ifrt_loaded_executable), + std::move(traceback), std::move(fingerprint)); +} + +/* static */ absl::StatusOr> PyClient::Compile( + nb_class_ptr client, std::string mlir_module, + CompileOptions options, std::vector host_callbacks) { + mlir::MLIRContext context; + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ParseMlirModuleString(mlir_module, context)); + if (options.executable_build_options.use_shardy_partitioner()) { + // Since Shardy is located in the middle of the XLA pipeline, we need to + // export it before going to HLO while preserving Shardy ops and attrs. + TF_RETURN_IF_ERROR(ExportShardyForHloRoundTrip(*module)); + } + return CompileIfrtProgram( + client, std::make_unique(module.get()), + MakeIfrtCompileOptions(std::move(options), std::move(host_callbacks))); +} + +/* static */ absl::StatusOr> PyClient::Compile( + nb_class_ptr client, std::string mlir_module, + CompileOptions options, std::vector host_callbacks) { + mlir::MLIRContext context; + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ParseMlirModuleString(mlir_module, context)); + if (options.executable_build_options.use_shardy_partitioner()) { + // Since Shardy is located in the middle of the XLA pipeline, we need to + // export it before going to HLO while preserving Shardy ops and attrs. + TF_RETURN_IF_ERROR(ExportShardyForHloRoundTrip(*module)); + } + + std::vector> + ifrt_loaded_host_callbacks; + ifrt_loaded_host_callbacks.reserve(host_callbacks.size()); + // Extract `ifrt::LoadedHostCallback`s from host callback capsules that were + // created by `PyClient::MakePythonCallbackUsingHostSendAndRecv()` or + // `PyClient::GetEmitPythonCallbackDescriptor()`. + for (auto& host_callback : host_callbacks) { + auto callback = tsl::MakeRef( + client->ifrt_client(), std::move(host_callback)); + ifrt_loaded_host_callbacks.push_back(callback); + } + auto compile_options = std::make_unique( + std::move(options), std::move(ifrt_loaded_host_callbacks)); + return CompileIfrtProgram( + client, std::make_unique(module.get()), + std::move(compile_options)); +} + +absl::StatusOr PyClient::SerializeExecutable( + const PyLoadedExecutable& executable) const { + TF_ASSIGN_OR_RETURN(auto serialized, + executable.ifrt_loaded_executable()->Serialize()); + return nb::bytes(serialized.data(), serialized.size()); +} + +/* static */ absl::StatusOr> +PyClient::DeserializeExecutable(nb_class_ptr client, + nb::bytes serialized, + std::optional options, + std::vector host_callbacks) { + std::unique_ptr ifrt_loaded_executable; + std::optional fingerprint; + auto ifrt_deserialize_options = MakeIfrtDeserializeExecutableOptions( + std::move(options), std::move(host_callbacks)); + { + nb::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN( + ifrt_loaded_executable, + client->ifrt_client_->GetDefaultCompiler()->DeserializeLoadedExecutable( + absl::string_view(serialized.c_str(), serialized.size()), + std::move(ifrt_deserialize_options))); + } + TF_ASSIGN_OR_RETURN(fingerprint, ifrt_loaded_executable->Fingerprint()); + auto traceback = Traceback::Get(); + return make_nb_class( + std::move(client), std::move(ifrt_loaded_executable), + std::move(traceback), std::move(fingerprint)); +} + +namespace { + +struct HeapProfileKey { + Traceback* traceback; + int64_t size; + xla::PjRtDevice* device; + bool operator==(const HeapProfileKey& other) const; +}; + +bool HeapProfileKey::operator==(const HeapProfileKey& other) const { + if (size != other.size || device != other.device) { + return false; + } + if ((traceback == nullptr) != (other.traceback == nullptr)) { + return false; + } + if (traceback && traceback->raw_frames() != other.traceback->raw_frames()) { + return false; + } + return true; +} + +template +H AbslHashValue(H h, const HeapProfileKey& key) { + if (key.traceback) { + h = H::combine(std::move(h), key.traceback->raw_frames()); + } + h = H::combine(std::move(h), key.size, key.device); + return h; +} + +} // namespace + +absl::StatusOr PyClient::HeapProfile() { + CHECK(PyGILState_Check()); + absl::flat_hash_set buffer_set; + absl::flat_hash_map entries; + + auto add_buffer_to_profile = [&](PjRtBuffer* buffer, Traceback* traceback) { + // We only wish to count each PjRtBuffer once, even though they may be + // shared by multiple PyArrays. + if (!buffer->IsDeleted() && buffer_set.insert(buffer).second) { + TF_ASSIGN_OR_RETURN(size_t size, buffer->GetOnDeviceSizeInBytes()); + HeapProfileKey key{traceback, static_cast(size), + buffer->device()}; + ++entries[key]; + } + return absl::OkStatus(); + }; + + std::vector arrays = LiveArrays(); + for (const PyArray& array : arrays) { + if (array.ifrt_array() == nullptr) { + continue; + } + auto* arr = + llvm::dyn_cast_or_null(array.ifrt_array()); + // TODO(hyeontaek): Support non-PjRt Arrays. + if (arr == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend " + "only."); + } + for (const auto& buffer : arr->pjrt_buffers()) { + TF_RETURN_IF_ERROR(add_buffer_to_profile( + buffer.get(), + array.traceback() ? array.traceback()->get() : nullptr)); + } + } + + for (PyLoadedExecutable* executable = executables_; executable; + executable = executable->next_) { + if (!executable->is_deleted()) { + HeapProfileKey key{ + executable->traceback() ? executable->traceback()->get() : nullptr, + executable->SizeOfGeneratedCodeInBytes(), nullptr}; + ++entries[key]; + } + } + + PprofProfileBuilder builder; + auto* allocations = builder.profile().add_sample_type(); + allocations->set_type(builder.StringId("allocations")); + allocations->set_unit(builder.StringId("count")); + auto* space = builder.profile().add_sample_type(); + space->set_type(builder.StringId("space")); + space->set_unit(builder.StringId("bytes")); + + const int kind_string_id = builder.StringId("kind"); + const int buffer_string_id = builder.StringId("buffer"); + const int executable_string_id = builder.StringId("executable"); + const int device_string_id = builder.StringId("device"); + for (const auto& entry : entries) { + auto* sample = builder.profile().add_sample(); + if (entry.first.traceback) { + for (const auto& frame : entry.first.traceback->raw_frames()) { + sample->add_location_id(builder.LocationId(frame.first, frame.second)); + } + } + sample->add_value(entry.second); + sample->add_value(entry.first.size * entry.second); + + auto* kind_label = sample->add_label(); + kind_label->set_key(kind_string_id); + if (entry.first.device) { + kind_label->set_str(buffer_string_id); + auto* device_label = sample->add_label(); + device_label->set_key(device_string_id); + std::string device_label_str(entry.first.device->DebugString()); + device_label->set_str(builder.StringId(device_label_str)); + } else { + kind_label->set_str(executable_string_id); + } + } + std::string serialized = builder.profile().SerializeAsString(); + return nb::bytes(serialized.data(), serialized.size()); +} + +absl::StatusOr PyClient::MakePythonCallbackUsingHostSendAndRecv( + nb::callable callable, absl::Span operand_shapes, + absl::Span result_shapes, + absl::Span send_channel_ids, + absl::Span recv_channel_ids, nb::callable serializer) { + TF_ASSIGN_OR_RETURN( + auto loaded_host_callback, + PyHostSendAndRecvLoadedHostCallback::Create( + ifrt_client(), std::move(callable), operand_shapes, result_shapes, + send_channel_ids, recv_channel_ids, std::move(serializer))); + nb::capsule callback_capsule( + loaded_host_callback.release(), [](void* ptr) noexcept { + static_cast(ptr)->DropRef(); + }); + return callback_capsule; +} + +// TODO(b/394595987): Remove this API method once we remove the call from +// mlir.py's get_emit_python_callback. +absl::StatusOr> +PyClient::GetEmitPythonCallbackDescriptor( + nb::callable callable, absl::Span operand_shapes, + absl::Span result_shapes) { + TF_ASSIGN_OR_RETURN( + auto loaded_host_callback, + PyCpuLoadedHostCallback::Create(ifrt_client(), std::move(callable), + operand_shapes, result_shapes)); + const uint64_t descriptor = loaded_host_callback->descriptor(); + + nb::capsule callback_capsule( + loaded_host_callback.release(), [](void* ptr) noexcept { + static_cast(ptr)->DropRef(); + }); + return std::make_pair(descriptor, nb::object(std::move(callback_capsule))); +} + +XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("xla_python_cpu_callback", + &XlaPythonCpuCallback); + +/* static */ int PyClient::tp_traverse(PyObject* self, visitproc visit, + void* arg) { + PyClient* c = nb::inst_ptr(self); + for (const auto& [ifrt_device, py_device] : c->devices_) { + Py_VISIT(py_device.ptr()); + } + for (const auto& [ifrt_memory, py_memory] : c->memory_spaces_) { + Py_VISIT(py_memory.ptr()); + } + return 0; +} + +/* static */ int PyClient::tp_clear(PyObject* self) { + PyClient* c = nb::inst_ptr(self); + absl::flat_hash_map> devices; + std::swap(devices, c->devices_); + absl::flat_hash_map> memory_spaces; + std::swap(memory_spaces, c->memory_spaces_); + return 0; +} + +PyType_Slot PyClient::slots_[] = { + {Py_tp_traverse, (void*)PyClient::tp_traverse}, + {Py_tp_clear, (void*)PyClient::tp_clear}, + {0, nullptr}, +}; + +/* static */ void PyClient::RegisterPythonTypes(nb::module_& m) { + nb::enum_(m, "HostBufferSemantics") + .value("IMMUTABLE_ONLY_DURING_CALL", + PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall) + .value("IMMUTABLE_UNTIL_TRANSFER_COMPLETES", + PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes) + .value("ZERO_COPY", PjRtClient::HostBufferSemantics::kImmutableZeroCopy); + + nb::class_ py_local_client(m, "Client", nb::is_weak_referenceable(), + nb::type_slots(PyClient::slots_)); + py_local_client.def_prop_ro("platform", &PyClient::platform_name) + .def_prop_ro("_raw_platform", &PyClient::raw_platform_name) + .def_prop_ro("platform_version", &PyClient::platform_version) + .def_prop_ro("runtime_type", &PyClient::runtime_type) + .def("device_count", &PyClient::device_count) + .def("local_device_count", &PyClient::addressable_device_count) + .def("devices", &PyClient::Devices) + .def("local_devices", &PyClient::LocalDevices) + // TODO(hyeontaek): Remove this method once we have a unified API for + // enumerating devices with different criteria. + .def("_get_all_devices", &PyClient::GetAllDevices) + .def("device_from_local_hardware_id", + xla::ValueOrThrowWrapper(&PyClient::DeviceFromLocalHardwareId)) + .def("live_executables", &PyClient::LiveExecutables) + .def("live_arrays", &PyClient::LiveArrays) + .def("live_buffers", &PyClient::LiveArrays) + .def("process_index", &PyClient::process_index) + .def("host_id", &PyClient::process_index) + .def("task_id", &PyClient::process_index) + .def( + "buffer_from_pyval", + [](nb_class_ptr client, nb::handle argument, + PyDevice* device, bool force_copy, + PjRtClient::HostBufferSemantics host_buffer_semantics) { + return ValueOrThrow( + PyClient::BufferFromPyval(std::move(client), argument, + device ? device->device() : nullptr, + force_copy, host_buffer_semantics)); + }, + nb::arg("argument"), nb::arg("device").none() = nullptr, + nb::arg("force_copy") = false, + nb::arg("host_buffer_semantics") = + PjRtClient::HostBufferSemantics::kImmutableZeroCopy) + .def( + "compile", + [](nb_class_ptr client, nb::bytes mlir_module, + CompileOptions options, std::vector host_callbacks) { + return ValueOrThrow(PyClient::Compile( + std::move(client), + std::string(mlir_module.c_str(), mlir_module.size()), + std::move(options), std::move(host_callbacks))); + }, + nb::arg("computation"), nb::arg("compile_options") = CompileOptions(), + nb::arg("host_callbacks") = std::vector()) + .def( + "compile", + [](nb_class_ptr client, nb::bytes mlir_module, + CompileOptions options, std::vector host_callbacks) { + return ValueOrThrow(PyClient::Compile( + std::move(client), + std::string(mlir_module.c_str(), mlir_module.size()), + std::move(options), std::move(host_callbacks))); + }, + nb::arg("computation"), nb::arg("compile_options") = CompileOptions(), + nb::arg("host_callbacks") = std::vector()) + .def( + "compile", + [](nb_class_ptr client, std::string mlir_module, + CompileOptions options, std::vector host_callbacks) { + return ValueOrThrow(PyClient::Compile( + std::move(client), std::move(mlir_module), std::move(options), + std::move(host_callbacks))); + }, + nb::arg("computation"), nb::arg("compile_options") = CompileOptions(), + nb::arg("host_callbacks") = std::vector()) + .def( + "compile", + [](nb_class_ptr client, std::string mlir_module, + CompileOptions options, std::vector host_callbacks) { + return ValueOrThrow(PyClient::Compile( + std::move(client), std::move(mlir_module), std::move(options), + std::move(host_callbacks))); + }, + nb::arg("computation"), nb::arg("compile_options") = CompileOptions(), + nb::arg("host_callbacks") = std::vector()) + .def("compile_ifrt_program", + xla::ValueOrThrowWrapper(PyClient::CompileIfrtProgram)) + .def("serialize_executable", + xla::ValueOrThrowWrapper(&PyClient::SerializeExecutable)) + .def( + "deserialize_executable", + [](nb_class_ptr client, nb::bytes serialized, + std::optional options, + std::vector host_callbacks) { + return ValueOrThrow(PyClient::DeserializeExecutable( + std::move(client), std::move(serialized), std::move(options), + std::move(host_callbacks))); + }, + nb::arg("serialized"), nb::arg("compile_options").none() = nb::none(), + nb::arg("host_callbacks") = std::vector()) + .def("heap_profile", xla::ValueOrThrowWrapper(&PyClient::HeapProfile)) + // TODO(zhangqiaorjc): Experimental. + .def("defragment", + [](PyClient& self) { xla::ThrowIfError(self.Defragment()); }) + .def("get_emit_python_callback_descriptor", + xla::ValueOrThrowWrapper(&PyClient::GetEmitPythonCallbackDescriptor), + nb::arg("callable"), nb::arg("operand_shapes"), + nb::arg("result_shapes").none() = nb::none()) + .def("make_python_callback_from_host_send_and_recv", + xla::ValueOrThrowWrapper( + &PyClient::MakePythonCallbackUsingHostSendAndRecv), + nb::arg("callable"), nb::arg("operand_shapes"), + nb::arg("result_shapes"), nb::arg("send_channel_ids"), + nb::arg("recv_channel_ids"), + nb::arg("serializer").none() = nb::none()) + .def( + "get_default_layout", + [](PyClient& self, nb_dtype dtype, nb::sequence shard_shape, + nb_class_ptr device) + -> std::shared_ptr { + ifrt::DType ifrt_type = xla::ValueOrThrow(DtypeToIfRtDType(dtype)); + std::vector dims = SequenceToVector(shard_shape); + return xla::ValueOrThrow(self.ifrt_client()->GetDefaultLayout( + ifrt_type, dims, device->device(), xla::ifrt::MemoryKind())); + }, + nb::arg("dtype"), nb::arg("shard_shape"), nb::arg("device")) + .def("__getattr__", + [](PyClient& client, absl::string_view name) -> nb::object { + const auto& attrs = client.Attributes().map(); + auto it = attrs.find(name); + if (it != attrs.end()) { + return std::visit([](auto&& v) { return nb::cast(v.value); }, + it->second); + } + throw nb::attribute_error( + absl::StrCat("Unknown attribute ", name).c_str()); + }); +} + +} // namespace xla diff --git a/jaxlib/xla/py_client.h b/jaxlib/xla/py_client.h new file mode 100644 index 000000000000..9b9d43d90228 --- /dev/null +++ b/jaxlib/xla/py_client.h @@ -0,0 +1,270 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_PY_CLIENT_H_ +#define JAXLIB_XLA_PY_CLIENT_H_ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/Support/Casting.h" +#include "nanobind/nanobind.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/python/ifrt/attribute_map.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/compiler.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/program.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "xla/shape.h" + +namespace xla { + +class PyClient; +class PyLoadedExecutable; +class PyArray; +class PyDevice; +class PyMemorySpace; +struct PyArray_Storage; + +// Python wrapper around PjRtClient. +// We use a wrapper class to add Python-specific functionality. +class PyClient { + public: + static nb_class_ptr Make(std::shared_ptr ifrt_client); + + // Do not call the constructor directly. Use `PyClient::Make` instead. + explicit PyClient(std::shared_ptr ifrt_client); + virtual ~PyClient(); + + ifrt::Client* ifrt_client() const { return ifrt_client_.get(); } + const std::shared_ptr& shared_ptr_ifrt_client() const { + return ifrt_client_; + } + + // Short-term escape hatch to get PjRtClient from PyClient. + // TODO(hyeontaek): Migrate all users of this method to be agnostic of PjRt. + xla::PjRtClient* pjrt_client() const { + auto* pjrt_client = + llvm::dyn_cast_or_null(ifrt_client_.get()); + if (pjrt_client == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + return pjrt_client->pjrt_client(); + } + std::shared_ptr shared_ptr_pjrt_client() { + auto* pjrt_client = + llvm::dyn_cast_or_null(ifrt_client_.get()); + if (pjrt_client == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + return pjrt_client->shared_ptr_pjrt_client(); + } + + // Legacy alises. + std::shared_ptr shared_pjrt_client() { + return shared_ptr_pjrt_client(); + } + + absl::string_view platform_name() const { + // TODO(phawkins): this is a temporary backwards compatibility shim. We + // changed the name PJRT reports for GPU platforms to "cuda" or "rocm", but + // we haven't yet updated JAX clients that expect "gpu". Migrate users and + // remove this code. + if (ifrt_client_->platform_name() == "cuda" || + ifrt_client_->platform_name() == "rocm") { + return "gpu"; + } else { + return ifrt_client_->platform_name(); + } + } + absl::string_view raw_platform_name() const { + // TODO(parkers): Once platform_name() is the same, remove this. + return ifrt_client_->platform_name(); + } + absl::string_view platform_version() const { + return ifrt_client_->platform_version(); + } + absl::string_view runtime_type() const { + return ifrt_client_->runtime_type(); + } + + // Returns implementation-specific attributes about this client, e.g. the PJRT + // C API version if applicable. + const xla::ifrt::AttributeMap& Attributes() const { + return client_attributes_; + } + + int addressable_device_count() const { + return ifrt_client_->addressable_device_count(); + } + int device_count() const { return ifrt_client_->device_count(); } + int process_index() const { return ifrt_client_->process_index(); } + + std::vector> Devices(); + std::vector> LocalDevices(); + // Returns all devices in the client. Private API; only use this method for + // implementing backend._get_all_devices(). + // TODO(hyeontaek): Remove this method once we have a unified API for + // enumerating devices with different criteria. + std::vector> GetAllDevices(); + absl::StatusOr> DeviceFromLocalHardwareId( + int local_hardware_id); + + // Returns the PyDevice associated with the given ifrt::Device. + nb_class_ptr GetPyDevice(ifrt::Device* device); + + // Returns the PyMemorySpace associated with the given ifrt::Memory. + nb_class_ptr GetPyMemorySpace(ifrt::Memory* memory_space); + + // Returns a vector of live PyArray objects. PyArray objects may share + // PjRtBuffers, so there may be duplicates of the same underlying device + // buffer. + std::vector LiveBuffersOnDevice(ifrt::Device* device); + + nanobind::list LiveExecutables(); + + // TODO(zhangqiaorjc): Remove when we have transparent defragmentation. + absl::Status Defragment(); + + static absl::StatusOr BufferFromPyval( + nb_class_ptr client, nanobind::handle argument, + ifrt::Device* device, bool force_copy, + ifrt::Client::HostBufferSemantics host_buffer_semantics); + + static absl::StatusOr> CompileIfrtProgram( + nb_class_ptr client, + std::unique_ptr ifrt_program, + std::unique_ptr ifrt_options); + + static absl::StatusOr> Compile( + nb_class_ptr client, std::string mlir_module, + CompileOptions options, std::vector host_callbacks); + + static absl::StatusOr> Compile( + nb_class_ptr client, std::string mlir_module, + CompileOptions options, std::vector host_callbacks); + + absl::StatusOr SerializeExecutable( + const PyLoadedExecutable& executable) const; + static absl::StatusOr> DeserializeExecutable( + nb_class_ptr client, nanobind::bytes serialized, + std::optional options, + std::vector host_callbacks); + + absl::StatusOr HeapProfile(); + + // `GetEmitPythonCallbackDescriptor` takes in an input Python callable that + // takes in arguments of shapes `operand_shapes` and returns values of shapes + // `result_shapes`. It returns a pair of a `uint64_t` descriptor and a Python + // object whose reference will keep the Python callback alive. The descriptor + // should be passed into a 'xla_python_cpu_callback' or + // 'xla_python_gpu_callback' CustomCall as its first argument. Typically the + // callback may be kept alive by attaching the keep-alive object to the + // executable built from this computation. + // + // The callable receives as arguments NumPy arrays for arguments with array + // types, and None for Token argument. The callable must return a tuple of + // either arrays or None values. + absl::StatusOr> + GetEmitPythonCallbackDescriptor(nanobind::callable callable, + absl::Span operand_shapes, + absl::Span result_shapes); + + // `MakePythonCallbackUsingHostSendAndRecv` takes in an input Python callable + // that takes in arguments of shapes `operand_shapes` and returns results of + // shapes `result_shapes`. The arguments correspond to Send ops in the HLO + // program through `send_channel_ids` and the results correspond to Recv ops + // through `recv_channel_ids`. It returns the host callback as an opaque + // object whose reference will keep the Python callback alive. The host + // callback can be passed to `PyClient::Compile` or + // `PyClient::DeserializeExecutable`. The corresponding Send/Recv ops in the + // XLA computation can trigger the execution of this host callback. + // `serializer` is a function that takes `callable` as an argument and returns + // a serialized callable as a string. + // + // The callable receives as arguments NumPy arrays for arguments with array + // types, and None for Token argument. The callable must return a tuple of + // either arrays or None values. + absl::StatusOr MakePythonCallbackUsingHostSendAndRecv( + nanobind::callable callable, absl::Span operand_shapes, + absl::Span result_shapes, + absl::Span send_channel_ids, + absl::Span recv_channel_ids, + nanobind::callable serializer); + + std::vector LiveArrays() const; + + static void RegisterPythonTypes(nanobind::module_& m); + + protected: + static void Initialize(nb_class_ptr client); + + private: + friend class PyLoadedExecutable; + friend class PyArray; + friend struct PyArray_Storage; + + static int tp_traverse(PyObject* self, visitproc visit, void* arg); + static int tp_clear(PyObject* self); + static PyType_Slot slots_[]; + + std::shared_ptr ifrt_client_; + xla::ifrt::AttributeMap client_attributes_; + // Pointers to intrusive doubly-linked lists of arrays and executables, used + // to iterate over all known objects when heap profiling. The list structure + // is protected by the GIL. + + nanobind::ft_mutex executables_mutex_; + // List guarded by executables_mutex_. + PyLoadedExecutable* executables_ = nullptr; + +#ifdef NB_FREE_THREADING + static constexpr size_t kNumArraysShards = 16; +#else + static constexpr size_t kNumArraysShards = 1; +#endif + struct ArraysShard { + mutable nanobind::ft_mutex mutex; + PyArray_Storage* arrays; + }; + std::array arrays_; + + absl::flat_hash_map> devices_; + absl::flat_hash_map> + memory_spaces_; +}; + +} // namespace xla + +#endif // JAXLIB_XLA_PY_CLIENT_H_ diff --git a/jaxlib/xla/py_compile_only_client.cc b/jaxlib/xla/py_compile_only_client.cc new file mode 100644 index 000000000000..6319c70f91b0 --- /dev/null +++ b/jaxlib/xla/py_compile_only_client.cc @@ -0,0 +1,131 @@ +/* Copyright 2023 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/py_compile_only_client.h" + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/py_client.h" +#include "xla/pjrt/mlir_to_hlo.h" +#include "xla/pjrt/pjrt_compiler.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/compile_only_ifrt/client.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/python/pjrt_ifrt/pjrt_executable.h" +#include "xla/python/pjrt_ifrt/pjrt_topology.h" +#include "xla/python/pjrt_ifrt/xla_compiler.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/python/lib/core/numpy.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace nb = nanobind; + +namespace xla { + +namespace { + +class CompileOnlyPyClient : public PyClient { + public: + using PyClient::PyClient; + + static nb_class_ptr Make( + std::shared_ptr topology) { + auto client = + nb::borrow>(make_nb_class( + std::make_unique(std::move(topology)))); + CompileOnlyPyClient::Initialize(client); + return client; + } + + absl::StatusOr> CompileUnloaded( + absl::string_view mlir_module, CompileOptions options, + std::vector host_callbacks) { + if (!host_callbacks.empty()) { + return Unimplemented( + "Compiling with host_callbacks not available with compile-only " + "client."); + } + nb::gil_scoped_release gil_release; + mlir::MLIRContext context; + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ParseMlirModuleString(mlir_module, context)); + if (options.executable_build_options.use_shardy_partitioner()) { + // Since Shardy is located in the middle of the XLA pipeline, we need to + // export it before going to HLO while preserving Shardy ops and attrs. + TF_RETURN_IF_ERROR(ExportShardyForHloRoundTrip(*module)); + } + auto* ifrt_client = + llvm::dyn_cast_or_null(this->ifrt_client()); + CHECK(ifrt_client) << "CompileOnlyPyClient requires ifrt_client be a " + "CompileOnlyIfRtClient"; + auto xla_options = std::make_unique(options); + TF_ASSIGN_OR_RETURN(auto executable, + PjRtCompile(std::move(options), module.get(), + *ifrt_client->topology().description())); + TF_ASSIGN_OR_RETURN(auto ifrt_executable, + ifrt::PjRtExecutable::Create(std::move(executable))); + return std::shared_ptr(std::move(ifrt_executable)); + } + + private: + static void Initialize(nb_class_ptr client) { + PyClient::Initialize(client); + } +}; + +} // namespace + +nb_class_ptr MakeCompileOnlyClient( + std::shared_ptr topology) { + return CompileOnlyPyClient::Make(std::move(topology)); +} + +void RegisterCompileOnlyClient(nb::module_& m) { + nb::class_(m, "CompileOnlyPyClient") + .def( + "compile", + [](CompileOnlyPyClient& self, nb::bytes mlir_module, + CompileOptions options, std::vector host_callbacks) { + return ValueOrThrow(self.CompileUnloaded( + absl::string_view(mlir_module.c_str(), mlir_module.size()), + std::move(options), std::move(host_callbacks))); + }, + nb::arg("computation"), nb::arg("compile_options") = CompileOptions(), + nb::arg("host_callbacks") = std::vector()) + .def( + "compile", ValueOrThrowWrapper(&CompileOnlyPyClient::CompileUnloaded), + nb::arg("computation"), nb::arg("compile_options") = CompileOptions(), + nb::arg("host_callbacks") = std::vector()); +} + +} // namespace xla diff --git a/jaxlib/xla/py_compile_only_client.h b/jaxlib/xla/py_compile_only_client.h new file mode 100644 index 000000000000..721830d6f52e --- /dev/null +++ b/jaxlib/xla/py_compile_only_client.h @@ -0,0 +1,45 @@ +/* Copyright 2023 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_PY_COMPILE_ONLY_CLIENT_H_ +#define JAXLIB_XLA_PY_COMPILE_ONLY_CLIENT_H_ + +#include + +// placeholder for index annotation headers +#include "nanobind/nanobind.h" +#include "jaxlib/xla/py_client.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/python/pjrt_ifrt/pjrt_topology.h" + +namespace xla { + +// This is a workaround for AOT compilation until topologies and device +// descriptions are better integrated into jax's Python code. It returns a +// PyClient that will return errors for all non-AOT methods. It also exposes a +// different compile method that returns an unloaded executable (vs. PyClient +// usually returns a loaded executable). RegisterCompileOnlyClient() overloads +// the Python "compile" method to return the unloaded executable, and we rely on +// Python duck typing to treat the unloaded executable like a loaded executable +// (except it will raise errors if you try to run it, which is what we want for +// AOT environments). +nb_class_ptr MakeCompileOnlyClient( + std::shared_ptr); + +void RegisterCompileOnlyClient(nanobind::module_& m); + +} // namespace xla + +#endif // JAXLIB_XLA_PY_COMPILE_ONLY_CLIENT_H_ diff --git a/jaxlib/xla/py_device.cc b/jaxlib/xla/py_device.cc new file mode 100644 index 000000000000..20c257bb7d1a --- /dev/null +++ b/jaxlib/xla/py_device.cc @@ -0,0 +1,350 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/py_device.h" + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "llvm/Support/Casting.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/py_client.h" +#include "jaxlib/xla/py_memory_space.h" +#include "xla/layout_util.h" +#include "xla/literal.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/python/nb_helpers.h" +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "xla/python/pjrt_ifrt/pjrt_device.h" +#include "xla/python/python_ref_manager.h" +#include "xla/python/types.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/tsl/framework/allocator.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" + +namespace nb = ::nanobind; + +namespace xla { + +PyDevice::PyDevice(nb_class_ptr client, ifrt::Device* device) + : client_(std::move(client)), device_(device) {} + +int PyDevice::id() const { return device_->Id().value(); } + +int PyDevice::process_index() const { return device_->ProcessIndex(); } + +absl::string_view PyDevice::platform() const { + // TODO(phawkins): this is a temporary backwards + // compatibility shim. We changed the name PJRT + // reports for GPU platforms to "cuda" or "rocm", + // but we haven't yet updated JAX clients that + // expect "gpu". Migrate users and remove this + // code. + if (client_->platform_name() == "cuda" || + client_->platform_name() == "rocm") { + return absl::string_view("gpu"); + } else { + return client_->platform_name(); + } +} + +absl::string_view PyDevice::device_kind() const { return device_->Kind(); } + +std::optional PyDevice::local_hardware_id() const { + // TODO(phawkins): consider supporting this for non-PJRT devices. + ifrt::PjRtDevice* device = llvm::dyn_cast(device_); + if (device == nullptr || !device->IsAddressable()) { + return std::nullopt; + } + int local_hardware_id = device->pjrt_device()->local_hardware_id().value(); + if (local_hardware_id == -1) { + return std::nullopt; + } + return local_hardware_id; +} + +absl::string_view PyDevice::Str() const { return device_->DebugString(); } + +absl::string_view PyDevice::Repr() const { return device_->ToString(); } + +absl::Status PyDevice::TransferToInfeed(LiteralSlice literal) { + GlobalPyRefManager()->CollectGarbage(); + nb::gil_scoped_release gil_release; + auto client = llvm::dyn_cast(client_->ifrt_client()); + auto device = llvm::dyn_cast(device_); + if (client == nullptr || device == nullptr) { + return xla::InvalidArgument( + "TransferToInfeed is only supported for PjRt devices."); + } + return client->TransferToInfeed(device, literal); +} + +absl::StatusOr PyDevice::TransferFromOutfeed(Shape shape) { + GlobalPyRefManager()->CollectGarbage(); + std::shared_ptr literal; + { + nb::gil_scoped_release gil_release; + auto client = llvm::dyn_cast(client_->ifrt_client()); + auto device = llvm::dyn_cast(device_); + if (client == nullptr || device == nullptr) { + return xla::InvalidArgument( + "TransferFromOutfeed is only supported for PjRt devices."); + } + ShapeUtil::ForEachMutableSubshape( + &shape, [](Shape* subshape, const ShapeIndex&) { + if (!subshape->has_layout()) { + LayoutUtil::SetToDefaultLayout(subshape); + } + }); + literal = std::make_shared(shape); + TF_RETURN_IF_ERROR(client->TransferFromOutfeed(device, literal.get())); + } + return LiteralToPython(std::move(literal)); +} + +absl::StatusOr> PyDevice::Memory( + absl::string_view kind) const { + ifrt::Memory* result_memory_space = nullptr; + for (auto* memory_space : device_->Memories()) { + if (memory_space->Kind().memory_kind() == kind) { + if (result_memory_space != nullptr) { + std::string memories = absl::StrJoin( + device_->Memories(), ", ", + [](std::string* out, const auto& memory_space) { + absl::StrAppend(out, *memory_space->Kind().memory_kind()); + }); + auto device_kind = device_->Kind(); + return xla::InvalidArgument( + "Found more than one addressable memory for " + "kind %s which is not allowed. There can only " + "be one memory for each " + "kind. Device %s can address the following " + "memory kinds: %s", + kind, device_kind, memories); + } + result_memory_space = memory_space; + } + } + if (result_memory_space == nullptr) { + std::string memories = absl::StrJoin( + device_->Memories(), ", ", + [](std::string* out, const auto& memory_space) { + absl::StrAppend(out, *memory_space->Kind().memory_kind()); + }); + auto device_kind = device_->Kind(); + return xla::InvalidArgument( + "Could not find memory addressable by device %s. Device %s " + "can address the following memory kinds: %s. " + "Got memory kind: %s", + device_kind, device_kind, memories, kind); + } + return client_->GetPyMemorySpace(result_memory_space); +} + +absl::StatusOr> PyDevice::DefaultMemory() const { + TF_ASSIGN_OR_RETURN(auto* memory_space, device_->DefaultMemory()); + return client_->GetPyMemorySpace(memory_space); +} + +nb::list PyDevice::AddressableMemories() const { + nb::list memory_spaces; + for (auto* memory_space : device_->Memories()) { + memory_spaces.append(client_->GetPyMemorySpace(memory_space)); + } + return memory_spaces; +} + +absl::StatusOr> PyDevice::MemoryStats() const { + GlobalPyRefManager()->CollectGarbage(); + ifrt::PjRtDevice* device = llvm::dyn_cast(device_); + if (device == nullptr || !device->IsAddressable()) { + return xla::InvalidArgument( + "MemoryStats is only supported for addressable PjRt devices."); + } + absl::StatusOr maybe_stats = + device->pjrt_device()->GetAllocatorStats(); + if (absl::IsUnimplemented(maybe_stats.status())) { + return std::nullopt; + } + // Raise error if any status other than Unimplemented is returned. + ThrowIfError(maybe_stats.status()); + + nb::dict result; + result["num_allocs"] = maybe_stats->num_allocs; + result["bytes_in_use"] = maybe_stats->bytes_in_use; + result["peak_bytes_in_use"] = maybe_stats->peak_bytes_in_use; + result["largest_alloc_size"] = maybe_stats->largest_alloc_size; + if (maybe_stats->bytes_limit) { + result["bytes_limit"] = *maybe_stats->bytes_limit; + } + result["bytes_reserved"] = maybe_stats->bytes_reserved; + result["peak_bytes_reserved"] = maybe_stats->peak_bytes_reserved; + if (maybe_stats->bytes_reservable_limit) { + result["bytes_reservable_limit"] = *maybe_stats->bytes_reservable_limit; + } + result["largest_free_block_bytes"] = maybe_stats->largest_free_block_bytes; + if (maybe_stats->pool_bytes) { + result["pool_bytes"] = *maybe_stats->pool_bytes; + } + if (maybe_stats->peak_pool_bytes) { + result["peak_pool_bytes"] = *maybe_stats->peak_pool_bytes; + } + return result; +} + +absl::StatusOr PyDevice::GetStreamForExternalReadyEvents() + const { + ifrt::PjRtDevice* device = llvm::dyn_cast(device_); + if (device == nullptr || !device->IsAddressable()) { + return xla::InvalidArgument( + "GetStreamForExternalReadyEvents is only supported for addressable " + "PjRt devices."); + } + return device->pjrt_device()->GetStreamForExternalReadyEvents(); +} + +/* static */ int PyDevice::tp_traverse(PyObject* self, visitproc visit, + void* arg) { + PyDevice* d = nb::inst_ptr(self); + Py_VISIT(d->client().ptr()); + return 0; +} + +/* static */ int PyDevice::tp_clear(PyObject* self) { + PyDevice* d = nb::inst_ptr(self); + nb_class_ptr client; + std::swap(client, d->client_); + return 0; +} + +PyType_Slot PyDevice::slots_[] = { + {Py_tp_traverse, (void*)PyDevice::tp_traverse}, + {Py_tp_clear, (void*)PyDevice::tp_clear}, + {0, nullptr}, +}; + +/* static */ void PyDevice::RegisterPythonType(nb::module_& m) { + nb::class_ device( + m, "Device", nb::type_slots(PyDevice::slots_), + "A descriptor of an available device.\n\nSubclasses are used to " + "represent specific types of devices, e.g. CPUs, GPUs. Subclasses may " + "have additional properties specific to that device type."); + device + .def_prop_ro( + "id", &PyDevice::id, + "Integer ID of this device.\n\nUnique across all available devices " + "of this type, including remote devices on multi-host platforms.") + .def_prop_ro("process_index", &PyDevice::process_index, + "Integer index of this device's process.\n\n" + "This is always 0 except on multi-process platforms.") + .def_prop_ro("host_id", &PyDevice::process_index, + "Deprecated; please use process_index") + .def_prop_ro("task_id", &PyDevice::process_index, + "Deprecated; please use process_index") + .def_prop_ro("platform", &PyDevice::platform) + .def_prop_ro("device_kind", &PyDevice::device_kind) + .def_prop_ro("client", &PyDevice::client) + .def_prop_ro( + "local_hardware_id", &PyDevice::local_hardware_id, + "Opaque hardware ID, e.g., the CUDA device number. In general, not " + "guaranteed to be dense, and not guaranteed to be defined on all " + "platforms.") + .def("__str__", &PyDevice::Str) + .def("__repr__", &PyDevice::Repr) + .def("transfer_to_infeed", + ThrowIfErrorWrapper(&PyDevice::TransferToInfeed)) + .def("transfer_from_outfeed", + ValueOrThrowWrapper(&PyDevice::TransferFromOutfeed)) + .def("memory", ValueOrThrowWrapper(&PyDevice::Memory), nb::arg("kind")) + .def("default_memory", ValueOrThrowWrapper(&PyDevice::DefaultMemory), + "Returns the default memory of a device.") + .def("addressable_memories", &PyDevice::AddressableMemories, + "Returns all the memories that a device can address.") + + .def("live_buffers", + [](nb::handle device) { + PythonDeprecationWarning( + /*stacklevel=*/1, + "Per device live_buffers() is deprecated. Please " + "use the jax.live_arrays() for jax.Arrays instead."); + return nb::list(); + }) + .def( + "memory_stats", ValueOrThrowWrapper(&PyDevice::MemoryStats), + "Returns memory statistics for this device keyed by name. May not " + "be implemented on all platforms, and different platforms may return " + "different stats, or -1 for unavailable stats. 'bytes_in_use' is " + "usually available. Intended for diagnostic use.") + .def( + "get_stream_for_external_ready_events", + xla::ValueOrThrowWrapper(&PyDevice::GetStreamForExternalReadyEvents)); + static PyMethodDef get_attr_method = { + "__getattr__", + +[](PyObject* self, PyObject* args) -> PyObject* { + PyObject* key; + if (!PyArg_ParseTuple(args, "O", &key)) { + PyErr_SetString(PyExc_TypeError, "__getattr__ must take 1 argument."); + return nullptr; + } + try { + auto device = nb::cast(nb::handle(self)); + auto name = nb::cast(nb::handle(key)); + const auto& attrs = device->device_->Attributes().map(); + auto it = attrs.find(name); + if (it != attrs.end()) { + auto result = std::visit([](auto&& v) { return nb::cast(v.value); }, + it->second); + return result.release().ptr(); + } + PyErr_SetNone(PyExc_AttributeError); + return nullptr; + } catch (std::exception& e) { + PyErr_Format(PyExc_SystemError, "Unhandled nanobind exception: %s", + e.what()); + return nullptr; + } catch (...) { + PyErr_SetString(PyExc_SystemError, "Unhandled nanobind exception."); + return nullptr; + } + }, + METH_VARARGS, + nullptr, + }; + device.attr("__getattr__") = nb::steal(PyDescr_NewMethod( + reinterpret_cast(device.ptr()), &get_attr_method)); +} + +} // namespace xla diff --git a/jaxlib/xla/py_device.h b/jaxlib/xla/py_device.h new file mode 100644 index 000000000000..6d2b3893dea8 --- /dev/null +++ b/jaxlib/xla/py_device.h @@ -0,0 +1,82 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_PY_DEVICE_H_ +#define JAXLIB_XLA_PY_DEVICE_H_ + +#include + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "nanobind/nanobind.h" +#include "jaxlib/xla/py_client.h" +#include "xla/literal.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/shape.h" + +namespace xla { + +class PyDevice { + public: + PyDevice(nb_class_ptr client, ifrt::Device* device); + + // Devices are compared using Python object identity, so we don't allow them + // to be copied or moved. + PyDevice(const PyDevice&) = delete; + PyDevice(PyDevice&&) = delete; + PyDevice& operator=(const PyDevice&) = delete; + PyDevice& operator=(PyDevice&&) = delete; + + const nb_class_ptr& client() const { return client_; } + ifrt::Device* device() const { return device_; } + + int id() const; + int process_index() const; + absl::string_view platform() const; + absl::string_view device_kind() const; + std::optional local_hardware_id() const; + + absl::string_view Str() const; + absl::string_view Repr() const; + + absl::Status TransferToInfeed(LiteralSlice literal); + absl::StatusOr TransferFromOutfeed(Shape shape); + + absl::StatusOr> Memory( + absl::string_view kind) const; + absl::StatusOr> DefaultMemory() const; + nanobind::list AddressableMemories() const; + absl::StatusOr> MemoryStats() const; + + absl::StatusOr GetStreamForExternalReadyEvents() const; + + static void RegisterPythonType(nanobind::module_& m); + + private: + static int tp_traverse(PyObject* self, visitproc visit, void* arg); + static int tp_clear(PyObject* self); + static PyType_Slot slots_[]; + + nb_class_ptr client_; + ifrt::Device* device_; +}; + +} // namespace xla + +#endif // JAXLIB_XLA_PY_DEVICE_H_ diff --git a/jaxlib/xla/py_device_list.cc b/jaxlib/xla/py_device_list.cc new file mode 100644 index 000000000000..593a86ccbe42 --- /dev/null +++ b/jaxlib/xla/py_device_list.cc @@ -0,0 +1,472 @@ +/* Copyright 2023 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/py_device_list.h" + +#include + +#include +#include +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "nanobind/make_iterator.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "jaxlib/xla/py_client.h" +#include "jaxlib/xla/py_device.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/python/nb_helpers.h" +#include "xla/python/python_ref_manager.h" +#include "xla/python/types.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/util.h" + +namespace jax { + +namespace nb = ::nanobind; + +PyDeviceList::PyDeviceList(xla::nb_class_ptr py_client, + xla::ifrt::DeviceListRef device_list) + : py_client_(std::move(py_client)), device_list_(std::move(device_list)) {} + +PyDeviceList::PyDeviceList(nb::tuple py_device_assignment) + : device_list_(py_device_assignment) { + // Attempt to convert to Python devices into `ifrt::DeviceList`. + if (py_device_assignment.size() == 0) { + return; + } + absl::InlinedVector devices; + devices.reserve(py_device_assignment.size()); + for (nb::handle obj : py_device_assignment) { + if (!nb::isinstance(obj.ptr())) { + // Non-`xla::PyDevice` is used on an alternative JAX backend with device + // duck typing. Use Python device objects already set in `device_list_`. + return; + } + auto py_device = nb::cast(obj); + if (py_client_.get() == nullptr) { + py_client_ = py_device->client(); + } else if (py_device->client().get() != py_client_.get()) { + // If the list contains multiple clients, fall back to device duck typing. + return; + } + devices.push_back(py_device->device()); + } + device_list_ = py_client_->ifrt_client()->MakeDeviceList(devices); +} + +PyDeviceList::~PyDeviceList() { + if (device_list_.index() == 1) { + xla::GlobalPyRefManager()->AddGarbage( + std::move(std::get<1>(std::move(device_list_)))); + } +} + +absl::StatusOr PyDeviceList::ifrt_device_list() + const { + switch (device_list_.index()) { + case 0: + return std::get<0>(device_list_); + case 1: + return xla::InvalidArgument("DeviceList contains non-IFRT devices"); + default: + return xla::InvalidArgument("Unrecognized DeviceList type"); + } +} + +int64_t PyDeviceList::Hash() { + if (!hash_.has_value()) { + switch (device_list_.index()) { + case 0: + hash_ = absl::HashOf(std::get<0>(device_list_)); + break; + case 1: + hash_ = nb::hash(std::get<1>(device_list_)); + break; + default: + throw nb::value_error("Unrecognized DeviceList type"); + } + } + return *hash_; +} + +/*static*/ bool PyDeviceList::Equal(xla::nb_class_ptr self, + nb::handle other) { + if (!nb::isinstance(other)) { + return false; + } + auto o = nb::cast(other); + // Fast-path using a pointer equality check. + if (self.get() == o) { + return true; + } + int64_t h1, h2; + { + nb::ft_object_guard lock(self); + h1 = self->Hash(); + } + { + nb::ft_object_guard lock(other); + h2 = o->Hash(); + } + if (h1 != h2) { + return false; + } + if (self->device_list_.index() == 0 && o->device_list_.index() == 0) { + nb::gil_scoped_release gil_release; + return *std::get<0>(self->device_list_) == *std::get<0>(o->device_list_); + } else { + return self->AsTuple().equal(o->AsTuple()); + } +} + +/*static*/ bool PyDeviceList::NotEqual(xla::nb_class_ptr self, + nb::handle other) { + return !Equal(std::move(self), other); +} + +int PyDeviceList::Len() const { + switch (device_list_.index()) { + case 0: + return std::get<0>(device_list_)->size(); + case 1: + return nb::len(std::get<1>(device_list_)); + default: + throw nb::value_error("Unrecognized DeviceList type"); + } +} + +nb::object PyDeviceList::GetItem(int index) { + switch (device_list_.index()) { + case 0: { + const xla::ifrt::DeviceListRef& device_list = std::get<0>(device_list_); + if (index < -device_list->size() || index >= device_list->size()) { + throw nb::index_error(); + } else if (index < 0) { + index += device_list->size(); + } + return py_client_->GetPyDevice(device_list->devices()[index]); + } + case 1: + return std::get<1>(device_list_).attr("__getitem__")(index); + default: + throw nb::value_error("Unrecognized DeviceList type"); + } +} + +nb::object PyDeviceList::GetSlice(nb::slice slice) { + switch (device_list_.index()) { + case 0: { + const xla::ifrt::DeviceListRef& device_list = std::get<0>(device_list_); + const absl::Span devices = + device_list->devices(); + Py_ssize_t start, stop, step, slicelength; + if (PySlice_GetIndicesEx(slice.ptr(), devices.size(), &start, &stop, + &step, &slicelength) != 0) { + throw nb::python_error(); + } + nb::tuple out = nb::steal(PyTuple_New(slicelength)); + for (size_t i = 0; i < slicelength; ++i) { + nb::object d = py_client_->GetPyDevice(devices[start]); + PyTuple_SET_ITEM(out.ptr(), i, d.release().ptr()); + start += step; + } + return std::move(out); + } + case 1: + return std::get<1>(device_list_).attr("__getitem__")(slice); + default: + throw nb::value_error("Unrecognized DeviceList type"); + } +} + +nb::tuple PyDeviceList::AsTuple() const { + switch (device_list_.index()) { + case 0: { + const xla::ifrt::DeviceListRef& device_list = std::get<0>(device_list_); + nb::tuple out = nb::steal(PyTuple_New(device_list->size())); + int i = 0; + for (xla::ifrt::Device* device : device_list->devices()) { + nb::object d = py_client_->GetPyDevice(device); + PyTuple_SET_ITEM(out.ptr(), i, d.release().ptr()); + ++i; + } + return out; + } + case 1: + return std::get<1>(device_list_); + default: + throw nb::value_error("Unrecognized DeviceList type"); + } +} + +nb::iterator PyDeviceList::Iter() { + switch (device_list_.index()) { + case 0: { + // Iterator whose deference converts `xla::ifrt::Device*` into JAX + // `PjRtDevice`. + struct Iterator { + void operator++() { ++it; } + bool operator==(const Iterator& other) const { return it == other.it; } + xla::nb_class_ptr operator*() const { + return py_client->GetPyDevice(*it); + } + xla::nb_class_ptr py_client; + absl::Span::const_iterator it; + }; + return nb::make_iterator( + nb::type(), "ifrt_device_iterator", + Iterator{py_client_, std::get<0>(device_list_)->devices().cbegin()}, + Iterator{py_client_, std::get<0>(device_list_)->devices().cend()}); + } + case 1: + return nb::make_iterator( + nb::type(), "python_device_iterator", + std::get<1>(device_list_).begin(), std::get<1>(device_list_).end()); + default: + throw nb::value_error("Unrecognized DeviceList type"); + } +} + +std::string PyDeviceList::Str() { + return nb::cast(nb::str(AsTuple())); +} + +nb::tuple PyDeviceList::Dump() const { return AsTuple(); } + +bool PyDeviceList::IsFullyAddressable() { + if (!is_fully_addressable_.has_value()) { + is_fully_addressable_ = true; + switch (device_list_.index()) { + case 0: { + const int process_index = py_client_ ? py_client_->process_index() : 0; + for (const xla::ifrt::Device* device : + std::get<0>(device_list_)->devices()) { + if (device->ProcessIndex() != process_index) { + is_fully_addressable_ = false; + break; + } + } + break; + } + case 1: { + for (nb::handle device : std::get<1>(device_list_)) { + if (nb::cast(device.attr("process_index")) != + nb::cast(device.attr("client").attr("process_index")())) { + is_fully_addressable_ = false; + break; + } + } + break; + } + default: + throw nb::value_error("Unrecognized DeviceList type"); + } + } + return *is_fully_addressable_; +} + +/*static*/ xla::nb_class_ptr PyDeviceList::AddressableDeviceList( + xla::nb_class_ptr self) { + nb::ft_object_guard lock(self); + if (self->IsFullyAddressable()) { + // Do not cache this result in `addressable_device_list_`. Otherwise, it + // will create a cycle that prevents deletion of this object. + return self; + } + if (!self->addressable_device_list_.has_value()) { + switch (self->device_list_.index()) { + case 0: { + absl::InlinedVector addressable_devices; + const int process_index = + self->py_client_ ? self->py_client_->process_index() : 0; + for (xla::ifrt::Device* device : + std::get<0>(self->device_list_)->devices()) { + if (device->ProcessIndex() == process_index) { + addressable_devices.push_back(device); + } + } + self->addressable_device_list_ = xla::make_nb_class( + self->py_client_, self->py_client_->ifrt_client()->MakeDeviceList( + addressable_devices)); + break; + } + case 1: { + auto device_list = std::get<1>(self->device_list_); + std::vector addressable_devices; + for (size_t i = 0; i < device_list.size(); ++i) { + nb::object device = device_list[i]; + if (nb::cast(device.attr("process_index")) == + nb::cast(device.attr("client").attr("process_index")())) { + addressable_devices.push_back(std::move(device)); + } + } + self->addressable_device_list_ = xla::make_nb_class( + xla::MutableSpanToNbTuple(absl::MakeSpan(addressable_devices))); + break; + } + default: + throw nb::value_error("Unrecognized DeviceList type"); + } + } + return *self->addressable_device_list_; +} + +void PyDeviceList::PopulateMemoryKindInfo() { + if (device_list_.index() == 1) { + // Handle Python duck-type devices in a separate function for readability. + PopulateMemoryKindInfoForDuckTypedDevices(); + return; + } + if (device_list_.index() != 0) { + throw nb::value_error("Unrecognized DeviceList type"); + } + MemoryKindInfo info; + xla::ifrt::Device* addressable_device = nullptr; + const int process_index = py_client_ ? py_client_->process_index() : 0; + for (xla::ifrt::Device* device : std::get<0>(device_list_)->devices()) { + if (device->ProcessIndex() == process_index) { + addressable_device = device; + break; + } + } + if (addressable_device == nullptr) { + info.default_memory_kind = nb::none(); + memory_kind_info_ = std::move(info); + return; + } + + auto default_memory = addressable_device->DefaultMemory(); + if (!default_memory.ok()) { + // Cache the error. + memory_kind_info_ = default_memory.status(); + return; + } + info.default_memory_kind = nb::cast(*(*default_memory)->Kind().memory_kind()); + nb::tuple memory_kinds = + nb::steal(PyTuple_New(addressable_device->Memories().size())); + for (size_t i = 0; i < addressable_device->Memories().size(); ++i) { + auto* memory = addressable_device->Memories()[i]; + nb::str s = nb::str(memory->Kind().memory_kind()->data(), + memory->Kind().memory_kind()->size()); + PyTuple_SET_ITEM(memory_kinds.ptr(), i, s.release().ptr()); + } + info.memory_kinds = std::move(memory_kinds); + memory_kind_info_ = std::move(info); +} + +void PyDeviceList::PopulateMemoryKindInfoForDuckTypedDevices() { + MemoryKindInfo info; + try { + nb::handle addressable_device; + for (nb::handle device : std::get<1>(device_list_)) { + if (nb::cast(device.attr("process_index")) == + nb::cast(device.attr("client").attr("process_index")())) { + addressable_device = device; + break; + } + } + if (!addressable_device) { + info.default_memory_kind = nb::none(); + // info.memory_kinds is default-initialized to an empty tuple. + memory_kind_info_ = std::move(info); + return; + } + auto default_memory = addressable_device.attr("default_memory")(); + info.default_memory_kind = default_memory.attr("kind"); + info.memory_kinds = nb::tuple( + nb::object(addressable_device.attr("addressable_memories")())); + memory_kind_info_ = std::move(info); + } catch (nb::python_error& e) { + // Cache the error. + memory_kind_info_ = xla::InvalidArgument("%s", e.what()); + } +} + +/*static*/ absl::StatusOr PyDeviceList::MemoryKinds( + xla::nb_class_ptr self) { + nb::ft_object_guard lock(self); + if (!self->memory_kind_info_.has_value()) { + self->PopulateMemoryKindInfo(); + } + if (!self->memory_kind_info_->ok()) { + return self->memory_kind_info_->status(); + } + return (*self->memory_kind_info_)->memory_kinds; +} + +/*static*/ absl::StatusOr PyDeviceList::DefaultMemoryKind( + xla::nb_class_ptr self) { + nb::ft_object_guard lock(self); + if (!self->memory_kind_info_.has_value()) { + self->PopulateMemoryKindInfo(); + } + if (!self->memory_kind_info_->ok()) { + return self->memory_kind_info_->status(); + } + return (*self->memory_kind_info_)->default_memory_kind; +} + +/*static*/ void PyDeviceList::Register(nb::module_& m) { + nb::class_(m, "DeviceList") + .def(nb::init()) + .def("__hash__", &PyDeviceList::Hash, nb::lock_self()) + .def("__eq__", &PyDeviceList::Equal) + .def("__ne__", &PyDeviceList::NotEqual) + .def("__len__", &PyDeviceList::Len) + .def("__getitem__", &PyDeviceList::GetItem) + .def("__getitem__", &PyDeviceList::GetSlice) + .def("__iter__", &PyDeviceList::Iter, nb::keep_alive<0, 1>()) + .def("__str__", &PyDeviceList::Str) + .def("__repr__", &PyDeviceList::Str) + .def("__getstate__", [](const PyDeviceList& l) { return l.Dump(); }) + .def("__setstate__", + [](PyDeviceList& self, nb::tuple t) { + new (&self) PyDeviceList(std::move(t)); + }) + .def_prop_ro("is_fully_addressable", &PyDeviceList::IsFullyAddressable, + nb::lock_self()) + .def_prop_ro("addressable_device_list", + &PyDeviceList::AddressableDeviceList) + // `xla::ValueOrThrowWrapper` does not work with + // `def_prop_ro()`. Manually convert an error into an exception. + .def_prop_ro("default_memory_kind", + [](xla::nb_class_ptr l) { + auto kind = DefaultMemoryKind(l); + if (!kind.ok()) { + throw nb::value_error(kind.status().ToString().c_str()); + } + return *kind; + }) + .def_prop_ro("memory_kinds", [](xla::nb_class_ptr l) { + auto kinds = MemoryKinds(l); + if (!kinds.ok()) { + throw nb::value_error(kinds.status().ToString().c_str()); + } + return *kinds; + }); +} + +} // namespace jax diff --git a/jaxlib/xla/py_device_list.h b/jaxlib/xla/py_device_list.h new file mode 100644 index 000000000000..ea574c5dc5a2 --- /dev/null +++ b/jaxlib/xla/py_device_list.h @@ -0,0 +1,137 @@ +/* Copyright 2023 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_PY_DEVICE_LIST_H_ +#define JAXLIB_XLA_PY_DEVICE_LIST_H_ + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "nanobind/nanobind.h" +#include "jaxlib/xla/py_client.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/tsl/concurrency/ref_count.h" + +namespace jax { + +// Device list with various caching and direct access to IFRT DeviceList. +class PyDeviceList { + public: + PyDeviceList(xla::nb_class_ptr py_client, + xla::ifrt::DeviceListRef device_list); + explicit PyDeviceList(nanobind::tuple py_device_assignment); + ~PyDeviceList(); + + PyDeviceList(const PyDeviceList&) = delete; + PyDeviceList(PyDeviceList&&) = delete; + PyDeviceList& operator=(const PyDeviceList&) = delete; + PyDeviceList& operator=(PyDeviceList&&) = delete; + + static nanobind::handle type() { + static auto type = nanobind::type(); + return type; + } + + // These two methods are safe to call from C++ without GIL. + xla::nb_class_ptr py_client() const { return py_client_; } + absl::StatusOr ifrt_device_list() const; + + int Len() const; // Requires the GIL in GIL mode. + nanobind::object GetItem(int index); // Requires the GIL in GIL mode. + + // Requires the GIL in GIL mode. Acquires the self lock in non-GIL mode. + static xla::nb_class_ptr AddressableDeviceList( + xla::nb_class_ptr self); + + // Requires the GIL in GIL mode. Acquires the self lock in non-GIL mode. + static absl::StatusOr DefaultMemoryKind( + xla::nb_class_ptr self); + + // Requires the GIL in GIL mode. Acquires the self lock in non-GIL mode. + static absl::StatusOr MemoryKinds( + xla::nb_class_ptr self); + + // go/pywald-pybind-annotation BEGIN + // refs { + // module_path: "third_party/py/jax/jaxlib/xla/xla.cc" + // module_arg {} + // } + // go/pywald-pybind-annotation END + static void Register(nanobind::module_& m); + + private: + nanobind::tuple AsTuple() const; + + // Methods below require GIL. + nanobind::object GetSlice(nanobind::slice slice); + nanobind::iterator Iter(); + + std::string Str(); + + nanobind::tuple Dump() const; + + int64_t Hash(); // Mutates hash_, needs self lock. + + static bool Equal(xla::nb_class_ptr self, + nanobind::handle other); + static bool NotEqual(xla::nb_class_ptr self, + nanobind::handle other); + + // Finds the memory kind info from an addressable device. Requires the GIL + // or self lock. + void PopulateMemoryKindInfo(); + // Same as `PopulateMemoryKindInfo()`, but uses `py_device_assignment_` + // instead of `ifrt_device_list_` to support duck-typed device objects. + // Requires the GIL or self lock. + void PopulateMemoryKindInfoForDuckTypedDevices(); + + // Requires the self lock or GIL is held. + bool IsFullyAddressable(); + + // Valid only if `device_list_` contains `xla::ifrt::DeviceList` and + // non-empty. + xla::nb_class_ptr py_client_; + + // Either C++ `ifrt::DeviceList` or Python duck-type devices. + // TODO(hyeontaek): Remove support for Python duck-type devices once all + // JAX backends and tests are migrated to use an `xla::ifrt::Device` type + // for JAX devices. + // Immutable after constructor; no locking needed. + std::variant device_list_; + + // Populated on demand. Guarded by the object's self lock. + std::optional hash_; + // TODO(hyeontaek): Make the following property cached within + // `xla::ifrt::DeviceList`. + // Populated on demand. Guarded by the object's self lock. + std::optional is_fully_addressable_; + // Populated on demand. Guarded by the object's self lock. + std::optional> addressable_device_list_; + + struct MemoryKindInfo { + nanobind::object default_memory_kind; + nanobind::tuple memory_kinds; + }; + // Populated on demand. Guarded by the object's self lock. + std::optional> memory_kind_info_; +}; + +} // namespace jax + +#endif // JAXLIB_XLA_PY_DEVICE_LIST_H_ diff --git a/jaxlib/xla/py_executable.cc b/jaxlib/xla/py_executable.cc new file mode 100644 index 000000000000..5a02a8f6dd20 --- /dev/null +++ b/jaxlib/xla/py_executable.cc @@ -0,0 +1,463 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/py_executable.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "jaxlib/xla/py_array.h" +#include "jaxlib/xla/py_client.h" +#include "jaxlib/xla/py_device.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/pjrt/pjrt_future.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/python/traceback.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/fingerprint.h" +#include "tsl/profiler/lib/traceme.h" + +namespace xla { + +namespace nb = nanobind; + +absl::Status PyToken::Await() { + CHECK(future_.IsValid()); + nb::gil_scoped_release gil_release; + return future_.Await(); +} + +absl::Status PyShardedToken::Await() { + nb::gil_scoped_release gil_release; + absl::Status status = absl::OkStatus(); + for (auto& future : futures_) { + auto s = future.Await(); + if (!s.ok()) status = std::move(s); + } + return status; +} + +PyLoadedExecutable::PyLoadedExecutable( + nb_class_ptr client, + std::shared_ptr ifrt_loaded_executable, + std::optional traceback, + std::optional fingerprint) + : client_(std::move(client)), + ifrt_loaded_executable_(std::move(ifrt_loaded_executable)), + traceback_(std::move(traceback)), + fingerprint_(std::move(fingerprint)), + next_launch_id_( + fingerprint_.has_value() ? tsl::Fingerprint32(*fingerprint_) : 1) { + CHECK(PyGILState_Check()); + if (fingerprint_) { + VLOG(1) << "Fingerprint for executable " << ifrt_loaded_executable_->name() + << ": " << *fingerprint_; + } + nb::ft_lock_guard lock(client_->executables_mutex_); + next_ = client_->executables_; + client_->executables_ = this; + prev_ = nullptr; + if (next_) { + next_->prev_ = this; + } +} + +PyLoadedExecutable::~PyLoadedExecutable() { + CHECK(PyGILState_Check()); + nb::ft_lock_guard lock(client_->executables_mutex_); + if (client_->executables_ == this) { + client_->executables_ = next_; + } + if (prev_) { + prev_->next_ = next_; + } + if (next_) { + next_->prev_ = prev_; + } +} + +std::vector> PyLoadedExecutable::AddressableDevices() + const { + std::vector> devices; + devices.reserve(ifrt_loaded_executable_->addressable_devices().size()); + for (ifrt::Device* device : ifrt_loaded_executable_->addressable_devices()) { + devices.push_back(client_->GetPyDevice(device)); + } + return devices; +} + +namespace { + +// Traits classes of common methods for std::vector. +template +struct ShardedBufferAdapter; + +template <> +struct ShardedBufferAdapter { + static int num_devices(const ExecuteShardedArg& arg) { + if (std::holds_alternative(arg)) { + return std::get(arg).num_addressable_shards(); + } else { + return std::get>(arg).size(); + } + } + static tsl::RCReference GetIfRtArray( + const ExecuteShardedArg& arg) { + if (std::holds_alternative(arg)) { + return tsl::FormRef(std::get(arg).ifrt_array()); + } + auto& arg_vector = std::get>(arg); + + // TODO(hyeontaek): This on-demand Array creation is not efficient and has + // insufficient information about the shape (a dummy shape is used). This + // should be removed if possible and only be used in the context where the + // shape information is unused. + std::vector> ifrt_arrays; + ifrt_arrays.reserve(arg_vector.size()); + absl::InlinedVector devices; + devices.reserve(arg_vector.size()); + for (auto& arr : arg_vector) { + CHECK_EQ(arr.ifrt_array()->sharding().devices()->size(), 1) + << arr.ifrt_array()->sharding().DebugString(); + ifrt_arrays.push_back(tsl::FormRef(arr.ifrt_array())); + devices.push_back( + arr.ifrt_array()->sharding().devices()->devices().front()); + } + CHECK(!ifrt_arrays.empty()); + // Use a dummy shape. + // TODO(hyeontaek): Find a way to compute a correct shape. + // TODO(yashkatariya): Plumb sharding or memory_kind here. + ifrt::Client* client = ifrt_arrays.front()->client(); + auto ifrt_array = client->AssembleArrayFromSingleDeviceArrays( + ifrt_arrays.front()->shape(), + ifrt::OpaqueSharding::Create(client->MakeDeviceList(devices), + ifrt::MemoryKind()), + absl::MakeSpan(ifrt_arrays), ifrt::ArrayCopySemantics::kReuseInput, + ifrt::SingleDeviceShardSemantics::kAddressableShards); + TF_CHECK_OK(ifrt_array.status()); + return *ifrt_array; + } +}; + +void PopulateExecuteShardedResults( + const nb_class_ptr& client, + std::vector> ifrt_arrays, + const PjRtFuture<>& result_status, int num_computations, + std::vector>& outputs) { + auto traceback = Traceback::Get(); + DCHECK_GT(num_computations, 0); + int num_output_buffers = ifrt_arrays.size(); + outputs.resize(num_output_buffers); + for (int buffer_id = 0; buffer_id < num_output_buffers; ++buffer_id) { + outputs[buffer_id].reserve(num_computations); + auto exploded_arrays = + ifrt_arrays[buffer_id]->DisassembleIntoSingleDeviceArrays( + ifrt::ArrayCopySemantics::kReuseInput, + ifrt::SingleDeviceShardSemantics::kAddressableShards); + TF_CHECK_OK(exploded_arrays.status()); + for (auto& exploded_array : *exploded_arrays) { + outputs[buffer_id].push_back(PyArray::MakeFromSingleDeviceArray( + client, traceback, std::move(exploded_array), false, true, + result_status)); + } + } +} + +template > +absl::StatusOr ExecuteShardedOnLocalDevicesInternal( + const ifrt::ExecuteOptions& options, const nb_class_ptr& client, + ifrt::LoadedExecutable* ifrt_loaded_executable, absl::Span args, + std::optional>>& returned_futures) { + std::vector> output_arrays; + std::unique_ptr> returned_future; + int num_computations = ifrt_loaded_executable->addressable_devices().size(); + PjRtFuture<> result_status; + { + nb::gil_scoped_release gil_release; + for (const auto& arg : args) { + if (ArgAdapter::num_devices(arg) != num_computations) { + return InvalidArgument( + "Expected args to execute_sharded_on_local_devices to have %d " + "shards, got: [%s]", + num_computations, + absl::StrJoin(args, ", ", [](std::string* out, const ArgT& arg) { + out->append(std::to_string(ArgAdapter::num_devices(arg))); + })); + } + } + std::vector> arg_arrays(args.size()); + absl::c_transform(args, arg_arrays.begin(), [&](const ArgT& arg) mutable { + return ArgAdapter::GetIfRtArray(arg); + }); + TF_ASSIGN_OR_RETURN(auto result, ifrt_loaded_executable->Execute( + absl::MakeSpan(arg_arrays), options, + /*devices=*/std::nullopt)); + output_arrays = std::move(result.outputs); + // options.fill_status is only supposed to be true when the computation has + // tokens. + if (options.fill_status) { + result_status = result.status; + if (returned_futures.has_value()) { + returned_futures->resize(num_computations, std::move(result.status)); + } + } + } + + // TODO(b/240696624): Although the PjRt interface require `returned_futures` + // to be resized correctly if it is not nullopt, some implementation does not + // implement this. So we have to check whether returned_futures is empty. + // Remove this check once the implementation is fixed. + auto py_sharded_token = returned_futures.has_value() + ? PyShardedToken(std::move(*returned_futures)) + : PyShardedToken(); + + return PyExecuteResults(client, std::move(output_arrays), num_computations, + std::move(py_sharded_token), result_status); +} + +} // namespace + +PyExecuteResults::PyExecuteResults( + const nb_class_ptr& client, + std::vector> ifrt_arrays, + int num_computations, PyShardedToken token, PjRtFuture<> result_status) + : client_(client), + ifrt_arrays_(std::move(ifrt_arrays)), + num_computations_(num_computations), + token_(std::move(token)), + result_status_(std::move(result_status)) {} + +void PyExecuteResults::CheckNotDisassembled() const { + if (is_exploded_) { + throw nb::value_error("ExecuteResults already exploded."); + } +} + +std::vector> PyExecuteResults::Consume() { + CheckNotDisassembled(); + is_exploded_ = true; + return std::move(ifrt_arrays_); +} + +PyShardedToken PyExecuteResults::ConsumeToken() { + if (token_consumed_) { + throw nb::value_error("ExecuteResults token already consumed."); + } + token_consumed_ = true; + return std::move(token_); +} + +std::vector> +PyExecuteResults::DisassembleIntoSingleDeviceArrays() { + std::vector> outputs; + PopulateExecuteShardedResults( + client_, Consume(), + result_status_.IsValid() ? result_status_ : PjRtFuture<>(), + num_computations_, outputs); + return outputs; +} + +std::vector> +PyExecuteResults::DisassemblePrefixIntoSingleDeviceArrays(size_t n) { + CheckNotDisassembled(); + if (n > ifrt_arrays_.size()) { + throw nb::value_error( + absl::StrCat("In DisassemblePrefixIntoSingleDeviceArrays: ", n, " > ", + ifrt_arrays_.size()) + .c_str()); + } + std::vector> ifrt_arrays; + ifrt_arrays.reserve(ifrt_arrays_.size() - n); + for (size_t i = n; i < ifrt_arrays_.size(); ++i) { + ifrt_arrays.push_back(std::move(ifrt_arrays_[i])); + } + ifrt_arrays_.erase(ifrt_arrays_.begin() + n, ifrt_arrays_.end()); + std::swap(ifrt_arrays_, ifrt_arrays); + std::vector> outputs; + PopulateExecuteShardedResults( + client_, std::move(ifrt_arrays), + result_status_.IsValid() ? result_status_ : PjRtFuture<>(), + num_computations_, outputs); + return outputs; +} + +std::vector PyExecuteResults::ConsumeWithHandlers( + std::vector> + out_handlers) { + std::vector outputs; + auto ifrt_arrays = Consume(); + auto traceback = Traceback::Get(); + DCHECK_GT(num_computations_, 0); + int num_output_buffers = ifrt_arrays.size(); + outputs.reserve(num_output_buffers); + if (out_handlers.size() != num_output_buffers) { + throw nb::value_error( + absl::StrCat("Mismatch between out_handlers and num_results: ", + out_handlers.size(), " vs ", num_output_buffers) + .c_str()); + } + for (int buffer_id = 0; buffer_id < num_output_buffers; ++buffer_id) { + auto& handler = out_handlers[buffer_id]; + if (std::holds_alternative(handler)) { + outputs.push_back(std::get(handler)->Call( + client_, std::move(ifrt_arrays[buffer_id]), + result_status_.IsValid() ? result_status_ : PjRtFuture<>())); + } else { + tsl::profiler::TraceMe traceme("ConsumeWithHandlers fallback."); + auto disassembled_arrays = + ifrt_arrays[buffer_id]->DisassembleIntoSingleDeviceArrays( + ifrt::ArrayCopySemantics::kReuseInput, + ifrt::SingleDeviceShardSemantics::kAddressableShards); + TF_CHECK_OK(disassembled_arrays.status()); + nb::list bufs = + nb::steal(PyList_New(disassembled_arrays->size())); + int i = 0; + for (auto& disassembled_array : *disassembled_arrays) { + nb::object array = PyArray::MakeFromSingleDeviceArray( + client_, traceback, std::move(disassembled_array), false, true, + result_status_.IsValid() ? result_status_ : PjRtFuture<>()); + PyList_SET_ITEM(bufs.ptr(), i, array.release().ptr()); + ++i; + } + outputs.push_back(std::get(handler)(std::move(bufs))); + } + } + return outputs; +} + +absl::StatusOr>> +PyLoadedExecutable::ExecuteShardedOnLocalDevices( + absl::Span args) { + xla::ifrt::ExecuteOptions options = options_; + options.launch_id = GetNextLaunchId(); + options.fill_status = false; + options.execution_stream_id = tsl::Env::Default()->GetCurrentThreadId(); + std::optional>> returned_futures; + TF_ASSIGN_OR_RETURN(auto outputs_and_tokens, + ExecuteShardedOnLocalDevicesInternal( + options, client_, ifrt_loaded_executable_.get(), args, + returned_futures)); + return outputs_and_tokens.DisassembleIntoSingleDeviceArrays(); +} + +absl::StatusOr>, PyShardedToken>> +PyLoadedExecutable::ExecuteShardedOnLocalDevicesWithTokens( + absl::Span args) { + xla::ifrt::ExecuteOptions options = options_; + options.launch_id = GetNextLaunchId(); + options.fill_status = true; + options.execution_stream_id = tsl::Env::Default()->GetCurrentThreadId(); + std::optional>> returned_futures; + returned_futures.emplace(); + TF_ASSIGN_OR_RETURN(auto outputs_and_tokens, + ExecuteShardedOnLocalDevicesInternal( + options, client_, ifrt_loaded_executable_.get(), args, + returned_futures)); + return std::make_pair(outputs_and_tokens.DisassembleIntoSingleDeviceArrays(), + outputs_and_tokens.ConsumeToken()); +} + +absl::StatusOr PyLoadedExecutable::ExecuteSharded( + std::vector args, bool with_tokens) { + xla::ifrt::ExecuteOptions options = options_; + options.launch_id = GetNextLaunchId(); + options.fill_status = with_tokens; + options.execution_stream_id = tsl::Env::Default()->GetCurrentThreadId(); + std::optional>> returned_futures; + if (with_tokens) { + returned_futures.emplace(); + } + absl::Span span_args = args; + return ExecuteShardedOnLocalDevicesInternal(options, client_, + ifrt_loaded_executable_.get(), + span_args, returned_futures); +} + +absl::StatusOr>> +PyLoadedExecutable::HloModules() const { + nb::gil_scoped_release gil_release; + return ifrt_loaded_executable_->GetHloModules(); +} + +absl::StatusOr>> +PyLoadedExecutable::GetOutputMemoryKinds() const { + nb::gil_scoped_release gil_release; + return ifrt_loaded_executable_->GetOutputMemoryKinds(); +} + +absl::StatusOr>> +PyLoadedExecutable::GetParameterLayouts() const { + nb::gil_scoped_release gil_release; + return ifrt_loaded_executable_->GetParameterLayouts(); +} + +absl::StatusOr>> +PyLoadedExecutable::GetOutputLayouts() const { + nb::gil_scoped_release gil_release; + return ifrt_loaded_executable_->GetOutputLayouts(); +} + +std::optional> +PyLoadedExecutable::GetParameterShardings() const { + nb::gil_scoped_release gil_release; + return ifrt_loaded_executable_->GetParameterShardings(); +} + +std::optional> PyLoadedExecutable::GetOutputShardings() + const { + nb::gil_scoped_release gil_release; + return ifrt_loaded_executable_->GetOutputShardings(); +} + +int64_t PyLoadedExecutable::GetNextLaunchId() { + return next_launch_id_.fetch_add(1, std::memory_order_relaxed); +} + +void PyLoadedExecutable::KeepAlive(nb::object obj) { + keepalives_.push_back(std::move(obj)); +} + +} // namespace xla diff --git a/jaxlib/xla/py_executable.h b/jaxlib/xla/py_executable.h new file mode 100644 index 000000000000..214431f9472e --- /dev/null +++ b/jaxlib/xla/py_executable.h @@ -0,0 +1,263 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_PY_EXECUTABLE_H_ +#define JAXLIB_XLA_PY_EXECUTABLE_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "llvm/Support/Casting.h" +#include "nanobind/nanobind.h" +#include "jaxlib/xla/py_array.h" +#include "jaxlib/xla/py_client.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/layout.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_common.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/pjrt_future.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/attribute_map.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/python/pjrt_ifrt/pjrt_executable.h" +#include "xla/python/traceback.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/status.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +class PyToken { + public: + PyToken() = default; + explicit PyToken(PjRtFuture<> future) : future_(std::move(future)) {} + + static PyToken ReadyPyToken() { + return PyToken(PjRtFuture<>(absl::OkStatus())); + } + + absl::Status Await(); + + private: + PjRtFuture<> future_; +}; + +// PyShardedToken contains a PyToken for each device's execution. +class PyShardedToken { + public: + // Default construction creates a always-ready token. + PyShardedToken() = default; + explicit PyShardedToken(std::vector> futures) + : futures_(std::move(futures)) {} + + PyToken GetPyToken(int device_id) const { + if (futures_.empty()) return PyToken::ReadyPyToken(); + return PyToken(futures_.at(device_id)); + } + + absl::Status Await(); + + private: + std::vector> futures_; +}; + +class PyExecuteResults { + public: + PyExecuteResults(const nb_class_ptr& client, + std::vector> ifrt_arrays, + int num_computations, PyShardedToken token, + PjRtFuture<> result_status = PjRtFuture<>()); + + std::vector> DisassembleIntoSingleDeviceArrays(); + + std::vector> DisassemblePrefixIntoSingleDeviceArrays( + size_t n); + + std::vector ConsumeWithHandlers( + std::vector> + out_handlers); + + std::vector> Consume(); + + PyShardedToken ConsumeToken(); + + size_t Size() const { + CheckNotDisassembled(); + return ifrt_arrays_.size(); + } + + void CheckNotDisassembled() const; + + private: + bool is_exploded_ = false; + bool token_consumed_ = false; + nb_class_ptr client_; + std::vector> ifrt_arrays_; + int num_computations_; + PyShardedToken token_; + // Only set if the computation has tokens. + PjRtFuture<> result_status_; +}; + +using ExecuteShardedArg = std::variant>; + +// Python wrapper around PjRtExecutable. We use a wrapper class: +// a) to keep the PyClient alive via a std::shared_ptr<> +// b) to add Python-specific functionality. +class PyLoadedExecutable { + public: + PyLoadedExecutable( + nb_class_ptr client, + std::shared_ptr ifrt_loaded_executable, + std::optional traceback, + std::optional fingerprint); + ~PyLoadedExecutable(); + + nb_class_ptr client() const { return client_; } + ifrt::LoadedExecutable* ifrt_loaded_executable() const { + return ifrt_loaded_executable_.get(); + } + + std::shared_ptr shared_ifrt_loaded_executable() { + return ifrt_loaded_executable_; + } + + std::vector> AddressableDevices() const; + + int64_t SizeOfGeneratedCodeInBytes() const { + return ifrt_loaded_executable_->SizeOfGeneratedCodeInBytes(); + } + + absl::StatusOr GetCompiledMemoryStats() const { + nanobind::gil_scoped_release scope; + return ifrt_loaded_executable_->GetCompiledMemoryStats(); + } + + absl::StatusOr GetCostAnalysis() const { + return ifrt_loaded_executable_->GetCostAnalysis(); + } + + void Delete() { + // TODO(hyeontaek): Return absl::Status. + TF_CHECK_OK(ifrt_loaded_executable_->Delete().Await()); + } + + bool is_deleted() { return ifrt_loaded_executable_->IsDeleted(); } + + // Takes args indexed by argid then deviceid, transposes them, and passes to + // PjRtExecutable::Execute. The result is similarly transposed back into the + // argid,deviceid format. + // args is [num_args x num_devices]. + absl::StatusOr>> + ExecuteShardedOnLocalDevices(absl::Span args); + + absl::StatusOr>, PyShardedToken>> + ExecuteShardedOnLocalDevicesWithTokens( + absl::Span args); + + absl::StatusOr ExecuteSharded( + std::vector args, bool with_tokens); + + absl::StatusOr>> HloModules() const; + + absl::StatusOr>> + GetOutputMemoryKinds() const; + + absl::StatusOr>> + GetParameterLayouts() const; + + absl::StatusOr>> + GetOutputLayouts() const; + + std::optional> GetParameterShardings() const; + + std::optional> GetOutputShardings() const; + + const std::optional& traceback() { return traceback_; } + + ifrt::LoadedExecutable* ifrt_executable() const { + return ifrt_loaded_executable_.get(); + } + + // Short-term escape hatch to get PjRtLoadedExecutable from PyExecutable. + // TODO(hyeontaek): Migrate all users of this method to be agnostic of PjRt. + std::shared_ptr shared_ptr_pjrt_executable() { + auto* exec = llvm::dyn_cast_or_null( + ifrt_loaded_executable_.get()); + if (exec == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + return exec->shared_ptr_pjrt_loaded_executable(); + } + + // Returns a template of execute options to pass to + // `ifrt_executable()->Execute()`. Note that the caller may need to override + // some options such as `launch_id` that change at each execution. + const ifrt::ExecuteOptions& options() const { return options_; } + + // Returns a unique launch ID to use for the next execution. + int64_t GetNextLaunchId(); + + const std::optional& fingerprint() const { return fingerprint_; } + + // Keep `obj` alive as long as PyLoadedExecutable. + void KeepAlive(nanobind::object obj); + + private: + friend class PyClient; + + nb_class_ptr client_; + std::shared_ptr ifrt_loaded_executable_; + std::optional traceback_; + + // Identical executables (i.e. representing the same program) will have the + // same fingerprint. nullopt on platforms or executables where fingerprints + // aren't implemented. + std::optional fingerprint_; + + // Launch ID to use for the next execution. + std::atomic next_launch_id_; + + // The options to pass to `executable_.Execute`. + ifrt::ExecuteOptions options_; + + // Python objects to keep alive as requested by user. + std::vector keepalives_; + + // Doubly-linked list of all executables known to the client. Protected by the + // GIL. + PyLoadedExecutable* next_; + PyLoadedExecutable* prev_; +}; + +} // namespace xla + +#endif // JAXLIB_XLA_PY_EXECUTABLE_H_ diff --git a/jaxlib/xla/py_memory_space.cc b/jaxlib/xla/py_memory_space.cc new file mode 100644 index 000000000000..f365dd25dfb6 --- /dev/null +++ b/jaxlib/xla/py_memory_space.cc @@ -0,0 +1,102 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/py_memory_space.h" + +#include + +#include + +#include "absl/strings/string_view.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "jaxlib/xla/py_client.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/nb_class_ptr.h" + +namespace nb = ::nanobind; + +namespace xla { + +PyMemorySpace::PyMemorySpace(nb_class_ptr client, + ifrt::Memory* memory) + : client_(std::move(client)), memory_(memory) {} + +int PyMemorySpace::process_index() const { return client_->process_index(); } + +absl::string_view PyMemorySpace::platform() const { + // TODO(phawkins): this is a temporary backwards + // compatibility shim. We changed the name PJRT + // reports for GPU platforms to "cuda" or "rocm", + // but we haven't yet updated JAX clients that + // expect "gpu". Migrate users and remove this + // code. + if (client_->platform_name() == "cuda" || + client_->platform_name() == "rocm") { + return absl::string_view("gpu"); + } else { + return client_->platform_name(); + } +} + +absl::string_view PyMemorySpace::kind() const { + return *memory_->Kind().memory_kind(); +} + +absl::string_view PyMemorySpace::Str() const { return memory_->DebugString(); } + +absl::string_view PyMemorySpace::Repr() const { return memory_->ToString(); } + +nb::list PyMemorySpace::AddressableByDevices() const { + nb::list devices; + for (ifrt::Device* device : memory_->Devices()) { + devices.append(client_->GetPyDevice(device)); + } + return devices; +} + +/* static */ int PyMemorySpace::tp_traverse(PyObject* self, visitproc visit, + void* arg) { + PyMemorySpace* d = nb::inst_ptr(self); + Py_VISIT(d->client().ptr()); + return 0; +} + +/* static */ int PyMemorySpace::tp_clear(PyObject* self) { + PyMemorySpace* d = nb::inst_ptr(self); + nb_class_ptr client; + std::swap(client, d->client_); + return 0; +} + +PyType_Slot PyMemorySpace::slots_[] = { + {Py_tp_traverse, (void*)PyMemorySpace::tp_traverse}, + {Py_tp_clear, (void*)PyMemorySpace::tp_clear}, + {0, nullptr}, +}; + +/* static */ void PyMemorySpace::RegisterPythonType(nb::module_& m) { + nb::class_ device(m, "Memory", + nb::type_slots(PyMemorySpace::slots_)); + device.def_prop_ro("process_index", &PyMemorySpace::process_index) + .def_prop_ro("platform", &PyMemorySpace::platform) + .def_prop_ro("kind", &PyMemorySpace::kind) + .def("__str__", &PyMemorySpace::Str) + .def("__repr__", &PyMemorySpace::Repr) + .def("addressable_by_devices", &PyMemorySpace::AddressableByDevices, + "Returns devices that can address this memory."); +} + +} // namespace xla diff --git a/jaxlib/xla/py_memory_space.h b/jaxlib/xla/py_memory_space.h new file mode 100644 index 000000000000..4ad7b852f416 --- /dev/null +++ b/jaxlib/xla/py_memory_space.h @@ -0,0 +1,64 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_PY_MEMORY_SPACE_H_ +#define JAXLIB_XLA_PY_MEMORY_SPACE_H_ + +#include + +#include "nanobind/nanobind.h" +#include "jaxlib/xla/py_client.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/nb_class_ptr.h" + +namespace xla { + +class PyMemorySpace { + public: + PyMemorySpace(nb_class_ptr client, ifrt::Memory* memory_space); + + // Memory spaces are compared using Python object identity, so we don't allow + // them to be copied or moved. + PyMemorySpace(const PyMemorySpace&) = delete; + PyMemorySpace(PyMemorySpace&&) = delete; + PyMemorySpace& operator=(const PyMemorySpace&) = delete; + PyMemorySpace& operator=(PyMemorySpace&&) = delete; + + const nb_class_ptr& client() const { return client_; } + ifrt::Memory* memory_space() const { return memory_; } + + int process_index() const; + absl::string_view platform() const; + absl::string_view kind() const; + + absl::string_view Str() const; + absl::string_view Repr() const; + + nanobind::list AddressableByDevices() const; + + static void RegisterPythonType(nanobind::module_& m); + + private: + static int tp_traverse(PyObject* self, visitproc visit, void* arg); + static int tp_clear(PyObject* self); + static PyType_Slot slots_[]; + + nb_class_ptr client_; + ifrt::Memory* memory_; +}; + +} // namespace xla + +#endif // JAXLIB_XLA_PY_MEMORY_SPACE_H_ diff --git a/jaxlib/xla/py_program.cc b/jaxlib/xla/py_program.cc new file mode 100644 index 000000000000..ec82292a50cd --- /dev/null +++ b/jaxlib/xla/py_program.cc @@ -0,0 +1,291 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/py_program.h" + +#include +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/py_device.h" +#include "jaxlib/xla/py_device_list.h" +#include "jaxlib/xla/sharding.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/pjrt/mlir_to_hlo.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/array_spec.h" +#include "xla/python/ifrt/compiler.h" +#include "xla/python/ifrt/custom_call_program.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/hlo/hlo_program.h" +#include "xla/python/ifrt/host_callback.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/plugin_program.h" +#include "xla/python/ifrt/program.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/pjrt_ifrt/xla_compiler.h" +#include "xla/python/pjrt_ifrt/xla_sharding.h" +#include "xla/python/python_ref_manager.h" +#include "xla/python/types.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/statusor.h" + +namespace xla { + +namespace nb = ::nanobind; + +namespace { + +// Gets `ifrt::DeviceList` from a sequence of JAX devices. +absl::StatusOr GetDeviceList(nb::sequence devices) { + ifrt::DeviceListRef ifrt_device_list; + if (devices.type().is(jax::PyDeviceList::type())) { + return nb::cast(devices)->ifrt_device_list(); + } else { + auto py_devices = nb::cast>>(devices); + if (py_devices.empty()) { + return absl::InvalidArgumentError( + "Colocated Python program requires at least one device"); + } + absl::InlinedVector ifrt_devices; + ifrt_devices.reserve(py_devices.size()); + for (const nb_class_ptr& py_device : py_devices) { + ifrt_devices.push_back(py_device->device()); + } + return py_devices.front()->client()->ifrt_client()->MakeDeviceList( + ifrt_devices); + } +} + +// Gets `xla::HloSharding` from a JAX Sharding. +xla::HloSharding GetXlaHloSharding(nb::handle sharding, + int64_t num_dimensions) { + if (sharding.type().is(jax::GSPMDSharding::type())) { + return nb::cast(sharding)->hlo_sharding(); + } else { + return nb::cast( + sharding.attr("_to_xla_hlo_sharding")(num_dimensions)); + } +} + +// Gets `ifrt::DeviceList` from a JAX Sharding. +absl::StatusOr GetIfrtDeviceList(nb::handle sharding) { + if (sharding.type().is(jax::NamedSharding::type())) { + TF_ASSIGN_OR_RETURN( + auto ns_device_list, + nb::cast(sharding)->internal_device_list()); + return ns_device_list->ifrt_device_list(); + } else if (sharding.type().is(jax::SingleDeviceSharding::type())) { + return nb::cast(sharding) + ->internal_device_list() + ->ifrt_device_list(); + } else if (sharding.type().is(jax::PmapSharding::type())) { + return nb::cast(sharding) + ->internal_device_list() + ->ifrt_device_list(); + } else if (sharding.type().is(jax::GSPMDSharding::type())) { + return nb::cast(sharding) + ->internal_device_list() + ->ifrt_device_list(); + } else { + return nb::cast( + sharding.attr("_internal_device_list")) + ->ifrt_device_list(); + } +} + +// Gets `ifrt::MemoryKind` from a JAX Sharding. +ifrt::MemoryKind GetIfrtMemoryKind(nb::handle sharding) { + auto memory_kind = sharding.attr("memory_kind"); + if (memory_kind.is_none()) { + return ifrt::MemoryKind(); + } else { + return ifrt::MemoryKind(nb::cast(memory_kind)); + } +} + +// Makes `ifrt::Sharding` from a JAX Sharding. It requires the number of shape +// dimensions, which may become necessary when building an HLO sharding. +absl::StatusOr> GetIfrtSharding( + nb::handle sharding, int64_t num_dimensions) { + auto ifrt_memory_kind = GetIfrtMemoryKind(sharding); + std::shared_ptr ifrt_sharding; + if (sharding.type().is(jax::SingleDeviceSharding::type())) { + TF_ASSIGN_OR_RETURN(auto ifrt_device_list, + nb::cast(sharding) + ->internal_device_list() + ->ifrt_device_list()); + return ifrt::SingleDeviceSharding::Create( + ifrt_device_list->devices().front(), ifrt_memory_kind); + } else { + TF_ASSIGN_OR_RETURN(auto ifrt_device_list, GetIfrtDeviceList(sharding)); + auto xla_hlo_sharding = GetXlaHloSharding(sharding, num_dimensions); + return ifrt::HloSharding::Create(std::move(ifrt_device_list), + ifrt_memory_kind, + std::move(xla_hlo_sharding)); + } +} + +// Gets `ifrt::ArraySpec`s from a sequence of JAX avals (e.g., +// `jax.ShapeDtypeStruct`). +absl::StatusOr> GetIfrtArraySpecs( + nb::sequence avals) { + std::vector ifrt_array_specs; + ifrt_array_specs.reserve(nb::len(avals)); + for (nb::handle aval : avals) { + ifrt::Shape ifrt_shape(nb::cast>(aval.attr("shape"))); + TF_ASSIGN_OR_RETURN( + auto ifrt_dtype, + DtypeToIfRtDType(nb::cast(aval.attr("dtype")))); + TF_ASSIGN_OR_RETURN( + auto ifrt_sharding, + GetIfrtSharding(aval.attr("sharding"), ifrt_shape.dims().size())); + ifrt_array_specs.push_back(ifrt::ArraySpec{ + ifrt_dtype, std::move(ifrt_shape), std::move(ifrt_sharding)}); + } + return ifrt_array_specs; +} + +absl::StatusOr> MakePluginProgramFromString( + std::string data) { + auto plugin_program = std::make_unique(); + plugin_program->data = std::move(data); + return plugin_program; +} + +absl::StatusOr> MakePluginProgramFromBytes( + nb::bytes data) { + auto plugin_program = std::make_unique(); + plugin_program->data = std::string(data.c_str(), data.size()); + return plugin_program; +} + +absl::StatusOr> +MakeColocatedPythonCompileOptions() { + return std::make_unique(); +} + +absl::StatusOr> +MakePluginCompileOptions() { + return std::make_unique(); +} + +absl::StatusOr> MakeHloProgram( + absl::string_view mlir_module) { + auto context = std::make_unique(); + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ParseMlirModuleString(mlir_module, *context)); + return std::make_unique(std::move(context), + std::move(module)); +} + +absl::StatusOr> MakeHloProgramFromString( + std::string mlir_module) { + return MakeHloProgram(mlir_module); +} + +absl::StatusOr> MakeHloProgramFromBytes( + nb::bytes mlir_module) { + return MakeHloProgram( + absl::string_view(mlir_module.c_str(), mlir_module.size())); +} + +absl::StatusOr> MakeXlaCompileOptions( + CompileOptions options, std::vector host_callbacks) { + std::vector> + ifrt_loaded_host_callbacks; + ifrt_loaded_host_callbacks.reserve(host_callbacks.size()); + // Extract `ifrt::LoadedHostCallback`s from host callback capsules that were + // created by `PyClient::MakePythonCallbackUsingHostSendAndRecv()` or + // `PyClient::GetEmitPythonCallbackDescriptor()`. + for (auto& host_callback : host_callbacks) { + ifrt_loaded_host_callbacks.push_back(tsl::FormRef( + static_cast(host_callback.data()))); + } + return std::make_unique( + std::move(options), std::move(ifrt_loaded_host_callbacks)); +} + +constexpr absl::string_view kColocatedPythonProgramType = + "jax_colocated_python_v0.0.1"; + +absl::StatusOr> MakeColocatedPythonProgram( + std::string name, nb::bytes picked_function, nb::sequence devices, + nb::sequence input_avals, nb::sequence output_avals) { + auto ifrt_serialized_program_text = absl::MakeCordFromExternal( + absl::string_view(reinterpret_cast(picked_function.data()), + picked_function.size()), + /*releaser=*/[picked_function](absl::string_view) mutable { + GlobalPyRefManager()->AddGarbage(std::move(picked_function)); + }); + TF_ASSIGN_OR_RETURN(auto ifrt_device_list, GetDeviceList(devices)); + TF_ASSIGN_OR_RETURN(auto ifrt_input_specs, GetIfrtArraySpecs(input_avals)); + TF_ASSIGN_OR_RETURN(auto ifrt_output_specs, GetIfrtArraySpecs(output_avals)); + return std::make_unique( + std::string(kColocatedPythonProgramType), std::move(name), + std::move(ifrt_serialized_program_text), std::move(ifrt_device_list), + std::move(ifrt_input_specs), std::move(ifrt_output_specs)); +} + +} // namespace + +void BuildIfrtProgramsSubmodule(nanobind::module_& m) { + auto sub_module = m.def_submodule("ifrt_programs"); + nb::class_ ifrt_program_base_class(sub_module, "Program"); + nb::class_ ifrt_compile_options_base_class( + sub_module, "CompileOptions"); + sub_module + .def("make_hlo_program", ValueOrThrowWrapper(MakeHloProgramFromString), + nb::arg("mlir_module")) + .def("make_hlo_program", ValueOrThrowWrapper(MakeHloProgramFromBytes), + nb::arg("mlir_module")) + .def("make_colocated_python_program", + ValueOrThrowWrapper(MakeColocatedPythonProgram), nb::arg("name"), + nb::arg("pickled_function"), nb::arg("devices"), + nb::arg("input_avals"), nb::arg("output_avals")) + .def("make_plugin_program", + ValueOrThrowWrapper(MakePluginProgramFromString), nb::arg("data")) + .def("make_plugin_program", + ValueOrThrowWrapper(MakePluginProgramFromBytes), nb::arg("data")) + .def("make_xla_compile_options", + ValueOrThrowWrapper(MakeXlaCompileOptions), nb::arg("options"), + nb::arg("host_callbacks")) + .def("make_colocated_python_compile_options", + ValueOrThrowWrapper(MakeColocatedPythonCompileOptions)) + .def("make_plugin_compile_options", + ValueOrThrowWrapper(MakePluginCompileOptions)); +} + +} // namespace xla diff --git a/jaxlib/xla/py_program.h b/jaxlib/xla/py_program.h new file mode 100644 index 000000000000..9fd30eeeed2f --- /dev/null +++ b/jaxlib/xla/py_program.h @@ -0,0 +1,27 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_PY_PROGRAM_H_ +#define JAXLIB_XLA_PY_PROGRAM_H_ + +#include "nanobind/nanobind.h" + +namespace xla { + +void BuildIfrtProgramsSubmodule(nanobind::module_& m); + +} // namespace xla + +#endif // JAXLIB_XLA_PY_PROGRAM_H_ diff --git a/jaxlib/xla/py_socket_transfer.cc b/jaxlib/xla/py_socket_transfer.cc index dd2c02898e18..b1c4fbcc541f 100644 --- a/jaxlib/xla/py_socket_transfer.cc +++ b/jaxlib/xla/py_socket_transfer.cc @@ -34,6 +34,9 @@ limitations under the License. #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/py_array.h" +#include "jaxlib/xla/py_client.h" +#include "jaxlib/xla/to_ifrt_sharding.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/status_casters.h" #include "xla/python/ifrt/array.h" @@ -48,9 +51,6 @@ limitations under the License. #include "xla/python/pjrt_ifrt/pjrt_device.h" #include "xla/python/pjrt_ifrt/pjrt_dtype.h" #include "xla/python/pjrt_ifrt/pjrt_memory.h" -#include "xla/python/py_array.h" -#include "xla/python/py_client.h" -#include "xla/python/to_ifrt_sharding.h" #include "xla/python/traceback.h" #include "xla/python/transfer/event_loop.h" #include "xla/python/transfer/socket-server.h" diff --git a/jaxlib/xla/py_values.cc b/jaxlib/xla/py_values.cc new file mode 100644 index 000000000000..9375dd5440c6 --- /dev/null +++ b/jaxlib/xla/py_values.cc @@ -0,0 +1,745 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/py_values.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/complex.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "jaxlib/xla/py_array.h" +#include "jaxlib/xla/sharding.h" +#include "xla/primitive_util.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" +#include "xla/python/python_ref_manager.h" +#include "xla/python/types.h" +#include "xla/shape.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/python/lib/core/numpy.h" +#include "xla/types.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/ml_dtypes.h" +#include "tsl/profiler/lib/traceme.h" + +namespace nb = nanobind; + +namespace xla { + +namespace { + +absl::StatusOr> StringDTypeArrayToCords( + PyArrayObject* py_array_obj) { + if (PyArray_SIZE(py_array_obj) == 0) { + return absl::InvalidArgumentError("empty numpy array"); + } + + std::vector cords; + cords.reserve(PyArray_SIZE(py_array_obj)); + + auto iter = + nb::steal(PyArray_IterNew(reinterpret_cast(py_array_obj))); + while (PyArray_ITER_NOTDONE(iter.ptr())) { + auto* iter_data = PyArray_ITER_DATA(iter.ptr()); + auto* item = PyArray_GETITEM(py_array_obj, static_cast(iter_data)); + if (!item) { + return absl::InternalError( + "Failed to get elements out of the ndarray iter."); + } + Py_ssize_t len; + auto str = PyUnicode_AsUTF8AndSize(item, &len); + cords.push_back(absl::Cord(absl::string_view(str, len))); + PyArray_ITER_NEXT(iter.ptr()); + } + return cords; +} + +using DevicePutFunc = std::function( + nb::handle, ifrt::Client*, ifrt::Device*, const DevicePutOptions& options, + ifrt::MemoryKind to_memory_kind)>; + +template +absl::StatusOr HandlePythonScalar( + nb::handle obj, ifrt::Client* client, ifrt::Device* to_device, + const DevicePutOptions& options, ifrt::MemoryKind to_memory_kind) { + T value; + try { + value = nb::cast(obj); + } catch (const std::exception& e) { + return InvalidArgument( + "Unable to convert Python scalar to %s. This most likely means the " + "value (%s) overflows the range of the type.", + PrimitiveType_Name(primitive_util::NativeToPrimitiveType()), + nb::cast(nb::repr(obj))); + } + + std::variant data; + Shape shape; + PrimitiveType type; + if (std::is_same() || !options.squash_64bit_types) { + data.template emplace<0>(value); + type = primitive_util::NativeToPrimitiveType(); + } else { + // TODO(phawkins): we should check for overflow here, e.g., because of bugs + // like https://github.com/google/jax/issues/2006 + data.template emplace<1>(static_cast(value)); + type = primitive_util::NativeToPrimitiveType(); + } + + return [client, data, type, to_device, to_memory_kind, + options]() -> absl::StatusOr { + const void* ptr = std::visit( + [](const auto& v) { return static_cast(&v); }, data); + TF_ASSIGN_OR_RETURN(auto ifrt_dtype, xla::ifrt::ToDType(type)); + // TODO(yashkatariya): Plumb sharding or memory_kind here. + TF_ASSIGN_OR_RETURN( + auto ifrt_array, + client->MakeArrayFromHostBuffer( + ptr, ifrt_dtype, /*shape=*/ifrt::Shape({}), /*byte_strides=*/{}, + ifrt::SingleDeviceSharding::Create(to_device, to_memory_kind), + ifrt::Client::HostBufferSemantics::kImmutableOnlyDuringCall, + /*on_done_with_host_buffer=*/{}, options.ifrt_user_context)); + return DevicePutResult(std::move(ifrt_array), /*weak_type=*/true); + }; +} + +absl::StatusOr HandlePythonInt( + nb::handle obj, ifrt::Client* client, ifrt::Device* to_device, + const DevicePutOptions& options, ifrt::MemoryKind to_memory_kind) { + PrimitiveType type; + std::variant data; + + if (options.squash_64bit_types) { + try { + data.emplace<1>(nb::cast(obj)); + } catch (const std::exception& e) { + return InvalidArgument( + "Unable to convert Python scalar to %s. This most likely means the " + "value (%s) overflows the range of the type.", + PrimitiveType_Name(primitive_util::NativeToPrimitiveType()), + nb::cast(nb::repr(obj))); + } + type = S32; + } else { + try { + data.emplace<0>(nb::cast(obj)); + } catch (const std::exception& e) { + return InvalidArgument( + "Unable to convert Python scalar to %s. This most likely means the " + "value (%s) overflows the range of the type.", + PrimitiveType_Name(primitive_util::NativeToPrimitiveType()), + nb::cast(nb::repr(obj))); + } + type = S64; + } + return [client, data, type, to_device, to_memory_kind, + options]() -> absl::StatusOr { + const void* ptr = std::visit( + [](const auto& v) { return static_cast(&v); }, data); + TF_ASSIGN_OR_RETURN(auto ifrt_dtype, xla::ifrt::ToDType(type)); + // TODO(yashkatariya): Plumb sharding or memory_kind here. + TF_ASSIGN_OR_RETURN( + auto ifrt_array, + client->MakeArrayFromHostBuffer( + ptr, ifrt_dtype, /*shape=*/xla::ifrt::Shape({}), + /*byte_strides=*/{}, + ifrt::SingleDeviceSharding::Create(to_device, to_memory_kind), + ifrt::Client::HostBufferSemantics::kImmutableOnlyDuringCall, + /*on_done_with_host_buffer=*/nullptr, options.ifrt_user_context)); + return DevicePutResult(std::move(ifrt_array), /*weak_type=*/true); + }; +} + +template +absl::StatusOr HandleNumpyScalar( + nb::handle h, ifrt::Client* client, ifrt::Device* to_device, + const DevicePutOptions& options, ifrt::MemoryKind to_memory_kind) { + std::variant data; + PrimitiveType type; + // For extension types, ScalarAsCtype returns a pointer to the data. + if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = S2; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = S4; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = U2; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = U4; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = BF16; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F4E2M1FN; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E3M4; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E4M3; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E4M3FN; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E4M3B11FNUZ; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E5M2; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E4M3FNUZ; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E5M2FNUZ; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E8M0FNU; + } else if (std::is_same() || !options.squash_64bit_types) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<0>()); + type = primitive_util::NativeToPrimitiveType(); + } else { + T value; + PyArray_ScalarAsCtype(h.ptr(), &value); + data.template emplace<1>(static_cast(value)); + type = primitive_util::NativeToPrimitiveType(); + } + std::shared_ptr py_buffer_ref; + if (data.index() == 2) { + py_buffer_ref = + GlobalPyRefManager()->ManageReference(nb::cast(h)); + } + return [client, data, py_buffer_ref, type, to_device, options, + to_memory_kind]() mutable -> absl::StatusOr { + const void* ptr = std::visit( + [](const auto& v) -> const void* { + if constexpr (std::is_same_v, void*>) { + return v; + } else { + return static_cast(&v); + } + }, + data); + TF_ASSIGN_OR_RETURN(auto ifrt_dtype, xla::ifrt::ToDType(type)); + // TODO(yashkatariya): Plumb sharding or memory_kind here. + TF_ASSIGN_OR_RETURN( + auto ifrt_array, + client->MakeArrayFromHostBuffer( + ptr, ifrt_dtype, /*shape=*/xla::ifrt::Shape({}), + /*byte_strides=*/{}, + ifrt::SingleDeviceSharding::Create(to_device, to_memory_kind), + ifrt::Client::HostBufferSemantics::kImmutableOnlyDuringCall, + /*on_done_with_host_buffer=*/ + [py_buffer_ref = std::move( + py_buffer_ref)]() { /* keeps py_buffer_ref alive */ }, + options.ifrt_user_context)); + return DevicePutResult(std::move(ifrt_array), /*weak_type=*/false); + }; +} + +absl::StatusOr HandleStringNumpyArray( + nb::handle h, ifrt::Client* client, ifrt::Device* to_device, + const DevicePutOptions& options, ifrt::MemoryKind to_memory_kind) { + xla::nb_numpy_ndarray array = nb::cast(h); + auto py_array_obj = reinterpret_cast(array.ptr()); + TF_ASSIGN_OR_RETURN(auto cords, StringDTypeArrayToCords(py_array_obj)); + + // Assemble all the parameters of MakeArrayFromHostBuffer + void* data = cords.data(); + + // Make an explicit copy of the shape elements so we won't run into complex + // endianness and precision issues that might arise if we reinterpret-casted + // from npy_intp, that can be just 32 bits-wide in some environments + // such as macos_arm64 to const int64_t* that must be 64 bits-wide. + ifrt::Shape::Dimensions dims; + dims.reserve(array.ndim()); + for (int i = 0; i < array.ndim(); ++i) { + dims.push_back(array.shape(i)); + } + ifrt::Shape shape(std::move(dims)); + + std::shared_ptr sharding = + xla::ifrt::SingleDeviceSharding::Create(to_device, to_memory_kind); + + auto on_done_with_host_buffer = [cords = std::move(cords)] {}; + + return [client, data = data, shape = std::move(shape), + sharding = std::move(sharding), + on_done_with_host_buffer = std::move(on_done_with_host_buffer), + options]() mutable -> absl::StatusOr { + TF_ASSIGN_OR_RETURN( + auto ifrt_array, + client->MakeArrayFromHostBuffer( + data, ifrt::DType(ifrt::DType::kString), std::move(shape), + /*byte_strides=*/std::nullopt, std::move(sharding), + ifrt::Client::HostBufferSemantics::kImmutableUntilTransferCompletes, + std::move(on_done_with_host_buffer), options.ifrt_user_context)); + + return DevicePutResult(std::move(ifrt_array), /*weak_type=*/false); + }; +} + +absl::StatusOr HandleNumpyArray( + nb::handle h, ifrt::Client* client, ifrt::Device* to_device, + const DevicePutOptions& options, ifrt::MemoryKind to_memory_kind) { + xla::nb_numpy_ndarray array = nb::cast(h); + + // String numpy arrays require substantially different processing. + if (array.dtype().char_() == (int)'T' || array.dtype().kind() == 'T') { + return HandleStringNumpyArray(h, client, to_device, options, + to_memory_kind); + } + + TF_ASSIGN_OR_RETURN(PrimitiveType type, DtypeToPrimitiveType(array.dtype())); + + PrimitiveType squashed_type; + if (options.squash_64bit_types) { + squashed_type = Squash64BitTypes(type); + if (squashed_type != type) { + TF_ASSIGN_OR_RETURN(xla::nb_dtype squashed_dtype, + PrimitiveTypeToNbDtype(squashed_type)); + array = nb::steal(PyArray_CastToType( + reinterpret_cast(array.ptr()), + reinterpret_cast(squashed_dtype.release().ptr()), + /*fortran=*/0)); + } + } else { + squashed_type = type; + } + + absl::InlinedVector dims(array.ndim()); + absl::InlinedVector byte_strides(array.ndim()); + for (int i = 0; i < array.ndim(); ++i) { + dims[i] = array.shape(i); + byte_strides[i] = array.strides(i); + } + const void* data = array.data(); + std::shared_ptr py_buffer_ref = + GlobalPyRefManager()->ManageReference(std::move(array)); + return [client, data, squashed_type, dims = std::move(dims), + byte_strides = std::move(byte_strides), + py_buffer_ref = std::move(py_buffer_ref), options, to_device, + to_memory_kind]() mutable -> absl::StatusOr { + TF_ASSIGN_OR_RETURN(auto ifrt_dtype, xla::ifrt::ToDType(squashed_type)); + + ifrt::Client::HostBufferSemantics host_buffer_semantics = + ifrt::Client::HostBufferSemantics::kImmutableOnlyDuringCall; + std::function on_done_with_host_buffer; + if (options.allow_zero_copy) { + on_done_with_host_buffer = + [py_buffer_ref{ + std::move(py_buffer_ref)}]() { /* keeps py_buffer_ref alive */ }; + host_buffer_semantics = + ifrt::Client::HostBufferSemantics::kImmutableZeroCopy; + } + + TF_ASSIGN_OR_RETURN( + auto ifrt_array, + client->MakeArrayFromHostBuffer( + data, ifrt_dtype, ifrt::Shape(dims), byte_strides, + xla::ifrt::SingleDeviceSharding::Create(to_device, to_memory_kind), + host_buffer_semantics, std::move(on_done_with_host_buffer), + options.ifrt_user_context)); + return DevicePutResult(std::move(ifrt_array), /*weak_type=*/false); + }; +} + +absl::StatusOr HandlePyArray( + nb::handle obj, ifrt::Client* client, ifrt::Device* to_device, + const DevicePutOptions& options, ifrt::MemoryKind to_memory_kind) { + auto py_array = nb::borrow(obj); + + // We only allow single device case for PyArray in device put. + if (py_array.num_shards() != 1) { + return InvalidArgument( + "device_put expects an array with exactly one shard, got an array with " + "with %d shards.", + py_array.num_shards()); + } + + ifrt::Array* ifrt_array = py_array.ifrt_array(); + if (ifrt_array == nullptr) { + return InvalidArgument("Array has been deleted."); + } + + // Fallback to python for non-matching clients or pmap sharding. + if (py_array.sharding().type().ptr() == jax::PmapSharding::type().ptr() || + ifrt_array->sharding().devices()->devices().front()->client() != + to_device->client()) { + return HandleNumpyArray(obj.attr("_value"), client, to_device, options, + to_memory_kind); + } + + if (ifrt_array->sharding().devices()->devices().front() == to_device && + (!to_memory_kind.memory_kind().has_value() || + !ifrt_array->sharding().memory_kind().memory_kind().has_value() || + ifrt_array->sharding().memory_kind() == to_memory_kind)) { + DevicePutResult result(tsl::FormRef(ifrt_array), py_array.weak_type(), + /*owning_pybuffer=*/nb::borrow(obj)); + return [result = std::move(result)]() mutable { return std::move(result); }; + } else { + return [ifrt_array = tsl::FormRef(ifrt_array), to_device, to_memory_kind, + owning_pybuffer = py_array.weak_type()]() mutable + -> absl::StatusOr { + auto* ifrt_client = ifrt_array->client(); + TF_ASSIGN_OR_RETURN( + auto copied_ifrt_arrays, + ifrt_client->CopyArrays(absl::MakeSpan(&ifrt_array, 1), + ifrt_client->MakeDeviceList({to_device}), + to_memory_kind, + ifrt::ArrayCopySemantics::kReuseInput)); + return DevicePutResult(std::move(copied_ifrt_arrays[0]), + std::move(owning_pybuffer)); + }; + } +} + +} // namespace + +absl::StatusOr DevicePut(nb::handle arg, + ifrt::Client* client, + ifrt::Device* to_device, + const DevicePutOptions& options, + ifrt::MemoryKind to_memory_kind) { + tsl::profiler::TraceMe traceme("DevicePut"); + static const absl::flat_hash_map* const handlers = + [] { + auto p = new absl::flat_hash_map(); + const NumpyScalarTypes& dtypes = GetNumpyScalarTypes(); + // Python scalar types. + static_assert(sizeof(bool) == 1, + "Conversion code assumes bool is 1 byte"); + (*p)[reinterpret_cast(&PyBool_Type)] = + HandlePythonScalar; + (*p)[reinterpret_cast(&PyLong_Type)] = HandlePythonInt; + (*p)[reinterpret_cast(&PyFloat_Type)] = + HandlePythonScalar; + (*p)[reinterpret_cast(&PyComplex_Type)] = + HandlePythonScalar; + + (*p)[reinterpret_cast(&PyArray_Type)] = HandleNumpyArray; + + // Numpy scalar types. For some of them, we share the handler with + // Python types (np_int64, np_float64, np_complex128). + (*p)[dtypes.np_bool.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_int4.ptr()] = HandleNumpyScalar; + if (dtypes.np_int2.has_value()) { + (*p)[dtypes.np_int2->ptr()] = HandleNumpyScalar; + } + (*p)[dtypes.np_int8.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_int16.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_int32.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_int64.ptr()] = HandleNumpyScalar; + if (dtypes.np_uint2.has_value()) { + (*p)[dtypes.np_uint2->ptr()] = HandleNumpyScalar; + } + (*p)[dtypes.np_uint4.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_uint8.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_uint16.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_uint32.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_uint64.ptr()] = HandleNumpyScalar; + if (dtypes.np_float4_e2m1fn.has_value()) { + (*p)[dtypes.np_float4_e2m1fn->ptr()] = + HandleNumpyScalar; + } + if (dtypes.np_float8_e3m4.has_value()) { + (*p)[dtypes.np_float8_e3m4->ptr()] = + HandleNumpyScalar; + } + if (dtypes.np_float8_e4m3.has_value()) { + (*p)[dtypes.np_float8_e4m3->ptr()] = + HandleNumpyScalar; + } + (*p)[dtypes.np_float8_e4m3fn.ptr()] = + HandleNumpyScalar; + (*p)[dtypes.np_float8_e4m3b11fnuz.ptr()] = + HandleNumpyScalar; + (*p)[dtypes.np_float8_e5m2.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_float8_e4m3fnuz.ptr()] = + HandleNumpyScalar; + (*p)[dtypes.np_float8_e5m2fnuz.ptr()] = + HandleNumpyScalar; + if (dtypes.np_float8_e8m0fnu.has_value()) { + (*p)[dtypes.np_float8_e8m0fnu->ptr()] = + HandleNumpyScalar; + } + (*p)[dtypes.np_bfloat16.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_float16.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_float32.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_float64.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_complex64.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_complex128.ptr()] = + HandleNumpyScalar; + static_assert(sizeof(long long) == sizeof(int64_t), // NOLINT + "long long must be the same size as int64_t"); + (*p)[dtypes.np_longlong.ptr()] = HandleNumpyScalar; + static_assert(sizeof(int) == sizeof(int32_t), + "int must be the same size as int32_t"); + (*p)[dtypes.np_intc.ptr()] = HandleNumpyScalar; + + return p; + }(); + + if (arg.type().ptr() == PyArray::type().ptr()) { + auto array = nb::borrow(arg); + return HandlePyArray(arg, client, to_device, options, to_memory_kind); + } + + auto res = handlers->find(arg.type().ptr()); + if (res == handlers->end()) { + for (auto base_class : arg.type().attr("__mro__")) { + res = handlers->find(base_class.ptr()); + if (res != handlers->end()) { + return res->second(arg, client, to_device, options, to_memory_kind); + } + } + return InvalidArgument( + "%s", absl::StrCat( + "Not supported: The C++ jax jit execution path, only accepts " + "DeviceArray, Numpy arrays scalars of supported types " + "(see implementation), or Python scalars. Got type ", + nb::cast(nb::str(arg.type())))); + } + return res->second(arg, client, to_device, options, to_memory_kind); +} + +bool IsFloat0(xla::nb_numpy_ndarray arg) { + static const auto* dtypes_module = + new nb::module_(nb::module_::import_("jax.dtypes")); + static const auto* float0_dtype = + new nb::handle(dtypes_module->attr("float0")); + return float0_dtype->is(arg.attr("dtype")); +} + +std::string PyArgSignature::DebugString() const { + std::string result = ""; + if (weak_type) { + absl::StrAppend(&result, "weak_"); + } + absl::StrAppend(&result, xla::PrimitiveType_Name(dtype)); + absl::StrAppend(&result, "[", absl::StrJoin(shape, ","), "]"); + return result; +} + +using ToPyArgSignatureHandler = + std::function(nb::handle, bool)>; + +absl::StatusOr PyArgSignatureOfValue(nb::handle arg, + bool jax_enable_x64) { + static const absl::flat_hash_map* const + handlers = [] { + auto p = new absl::flat_hash_map(); + + const NumpyScalarTypes& dtypes = GetNumpyScalarTypes(); + + // The 4 Python native types. + ToPyArgSignatureHandler bool_handler = + [](nb::handle, bool) -> absl::StatusOr { + return PyArgSignature(PrimitiveType::PRED, {}, true); + }; + ToPyArgSignatureHandler int_handler = + [](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { + // TODO(phawkins): we should consider checking for integer overflow. + if (jax_enable_x64) { + return PyArgSignature(PrimitiveType::S64, {}, true); + } else { + return PyArgSignature(PrimitiveType::S32, {}, true); + } + }; + ToPyArgSignatureHandler float_handler = + [&dtypes](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { + // Only Python native types has a True weak_type. + bool weak_type = !nb::isinstance(h, dtypes.np_float64); + if (jax_enable_x64) { + return PyArgSignature(PrimitiveType::F64, {}, weak_type); + } else { + return PyArgSignature(PrimitiveType::F32, {}, weak_type); + } + }; + ToPyArgSignatureHandler complex_handler = + [&dtypes](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { + // Note that this branch is also taken for np.complex128: + // isinstance(np.complex128(3), complex) returns True + // isinstance(np.complex64(3), complex) returns False + bool weak_type = !nb::isinstance(h, dtypes.np_complex128); + if (jax_enable_x64) { + return PyArgSignature(PrimitiveType::C128, {}, weak_type); + } else { + return PyArgSignature(PrimitiveType::C64, {}, weak_type); + } + }; + + (*p)[reinterpret_cast(&PyBool_Type)] = bool_handler; + (*p)[reinterpret_cast(&PyLong_Type)] = int_handler; + (*p)[reinterpret_cast(&PyFloat_Type)] = float_handler; + (*p)[reinterpret_cast(&PyComplex_Type)] = complex_handler; + + ToPyArgSignatureHandler numpy_handler = + [](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { + xla::nb_numpy_ndarray numpy_array = + nb::cast(h); + TF_ASSIGN_OR_RETURN(PrimitiveType dtype, + DtypeToPrimitiveType(numpy_array.dtype())); + if (!jax_enable_x64) { + dtype = Squash64BitTypes(dtype); + } + // We use reinterpret_cast<> to defend against environments where + // ssize_t may not be precisely the same type as int64_t, even if it + // is the same size (long vs long long). + static_assert(sizeof(int64_t) == sizeof(ssize_t), + "Code assumes ssize_t is the same as int64_t"); + return PyArgSignature( + dtype, + absl::MakeConstSpan( + reinterpret_cast(numpy_array.shape()), + numpy_array.ndim()), + /*weak_type=*/false); + }; + (*p)[reinterpret_cast(&PyArray_Type)] = numpy_handler; + + ToPyArgSignatureHandler np_uint64_handler = + [](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { + if (jax_enable_x64) { + return PyArgSignature(PrimitiveType::U64, {}, /*weak_type=*/false); + } else { + return PyArgSignature(PrimitiveType::U32, {}, /*weak_type=*/false); + } + }; + ToPyArgSignatureHandler np_int_handler = + [](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { + if (jax_enable_x64) { + return PyArgSignature(PrimitiveType::S64, {}, /*weak_type=*/false); + } else { + return PyArgSignature(PrimitiveType::S32, {}, /*weak_type=*/false); + } + }; + ToPyArgSignatureHandler numpy_array_handler = + [](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { + // This block deals with all numpy scalar types, except for int64_dt, + // float64_dt and complex128_dt which are taken care of in previous if + // blocks. + TF_ASSIGN_OR_RETURN(auto dtype, + DtypeToPrimitiveType(h.attr("dtype"))); + return PyArgSignature(dtype, {}, /*weak_type=*/false); + }; + + // This block deals with all numpy scalar types, except for int64_dt, + // float64_dt and complex128_dt which are taken care of in previous if + // blocks. + (*p)[dtypes.np_bool.ptr()] = numpy_array_handler; + (*p)[dtypes.np_int8.ptr()] = numpy_array_handler; + (*p)[dtypes.np_int16.ptr()] = numpy_array_handler; + (*p)[dtypes.np_int32.ptr()] = numpy_array_handler; + (*p)[dtypes.np_int64.ptr()] = np_int_handler; + (*p)[dtypes.np_uint8.ptr()] = numpy_array_handler; + (*p)[dtypes.np_uint16.ptr()] = numpy_array_handler; + (*p)[dtypes.np_uint32.ptr()] = numpy_array_handler; + (*p)[dtypes.np_uint64.ptr()] = np_uint64_handler; + // TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. + // (*p)[dtypes.np_float4_e2m1fn.ptr()] = numpy_array_handler; + // (*p)[dtypes.np_float8_e3m4.ptr()] = numpy_array_handler; + // (*p)[dtypes.np_float8_e4m3.ptr()] = numpy_array_handler; + // (*p)[dtypes.np_float8_e8m0fnu.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float8_e4m3fn.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float8_e4m3b11fnuz.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float8_e5m2.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float8_e4m3fnuz.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float8_e5m2fnuz.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float16.ptr()] = numpy_array_handler; + (*p)[dtypes.np_bfloat16.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float32.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float64.ptr()] = float_handler; + (*p)[dtypes.np_complex64.ptr()] = numpy_array_handler; + (*p)[dtypes.np_complex128.ptr()] = complex_handler; + (*p)[dtypes.np_longlong.ptr()] = np_int_handler; + (*p)[dtypes.np_intc.ptr()] = numpy_array_handler; + + return p; + }(); + + if (arg.type().ptr() == PyArray::type().ptr()) { + auto array = nb::borrow(arg); + ifrt::Array* ifrt_array = array.ifrt_array(); + if (ifrt_array == nullptr) { + return xla::InvalidArgument("Array has been deleted."); + } + TF_ASSIGN_OR_RETURN(auto primitive_type, + ifrt::ToPrimitiveType(ifrt_array->dtype())); + return PyArgSignature(primitive_type, array.shape(), array.weak_type()); + } + + auto res = handlers->find(arg.type().ptr()); + if (res == handlers->end()) { + // We attempt to look at the MRO classes + for (auto base_class : arg.type().attr("__mro__")) { + res = handlers->find(base_class.ptr()); + if (res != handlers->end()) { + return res->second(arg, jax_enable_x64); + } + } + return InvalidArgument( + "%s", + absl::StrCat("Not supported: The C++ ToPyArgSignature only accepts " + "Buffer/DeviceArray, Numpy " + "arrays scalars of supported types " + "(see implementation), or Python scalars. Got type ", + nb::cast(nb::str(arg.type())))); + } + return res->second(arg, jax_enable_x64); +} + +} // namespace xla diff --git a/jaxlib/xla/py_values.h b/jaxlib/xla/py_values.h new file mode 100644 index 000000000000..b64895100d8c --- /dev/null +++ b/jaxlib/xla/py_values.h @@ -0,0 +1,127 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Helpers for converting Python values into buffers. + +#ifndef JAXLIB_XLA_PY_VALUES_H_ +#define JAXLIB_XLA_PY_VALUES_H_ + +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/functional/any_invocable.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/user_context.h" +#include "xla/python/nb_numpy.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +struct DevicePutResult { + explicit DevicePutResult( + tsl::RCReference ifrt_array, bool weak_type, + nanobind::object owning_pybuffer = nanobind::object()) + : ifrt_array(std::move(ifrt_array)), + weak_type(weak_type), + owning_pybuffer(owning_pybuffer) {} + + // Disallow copy since copying `DevicePutResult` without holding GIL may be + // dangerous due to `owning_pybuffer`. + DevicePutResult(const DevicePutResult&) = delete; + DevicePutResult& operator=(const DevicePutResult&) = delete; + DevicePutResult(DevicePutResult&&) noexcept = default; + DevicePutResult& operator=(DevicePutResult&&) noexcept = default; + + // Points to the on-device array. Not owned. + tsl::RCReference ifrt_array; + bool weak_type; + + nanobind::object owning_pybuffer; +}; + +// Copies a buffer-like object to be on device. +// +// If `arg` is not convertible to a `PjRtBuffer` from C++, an error will be +// returned; float0s are not supported yet. +// If the value is known to be a PyBuffer object, py_buffer can be passed as +// an optimization to avoid a Python->C++ cast. +// +// This function performs Python work inline but postpones C++ work until the +// returned function is called. The returned function must be called after +// releasing GIL. Useful for batching GIL release when there are many device_put +// to execute. +// +// May throw exceptions from nanobind in addition to failing via an error +// absl::Status. (We could catch these if needed, but there seems little point.) +struct DevicePutOptions { + bool squash_64bit_types = false; + bool allow_zero_copy = true; + tsl::RCReference ifrt_user_context; +}; +using DevicePutResultFn = + absl::AnyInvocable() &&>; +absl::StatusOr DevicePut(nanobind::handle arg, + ifrt::Client* client, + ifrt::Device* to_device, + const DevicePutOptions& options, + ifrt::MemoryKind to_memory_kind); + +// Returns `true` if `arg` is a JAX float0 array. +bool IsFloat0(xla::nb_numpy_ndarray arg); + +// Describes the abstract shape and dtype of an argument. +struct PyArgSignature { + PyArgSignature(PrimitiveType dtype, absl::Span shape, + bool weak_type) + : dtype(dtype), shape(shape.begin(), shape.end()), weak_type(weak_type) {} + // This is the XLA dtype of the object. + const PrimitiveType dtype; + const absl::InlinedVector shape; + // JAX arguments can be of weak type, if and only if they are Python scalars + // or `DeviceArray` values such that `aval.weak_type` is true. + const bool weak_type; + bool operator==(const PyArgSignature& other) const { + return std::tie(dtype, weak_type, shape) == + std::tie(other.dtype, other.weak_type, other.shape); + } + bool operator!=(const PyArgSignature& other) const { + return !(*this == other); + } + std::string DebugString() const; +}; + +// Returns the PyArgSignature associated with an argument. Returns an error if +// the argument is not supported. +absl::StatusOr PyArgSignatureOfValue(nanobind::handle arg, + bool jax_enable_x64); + +template +H AbslHashValue(H h, const xla::PyArgSignature& s) { + h = H::combine(std::move(h), s.dtype); + h = H::combine_contiguous(std::move(h), s.shape.data(), s.shape.size()); + return h; +} +} // namespace xla + +#endif // JAXLIB_XLA_PY_VALUES_H_ diff --git a/jaxlib/xla/sharded_device_array.h b/jaxlib/xla/sharded_device_array.h new file mode 100644 index 000000000000..1b0ca20aa1fc --- /dev/null +++ b/jaxlib/xla/sharded_device_array.h @@ -0,0 +1,217 @@ +/* Copyright 2021 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_SHARDED_DEVICE_ARRAY_H_ +#define JAXLIB_XLA_SHARDED_DEVICE_ARRAY_H_ + +#include +#include +#include + +#include "absl/types/variant.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "xla/python/types.h" + +// TODO(jblespiau): The current implementation moves the Python logic to C++, +// as a preliminary step to executing the `pmap` execution path from C++. +// It implements the current Python behavior (thus, it may not be optimal, and +// we will be able to modify it later). + +namespace jax { + +// High level introduction. +// +// pmap and other parallel computation functions distribute some computation on +// several devices. On December 2020, the devices mesh (i.e. N-dimentional array +// of devices on which we map the computation) is defined by the user. +// +// We describe how to shard the inputs, and how to map it to the mesh of devices +// using `ShardingSpec`. It's mainly based on 2 components: +// - `sharding`, which specifies how to shard the inputs. +// - `mesh_mapping`, which specifies how to map shards to devices. +// +// The 3 following structs define how to shard one dimension of an ndarry. +// +// `NoSharding` (`None` in Python) means no sharding. +struct NoSharding { + bool operator==(const NoSharding& other) const { return true; } + bool operator!=(const NoSharding& other) const { return false; } +}; + +template +H AbslHashValue(H h, const NoSharding& key) { + return h; +} + +// `Chunked` means that the dimension is split into np.prod(chunks) chunks +// and the split dimension itself is preserved inside the map. +// Those chunks are distributed over `len(chunks)` ShardedAxes axes +// (major-to-minor). +// For example, for a tensor `t` of shape [N] sharded using [Chunked([p])] (with +// p dividing N, let S = N // p) the tensor will be split into p chunks of +// shape [S], such sharded_t[k] = t[k * S: (k+1)*S] (left included, right +// excluded) for k in {0, ... p-1}. +struct Chunked { + public: + explicit Chunked(std::vector chunks_) : chunks(std::move(chunks_)) {} + // The number of chunks per axis. + std::vector chunks; + + bool operator==(const Chunked& other) const { return chunks == other.chunks; } + bool operator!=(const Chunked& other) const { return chunks != other.chunks; } +}; + +template +H AbslHashValue(H h, const Chunked& key) { + h = H::combine(std::move(h), key.chunks); + return h; +} + +// `Unstacked` means that the dimension is split into chunks of size 1, and +// doesn't appear inside the map. `size` is always the dimension size. +// For example, a Tensor t of shape [N] will be sharded into N tensors of shape +// [], when using `Unstacked(N)`. +struct Unstacked { + public: + explicit Unstacked(int sz) : size(sz) {} + int size; + + bool operator==(const Unstacked& other) const { return size == other.size; } + bool operator!=(const Unstacked& other) const { return size != other.size; } +}; + +template +H AbslHashValue(H h, const Unstacked& key) { + h = H::combine(std::move(h), key.size); + return h; +} + +using AvalDimSharding = std::variant; + +// Assigns sharded axes to mesh dimensions. +// +// The devices will be for each dimension which has a sharded `AvalDimSharding` +// When no axis is assigned, the data is replicated. +// As indices are 0-indexed, `ShardedAxis(1)` refers to the second actually +// sharded axis (i.e. counting as if the None dimensions of sharding were +// filtered out). +// For example, given the sharding `[Unstacked(n), None, Chunked(m)]`, an entry +// of `ShardedAxis(1)` refers to the `Chunked(m)` axis, not the `None`. + +struct ShardedAxis { + int axis; + bool operator==(const ShardedAxis& other) const { return axis == other.axis; } + bool operator!=(const ShardedAxis& other) const { return axis != other.axis; } +}; + +template +H AbslHashValue(H h, const ShardedAxis& key) { + h = H::combine(std::move(h), key.axis); + return h; +} + +struct Replicated { + int replicas; + bool operator==(const Replicated& other) const { + return replicas == other.replicas; + } + bool operator!=(const Replicated& other) const { + return replicas != other.replicas; + } +}; + +template +H AbslHashValue(H h, const Replicated& key) { + h = H::combine(std::move(h), key.replicas); + return h; +} + +using MeshDimAssignment = std::variant; + +// Describes how each axis is sharded (if it is), and how it's mapped to the +// devices mesh. See Jax pxla.py for the documentation. +// +// ShardingSpec is shared across pmap, pjit and xpmap. For pmap, an input +// `sharding` is composed of `NoSharding` and at most one `Unstacked`. +// If `axis_size=None`, at least one the inputs has a dimension associated to +// `Unstacked`. +// +// Examples: +// +// 1. For pmap, with a tensor of shape [8, 2, 2], to unstack along the first +// dimension into [8] devices: +// +// sharding = [Unstacked(8), NoSharding, NoSharding] +// mesh_mapping = [ShardedAxis(0)] +// +// 2. With an input array of shape [6], that we want to chunk into [2, 3] +// Assuming an device mesh [3, 4, 2] of devices, we will have: +// +// sharding = [Chunked([2, 3])] +// mesh_mapping = [ShardedAxis(1), Replicated, ShardedAxis(0)] +// +// In particular, in the above example, the ShardedAxis refers to indices +// of the sharded shape [2, 3]. (only the `Chunked` sharding can produce more +// than one dimension). +class ShardingSpec { + public: + ShardingSpec(std::vector sharding, + std::vector mesh_mapping) + : sharding_(std::move(sharding)), + mesh_mapping_(std::move(mesh_mapping)) {} + ShardingSpec(nanobind::iterable py_sharding, + nanobind::iterable py_mesh_mapping) + : sharding_(xla::IterableToVector(py_sharding)), + mesh_mapping_( + xla::IterableToVector(py_mesh_mapping)) {} + + const std::vector& GetSharding() const { return sharding_; } + const std::vector& GetMeshMapping() const { + return mesh_mapping_; + } + + bool operator==(const ShardingSpec& other) const { + return sharding_ == other.sharding_ && mesh_mapping_ == other.mesh_mapping_; + } + + bool operator!=(const ShardingSpec& other) const { return !(*this == other); } + + template + friend H AbslHashValue(H h, const ShardingSpec& key); + + private: + // `sharding` specifies how the array is supposed to get partitioned into + // chunks. Its length matchs the rank of the array. See the docstring + // of `AvalDimSharding` for the supported partitioning schemes. + std::vector sharding_; + // `mesh_mapping` describes an assignments of the array chunks created by + // `sharding` to a logical device mesh. The length of the tuple is equal to + // the rank of the mesh. Each mesh dimension can either get partitions of + // data varying along one of the sharded dimensions, or the data can be + // replicated. + std::vector mesh_mapping_; +}; + +template +H AbslHashValue(H h, const ShardingSpec& key) { + h = H::combine(std::move(h), key.sharding_); + h = H::combine(std::move(h), key.mesh_mapping_); + return h; +} + +} // namespace jax + +#endif // JAXLIB_XLA_SHARDED_DEVICE_ARRAY_H_ diff --git a/jaxlib/xla/sharding.cc b/jaxlib/xla/sharding.cc new file mode 100644 index 000000000000..9952c31bd393 --- /dev/null +++ b/jaxlib/xla/sharding.cc @@ -0,0 +1,346 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/sharding.h" + +#include + +#include +#include +#include +#include + +#include "absl/hash/hash.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "jaxlib/xla/py_client.h" +#include "jaxlib/xla/py_device_list.h" +#include "jaxlib/xla/sharded_device_array.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/python/nb_numpy.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/xla_data.pb.h" + +namespace jax { + +namespace nb = nanobind; + +// Gets `jax::PyDeviceList` from a JAX Sharding. +absl::StatusOr> GetPyDeviceList( + nb::handle sharding_py) { + nb::handle sharding(sharding_py.ptr()); + if (sharding.type().is(jax::NamedSharding::type())) { + TF_ASSIGN_OR_RETURN( + auto ns_device_list, + nb::cast(sharding)->internal_device_list()); + return ns_device_list; + } else if (sharding.type().is(jax::SingleDeviceSharding::type())) { + return nb::cast(sharding) + ->internal_device_list(); + } else if (sharding.type().is(jax::PmapSharding::type())) { + return nb::cast(sharding)->internal_device_list(); + } else if (sharding.type().is(jax::GSPMDSharding::type())) { + return nb::cast(sharding) + ->internal_device_list(); + } else { + return nb::cast>( + sharding.attr("_internal_device_list")); + } +} + +nb::object CheckAndCanonicalizeMemoryKind( + nb::object memory_kind, + const xla::nb_class_ptr& device_list) { + if (!memory_kind.is_none()) { + // If memory kind is not None, check if it's supported by the devices + // mentioned in the Sharding. + auto supported_memory_kinds = PyDeviceList::MemoryKinds(device_list); + if (!supported_memory_kinds.ok()) { + supported_memory_kinds = nb::tuple(); + } + for (nb::handle supported_memory_kind : *supported_memory_kinds) { + if (supported_memory_kind.equal(memory_kind)) { + return memory_kind; + } + } + auto addressable_device_list = + PyDeviceList::AddressableDeviceList(device_list); + if (addressable_device_list->Len() == 0) { + // If the device list is not addressable, we can't check if the memory + // kind is supported, so we assume it is. + return memory_kind; + } + nb::object device_kind = + addressable_device_list->GetItem(0).attr("device_kind"); + absl::string_view device_kind_str = + nb::cast(device_kind); + auto py_str_formatter = [](std::string* out, nb::handle h) { + *out += nb::cast(nb::str(h)); + }; + throw nb::value_error( + absl::StrCat( + "Could not find memory addressable by device ", device_kind_str, + ". Device ", device_kind_str, + " can address the following memory kinds: ", + absl::StrJoin(*supported_memory_kinds, ", ", py_str_formatter), + ". Got memory kind: ", nb::cast(memory_kind)) + .c_str()); + } + // If memory kind is None, canonicalize to default memory. + absl::StatusOr default_memory_kind = + PyDeviceList::DefaultMemoryKind(device_list); + if (!default_memory_kind.ok()) { + return nb::none(); + } + return *std::move(default_memory_kind); +} + +int Sharding::SafeNumDevices(nb::handle sharding) { + const jax::Sharding* cpp_sharding; + if (nb::try_cast(sharding, cpp_sharding)) { + if (cpp_sharding->num_devices_.has_value()) { + return (*cpp_sharding->num_devices_); + } + } + nb::set device_set = sharding.attr("device_set"); + return device_set.size(); +} + +size_t ShardingHash(nb::handle sharding) { + auto type = sharding.type(); + + if (type.is(NamedSharding::type())) { + const auto* named_sharding = nb::inst_ptr(sharding); + return absl::Hash()(named_sharding->mesh().ptr()); + } + + if (type.is(GSPMDSharding::type())) { + auto* gspmd_sharding = nb::inst_ptr(sharding); + return gspmd_sharding->Hash(); + } + + if (type.is(SingleDeviceSharding::type())) { + auto* single_device_sharding = nb::inst_ptr(sharding); + return absl::Hash()(single_device_sharding->device().ptr()); + } + + return nb::hash(sharding); +} + +bool ShardingEqual(nb::handle a, nb::handle b) { + if (a.ptr() == b.ptr()) return true; + + auto a_type = a.type(); + auto b_type = b.type(); + + if (!a_type.is(b_type)) return false; + + if (a_type.is(NamedSharding::type())) { + auto* a_named_sharding = nb::inst_ptr(a); + auto* b_named_sharding = nb::inst_ptr(b); + + return a_named_sharding->mesh().ptr() == b_named_sharding->mesh().ptr() && + a_named_sharding->spec().equal(b_named_sharding->spec()) && + a_named_sharding->memory_kind().equal( + b_named_sharding->memory_kind()) && + a_named_sharding->manual_axes().equal( + b_named_sharding->manual_axes()) && + a_named_sharding->logical_device_ids().equal( + b_named_sharding->logical_device_ids()); + } + + if (a_type.is(GSPMDSharding::type())) { + auto* a_gspmd_sharding = nb::inst_ptr(a); + auto* b_gspmd_sharding = nb::inst_ptr(b); + + return a_gspmd_sharding == b_gspmd_sharding; + } + + if (a_type.is(SingleDeviceSharding::type())) { + auto* a_single_device_sharding = + nb::inst_ptr(a); + auto* b_single_device_sharding = + nb::inst_ptr(b); + + return a_single_device_sharding->device().ptr() == + b_single_device_sharding->device().ptr() && + a_single_device_sharding->memory_kind().equal( + b_single_device_sharding->memory_kind()); + } + + return a.equal(b); +} + +NamedSharding::NamedSharding(nb::object mesh, nb::object spec, + nb::object memory_kind, nb::object manual_axes, + nb::object logical_device_ids) + : Sharding(/*num_devices=*/[&mesh]() { + return nb::cast(mesh.attr("size")); + }()), + mesh_(std::move(mesh)), + spec_(std::move(spec)), + memory_kind_(std::move(memory_kind)), + manual_axes_(std::move(manual_axes)), + logical_device_ids_(std::move(logical_device_ids)) { + if (spec_.is_none()) { + throw nb::type_error( + "Unexpected None passed as spec for NamedSharding. Did you mean P()?"); + } + nb::object idl = nb::object(mesh_.attr("_internal_device_list")); + if (idl.is_none()) { + internal_device_list_ = std::nullopt; + } else { + internal_device_list_ = nb::cast>(idl); + } + if (internal_device_list_) { + memory_kind_ = + CheckAndCanonicalizeMemoryKind(memory_kind_, *internal_device_list_); + } else { + memory_kind_ = nb::none(); + } + + // TODO(phawkins): this leaks a reference to the check_pspec function. + // A better way to fix this would be to move PartitionSpec and this check into + // C++. + static nb::object* check_pspec = []() { + nb::module_ si = nb::module_::import_("jax._src.named_sharding"); + return new nb::object(si.attr("check_pspec")); + }(); + (*check_pspec)(mesh_, spec_, manual_axes_); +} + +SingleDeviceSharding::SingleDeviceSharding(nb::object device, + nb::object memory_kind) + : Sharding(/*num_devices=*/1), + device_(device), + memory_kind_(std::move(memory_kind)), + internal_device_list_( + xla::make_nb_class(nb::make_tuple(std::move(device)))) { + memory_kind_ = + CheckAndCanonicalizeMemoryKind(memory_kind_, internal_device_list_); +} + +SingleDeviceSharding::SingleDeviceSharding( + xla::nb_class_ptr client, + xla::ifrt::DeviceListRef device_list, nb::object memory_kind) + : Sharding(/*num_devices=*/1), + device_(client->GetPyDevice(device_list->devices().front())), + memory_kind_(std::move(memory_kind)), + internal_device_list_(xla::make_nb_class( + std::move(client), std::move(device_list))) { + memory_kind_ = + CheckAndCanonicalizeMemoryKind(memory_kind_, internal_device_list_); +} + +PmapSharding::PmapSharding(xla::nb_numpy_ndarray devices, + ShardingSpec sharding_spec) + : Sharding(/*num_devices=*/devices.size()), + devices_(std::move(devices)), + sharding_spec_(std::move(sharding_spec)) { + nb::object flat_devices = devices_.attr("flat"); + internal_device_list_ = + xla::make_nb_class(nb::tuple(flat_devices)); +} + +GSPMDSharding::GSPMDSharding(nb::sequence devices, xla::HloSharding op_sharding, + nb::object memory_kind, nb::object device_list) + : Sharding(/*num_devices=*/nb::len(devices.ptr())), + devices_(nb::tuple(devices)), + hlo_sharding_(std::move(op_sharding)), + memory_kind_(std::move(memory_kind)) { + if (device_list.is_none()) { + internal_device_list_ = xla::make_nb_class(devices_); + } else { + internal_device_list_ = + nb::cast>(std::move(device_list)); + } + // This checks in python if the memory kind is correct for the given + // devices. Currently in python this check is optimized but we want to + // move that check to C++ after which we can remove this call. + CHECK(devices_.size() != 0) + << "Devices given to GSPMDSharding must not be empty"; + memory_kind_ = + CheckAndCanonicalizeMemoryKind(memory_kind_, internal_device_list_); +} + +void RegisterSharding(nb::module_& m) { + nb::class_(m, "Sharding").def(nb::init<>()); + + nb::class_(m, "NamedSharding", nb::dynamic_attr()) + .def(nb::init(), + nb::arg("mesh"), nb::arg("spec").none(), + nb::arg("memory_kind").none() = nb::none(), + nb::arg("_manual_axes") = nb::steal(PyFrozenSet_New(nullptr)), + nb::arg("_logical_device_ids").none() = nb::none()) + .def_prop_ro("mesh", &NamedSharding::mesh) + .def_prop_ro("spec", &NamedSharding::spec) + .def_prop_ro("_memory_kind", &NamedSharding::memory_kind) + .def_prop_ro("_manual_axes", &NamedSharding::manual_axes) + .def_prop_ro("_logical_device_ids", &NamedSharding::logical_device_ids) + .def_prop_ro("_internal_device_list", [](const NamedSharding& s) { + return xla::ValueOrThrow(s.internal_device_list()); + }); + + nb::class_(m, "SingleDeviceSharding", + nb::dynamic_attr()) + .def(nb::init(), nb::arg("device"), + nb::arg("memory_kind").none() = nb::none()) + .def_prop_ro("_device", &SingleDeviceSharding::device) + .def_prop_ro("_memory_kind", &SingleDeviceSharding::memory_kind) + .def_prop_ro("_internal_device_list", + &SingleDeviceSharding::internal_device_list); + + nb::class_(m, "PmapSharding", nb::dynamic_attr()) + .def( + "__init__", + [](PmapSharding* self, nb::object devices, + ShardingSpec sharding_spec) { + new (self) PmapSharding(xla::nb_numpy_ndarray::ensure(devices), + std::move(sharding_spec)); + }, + nb::arg("devices"), nb::arg("sharding_spec")) + .def_prop_ro("devices", &PmapSharding::devices) + .def_prop_ro("sharding_spec", &PmapSharding::sharding_spec) + .def_prop_ro("_internal_device_list", + &PmapSharding::internal_device_list); + + nb::class_(m, "GSPMDSharding", nb::dynamic_attr()) + .def(nb::init(), + nb::arg("devices"), nb::arg("op_sharding"), + nb::arg("memory_kind").none() = nb::none(), + nb::arg("_device_list").none() = nb::none()) + .def(nb::init(), + nb::arg("devices"), nb::arg("op_sharding"), + nb::arg("memory_kind").none() = nb::none(), + nb::arg("_device_list").none() = nb::none()) + .def_prop_ro("_devices", &GSPMDSharding::devices) + .def_prop_ro("_hlo_sharding", &GSPMDSharding::hlo_sharding) + .def_prop_ro("_memory_kind", &GSPMDSharding::memory_kind) + .def_prop_ro("_internal_device_list", + &GSPMDSharding::internal_device_list); +} + +} // namespace jax diff --git a/jaxlib/xla/sharding.h b/jaxlib/xla/sharding.h new file mode 100644 index 000000000000..572a6cd3c86e --- /dev/null +++ b/jaxlib/xla/sharding.h @@ -0,0 +1,242 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_SHARDING_H_ +#define JAXLIB_XLA_SHARDING_H_ + +#include +#include +#include + +// placeholder for index annotation headers +#include "absl/hash/hash.h" +#include "absl/status/statusor.h" +#include "nanobind/nanobind.h" +#include "jaxlib/xla/py_client.h" +#include "jaxlib/xla/py_device_list.h" +#include "jaxlib/xla/sharded_device_array.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/python/nb_numpy.h" +#include "xla/xla_data.pb.h" + +namespace jax { + +class Sharding { + public: + Sharding() = default; + + // This constructor is used in the fast path to retrieve the number of devices + // without falling back to python. This is only used in the cpp path. + explicit Sharding(int num_devices) : num_devices_(num_devices) {} + + virtual ~Sharding() = default; + + static int SafeNumDevices(nanobind::handle sharding); + + private: + std::optional num_devices_; +}; + +// Gets `jax::PyDeviceList` from a JAX Sharding. +absl::StatusOr> GetPyDeviceList( + nanobind::handle sharding_py); + +// Checks if the memory kind is valid, and canonicalizes the +// memory kind to default memory on backends that support memories. +nanobind::object CheckAndCanonicalizeMemoryKind( + nanobind::object memory_kind, + const xla::nb_class_ptr& device_list); + +// Returns a hash that may sometimes return different hashes for equal values. +// It is not a correct implementation of `__hash__` in python, but it's fine +// for jit/pjit dispatch since it only causes spurious cache misses. +size_t ShardingHash(nanobind::handle sharding); + +bool ShardingEqual(nanobind::handle a, nanobind::handle b); + +class NamedSharding : public Sharding { + public: + NamedSharding(nanobind::object mesh, nanobind::object spec, + nanobind::object memory_kind, nanobind::object manual_axes, + nanobind::object logical_device_ids); + + const nanobind::object& mesh() const { return mesh_; } + const nanobind::object& spec() const { return spec_; } + const nanobind::object& memory_kind() const { return memory_kind_; } + const nanobind::object& manual_axes() const { return manual_axes_; } + const nanobind::object& logical_device_ids() const { + return logical_device_ids_; + } + + static nanobind::handle type() { + static auto type = nanobind::type(); + return type; + } + + absl::StatusOr> internal_device_list() const { + if (internal_device_list_) { + return *internal_device_list_; + } + return xla::InvalidArgument( + "internal_device_list is not implemented for " + "`jax.sharding.AbstractMesh`"); + } + + private: + nanobind::object mesh_; + nanobind::object spec_; + nanobind::object memory_kind_; + nanobind::object manual_axes_; + nanobind::object logical_device_ids_; + std::optional> internal_device_list_; +}; + +class SingleDeviceSharding : public Sharding { + public: + explicit SingleDeviceSharding( + nanobind::object device, nanobind::object memory_kind = nanobind::none()); + + // Used only in C++ to accelerate `PyArray::MakeFromSingleDeviceArray()`. + SingleDeviceSharding(xla::nb_class_ptr client, + xla::ifrt::DeviceListRef device_list, + nanobind::object memory_kind); + + const nanobind::object& device() const { return device_; } + const nanobind::object& memory_kind() const { return memory_kind_; } + + static nanobind::handle type() { + static auto type = nanobind::type(); + return type; + } + + xla::nb_class_ptr internal_device_list() const { + return internal_device_list_; + } + + private: + nanobind::object device_; + nanobind::object memory_kind_; + xla::nb_class_ptr internal_device_list_; +}; + +// The C++ implementation of jax.PmapSharding in python. It contains a few key +// data members and methods that are performance-critical. +class PmapSharding : public Sharding { + public: + PmapSharding(xla::nb_numpy_ndarray devices, ShardingSpec sharding_spec); + + ~PmapSharding() override = default; + + xla::nb_numpy_ndarray devices() const { return devices_; } + + const ShardingSpec& sharding_spec() const { return sharding_spec_; } + + static nanobind::handle type() { + static auto type = nanobind::type(); + return type; + } + + xla::nb_class_ptr internal_device_list() const { + return internal_device_list_; + } + + private: + xla::nb_numpy_ndarray devices_; + ShardingSpec sharding_spec_; + xla::nb_class_ptr internal_device_list_; +}; + +class GSPMDSharding : public Sharding { + public: + GSPMDSharding(nanobind::sequence devices, xla::OpSharding op_sharding, + nanobind::object memory_kind, nanobind::object device_list) + : GSPMDSharding( + std::move(devices), + xla::ValueOrThrow(xla::HloSharding::FromProto(op_sharding)), + std::move(memory_kind), std::move(device_list)) {} + + GSPMDSharding(nanobind::sequence devices, xla::HloSharding op_sharding, + nanobind::object memory_kind, nanobind::object device_list); + + const nanobind::tuple& devices() const { return devices_; } + const nanobind::object& memory_kind() const { return memory_kind_; } + + size_t Hash() { + if (!hash_.has_value()) { + hash_ = CalculateHash(); + } + return *hash_; + } + + static nanobind::handle type() { + static auto type = nanobind::type(); + return type; + } + + const xla::HloSharding& hlo_sharding() const { return hlo_sharding_; } + + bool operator==(const GSPMDSharding& other) const { + return AreOpShardingsEqual(*this, other) && + this->devices().equal(other.devices()) && + this->memory_kind().equal(other.memory_kind()); + } + + xla::nb_class_ptr internal_device_list() const { + return internal_device_list_; + } + + private: + size_t CalculateHash() const { + // We only hash `hlo_sharding_` here for performance. + return absl::Hash()(hlo_sharding_); + } + + static bool AreOpShardingsEqual(const GSPMDSharding& a, + const GSPMDSharding& b) { + // If the OpSharding object is the same, return true + if (&a.hlo_sharding() == &b.hlo_sharding()) { + return true; + } + // If both OpShardings are replicated, return true + if (a.IsOpShardingReplicated() && b.IsOpShardingReplicated()) { + return true; + } + return a.hlo_sharding() == b.hlo_sharding(); + } + + bool IsOpShardingReplicated() const { + // For JAX, shardings with 1 device are considered as replicated in its + // semantics so that downstream things continue to work. + if (hlo_sharding_.tile_assignment().num_elements() == 1) { + return true; + } + return hlo_sharding().IsReplicated(); + } + + nanobind::tuple devices_; + xla::HloSharding hlo_sharding_; + nanobind::object memory_kind_; + std::optional hash_; + xla::nb_class_ptr internal_device_list_; +}; + +void RegisterSharding(nanobind::module_& m); + +} // namespace jax + +#endif // JAXLIB_XLA_SHARDING_H_ diff --git a/jaxlib/xla/to_ifrt_sharding.cc b/jaxlib/xla/to_ifrt_sharding.cc new file mode 100644 index 000000000000..96ec9c77071d --- /dev/null +++ b/jaxlib/xla/to_ifrt_sharding.cc @@ -0,0 +1,141 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/to_ifrt_sharding.h" + +#include +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "jaxlib/xla/py_device_list.h" +#include "jaxlib/xla/sharding.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" +#include "xla/python/pjrt_ifrt/xla_sharding.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/statusor.h" + +namespace xla { + +namespace nb = ::nanobind; + +// Gets `xla::HloSharding` from a JAX Sharding. +xla::HloSharding GetXlaHloSharding(nb::handle sharding, + int64_t num_dimensions) { + if (sharding.type().is(nb::handle(jax::GSPMDSharding::type().ptr()))) { + return nb::cast(nb::handle(sharding.ptr())) + ->hlo_sharding(); + } else { + return nb::cast( + sharding.attr("_to_xla_hlo_sharding")(num_dimensions)); + } +} + +// Gets `xla::ifrt::DeviceList` from a JAX Sharding. +absl::StatusOr GetIfrtDeviceList( + nb::handle sharding_py) { + TF_ASSIGN_OR_RETURN(auto py_device_list, jax::GetPyDeviceList(sharding_py)); + return py_device_list->ifrt_device_list(); +} + +// Gets `xla::ifrt::MemoryKind` from a JAX Sharding. +xla::ifrt::MemoryKind GetMemoryKind(nb::handle sharding) { + nb::object py_memory_kind = nb::none(); + + // sharding.attr("memory_kind") can crash if sharding was originally created + // from C++ and casted into a Python Sharding object. Thus, we cast sharding + // to a C++ type and use C++ `memory_kind()` method, which bypasses any Python + // attribute access. + nb::handle type = sharding.type(); + if (type.is(jax::NamedSharding::type())) { + py_memory_kind = + nb::cast(sharding)->memory_kind(); + } else if (type.is(jax::SingleDeviceSharding::type())) { + py_memory_kind = + nb::cast(sharding)->memory_kind(); + } else if (type.is(jax::GSPMDSharding::type())) { + py_memory_kind = + nb::cast(sharding)->memory_kind(); + } else { + py_memory_kind = sharding.attr("memory_kind"); + } + + if (py_memory_kind.is_none()) { + return xla::ifrt::MemoryKind(); + } + return xla::ifrt::MemoryKind(nb::cast(py_memory_kind)); +} + +// Converts a JAX Sharding into `xla::ifrt::HloSharding`. +absl::StatusOr> GetIfrtHloSharding( + nb::handle sharding, const xla::ifrt::Shape& shape) { + TF_ASSIGN_OR_RETURN(xla::ifrt::DeviceListRef device_list, + GetIfrtDeviceList(sharding)); + xla::ifrt::MemoryKind memory_kind = GetMemoryKind(sharding.ptr()); + xla::HloSharding hlo_sharding = + GetXlaHloSharding(sharding, shape.dims().size()); + return xla::ifrt::HloSharding::Create( + std::move(device_list), std::move(memory_kind), std::move(hlo_sharding)); +} + +// Converts a JAX Sharding into `xla::ifrt::ConcreteEvenSharding`. +absl::StatusOr> +GetIfrtConcreteEvenSharding(nb::handle sharding, xla::ifrt::DType dtype, + const xla::ifrt::Shape& shape) { + TF_ASSIGN_OR_RETURN(xla::ifrt::DeviceListRef device_list, + GetIfrtDeviceList(sharding)); + xla::ifrt::MemoryKind memory_kind = GetMemoryKind(sharding.ptr()); + TF_ASSIGN_OR_RETURN(xla::PrimitiveType xla_primitive_type, + xla::ifrt::ToPrimitiveType(dtype)); + // The XLA shape's layout is irrelevant because we only need to know the + // tile shape, which is independent from the layout. + xla::Shape xla_shape = xla::ShapeUtil::MakeShapeWithDescendingLayout( + xla_primitive_type, shape.dims()); + xla::HloSharding hlo_sharding = + GetXlaHloSharding(sharding, shape.dims().size()); + xla::Shape tile_shape = hlo_sharding.TileShape(xla_shape); + xla::ifrt::Shape shard_shape(xla::ifrt::Shape::Dimensions( + tile_shape.dimensions().begin(), tile_shape.dimensions().end())); + return xla::ifrt::ConcreteEvenSharding::Create( + std::move(device_list), std::move(memory_kind), shape, + /*shard_shape=*/std::move(shard_shape)); +} + +// Converts a JAX Sharding into `xla::ifrt::ConcreteSharding`. +absl::StatusOr> +GetIfrtConcreteSharding(nb::handle sharding, const xla::ifrt::Shape& shape, + std::vector shard_shapes) { + TF_ASSIGN_OR_RETURN(xla::ifrt::DeviceListRef device_list, + GetIfrtDeviceList(sharding)); + xla::ifrt::MemoryKind memory_kind = GetMemoryKind(sharding.ptr()); + return xla::ifrt::ConcreteSharding::Create( + std::move(device_list), std::move(memory_kind), shape, + /*shard_shapes=*/std::move(shard_shapes)); +} + +} // namespace xla diff --git a/jaxlib/xla/to_ifrt_sharding.h b/jaxlib/xla/to_ifrt_sharding.h new file mode 100644 index 000000000000..0fa7f17c4563 --- /dev/null +++ b/jaxlib/xla/to_ifrt_sharding.h @@ -0,0 +1,56 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_TO_IFRT_SHARDING_H_ +#define JAXLIB_XLA_TO_IFRT_SHARDING_H_ + +#include "nanobind/nanobind.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/tsl/platform/statusor.h" + +namespace xla { + +// Gets `xla::HloSharding` from a JAX Sharding. +xla::HloSharding GetXlaHloSharding(nanobind::handle sharding, + int64_t num_dimensions); + +// Gets `xla::ifrt::DeviceList` from a JAX Sharding. +absl::StatusOr GetIfrtDeviceList( + nanobind::handle sharding_py); + +// Gets `xla::ifrt::MemoryKind` from a JAX Sharding. +xla::ifrt::MemoryKind GetMemoryKind(nanobind::handle sharding); + +// Converts a JAX Sharding into `xla::ifrt::HloSharding`. +absl::StatusOr> GetIfrtHloSharding( + nanobind::handle sharding, const xla::ifrt::Shape& shape); + +// Converts a JAX Sharding into `xla::ifrt::ConcreteEvenSharding`. +absl::StatusOr> +GetIfrtConcreteEvenSharding(nanobind::handle sharding, xla::ifrt::DType dtype, + const xla::ifrt::Shape& shape); + +// Converts a JAX Sharding into `xla::ifrt::ConcreteSharding`. +absl::StatusOr> +GetIfrtConcreteSharding(nanobind::handle sharding, + const xla::ifrt::Shape& shape, + std::vector shard_shapes); + +} // namespace xla + +#endif // JAXLIB_XLA_TO_IFRT_SHARDING_H_ diff --git a/jaxlib/xla/xla.cc b/jaxlib/xla/xla.cc index 54c94c57a734..0e1ba031670f 100644 --- a/jaxlib/xla/xla.cc +++ b/jaxlib/xla/xla.cc @@ -46,6 +46,9 @@ limitations under the License. #include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep #include "nanobind/stl/variant.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/ifrt_proxy.h" +#include "jaxlib/xla/py_client.h" +#include "jaxlib/xla/py_program.h" #include "jaxlib/xla/sdy.h" #include "xla/backends/cpu/collectives/cpu_collectives.h" #include "xla/pjrt/c/pjrt_c_api.h" @@ -61,10 +64,7 @@ limitations under the License. #include "xla/python/ifrt/device_list.h" #include "xla/python/ifrt/executable.h" #include "xla/python/ifrt/topology.h" -#include "xla/python/ifrt_proxy/client/py_module.h" #include "xla/python/pjrt_ifrt/pjrt_attribute_map_util.h" -#include "xla/python/py_client.h" -#include "xla/python/py_program.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/tsl/python/lib/core/numpy.h" // NOLINT @@ -91,7 +91,14 @@ limitations under the License. #include "jaxlib/xla/mlir.h" #include "jaxlib/xla/pjit.h" #include "jaxlib/xla/pmap_lib.h" +#include "jaxlib/xla/py_array.h" +#include "jaxlib/xla/py_compile_only_client.h" +#include "jaxlib/xla/py_device.h" +#include "jaxlib/xla/py_device_list.h" +#include "jaxlib/xla/py_executable.h" +#include "jaxlib/xla/py_memory_space.h" #include "jaxlib/xla/pytree.h" +#include "jaxlib/xla/sharding.h" #include "jaxlib/xla/weakref_lru_cache.h" #include "jaxlib/xla/xla_compiler.h" #include "xla/pjrt/distributed/key_value_store_interface.h" @@ -113,14 +120,7 @@ limitations under the License. #include "xla/python/pjrt_ifrt/pjrt_topology.h" #include "xla/python/pprof_profile_builder.h" #include "xla/python/profiler.h" -#include "xla/python/py_array.h" -#include "xla/python/py_compile_only_client.h" -#include "xla/python/py_device.h" -#include "xla/python/py_device_list.h" -#include "xla/python/py_executable.h" -#include "xla/python/py_memory_space.h" #include "xla/python/python_ref_manager.h" -#include "xla/python/sharding.h" #include "xla/python/traceback.h" #include "xla/tsl/distributed_runtime/preemption/preemption_sync_manager.h" #include "xla/tsl/platform/status.h" diff --git a/jaxlib/xla/xla_compiler.cc b/jaxlib/xla/xla_compiler.cc index f4719b450988..00f8b4c295a7 100644 --- a/jaxlib/xla/xla_compiler.cc +++ b/jaxlib/xla/xla_compiler.cc @@ -44,6 +44,7 @@ limitations under the License. #include "nanobind/stl/variant.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep #include "jaxlib/xla/dlpack.h" +#include "jaxlib/xla/py_client.h" #include "xla/array.h" #include "xla/client/executable_build_options.h" #include "xla/debug_options_flags.h" @@ -71,7 +72,6 @@ limitations under the License. #include "xla/python/nb_absl_span.h" // IWYU pragma: keep #include "xla/python/nb_helpers.h" #include "xla/python/nb_numpy.h" -#include "xla/python/py_client.h" #include "xla/python/types.h" #include "xla/service/call_inliner.h" #include "xla/service/computation_placer.h" From 0a53c9aad23e4b63843c64f6b1af3652a22a16e4 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 25 Mar 2025 14:10:26 -0700 Subject: [PATCH 158/483] [pallas:mosaic_gpu] Updated the tests to use `plgpu.kernel` It leads to much more compact kernel definitions, just look at the diff! The combination of `pl.core_map` and `pl.run_state` is too noisy to easily follow the kernel logic. PiperOrigin-RevId: 740479934 --- jax/_src/pallas/mosaic_gpu/core.py | 2 +- jax/_src/pallas/mosaic_gpu/pipeline.py | 6 +- tests/pallas/mosaic_gpu_test.py | 336 ++++++++++++------------- 3 files changed, 161 insertions(+), 183 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 1e4a9de1830c..99a84962ae50 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -120,7 +120,7 @@ def __call__( return GPUMemoryRef(shape, dtype, memory_space=self, transforms=transforms) -def kernel(body, out_shape, compiler_params=None, **mesh_kwargs): +def kernel(body, out_shape, *, compiler_params=None, **mesh_kwargs): if unwrap_out := not isinstance(out_shape, (tuple, list)): out_shape = (out_shape,) def wrapper(*operands): diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index a48fec61b7af..d85ba4ae2a03 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -360,8 +360,8 @@ def emit_pipeline_warp_specialized( *, grid: pallas_core.StaticGrid, memory_registers: int, - in_specs: Sequence[gpu_core.GPUBlockSpec] = (), - out_specs: Sequence[gpu_core.GPUBlockSpec] = (), + in_specs: Sequence[pl.BlockSpec] = (), + out_specs: Sequence[pl.BlockSpec] = (), max_concurrent_steps: int = 2, wg_axis: str, num_compute_wgs: int, @@ -458,7 +458,7 @@ def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef): gpu_core.SMEM( (slots, *spec.block_shape), # type: ignore gmem_ref.dtype, - transforms=spec.transforms, + transforms=getattr(spec, "transforms", ()), ) ) in_smem_refs, out_smem_refs = util.split_list( diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index b39288252e08..7f8cfa21e980 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1624,23 +1624,28 @@ def scope(acc_ref): class PallasCallSm100ATest(PallasSm100ATest): def test_tmem_alloc(self): - mesh = plgpu.GPUMesh(num_threads=1, axis_names=("x")) - @pl.run_state - def inner(y_ref): - @pl.core_map(mesh) - def _(): - def scope(tmem_ref, smem_ref): - # Issue a write so the TMEM load is not DCE'd. - smem_ref[...] = tmem_ref[...] - plgpu.commit_smem() - plgpu.copy_smem_to_gmem(smem_ref, y_ref) - plgpu.wait_smem_to_gmem(0) - pl.run_scoped(scope, + @functools.partial( + plgpu.kernel, + out_shape=jnp.zeros((128, 128), jnp.float32), + num_threads=1, + axis_names=("x",), + ) + def kernel(y_ref): + def scope(tmem_ref, smem_ref): + # Issue a write so the TMEM load is not DCE'd. + smem_ref[...] = tmem_ref[...] + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(smem_ref, y_ref) + plgpu.wait_smem_to_gmem(0) + + pl.run_scoped( + scope, plgpu.TMEM((128, 128), jnp.float32), - plgpu.SMEM((128, 128), jnp.float32)) - y_init = jnp.zeros((128, 128), np.float32) + plgpu.SMEM((128, 128), jnp.float32), + ) + # Test that this runs without errors. - jax.block_until_ready(inner(y_init)) + jax.block_until_ready(kernel()) class PipelineTest(PallasTest): @@ -1979,9 +1984,7 @@ class WarpSpecializedPipelineTest(PallasTest): manual_consumed_barriers=[False, True]) def test_pipelined_copy(self, m, n, manual_consumed_barriers): x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float16) - o = jnp.zeros((m, n), dtype=jnp.float16) blk_m = blk_n = 64 - o_last_block = jnp.zeros((blk_m, blk_n), dtype=jnp.float16) def copy_kernel(x_smem, o_smem, o_last_block_smem, *consumed_barriers): # TODO(justinfu): Have each wg compute a separate slice @@ -1992,11 +1995,10 @@ def copy_kernel(x_smem, o_smem, o_last_block_smem, *consumed_barriers): if manual_consumed_barriers: [x_barrier] = consumed_barriers plgpu.barrier_arrive(x_barrier) - block_spec = plgpu.GPUBlockSpec( - block_shape=(blk_m, blk_n), - index_map=lambda i, j: (i, j), - transforms=[], - ) + + spec = pl.BlockSpec( + block_shape=(blk_m, blk_n), index_map=lambda i, j: (i, j) + ) pipeline = mgpu_pipeline.emit_pipeline_warp_specialized( copy_kernel, grid=(m // blk_m, n // blk_n), @@ -2005,33 +2007,35 @@ def copy_kernel(x_smem, o_smem, o_last_block_smem, *consumed_barriers): num_compute_wgs=2, wg_axis="wg", manual_consumed_barriers=manual_consumed_barriers, - in_specs=[block_spec], - out_specs=[block_spec, - # Create an index-invariant output. - plgpu.GPUBlockSpec(block_shape=(blk_m, blk_n), - index_map=lambda i, j: (0, 0)) - ], - ) - mesh = plgpu.GPUMesh(grid=(1,), num_threads=3, axis_names=("_", "wg")) - def run(refs): - @pl.core_map( - mesh, compiler_params=plgpu.GPUCompilerParams(approx_math=True) - ) - def _kernel_entry(): - pipeline(*refs) - @jax.jit - def run_function(x, o, o_last_block): - _, out, out_last = pl.run_state(run)((x, o, o_last_block)) - return (out, out_last) - out, out_last_block = run_function(x, o, o_last_block) + in_specs=[spec], + out_specs=[ + spec, + # Create an index-invariant output. + pl.BlockSpec( + block_shape=(blk_m, blk_n), index_map=lambda i, j: (0, 0) + ), + ], + ) + kernel = plgpu.kernel( + pipeline, + out_shape=( + jax.ShapeDtypeStruct((m, n), jnp.float16), + jax.ShapeDtypeStruct((blk_m, blk_n), jnp.float16), + ), + compiler_params=plgpu.GPUCompilerParams(approx_math=True), + grid=(1,), + num_threads=3, + axis_names=("_", "wg"), + ) + out, out_last_block = kernel(x) np.testing.assert_array_equal(out, x) np.testing.assert_array_equal(out_last_block, x[-blk_m:, -blk_n:]) def test_elementwise_add(self, m=256, n=256, num_compute_wgs=2): blk_m = blk_n = 64 - x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float32) - y = jax.random.uniform(jax.random.key(1), (m, n), dtype=jnp.float32) - o = jnp.zeros((m, n), dtype=jnp.float32) + spec = pl.BlockSpec( + block_shape=(blk_m, blk_n), index_map=lambda i, j: (i, j) + ) def tiled_add_kernel(x_smem, y_smem, o_smem): # TODO(justinfu): Have each wg compute a separate slice @@ -2046,43 +2050,23 @@ def tiled_add_kernel(x_smem, y_smem, o_smem): num_compute_wgs=num_compute_wgs, memory_registers=40, wg_axis="wg", - in_specs=[ - plgpu.GPUBlockSpec( - block_shape=(blk_m, blk_n), - index_map=lambda i, j: (i, j), - transforms=[]), - plgpu.GPUBlockSpec( - block_shape=(blk_m, blk_n), - index_map=lambda i, j: (i, j), - transforms=[]), - ], - out_specs=[ - plgpu.GPUBlockSpec( - block_shape=(blk_m, blk_n), - index_map=lambda i, j: (i, j), - transforms=[])], + in_specs=[spec, spec], + out_specs=[spec], ) - mesh = plgpu.GPUMesh( - grid=(1,), num_threads=num_compute_wgs + 1, axis_names=("_", "wg") + kernel = plgpu.kernel( + pipeline, + out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32), + compiler_params=plgpu.GPUCompilerParams(approx_math=True), + grid=(1,), + num_threads=num_compute_wgs + 1, + axis_names=("_", "wg"), ) - def run(refs): - @pl.core_map( - mesh, compiler_params=plgpu.GPUCompilerParams(approx_math=True) - ) - def _kernel_entry(): - pipeline(*refs) - @jax.jit - def run_function(x, y, o): - _, _, out = pl.run_state(run)((x, y, o)) - return out - out = run_function(x, y, o) - reference = x + y - np.testing.assert_allclose(out, reference, atol=1e-4) + x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float32) + y = jax.random.uniform(jax.random.key(1), (m, n), dtype=jnp.float32) + np.testing.assert_allclose(kernel(x, y), x + y, atol=1e-4) def test_carry_accumulate(self, m=256, n=256, num_compute_wgs=2): blk_m = blk_n = 64 - x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float32) - acc_init = jnp.zeros((blk_m, blk_n), dtype=jnp.float32) def _scoped(acc_smem, x_gmem, acc_gmem): def _compute_thread(): @@ -2116,77 +2100,70 @@ def tiled_acc_kernel(x_smem, carry): wg_axis="wg", carry_coroutine=_compute_thread, in_specs=[ - plgpu.GPUBlockSpec( - block_shape=(blk_m, blk_n), - index_map=lambda i, j: (i, j), - transforms=[]), + pl.BlockSpec( + block_shape=(blk_m, blk_n), index_map=lambda i, j: (i, j) + ) ], out_specs=[], ) pipeline(x_gmem) - mesh = plgpu.GPUMesh( + @functools.partial( + plgpu.kernel, + out_shape=jax.ShapeDtypeStruct((blk_m, blk_n), jnp.float32), + compiler_params=plgpu.GPUCompilerParams(approx_math=True), grid=(1,), num_threads=num_compute_wgs + 1, - axis_names=("_", "wg",), - ) - def run(refs): - x_ref, acc_ref = refs - @pl.core_map(mesh) - def _kernel_entry(): - pl.run_scoped( - functools.partial(_scoped, x_gmem=x_ref, acc_gmem=acc_ref), - plgpu.SMEM((blk_m, blk_n), jnp.float32) - ) - @jax.jit - def run_function(x, acc): - _, out_acc = pl.run_state(run)((x, acc)) - return out_acc - out_acc = run_function(x, acc_init) + axis_names=("_", "wg"), + ) + def kernel(x_ref, acc_ref): + pl.run_scoped( + functools.partial(_scoped, x_gmem=x_ref, acc_gmem=acc_ref), + plgpu.SMEM((blk_m, blk_n), jnp.float32), + ) + + x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float32) ref = jnp.sum(jnp.stack(np.split(x, m // blk_m, axis=0)), axis=0) ref = jnp.sum(jnp.stack(np.split(ref, n // blk_n, axis=1)), axis=0) - np.testing.assert_allclose(out_acc, ref, atol=1e-4) + np.testing.assert_allclose(kernel(x), ref, atol=1e-4) class CoreMapTest(PallasTest): def test_multiple_wg(self): - mesh = plgpu.GPUMesh(num_threads=2, axis_names=("y",)) + @functools.partial( + plgpu.kernel, + out_shape=jnp.zeros((2, 128), np.int32), + num_threads=2, + axis_names=("wg",), + ) + def kernel(o_ref): + wg_idx = jax.lax.axis_index("wg") + o_ref[wg_idx] = jnp.broadcast_to(wg_idx, (128,)) - @jax.jit - def f(): - @pl.run_state - def inner(y_ref): - @pl.core_map(mesh) - def kernel(): - wg_idx = jax.lax.axis_index("y") - y_ref[wg_idx] = jnp.broadcast_to(wg_idx, (128,)) - y_init = jnp.zeros((2, 128), np.int32) - return inner(y_init) np.testing.assert_array_equal( - f(), np.repeat(np.arange(2), 128).reshape(2, 128) + kernel(), np.repeat(np.arange(2), 128).reshape(2, 128) ) def test_multiple_wg_with_grid(self): - mesh = plgpu.GPUMesh(grid=(2, 2), num_threads=2, axis_names=("x", "y", "wg")) + @functools.partial( + plgpu.kernel, + out_shape=jnp.zeros((4, 2, 128), np.int32), + grid=(2, 2), + num_threads=2, + axis_names=("x", "y", "wg"), + ) + def kernel(o_ref): + xy_idx = jax.lax.axis_index(("x", "y")) + yx_idx = jax.lax.axis_index(("y", "x")) + wg_idx = jax.lax.axis_index("wg") + num_wgs = jax.lax.psum(1, "wg") + o_ref[xy_idx, wg_idx] = jnp.broadcast_to( + yx_idx * num_wgs + wg_idx, (128,) + ) - @jax.jit - def f(): - @pl.run_state - def inner(y_ref): - @pl.core_map(mesh) - def kernel(): - xy_idx = jax.lax.axis_index(("x", "y")) - yx_idx = jax.lax.axis_index(("y", "x")) - wg_idx = jax.lax.axis_index("wg") - num_wgs = jax.lax.psum(1, "wg") - y_ref[xy_idx, wg_idx] = jnp.broadcast_to( - yx_idx * num_wgs + wg_idx, (128,) - ) - y_init = jnp.zeros((4, 2, 128), np.int32) - return inner(y_init) np.testing.assert_array_equal( - f(), np.repeat([0, 1, 4, 5, 2, 3, 6, 7], 128).reshape(4, 2, 128) + kernel(), np.repeat([0, 1, 4, 5, 2, 3, 6, 7], 128).reshape(4, 2, 128) ) def test_multiple_wg_with_squashed_grid(self): @@ -2197,70 +2174,71 @@ def test_multiple_wg_with_squashed_grid(self): y_dim = 5 z_dim = 7 num_threads = 2 - mesh = plgpu.GPUMesh(grid=(b, x_dim, y_dim, z_dim), - num_threads=num_threads, - axis_names=("b", "x", "y", "z", "wg")) - @jax.jit - def f(): - @pl.run_state - def inner(y_ref): - @pl.core_map(mesh) - def _(): - b_idx = jax.lax.axis_index("b") - x_idx = jax.lax.axis_index("x") - y_idx = jax.lax.axis_index("y") - z_idx = jax.lax.axis_index("z") - wg_idx = jax.lax.axis_index("wg") - bxyzw_idx = jax.lax.axis_index(("b", "x", "y", "z", "wg")) - y_ref[b_idx, x_idx, y_idx, z_idx, wg_idx] = jnp.broadcast_to( - bxyzw_idx, (128,) - ) - y_init = jnp.zeros((b, x_dim, y_dim, z_dim, num_threads, 128), np.int32) - return inner(y_init) - result = f()[:, :, :, :, :, 0] + @functools.partial( + plgpu.kernel, + out_shape=jnp.zeros( + (b, x_dim, y_dim, z_dim, num_threads, 128), np.int32 + ), + grid=(b, x_dim, y_dim, z_dim), + num_threads=num_threads, + axis_names=("b", "x", "y", "z", "wg"), + ) + def kernel(o_ref): + b_idx = jax.lax.axis_index("b") + x_idx = jax.lax.axis_index("x") + y_idx = jax.lax.axis_index("y") + z_idx = jax.lax.axis_index("z") + wg_idx = jax.lax.axis_index("wg") + bxyzw_idx = jax.lax.axis_index(("b", "x", "y", "z", "wg")) + o_ref[b_idx, x_idx, y_idx, z_idx, wg_idx] = jnp.broadcast_to( + bxyzw_idx, (128,) + ) + + result = kernel()[:, :, :, :, :, 0] ref = np.arange(b * x_dim * y_dim * z_dim * num_threads).reshape( - result.shape) + result.shape + ) np.testing.assert_array_equal(result, ref) def test_cross_wg_barrier(self): - mesh = plgpu.GPUMesh(num_threads=2, axis_names=("wg",)) + @functools.partial( + plgpu.kernel, + out_shape=jnp.zeros((2, 128), np.int32), + num_threads=2, + axis_names=("wg",), + ) + def kernel(y_ref): + def scoped(barrier): + plgpu.barrier_arrive(barrier) + plgpu.barrier_wait(barrier) + wg_idx = jax.lax.axis_index("wg") + y_ref[wg_idx] = jnp.broadcast_to(wg_idx, (128,)) - @jax.jit - def f(): - @pl.run_state - def inner(y_ref): - @pl.core_map(mesh) - def kernel(): - def scoped(barrier): - plgpu.barrier_arrive(barrier) - plgpu.barrier_wait(barrier) - wg_idx = jax.lax.axis_index("wg") - y_ref[wg_idx] = jnp.broadcast_to(wg_idx, (128,)) - # Each warpgroup is a single logical thread! - pl.run_scoped(scoped, plgpu.Barrier(num_arrivals=2)) - y_init = jnp.zeros((2, 128), np.int32) - return inner(y_init) - np.testing.assert_array_equal(f(), np.repeat([0, 1], 128).reshape(2, 128)) + # Each warpgroup is a single logical thread! + pl.run_scoped(scoped, plgpu.Barrier(num_arrivals=2)) + + np.testing.assert_array_equal( + kernel(), np.repeat([0, 1], 128).reshape(2, 128) + ) def test_cluster(self): - mesh = plgpu.GPUMesh(grid=(2,), cluster=(2,), axis_names=("x", "cluster")) + @functools.partial( + plgpu.kernel, + out_shape=jnp.zeros(128, np.int32), + grid=(2,), + cluster=(2,), + axis_names=("x", "cluster"), + ) + def kernel(ref): + block_idx = jax.lax.axis_index("x") + cluster_idx = jax.lax.axis_index("cluster") + pl.debug_print("block: {} cluster: {}", block_idx, cluster_idx) - @jax.jit - def f(): - @pl.run_state - def inner(ref): - @pl.core_map(mesh) - def kernel(): - block_idx = jax.lax.axis_index("x") - cluster_idx = jax.lax.axis_index("cluster") - pl.debug_print("block: {} cluster: {}", block_idx, cluster_idx) - - ref[...] = ref[...] - return inner(jnp.zeros(128, np.int32)) + ref[...] = ref[...] with self.capture_stdout() as output: - jax.block_until_ready(f()) + jax.block_until_ready(kernel()) self.assertEqual( set(output().splitlines()), { From e9fdf67ecc243c4fcf2344e7047367b6e79c9035 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Tue, 25 Mar 2025 14:45:33 -0700 Subject: [PATCH 159/483] [jaxlib:cpu] Cleaning up after callback FFI refactor. PiperOrigin-RevId: 740492139 --- jaxlib/xla/py_client.cc | 35 +++-------------------------------- jaxlib/xla/py_client.h | 17 ----------------- 2 files changed, 3 insertions(+), 49 deletions(-) diff --git a/jaxlib/xla/py_client.cc b/jaxlib/xla/py_client.cc index 434077b0824f..5fe6bc648e07 100644 --- a/jaxlib/xla/py_client.cc +++ b/jaxlib/xla/py_client.cc @@ -367,8 +367,7 @@ std::unique_ptr MakeIfrtCompileOptions( ifrt_loaded_host_callbacks; ifrt_loaded_host_callbacks.reserve(host_callbacks.size()); // Extract `ifrt::LoadedHostCallback`s from host callback capsules that were - // created by `PyClient::MakePythonCallbackUsingHostSendAndRecv()` or - // `PyClient::GetEmitPythonCallbackDescriptor()`. + // created by `PyClient::MakePythonCallbackUsingHostSendAndRecv()`. for (auto& host_callback : host_callbacks) { ifrt_loaded_host_callbacks.push_back(tsl::FormRef( static_cast(host_callback.data()))); @@ -386,8 +385,7 @@ MakeIfrtDeserializeExecutableOptions(std::optional options, ifrt_loaded_host_callbacks; ifrt_loaded_host_callbacks.reserve(host_callbacks.size()); // Extract `ifrt::LoadedHostCallback`s from host callback capsules that were - // created by `PyClient::MakePythonCallbackUsingHostSendAndRecv()` or - // `PyClient::GetEmitPythonCallbackDescriptor()`. + // created by `PyClient::MakePythonCallbackUsingHostSendAndRecv()`. for (auto& host_callback : host_callbacks) { ifrt_loaded_host_callbacks.push_back(tsl::FormRef( static_cast(host_callback.data()))); @@ -480,8 +478,7 @@ PyClient::CompileIfrtProgram( ifrt_loaded_host_callbacks; ifrt_loaded_host_callbacks.reserve(host_callbacks.size()); // Extract `ifrt::LoadedHostCallback`s from host callback capsules that were - // created by `PyClient::MakePythonCallbackUsingHostSendAndRecv()` or - // `PyClient::GetEmitPythonCallbackDescriptor()`. + // created by `PyClient::MakePythonCallbackUsingHostSendAndRecv()`. for (auto& host_callback : host_callbacks) { auto callback = tsl::MakeRef( client->ifrt_client(), std::move(host_callback)); @@ -660,28 +657,6 @@ absl::StatusOr PyClient::MakePythonCallbackUsingHostSendAndRecv( return callback_capsule; } -// TODO(b/394595987): Remove this API method once we remove the call from -// mlir.py's get_emit_python_callback. -absl::StatusOr> -PyClient::GetEmitPythonCallbackDescriptor( - nb::callable callable, absl::Span operand_shapes, - absl::Span result_shapes) { - TF_ASSIGN_OR_RETURN( - auto loaded_host_callback, - PyCpuLoadedHostCallback::Create(ifrt_client(), std::move(callable), - operand_shapes, result_shapes)); - const uint64_t descriptor = loaded_host_callback->descriptor(); - - nb::capsule callback_capsule( - loaded_host_callback.release(), [](void* ptr) noexcept { - static_cast(ptr)->DropRef(); - }); - return std::make_pair(descriptor, nb::object(std::move(callback_capsule))); -} - -XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("xla_python_cpu_callback", - &XlaPythonCpuCallback); - /* static */ int PyClient::tp_traverse(PyObject* self, visitproc visit, void* arg) { PyClient* c = nb::inst_ptr(self); @@ -813,10 +788,6 @@ PyType_Slot PyClient::slots_[] = { // TODO(zhangqiaorjc): Experimental. .def("defragment", [](PyClient& self) { xla::ThrowIfError(self.Defragment()); }) - .def("get_emit_python_callback_descriptor", - xla::ValueOrThrowWrapper(&PyClient::GetEmitPythonCallbackDescriptor), - nb::arg("callable"), nb::arg("operand_shapes"), - nb::arg("result_shapes").none() = nb::none()) .def("make_python_callback_from_host_send_and_recv", xla::ValueOrThrowWrapper( &PyClient::MakePythonCallbackUsingHostSendAndRecv), diff --git a/jaxlib/xla/py_client.h b/jaxlib/xla/py_client.h index 9b9d43d90228..8f50c6451627 100644 --- a/jaxlib/xla/py_client.h +++ b/jaxlib/xla/py_client.h @@ -184,23 +184,6 @@ class PyClient { absl::StatusOr HeapProfile(); - // `GetEmitPythonCallbackDescriptor` takes in an input Python callable that - // takes in arguments of shapes `operand_shapes` and returns values of shapes - // `result_shapes`. It returns a pair of a `uint64_t` descriptor and a Python - // object whose reference will keep the Python callback alive. The descriptor - // should be passed into a 'xla_python_cpu_callback' or - // 'xla_python_gpu_callback' CustomCall as its first argument. Typically the - // callback may be kept alive by attaching the keep-alive object to the - // executable built from this computation. - // - // The callable receives as arguments NumPy arrays for arguments with array - // types, and None for Token argument. The callable must return a tuple of - // either arrays or None values. - absl::StatusOr> - GetEmitPythonCallbackDescriptor(nanobind::callable callable, - absl::Span operand_shapes, - absl::Span result_shapes); - // `MakePythonCallbackUsingHostSendAndRecv` takes in an input Python callable // that takes in arguments of shapes `operand_shapes` and returns results of // shapes `result_shapes`. The arguments correspond to Send ops in the HLO From ec061566558ac9f3ca9c7966fa59f7e07e1b8d74 Mon Sep 17 00:00:00 2001 From: Jacob Burnim Date: Tue, 25 Mar 2025 14:47:33 -0700 Subject: [PATCH 160/483] [Pallas] A few fixes for TPU interpret mode: - Actually de-allocate buffers after a pl.run_scoped. - Periodically run an explicit garbage collection after de-allocating buffers. - Add no-op implementations for a few internal/testing mosaic primitives (prng_seed_p, prng_random_bits_p, assume_p, random_p). --- jax/_src/pallas/mosaic/BUILD | 1 + jax/_src/pallas/mosaic/interpret.py | 55 ++++++++++++++++++++++------- 2 files changed, 43 insertions(+), 13 deletions(-) diff --git a/jax/_src/pallas/mosaic/BUILD b/jax/_src/pallas/mosaic/BUILD index 24e8341046b0..fdd3a56ac7c8 100644 --- a/jax/_src/pallas/mosaic/BUILD +++ b/jax/_src/pallas/mosaic/BUILD @@ -158,6 +158,7 @@ py_library( deps = [ ":core", ":primitives", + ":verification", "//jax", "//jax:core", "//jax:source_info_util", diff --git a/jax/_src/pallas/mosaic/interpret.py b/jax/_src/pallas/mosaic/interpret.py index 5acbabc673aa..13e71a1f5c56 100644 --- a/jax/_src/pallas/mosaic/interpret.py +++ b/jax/_src/pallas/mosaic/interpret.py @@ -16,6 +16,7 @@ import dataclasses import enum import functools +import gc import itertools import math import threading @@ -28,8 +29,9 @@ from jax._src.lax.control_flow import for_loop from jax._src import linear_util as lu from jax._src import source_info_util -from jax._src.pallas.mosaic import primitives as mosaic_primitives from jax._src.pallas.mosaic import core as mosaic_core +from jax._src.pallas.mosaic import primitives as mosaic_primitives +from jax._src.pallas.mosaic import verification from jax._src.pallas import core as pallas_core from jax._src.pallas import primitives from jax._src import pjit @@ -477,6 +479,8 @@ class SharedMemory: next_dma_id: int = 100 + deallocated_bytes: int = 0 + # TODO(jburnim): Do we want to support multiple instances of SharedMemory? # Maybe for running multiple distinct interpreted computations in parallel? @@ -570,8 +574,18 @@ def _deallocate_buffer(device_id, memory_space, buffer_id): shared_memory = _get_shared_memory() with shared_memory.lock: - # TODO(jburnim): Error if buffer doesn't exist? - shared_memory.mem.pop((memory_space, buffer_id, device_id), None) + buff = shared_memory.mem.pop((memory_space, buffer_id, device_id)) + shared_memory.deallocated_bytes += buff.size * buff.itemsize + del buff + + should_collect = shared_memory.deallocated_bytes > 100_000_000 + if should_collect: + shared_memory.deallocated_bytes = 0 + + if should_collect: + # Periodic garbage collection here prevents OOMs -- although it's not clear + # why arrays are not getting freed without this. + gc.collect() def _allocate_semaphores(device_id, shape): device_id = int(device_id) @@ -1067,6 +1081,21 @@ def write(var, value): ordered=True) elif prim is mosaic_primitives.delay_p: + # TODO(jburnim): Implement this properly? + out = [] + + elif prim is mosaic_primitives.prng_seed_p: + # TODO(jburnim): Implement this properly? + out = [] + + elif prim is mosaic_primitives.prng_random_bits_p: + # TODO(jburnim): Implement this properly? + out = jnp.zeros(eqn.params['shape'], jnp.int32) + + elif prim is verification.assume_p: + out = read(eqn.invars[0]) + + elif prim is verification.pretend_p: out = [] elif prim is lax.cond_p: @@ -1142,16 +1171,8 @@ def f(*args, jaxpr): out = _interpret(eqn.params['jaxpr'], *deferred_invals(), *allocs) - for a in allocs: - if isinstance(a, tuple): - callback.io_callback( - _deallocate_buffer, - None, - device_id, - TPU_MEMORY_SPACE_IDXS[v.aval.memory_space], - a, - ordered=True) - else: + for a, v in zip(allocs, eqn.params['jaxpr'].invars): + if v.aval.memory_space == mosaic_core.TPUMemorySpace.SEMAPHORE: # TODO(jburnim): De-allocate semaphores. # callback.io_callback( # _deallocate_semaphores, @@ -1160,6 +1181,14 @@ def f(*args, jaxpr): # a, # ordered=True) pass + else: + callback.io_callback( + _deallocate_buffer, + None, + device_id, + TPU_MEMORY_SPACE_IDXS[v.aval.memory_space], + a, + ordered=True) elif prim is state_primitives.get_p: invals = deferred_invals() From ed75189c921a66e1d5232923ff5b5cdbc23e766f Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 25 Mar 2025 14:47:39 -0700 Subject: [PATCH 161/483] [sharding_in_types] Add support for rng_bit_generator PiperOrigin-RevId: 740492876 --- jax/_src/internal_test_util/test_harnesses.py | 5 ++-- jax/_src/lax/control_flow/loops.py | 15 +++++++---- jax/_src/lax/lax.py | 26 ++++++++++++------- jax/experimental/jax2tf/jax2tf.py | 3 ++- tests/pjit_test.py | 21 +++++++++++++++ 5 files changed, 53 insertions(+), 17 deletions(-) diff --git a/jax/_src/internal_test_util/test_harnesses.py b/jax/_src/internal_test_util/test_harnesses.py index 48c645c4d033..02779c85977e 100644 --- a/jax/_src/internal_test_util/test_harnesses.py +++ b/jax/_src/internal_test_util/test_harnesses.py @@ -3375,8 +3375,9 @@ def _make_conv_harness(name, define( lax.rng_bit_generator_p, f"{key_dtype=}_shape={jtu.format_shape_dtype_string(shape, dtype)}_{algorithm=}", - lambda key, shape, dtype, algorithm: lax.rng_bit_generator(key, shape, dtype=dtype, - algorithm=algorithm), + lambda key, shape, dtype, algorithm, out_sharding=None: lax.rng_bit_generator( + key, shape, dtype=dtype, algorithm=algorithm, + out_sharding=out_sharding), [RandArg(key_shape, key_dtype), StaticArg(shape), StaticArg(dtype), StaticArg(algorithm)], shape=shape, diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 33e2d2cbb0c8..88af7c24e5b8 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -2264,17 +2264,22 @@ def map(f, xs): _, ys = scan(g, (), xs) return ys -def _rng_bit_generator_batching_rule(batched_args, batch_dims, *, shape, dtype, algorithm): +def _rng_bit_generator_batching_rule(batched_args, batch_dims, *, shape, dtype, + algorithm, out_sharding): keys, = batched_args bd, = batch_dims if bd is batching.not_mapped: - return lax.rng_bit_generator_p.bind(keys, shape=shape, dtype=dtype, - algorithm=algorithm), (None, None) + return lax.rng_bit_generator_p.bind( + keys, shape=shape, dtype=dtype, algorithm=algorithm, + out_sharding=out_sharding), (None, None) keys = batching.moveaxis(keys, bd, 0) batch_size = keys.shape[0] + out_s = (out_sharding.with_spec((keys.aval.sharding.spec[0], *out_sharding.spec)) + if out_sharding is not None else None) key = keys[0] - new_key, bits = lax.rng_bit_generator_p.bind(key, shape=(batch_size, *shape), - dtype=dtype, algorithm=algorithm) + new_key, bits = lax.rng_bit_generator_p.bind( + key, shape=(batch_size, *shape), dtype=dtype, algorithm=algorithm, + out_sharding=out_s) new_keys = slicing.dynamic_update_index_in_dim(keys, new_key, 0, axis=0) return (new_keys, bits), (0, 0) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 80a469ab6a11..dd6e7399321b 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -8053,15 +8053,20 @@ def _rng_uniform_lowering(ctx, a, b, *, shape): mlir.register_lowering(rng_uniform_p, _rng_uniform_lowering) -def _rng_bit_generator_shape_rule(key, *, shape, dtype, algorithm): +def _rng_bit_generator_shape_rule(key, *, shape, dtype, algorithm, out_sharding): del dtype, algorithm return (key.shape, tuple(shape)) -def _rng_bit_generator_dtype_rule(key, *, shape, dtype, algorithm): +def _rng_bit_generator_sharding_rule(key, *, shape, dtype, algorithm, + out_sharding): + return (key.sharding, out_sharding) + +def _rng_bit_generator_dtype_rule(key, *, shape, dtype, algorithm, out_sharding): del shape, algorithm return (key.dtype, dtype) -def _rng_bit_generator_weak_type_rule(key, *, shape, dtype, algorithm): +def _rng_bit_generator_weak_type_rule(key, *, shape, dtype, algorithm, + out_sharding): del shape, dtype, algorithm return (key.weak_type, False) @@ -8092,7 +8097,7 @@ def _rng_algorithm(algorithm: RandomAlgorithm): assert False def _rng_bit_generator_lowering( - ctx, key, *, shape, dtype, algorithm): + ctx, key, *, shape, dtype, algorithm, out_sharding): key_type = ir.RankedTensorType(key.type) key_shape, key_etype = key_type.shape, key_type.element_type # While the RngBitGenerator HLO accepts a u64[2] key on all backends, we @@ -8121,7 +8126,7 @@ def _rng_bit_generator_lowering( ir.RankedTensorType.get([2], u64_type), hlo.reshape(ir.RankedTensorType.get([2, 2], u32_type), key)) algorithm_attr = _rng_algorithm(algorithm) - _, out_vals_aval = ctx.avals_out + out_key_aval, out_vals_aval = ctx.avals_out if any(not core.is_constant_shape(a.shape) for a in ctx.avals_out): output_shape = mlir.shape_tensor( mlir.eval_dynamic_shape(ctx, out_vals_aval.shape)) @@ -8145,7 +8150,8 @@ def _rng_bit_generator_lowering( out_vals = hlo.convert( ir.RankedTensorType.get(ir.RankedTensorType(out_vals.type).shape, etype), out_vals) - return [out_key, out_vals] + return [mlir.lower_with_sharding_in_types(ctx, out_key, out_key_aval), + mlir.lower_with_sharding_in_types(ctx, out_vals, out_vals_aval)] rng_bit_generator_p = Primitive("rng_bit_generator") @@ -8155,7 +8161,7 @@ def _rng_bit_generator_lowering( rng_bit_generator_p.def_abstract_eval( partial(standard_multi_result_abstract_eval, rng_bit_generator_p, _rng_bit_generator_shape_rule, _rng_bit_generator_dtype_rule, - _rng_bit_generator_weak_type_rule, None)) + _rng_bit_generator_weak_type_rule, _rng_bit_generator_sharding_rule)) mlir.register_lowering(rng_bit_generator_p, _rng_bit_generator_lowering) @@ -8219,7 +8225,7 @@ def _propagate_mem_kind_copy(in_mem_kind): pxla.memory_kind_propagate_rule[copy_p] = _propagate_mem_kind_copy def rng_bit_generator(key, shape, dtype=np.uint32, - algorithm=RandomAlgorithm.RNG_DEFAULT): + algorithm=RandomAlgorithm.RNG_DEFAULT, out_sharding=None): """Stateless PRNG bit generator. Experimental and its use is discouraged. Returns uniformly distributed random bits with the specified shape and dtype @@ -8235,12 +8241,14 @@ def rng_bit_generator(key, shape, dtype=np.uint32, """ shape = core.canonicalize_shape(shape) dtype = dtypes.canonicalize_dtype(dtype) + out_sharding = canonicalize_sharding(out_sharding, 'rng_bit_generator') if np.dtype(dtype) not in {np.dtype('uint8'), np.dtype('uint16'), np.dtype('uint32'), np.dtype('uint64')}: raise TypeError(f'rng_bit_generator: unsupported dtype {dtype}') return tuple( rng_bit_generator_p.bind( - key, shape=shape, dtype=dtype, algorithm=algorithm)) + key, shape=shape, dtype=dtype, algorithm=algorithm, + out_sharding=out_sharding)) def _iota_abstract_eval(*dyn_shape, dtype, shape, dimension, sharding): diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 7f98ce433815..3d71af38388b 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -2822,7 +2822,8 @@ def _threefry2x32_jax_impl(*args: TfVal, _in_avals, _out_aval): multiple_results=False, extra_name_stack="random_gamma") -def _rng_bit_generator(key: TfVal, *, shape, dtype, algorithm) -> Sequence[TfVal]: +def _rng_bit_generator(key: TfVal, *, shape, dtype, algorithm, + out_sharding) -> Sequence[TfVal]: is_uint32_key = key.dtype == _to_tf_dtype(jnp.uint32) if is_uint32_key: key = tf.reshape(key, (2, 2)) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 2033126759e4..608c54994b5d 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7191,6 +7191,27 @@ def f(x): self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) self.assertArraysEqual(out, np.arange(8) * 2) + @jtu.with_user_mesh((2,), ('x',)) + def test_rng_bit_generator(self, mesh): + def f(key): + out = lax.rng_bit_generator(key, shape=(4, 8), out_sharding=P('x')) + self.assertEqual(out[0].aval.sharding.spec, P(None)) + self.assertEqual(out[1].aval.sharding.spec, P('x', None)) + return out + + key = np.array((1, 2, 3, 4)).astype(np.uint32) + out1 = f(key) + jit_f = jax.jit(f) + out2 = jit_f(key) + self.assertEqual(out1[0].shape, (4,)) + self.assertEqual(out1[1].shape, (4, 8)) + self.assertEqual(out2[0].sharding, NamedSharding(mesh, P())) + self.assertEqual(out2[1].sharding, NamedSharding(mesh, P('x', None))) + self.assertEqual(out1[0].sharding, out2[0].sharding) + self.assertEqual(out1[1].sharding, out2[1].sharding) + self.assertArraysEqual(out1[0], out2[0]) + self.assertArraysEqual(out1[1], out2[1]) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): From 289fa625e562c96ffaf466368a5d620e14d2659c Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 25 Mar 2025 15:28:32 -0700 Subject: [PATCH 162/483] [sharding_in_types] Add fold_in support PiperOrigin-RevId: 740505750 --- jax/_src/prng.py | 4 +++- tests/pjit_test.py | 14 ++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 5fdd673b3454..ead939d74351 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -621,7 +621,9 @@ def random_fold_in(keys, msgs): def random_fold_in_abstract_eval(keys_aval, msgs_aval): shape = lax_internal.broadcasting_shape_rule( 'random_fold_in', keys_aval, msgs_aval) - return core.ShapedArray(shape, keys_aval.dtype) + sharding = lax_internal.broadcasting_sharding_rule( + 'random_fold_in', keys_aval, msgs_aval) + return core.ShapedArray(shape, keys_aval.dtype, sharding=sharding) @random_fold_in_p.def_impl def random_fold_in_impl(keys, msgs): diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 608c54994b5d..ed1f9e9b62d8 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7212,6 +7212,20 @@ def f(key): self.assertArraysEqual(out1[0], out2[0]) self.assertArraysEqual(out1[1], out2[1]) + @jtu.with_user_mesh((2,), ('x',)) + def test_fold_in(self, mesh): + key = jax.random.key(72) + key = jax.device_put(key, NamedSharding(mesh, P())) + + @jax.jit + def f(key): + f1 = jax.random.fold_in(key, 1) + self.assertEqual(jax.random.key_data(f1).aval.sharding.spec, P(None)) + return f1 + + out = f(key) + self.assertEqual(out.sharding, NamedSharding(mesh, P())) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): From 588b6932d69aad36d312cbec336effda9735da43 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 25 Mar 2025 15:35:20 -0700 Subject: [PATCH 163/483] [JAX] [XLA:Python] Migrate more Python modules to JAX. PiperOrigin-RevId: 740507886 --- jaxlib/gpu/BUILD | 1 + jaxlib/jax.bzl | 1 + jaxlib/xla/BUILD | 176 ++++++++++++++++-- jaxlib/xla/callback.cc | 184 +++++++++++++++++++ jaxlib/xla/callback.h | 92 ++++++++++ jaxlib/xla/dlpack.cc | 2 +- jaxlib/xla/guard_lib.cc | 197 ++++++++++++++++++++ jaxlib/xla/guard_lib.h | 115 ++++++++++++ jaxlib/xla/pjit.cc | 2 +- jaxlib/xla/py_array.cc | 4 +- jaxlib/xla/py_client.cc | 6 +- jaxlib/xla/py_client_cpu.cc | 166 +++++++++++++++++ jaxlib/xla/py_client_cpu.h | 28 +++ jaxlib/xla/py_host_callback.cc | 290 ++++++++++++++++++++++++++++++ jaxlib/xla/py_host_callback.h | 170 ++++++++++++++++++ jaxlib/xla/py_host_callback.proto | 25 +++ jaxlib/xla/util.cc | 60 +++++++ jaxlib/xla/util.h | 31 ++++ jaxlib/xla/xla.cc | 2 +- 19 files changed, 1533 insertions(+), 19 deletions(-) create mode 100644 jaxlib/xla/callback.cc create mode 100644 jaxlib/xla/callback.h create mode 100644 jaxlib/xla/guard_lib.cc create mode 100644 jaxlib/xla/guard_lib.h create mode 100644 jaxlib/xla/py_client_cpu.cc create mode 100644 jaxlib/xla/py_client_cpu.h create mode 100644 jaxlib/xla/py_host_callback.cc create mode 100644 jaxlib/xla/py_host_callback.h create mode 100644 jaxlib/xla/py_host_callback.proto create mode 100644 jaxlib/xla/util.cc create mode 100644 jaxlib/xla/util.h diff --git a/jaxlib/gpu/BUILD b/jaxlib/gpu/BUILD index 3613be567533..59c0ab8dc164 100644 --- a/jaxlib/gpu/BUILD +++ b/jaxlib/gpu/BUILD @@ -84,6 +84,7 @@ proto_library( cc_proto_library( name = "triton_cc_proto", + compatible_with = None, deps = [":triton_proto"], ) diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 9b8c861404c2..560db85d6a1e 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -31,6 +31,7 @@ load("@xla//xla/tsl/platform:build_config_root.bzl", _tf_cuda_tests_tags = "tf_c cc_proto_library = _cc_proto_library cuda_library = _cuda_library rocm_library = _rocm_library +proto_library = native.proto_library nanobind_extension = _pybind_extension if_cuda_is_configured = _if_cuda_is_configured if_rocm_is_configured = _if_rocm_is_configured diff --git a/jaxlib/xla/BUILD b/jaxlib/xla/BUILD index 979e659a309f..e4db73d6c86d 100644 --- a/jaxlib/xla/BUILD +++ b/jaxlib/xla/BUILD @@ -18,12 +18,12 @@ load( "if_oss", "jax_visibility", "nanobind_extension", + "proto_library", "py_deps", "py_strict_library", "py_strict_test", "pytype_strict_library", ) -# Placeholder: load proto_library licenses(["notice"]) @@ -49,6 +49,7 @@ nanobind_extension( ":config", ":custom_call_sharding", ":dlpack", + ":guard_lib", ":ifrt_proxy", ":jax_jit", ":mlir", @@ -57,6 +58,7 @@ nanobind_extension( ":py_client", ":pytree", ":sdy", + ":util", ":weakref_lru_cache", ":xla_compiler", "@com_google_absl//absl/base", @@ -99,7 +101,6 @@ nanobind_extension( "@xla//xla/pjrt/distributed:service", "@xla//xla/pjrt/plugin/xla_cpu:cpu_client_options", "@xla//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client", - "@xla//xla/python:guard_lib", "@xla//xla/python:logging", "@xla//xla/python:nb_absl_flat_hash_map", "@xla//xla/python:nb_absl_span", @@ -111,7 +112,6 @@ nanobind_extension( "@xla//xla/python:refine_polymorphic_shapes", "@xla//xla/python:traceback", "@xla//xla/python:types", - "@xla//xla/python:util", "@xla//xla/python/ifrt", "@xla//xla/python/ifrt:plugin_program", "@xla//xla/python/ifrt:plugin_program_serdes", @@ -147,6 +147,41 @@ nanobind_extension( }), ) +cc_library( + name = "callback", + srcs = [ + "callback.cc", + ], + hdrs = [ + "callback.h", + ], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/base", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@nanobind", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla:comparison_util", + "@xla//xla:xla_data_proto_cc", + "@xla//xla/pjrt:host_callback", + "@xla//xla/pjrt:transpose", + "@xla//xla/python:nb_numpy", + "@xla//xla/python:python_ref_manager", + "@xla//xla/service:custom_call_status", + "@xla//xla/tsl/platform:statusor", + ], +) + cc_library( name = "config", srcs = ["config.cc"], @@ -212,6 +247,7 @@ cc_library( features = ["-use_header_modules"], deps = [ ":py_client", + ":util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -233,7 +269,6 @@ cc_library( "@xla//xla/python:python_ref_manager", "@xla//xla/python:traceback", "@xla//xla/python:types", - "@xla//xla/python:util", "@xla//xla/python/ifrt", "@xla//xla/python/pjrt_ifrt", "@xla//xla/tsl/platform:errors", @@ -242,6 +277,26 @@ cc_library( ], ) +cc_library( + name = "guard_lib", + srcs = ["guard_lib.cc"], + hdrs = ["guard_lib.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@nanobind", + "@xla//xla:util", + ], +) + cc_library( name = "ifrt_proxy", srcs = ["ifrt_proxy.cc"], @@ -361,6 +416,7 @@ cc_library( features = ["-use_header_modules"], deps = [ ":config", + ":guard_lib", ":jax_jit", ":py_client", ":pytree", @@ -382,7 +438,6 @@ cc_library( "@xla//xla:util", "@xla//xla/pjrt:exceptions", "@xla//xla/pjrt:lru_cache", - "@xla//xla/python:guard_lib", "@xla//xla/python:nb_class_ptr", "@xla//xla/python:nb_helpers", "@xla//xla/python:nb_numpy", @@ -482,6 +537,12 @@ cc_library( features = ["-use_header_modules"], visibility = jax_visibility("jaxlib/xla/py_client"), deps = [ + ":callback", + ":guard_lib", + ":py_client_cpu", + ":py_host_callback", + ":py_host_callback_cc_proto", + ":util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", "@com_google_absl//absl/cleanup", @@ -541,20 +602,14 @@ cc_library( "@xla//xla/pjrt/distributed", "@xla//xla/pjrt/distributed:client", "@xla//xla/python:aggregate_profile", - "@xla//xla/python:callback", - "@xla//xla/python:guard_lib", "@xla//xla/python:nb_absl_span", "@xla//xla/python:nb_class_ptr", "@xla//xla/python:nb_helpers", "@xla//xla/python:nb_numpy", "@xla//xla/python:pprof_profile_builder", - "@xla//xla/python:py_client_cpu", - "@xla//xla/python:py_host_callback", - "@xla//xla/python:py_host_callback_proto_cc", "@xla//xla/python:python_ref_manager", "@xla//xla/python:traceback", "@xla//xla/python:types", - "@xla//xla/python:util", "@xla//xla/python:xplane_to_profile_instructions", "@xla//xla/python/compile_only_ifrt:client", "@xla//xla/python/ifrt", @@ -588,6 +643,86 @@ cc_library( ], ) +cc_library( + name = "py_client_cpu", + srcs = ["py_client_cpu.cc"], + hdrs = ["py_client_cpu.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@nanobind", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla:shape_util", + "@xla//xla/ffi:ffi_api", + "@xla//xla/ffi/api:ffi", + "@xla//xla/pjrt:host_callback", + "@xla//xla/pjrt:transpose", + "@xla//xla/python:nb_numpy", + "@xla//xla/python:types", + ], + alwayslink = 1, +) + +cc_library( + name = "py_host_callback", + srcs = ["py_host_callback.cc"], + hdrs = ["py_host_callback.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":callback", + ":py_host_callback_cc_proto", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@nanobind", + "@xla//xla:shape_util", + "@xla//xla:status_macros", + "@xla//xla:util", + "@xla//xla:xla_data_proto_cc", + "@xla//xla/pjrt:host_callback", + "@xla//xla/pjrt:pjrt_compiler", + "@xla//xla/python:python_ref_manager", + "@xla//xla/python:types", + "@xla//xla/python/ifrt", + "@xla//xla/python/pjrt_ifrt", + "@xla//xla/python/pjrt_ifrt:xla_host_callback_proto_cc", + "@xla//xla/tsl/concurrency:ref_count", + "@xla//xla/tsl/platform:statusor", + ], +) + +proto_library( + name = "py_host_callback_proto", + srcs = ["py_host_callback.proto"], +) + +cc_proto_library( + name = "py_host_callback_cc_proto", + visibility = jax_visibility("jaxlib/xla/py_host_callback_cc_proto"), + deps = [":py_host_callback_proto"], +) + cc_library( name = "py_socket_transfer", srcs = ["py_socket_transfer.cc"], @@ -697,6 +832,25 @@ cc_library( ], ) +cc_library( + name = "util", + srcs = ["util.cc"], + hdrs = ["util.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:span", + "@xla//xla:util", + "@xla//xla/python/ifrt", + "@xla//xla/tsl/concurrency:ref_count", + ], +) + cc_library( name = "weakref_lru_cache", srcs = ["weakref_lru_cache.cc"], diff --git a/jaxlib/xla/callback.cc b/jaxlib/xla/callback.cc new file mode 100644 index 000000000000..4eab8290c7bb --- /dev/null +++ b/jaxlib/xla/callback.cc @@ -0,0 +1,184 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/callback.h" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/casts.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "xla/pjrt/host_callback.h" +#include "xla/pjrt/transpose.h" +#include "xla/primitive_util.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/python_ref_manager.h" +#include "xla/service/custom_call_status.h" +#include "xla/tsl/platform/statusor.h" + +namespace nb = nanobind; + +namespace xla { + +CpuCallback::~CpuCallback() { + // The destructor may be called without GIL held. In that case, we defer it + // to GlobalPyRefManager. + std::vector objects; + objects.push_back(std::move(callable_)); + for (auto& arg : args_) { + objects.push_back(std::move(arg.dtype)); + } + + GlobalPyRefManager()->AddGarbage(absl::MakeSpan(objects)); +} + +absl::Status CpuCallback::PrepareAndCall(void* result, void** arg_ptrs) { + absl::Span inputs(arg_ptrs, args_.size()); + absl::Span outputs(reinterpret_cast(result), + results_.size()); + + nb::gil_scoped_acquire gil; + nb::tuple args = nb::steal(PyTuple_New(inputs.size())); + for (size_t i = 0; i < inputs.size(); ++i) { + if (args_[i].type == xla::TOKEN) { + PyTuple_SET_ITEM(args.ptr(), i, nb::none().release().ptr()); + } else { + nb_numpy_ndarray array = + nb_numpy_ndarray(args_[i].dtype, args_[i].dims, args_[i].strides, + const_cast(inputs[i])); + array.attr("flags").attr("writeable") = nb::bool_(false); + PyTuple_SET_ITEM(args.ptr(), i, array.release().ptr()); + } + } + + EnterHostCallback(); + absl::StatusOr maybe_result_tuple = Call(std::move(args)); + LeaveHostCallback(); + TF_ASSIGN_OR_RETURN(auto result_tuple, maybe_result_tuple); + + for (size_t i = 0; i < results_.size(); ++i) { + if (results_[i].type == xla::TOKEN) { + continue; + } + nb::object output = + nb::borrow(PyTuple_GetItem(result_tuple.ptr(), i)); + nb_numpy_ndarray array = nb_numpy_ndarray::ensure(std::move(output)); + absl::Span dims( + reinterpret_cast(array.shape()), array.ndim()); + absl::Span strides( + reinterpret_cast(array.strides()), array.ndim()); + if (strides == results_[i].expected_strides) { + std::memcpy(outputs[i], array.data(), results_[i].size_in_bytes); + } else { + xla::TransposePlan::Options options; + options.elem_size_in_bytes = + xla::primitive_util::ByteWidth(results_[i].type); + options.dims = dims; + options.permutation = results_[i].reversed_layout; + options.input_layout = xla::TransposePlan::Striding{strides}; + absl::StatusOr> plan = + transpose_cache_.GetOrCreate(options); + if (!plan.ok()) { + return std::move(plan).status(); + } + plan.value()->Execute(array.data(), outputs[i]); + } + } + + return absl::OkStatus(); +} + +absl::StatusOr CpuCallback::Call(nb::tuple args) { + auto py_error_to_status = [](nb::python_error& e) { + std::string error_message = e.what(); + return absl::InternalError( + absl::StrFormat("CpuCallback error: %s", error_message)); + }; + nb::object result_object; + try { + result_object = callable_(*nb::borrow(args)); + } catch (nb::python_error& e) { + return py_error_to_status(e); + } + if (!PyTuple_Check(result_object.ptr())) { + return absl::InternalError( + absl::StrFormat("CPU callback expected a tuple result, got %s", + nb::cast(nb::repr(result_object)))); + } + if (PyTuple_Size(result_object.ptr()) != results_.size()) { + return absl::InternalError( + absl::StrFormat("CPU callback expected a tuple with %d results, got %d", + results_.size(), PyTuple_Size(result_object.ptr()))); + } + nb::tuple result_tuple = nb::cast(result_object); + for (size_t i = 0; i < results_.size(); ++i) { + nb::object output = + nb::borrow(PyTuple_GetItem(result_tuple.ptr(), i)); + if (results_[i].type == xla::TOKEN) { + if (!output.is_none()) { + return absl::InternalError(absl::StrFormat( + "Token output from Python callback should be None, got %s", + nb::cast(nb::repr(output)))); + } + continue; + } + nb_numpy_ndarray array; + try { + array = nb_numpy_ndarray::from_any(output, NPY_ARRAY_ENSUREARRAY); + } catch (nb::python_error& e) { + return py_error_to_status(e); + } + static_assert(sizeof(ssize_t) == sizeof(int64_t), + "Expected ssize_t to be of equal size to int64_t"); + absl::Span dims( + reinterpret_cast(array.shape()), array.ndim()); + if (dims != results_[i].expected_dims) { + return absl::InternalError(absl::StrFormat( + "Mismatched result shape for %d-th return value from CPU callback; " + "expected array with dimensions %s, got %s", + i, absl::StrJoin(results_[i].expected_dims, ","), + absl::StrJoin(dims, ","))); + } + } + return result_tuple; +} + +void XlaPythonCpuCallback(void* output, void** inputs, + XlaCustomCallStatus* status) { + CpuCallback* callback = + absl::bit_cast(*static_cast(inputs[0])); + auto s = callback->PrepareAndCall(output, inputs + 1); + if (!s.ok()) { + auto msg = s.message(); + XlaCustomCallStatusSetFailure(status, msg.data(), msg.length()); + } +} + +} // namespace xla diff --git a/jaxlib/xla/callback.h b/jaxlib/xla/callback.h new file mode 100644 index 000000000000..b63025efe120 --- /dev/null +++ b/jaxlib/xla/callback.h @@ -0,0 +1,92 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_CALLBACK_H_ +#define JAXLIB_XLA_CALLBACK_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "nanobind/nanobind.h" +#include "xla/pjrt/transpose.h" +#include "xla/python/nb_numpy.h" +#include "xla/service/custom_call_status.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +class CpuCallback { + public: + struct Arg { + xla::PrimitiveType type; // XLA type + nb_dtype dtype; // NumPy type, for array types. + absl::InlinedVector dims; // Dimensions, for array types. + std::vector strides; // Byte strides, for array types. + size_t size_in_bytes; // Size of the array in bytes. + }; + struct Result { + xla::PrimitiveType type; // XLA type + // Expected output shape, for array types + absl::InlinedVector expected_dims; + // Expected output byte strides, for array types. If the strides do not + // match the output will be transposed into the expected layout. + std::vector expected_strides; + // The desired order of output dimensions in major-to-minor order. + absl::InlinedVector reversed_layout; + // Size of the array in bytes. + size_t size_in_bytes; + }; + + explicit CpuCallback(nanobind::callable callable, std::vector args, + std::vector results) + : callable_(std::move(callable)), + args_(std::move(args)), + results_(std::move(results)), + transpose_cache_(/*capacity=*/16) {} + + ~CpuCallback(); + + const std::vector& args() const { return args_; } + size_t num_args() const { return args_.size(); } + + const std::vector& results() const { return results_; } + size_t num_results() const { return results_.size(); } + void* callback() const { return callable_.ptr(); } + + xla::TransposePlanCache& transpose_cache() { return transpose_cache_; } + + absl::Status PrepareAndCall(void* result, void** arg_ptrs); + + absl::StatusOr Call(nanobind::tuple args); + + private: + nanobind::callable callable_; + std::vector args_; + std::vector results_; + xla::TransposePlanCache transpose_cache_; +}; + +void XlaPythonCpuCallback(void* output, void** inputs, + XlaCustomCallStatus* status); + +} // namespace xla + +#endif // JAXLIB_XLA_CALLBACK_H_ diff --git a/jaxlib/xla/dlpack.cc b/jaxlib/xla/dlpack.cc index 8b29e136f296..94d57e07c34a 100644 --- a/jaxlib/xla/dlpack.cc +++ b/jaxlib/xla/dlpack.cc @@ -36,6 +36,7 @@ limitations under the License. #include "nanobind/ndarray.h" #include "jaxlib/xla/py_array.h" #include "jaxlib/xla/py_client.h" +#include "jaxlib/xla/util.h" #include "xla/layout.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/pjrt_client.h" @@ -51,7 +52,6 @@ limitations under the License. #include "xla/python/python_ref_manager.h" #include "xla/python/traceback.h" #include "xla/python/types.h" -#include "xla/python/util.h" #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/tsl/platform/errors.h" diff --git a/jaxlib/xla/guard_lib.cc b/jaxlib/xla/guard_lib.cc new file mode 100644 index 000000000000..77866741819c --- /dev/null +++ b/jaxlib/xla/guard_lib.cc @@ -0,0 +1,197 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This files implements the configuration management for different types of +// guards. +// C++ backends are responsible for enforcing transfer guard levels. + +#include "jaxlib/xla/guard_lib.h" + +#include +#include + +#include "absl/base/attributes.h" +#include "absl/functional/function_ref.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "xla/util.h" + +namespace jax { + +namespace nb = ::nanobind; + +namespace { + +// Protected by the GIL. +GuardState& global_state = *new GuardState(); + +ABSL_CONST_INIT thread_local GuardState thread_local_state; + +// The default transfer guard level. +constexpr TransferGuardLevel kDefaultGuardLevel = TransferGuardLevel::kAllow; + +// The default garbage collection guard level. +constexpr GarbageCollectionGuardLevel kDefaultGarbageCollectionGuardLevel = + GarbageCollectionGuardLevel::kAllow; + +// Returns the transfer guard action for a transfer. +TransferGuardAction GetTransferGuardAction(TransferGuardLevel guard_level, + bool explicit_transfer) { + switch (guard_level) { + case TransferGuardLevel::kAllow: + return TransferGuardAction::kAllow; + case TransferGuardLevel::kLog: + if (explicit_transfer) { + return TransferGuardAction::kAllow; + } else { + return TransferGuardAction::kLog; + } + case TransferGuardLevel::kDisallow: + if (explicit_transfer) { + return TransferGuardAction::kAllow; + } else { + return TransferGuardAction::kDisallow; + } + case TransferGuardLevel::kLogExplicit: + return TransferGuardAction::kLog; + case TransferGuardLevel::kDisallowExplicit: + return TransferGuardAction::kDisallow; + default: + // Unreachable; gracefully handle the unexpected guard level and prevent a + // compiler warning. + return TransferGuardAction::kDisallow; + } +} + +// Returns the transfer guard action for a host-to-device transfer. +// REQUIRES: Python GIL. +TransferGuardAction GetTransferGuardActionForHostToDevice() { + return GetTransferGuardAction( + thread_local_state.host_to_device.value_or( + global_state.host_to_device.value_or(kDefaultGuardLevel)), + thread_local_state.explicit_device_put); +} + +// Returns the transfer guard action for a device-to-device transfer. +// REQUIRES: Python GIL. +TransferGuardAction GetTransferGuardActionForDeviceToDevice() { + return GetTransferGuardAction( + thread_local_state.device_to_device.value_or( + global_state.device_to_device.value_or(kDefaultGuardLevel)), + thread_local_state.explicit_device_put); +} + +// Returns the transfer guard action for a device-to-host transfer. +// REQUIRES: Python GIL. +TransferGuardAction GetTransferGuardActionForDeviceToHost() { + return GetTransferGuardAction( + thread_local_state.device_to_host.value_or( + global_state.device_to_host.value_or(kDefaultGuardLevel)), + thread_local_state.explicit_device_get); +} + +} // namespace + +absl::Status ApplyTransferGuardToHostToDevice( + absl::FunctionRef formatter) { + switch (GetTransferGuardActionForHostToDevice()) { + case TransferGuardAction::kAllow: + break; + case TransferGuardAction::kLog: + LOG(WARNING) << "host-to-device transfer: " << formatter(); + break; + case TransferGuardAction::kDisallow: + return xla::InvalidArgument("Disallowed host-to-device transfer: %s", + formatter()); + } + return absl::OkStatus(); +} + +absl::Status ApplyTransferGuardToDeviceToDevice( + absl::FunctionRef formatter) { + switch (GetTransferGuardActionForDeviceToDevice()) { + case TransferGuardAction::kAllow: + break; + case TransferGuardAction::kLog: + LOG(WARNING) << "device-to-device transfer: " << formatter(); + break; + case TransferGuardAction::kDisallow: + return xla::InvalidArgument("Disallowed device-to-device transfer: %s", + formatter()); + } + return absl::OkStatus(); +} + +absl::Status ApplyTransferGuardToDeviceToHost( + absl::FunctionRef formatter) { + switch (GetTransferGuardActionForDeviceToHost()) { + case TransferGuardAction::kAllow: + break; + case TransferGuardAction::kLog: + LOG(WARNING) << "device-to-host transfer: " << formatter(); + break; + case TransferGuardAction::kDisallow: + return xla::InvalidArgument("Disallowed device-to-host transfer: %s", + formatter()); + } + return absl::OkStatus(); +} + +GarbageCollectionGuardLevel GetGarbageCollectArrayGuard() { + return thread_local_state.garbage_collect_array.value_or( + global_state.garbage_collect_array.value_or( + kDefaultGarbageCollectionGuardLevel)); +} + +void BuildGuardSubmodule(nb::module_& m) { + nb::module_ glib = + m.def_submodule("guard_lib", "Jax support library for guards"); + + nb::enum_ tglevel(glib, "TransferGuardLevel"); + tglevel.value("ALLOW", TransferGuardLevel::kAllow); + tglevel.value("LOG", TransferGuardLevel::kLog); + tglevel.value("DISALLOW", TransferGuardLevel::kDisallow); + tglevel.value("LOG_EXPLICIT", TransferGuardLevel::kLogExplicit); + tglevel.value("DISALLOW_EXPLICIT", TransferGuardLevel::kDisallowExplicit); + + nb::enum_ gcglevel( + glib, "GarbageCollectionGuardLevel"); + gcglevel.value("ALLOW", GarbageCollectionGuardLevel::kAllow); + gcglevel.value("LOG", GarbageCollectionGuardLevel::kLog); + gcglevel.value("FATAL", GarbageCollectionGuardLevel::kFatal); + + nb::class_ tgstate(glib, "GuardState"); + tgstate.def_rw("host_to_device", &GuardState::host_to_device, + nb::arg().none()); + tgstate.def_rw("device_to_device", &GuardState::device_to_device, + nb::arg().none()); + tgstate.def_rw("device_to_host", &GuardState::device_to_host, + nb::arg().none()); + tgstate.def_rw("explicit_device_put", &GuardState::explicit_device_put); + tgstate.def_rw("explicit_device_get", &GuardState::explicit_device_get); + tgstate.def_rw("garbage_collect_array", &GuardState::garbage_collect_array, + nb::arg().none()); + + glib.def( + "global_state", [&]() { return &global_state; }, + nb::rv_policy::reference); + glib.def( + "thread_local_state", [&]() { return &thread_local_state; }, + nb::rv_policy::reference); +} + +} // namespace jax diff --git a/jaxlib/xla/guard_lib.h b/jaxlib/xla/guard_lib.h new file mode 100644 index 000000000000..8ddf6e8e892e --- /dev/null +++ b/jaxlib/xla/guard_lib.h @@ -0,0 +1,115 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_GUARD_LIB_H_ +#define JAXLIB_XLA_GUARD_LIB_H_ + +#include +#include + +// placeholder for index annotation headers +#include "absl/functional/function_ref.h" +#include "absl/status/status.h" +#include "nanobind/nanobind.h" + +namespace jax { + +// Transfer guard level chosen by the user code. +enum class TransferGuardLevel { + // Explicit transfers: allow + // Implicit transfers: allow + kAllow, + // Explicit transfers: allow + // Implicit transfers: log + kLog, + // Explicit transfers: allow + // Implicit transfers: disallow + kDisallow, + // Explicit transfers: log + // Implicit transfers: log + kLogExplicit, + // Explicit transfers: disallow + // Implicit transfers: disallow + kDisallowExplicit, +}; + +// Garbage collection guard level chose by the user code. +enum class GarbageCollectionGuardLevel { + // Silently allow the object to be garbage collected. + kAllow, + // Log and allow the object to be garbage collected. + kLog, + // Fatal crash on object garbage collection. + kFatal, +}; + +// Flags for guard levels are controlled by: +// - a global flag value, +// e.g., associated to --jax_transfer_guard_device_to_host +// which defaults to TransferGuardLevel::kAllow. +// - possibly a thread-local value, which initially is std::nullopt and +// overrides the global value if set. The thread-local state is used to +// implement context managers that locally override the global state. +// +// Explicit device_put/device_get contexts are tracked by context managers. +struct GuardState { + std::optional host_to_device; + std::optional device_to_device; + std::optional device_to_host; + bool explicit_device_put = false; + bool explicit_device_get = false; + + std::optional garbage_collect_array; +}; + +// Resulting action for a transfer given the transfer guard level and the +// transfer type. +enum class TransferGuardAction { + // Silently allow the transfer. + kAllow, + // Log and allow the transfer. + kLog, + // Disallow the transfer. + kDisallow, +}; + +// Guards a host-to-device transfer. formatter is called to describe the +// transfer in a log message or error status. +// REQUIRES: Python GIL. +absl::Status ApplyTransferGuardToHostToDevice( + absl::FunctionRef formatter); + +// Guards a device-to-device transfer. formatter is called to describe the +// transfer in a log message or error status. +// REQUIRES: Python GIL. +absl::Status ApplyTransferGuardToDeviceToDevice( + absl::FunctionRef formatter); + +// Guards a device-to-host transfer. formatter is called to describe the +// transfer in a log message or error status. +// REQUIRES: Python GIL. +absl::Status ApplyTransferGuardToDeviceToHost( + absl::FunctionRef formatter); + +// Returns the garbage collection guard level for "jax.Array" objects. +// REQUIRES: Python GIL. +GarbageCollectionGuardLevel GetGarbageCollectArrayGuard(); + +// The function to call in `xla.cc` to add the bindings for this module. +void BuildGuardSubmodule(nanobind::module_& m); + +} // namespace jax + +#endif // JAXLIB_XLA_GUARD_LIB_H_ diff --git a/jaxlib/xla/pjit.cc b/jaxlib/xla/pjit.cc index 6681f72b7b49..0409397c82de 100644 --- a/jaxlib/xla/pjit.cc +++ b/jaxlib/xla/pjit.cc @@ -51,6 +51,7 @@ limitations under the License. #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep #include "jaxlib/xla/config.h" +#include "jaxlib/xla/guard_lib.h" #include "jaxlib/xla/jax_jit.h" #include "jaxlib/xla/py_array.h" #include "jaxlib/xla/py_executable.h" @@ -60,7 +61,6 @@ limitations under the License. #include "xla/layout.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/lru_cache.h" -#include "xla/python/guard_lib.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/device_list.h" diff --git a/jaxlib/xla/py_array.cc b/jaxlib/xla/py_array.cc index 305582a987f7..a348b47454e7 100644 --- a/jaxlib/xla/py_array.cc +++ b/jaxlib/xla/py_array.cc @@ -57,12 +57,14 @@ limitations under the License. #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/guard_lib.h" #include "jaxlib/xla/py_client.h" #include "jaxlib/xla/py_device.h" #include "jaxlib/xla/py_device_list.h" #include "jaxlib/xla/py_values.h" #include "jaxlib/xla/sharding.h" #include "jaxlib/xla/to_ifrt_sharding.h" +#include "jaxlib/xla/util.h" #include "xla/layout.h" #include "xla/layout_util.h" #include "xla/pjrt/exceptions.h" @@ -73,7 +75,6 @@ limitations under the License. #include "xla/pjrt/pjrt_layout.h" #include "xla/pjrt/status_casters.h" #include "xla/primitive_util.h" -#include "xla/python/guard_lib.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/array_spec.h" #include "xla/python/ifrt/device.h" @@ -95,7 +96,6 @@ limitations under the License. #include "xla/python/python_ref_manager.h" #include "xla/python/traceback.h" #include "xla/python/types.h" -#include "xla/python/util.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" diff --git a/jaxlib/xla/py_client.cc b/jaxlib/xla/py_client.cc index 5fe6bc648e07..b74c37f28863 100644 --- a/jaxlib/xla/py_client.cc +++ b/jaxlib/xla/py_client.cc @@ -48,9 +48,12 @@ limitations under the License. #include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep #include "nanobind/stl/variant.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/callback.h" +#include "jaxlib/xla/guard_lib.h" #include "jaxlib/xla/py_array.h" #include "jaxlib/xla/py_device.h" #include "jaxlib/xla/py_executable.h" +#include "jaxlib/xla/py_host_callback.h" #include "jaxlib/xla/py_memory_space.h" #include "jaxlib/xla/py_values.h" #include "xla/literal.h" @@ -61,8 +64,6 @@ limitations under the License. #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/pjrt_layout.h" #include "xla/pjrt/status_casters.h" -#include "xla/python/callback.h" -#include "xla/python/guard_lib.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/compiler.h" #include "xla/python/ifrt/device.h" @@ -79,7 +80,6 @@ limitations under the License. #include "xla/python/pjrt_ifrt/pjrt_client.h" #include "xla/python/pjrt_ifrt/xla_compiler.h" #include "xla/python/pprof_profile_builder.h" -#include "xla/python/py_host_callback.h" #include "xla/python/python_ref_manager.h" #include "xla/python/traceback.h" #include "xla/python/types.h" diff --git a/jaxlib/xla/py_client_cpu.cc b/jaxlib/xla/py_client_cpu.cc new file mode 100644 index 000000000000..936a89aa3b42 --- /dev/null +++ b/jaxlib/xla/py_client_cpu.cc @@ -0,0 +1,166 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/py_client_cpu.h" + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "xla/ffi/api/ffi.h" +#include "xla/ffi/ffi_api.h" +#include "xla/pjrt/host_callback.h" +#include "xla/pjrt/transpose.h" +#include "xla/primitive_util.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/types.h" +#include "xla/shape_util.h" + +namespace nb = nanobind; + +namespace xla { + +struct CpuTransposePlanCache { + static ffi::TypeId id; + explicit CpuTransposePlanCache(int capacity) : cache(capacity) {} + xla::TransposePlanCache cache; +}; + +ffi::TypeId CpuTransposePlanCache::id = {}; + +XLA_FFI_REGISTER_TYPE(ffi::GetXlaFfiApi(), "CpuTransposePlanCache", + &CpuTransposePlanCache::id); + +static ffi::ErrorOr> +CpuTransposePlanCacheInstantiate(uint64_t index) { + return std::make_unique(/*capacity=*/16); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + kCpuTransposePlanCacheInstantiate, CpuTransposePlanCacheInstantiate, + ffi::Ffi::BindInstantiate().Attr("index")); + +ffi::Error XlaFfiPythonCpuCallback(FfiLoadedHostCallbacks* callbacks, + CpuTransposePlanCache* transpose_cache, + uint64_t index, ffi::RemainingArgs args, + ffi::RemainingRets rets) { + nb::gil_scoped_acquire gil; + auto callback = nb::borrow( + static_cast(callbacks->callbacks[index])); + auto nb_args = nb::steal(PyTuple_New(args.size())); + for (size_t i = 0; i < args.size(); ++i) { + auto arg = args.get(i); + auto ptype = static_cast(arg->element_type()); + if (ptype == TOKEN) { + PyTuple_SET_ITEM(nb_args.ptr(), i, nb::none().release().ptr()); + continue; + } + auto maybe_dtype = PrimitiveTypeToNbDtype(ptype); + if (!maybe_dtype.ok()) { + return ffi::Error::Internal(maybe_dtype.status().ToString()); + } + auto dtype = maybe_dtype.value(); + auto dims = absl::Span(arg->dimensions().begin(), + arg->dimensions().size()); + // We pass in data using default numpy layout i.e., std::nullopt. + auto array = + nb_numpy_ndarray(dtype, dims, std::nullopt, arg.value().untyped_data()); + array.attr("flags").attr("writeable") = nb::bool_(false); + PyTuple_SET_ITEM(nb_args.ptr(), i, array.release().ptr()); + } + + EnterHostCallback(); + // TODO(dsuo): Change this to use the Python vectorcall protocol, which allows + // you to avoid constructing a tuple for the arguments. + nb::tuple result_tuple; + try { + auto result_object = callback(*nb::borrow(nb_args)); + result_tuple = nb::cast(result_object); + } catch (nb::python_error& e) { + return ffi::Error::Internal( + absl::StrFormat("CpuCallback error calling callback: %s", e.what())); + } + LeaveHostCallback(); + + for (size_t i = 0; i < rets.size(); ++i) { + auto ret = rets.get(i).value(); + auto ptype = static_cast(ret->element_type()); + if (ptype == TOKEN) continue; + nb::object output = + nb::borrow(PyTuple_GetItem(result_tuple.ptr(), i)); + nb_numpy_ndarray array = nb_numpy_ndarray::ensure(std::move(output)); + absl::Span strides( + reinterpret_cast(array.strides()), array.ndim()); + // We expect the output to be in default numpy layout. + auto dims = absl::Span(ret->dimensions().begin(), + ret->dimensions().size()); + auto maybe_expected_shape = ShapeUtil::MakeValidatedShape(ptype, dims); + if (!maybe_expected_shape.ok()) { + return ffi::Error::Internal(maybe_expected_shape.status().ToString()); + } + auto expected_shape = maybe_expected_shape.value(); + auto expected_strides = ByteStridesForShape(expected_shape); + if (strides == expected_strides) { + std::memcpy(ret->untyped_data(), array.data(), ret->size_bytes()); + continue; + } + xla::TransposePlan::Options options; + options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); + options.dims = absl::Span( + reinterpret_cast(array.shape()), array.ndim()); + absl::InlinedVector reversed_layout; + reversed_layout.resize(expected_shape.dimensions_size()); + absl::c_reverse_copy(expected_shape.layout().minor_to_major(), + reversed_layout.begin()); + options.permutation = reversed_layout; + options.input_layout = xla::TransposePlan::Striding{strides}; + auto maybe_plan = transpose_cache->cache.GetOrCreate(options); + if (!maybe_plan.ok()) { + return ffi::Error::Internal(maybe_plan.status().ToString()); + } + auto plan = maybe_plan.value(); + plan->Execute(array.data(), ret->untyped_data()); + } + + return ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(kXlaFfiPythonCpuCallback, XlaFfiPythonCpuCallback, + ffi::Ffi::Bind() + .Ctx>() + .Ctx>() + .Attr("index") + .RemainingArgs() + .RemainingRets()); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "xla_ffi_python_cpu_callback", + "HOST", + {kCpuTransposePlanCacheInstantiate, nullptr, nullptr, + kXlaFfiPythonCpuCallback}); +} // namespace xla diff --git a/jaxlib/xla/py_client_cpu.h b/jaxlib/xla/py_client_cpu.h new file mode 100644 index 000000000000..0035b0a361fa --- /dev/null +++ b/jaxlib/xla/py_client_cpu.h @@ -0,0 +1,28 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_PY_CLIENT_CPU_H_ +#define JAXLIB_XLA_PY_CLIENT_CPU_H_ + +#include "xla/ffi/api/ffi.h" + +namespace xla { + +XLA_FFI_DECLARE_HANDLER_SYMBOL(kCpuTransposePlanCacheInstantiate); +XLA_FFI_DECLARE_HANDLER_SYMBOL(kXlaFfiPythonCpuCallback); + +} // namespace xla + +#endif // JAXLIB_XLA_PY_CLIENT_CPU_H_ diff --git a/jaxlib/xla/py_host_callback.cc b/jaxlib/xla/py_host_callback.cc new file mode 100644 index 000000000000..9d759cc6b77c --- /dev/null +++ b/jaxlib/xla/py_host_callback.cc @@ -0,0 +1,290 @@ +/* Copyright 2023 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/py_host_callback.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "llvm/Support/ExtensibleRTTI.h" +#include "nanobind/nanobind.h" +#include "jaxlib/xla/callback.h" +#include "jaxlib/xla/py_host_callback.pb.h" +#include "xla/layout_util.h" +#include "xla/pjrt/host_callback.h" +#include "xla/pjrt/pjrt_compiler.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/host_callback.h" +#include "xla/python/pjrt_ifrt/pjrt_host_callback.h" +#include "xla/python/pjrt_ifrt/xla_host_callback.pb.h" +#include "xla/python/python_ref_manager.h" +#include "xla/python/types.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status_macros.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace nb = nanobind; + +namespace xla { + +char PyFfiLoadedHostCallback::ID = 0; +char PyCpuLoadedHostCallback::ID = 0; +char PyHostSendAndRecvLoadedHostCallback::ID = 0; + +namespace { + +absl::StatusOr> CreateCallbackArgs( + absl::Span operand_shapes) { + std::vector callback_args(operand_shapes.size()); + for (int i = 0; i < operand_shapes.size(); ++i) { + Shape shape = operand_shapes[i]; + + if (shape.IsArray()) { + Shape layout = + (shape.has_layout() ? shape + : LayoutUtil::GetWithDefaultLayout(shape)); + callback_args[i].dims.resize(shape.dimensions_size()); + absl::c_copy(shape.dimensions(), callback_args[i].dims.begin()); + callback_args[i].strides = ByteStridesForShape(layout); + callback_args[i].type = shape.element_type(); + callback_args[i].size_in_bytes = ShapeUtil::ByteSizeOf(layout); + TF_ASSIGN_OR_RETURN(callback_args[i].dtype, + PrimitiveTypeToNbDtype(shape.element_type())); + } else if (shape.IsToken()) { + callback_args[i].type = TOKEN; + } else { + return InvalidArgument( + "Only array and token arguments to Python callbacks are supported, " + "got %s", + shape.ToString()); + } + } + return callback_args; +} + +absl::StatusOr> CreateCallbackResults( + absl::Span result_shapes) { + std::vector callback_results(result_shapes.size()); + for (int i = 0; i < result_shapes.size(); ++i) { + if (result_shapes[i].IsArray()) { + const Shape& shape = + result_shapes[i].has_layout() + ? result_shapes[i] + : LayoutUtil::GetWithDefaultLayout(result_shapes[i]); + callback_results[i].expected_dims.resize(shape.dimensions_size()); + absl::c_copy(shape.dimensions(), + callback_results[i].expected_dims.begin()); + callback_results[i].expected_strides = ByteStridesForShape(shape); + callback_results[i].type = shape.element_type(); + callback_results[i].size_in_bytes = ShapeUtil::ByteSizeOf(shape); + callback_results[i].reversed_layout.resize(shape.dimensions_size()); + absl::c_reverse_copy(shape.layout().minor_to_major(), + callback_results[i].reversed_layout.begin()); + } else if (result_shapes[i].IsToken()) { + callback_results[i].type = TOKEN; + } else { + return InvalidArgument( + "Only array and token return values from Python callbacks are " + "supported, got %s", + result_shapes[i].ToString()); + } + } + return callback_results; +} + +} // namespace + +PyFfiLoadedHostCallback::~PyFfiLoadedHostCallback() { + // The destructor may be called without GIL held. In that case, we defer it + // to GlobalPyRefManager. + std::vector objects; + objects.push_back(std::move(callable_)); + GlobalPyRefManager()->AddGarbage(absl::MakeSpan(objects)); +} + +absl::StatusOr> +PyCpuLoadedHostCallback::Create(ifrt::Client* ifrt_client, + nb::callable callable, + absl::Span operand_shapes, + absl::Span result_shapes) { + ifrt::PlatformId platform_id = ifrt_client->platform_id(); + if (platform_id != CpuId() && platform_id != CudaId() && + platform_id != RocmId() && platform_id != SyclId()) { + return Unimplemented("CpuCallback supports CPU and GPU only"); + } + + TF_ASSIGN_OR_RETURN(auto callback_args, CreateCallbackArgs(operand_shapes)); + TF_ASSIGN_OR_RETURN(auto callback_results, + CreateCallbackResults(result_shapes)); + + // `callable` will be destroyed safely with `PythonRefManager` when + // `CpuCallback` is destroyed. + auto cpu_callback = std::make_unique( + std::move(callable), callback_args, callback_results); + return tsl::RCReference( + tsl::MakeRef(ifrt_client, + std::move(cpu_callback))); +} + +absl::StatusOr PyCpuLoadedHostCallback::Serialize() const { + return Unimplemented( + "PyCpuLoadedHostCallback serialization is not supported"); +} + +absl::StatusOr> +PyHostSendAndRecvLoadedHostCallback::Create( + ifrt::Client* ifrt_client, nb::callable callable, + absl::Span operand_shapes, + absl::Span result_shapes, + absl::Span send_channel_ids, + absl::Span recv_channel_ids, nb::callable serializer) { + TF_ASSIGN_OR_RETURN(auto callback_args, CreateCallbackArgs(operand_shapes)); + TF_ASSIGN_OR_RETURN(auto callback_results, + CreateCallbackResults(result_shapes)); + + // `callable` will be destroyed safely with `PythonRefManager` when + // `CpuCallback` is destroyed. + auto cpu_callback = + std::make_shared(callable, callback_args, callback_results); + + auto host_callback = std::make_unique(); + + auto assign_arg_info = [](absl::Span shapes, + absl::Span channel_ids, + std::vector& arg_infos) { + DCHECK_EQ(shapes.size(), channel_ids.size()); + arg_infos.reserve(shapes.size()); + for (int i = 0; i < shapes.size(); ++i) { + HostCallbackArgInfo host_callback_arg_info; + host_callback_arg_info.channel_id = channel_ids[i]; + const auto& shape = shapes[i]; + Shape layout = + (shape.has_layout() ? shape + : LayoutUtil::GetWithDefaultLayout(shape)); + host_callback_arg_info.shape = layout; + arg_infos.push_back(std::move(host_callback_arg_info)); + } + }; + + assign_arg_info(operand_shapes, send_channel_ids, host_callback->operands); + assign_arg_info(result_shapes, recv_channel_ids, host_callback->results); + + host_callback->callback = [cpu_callback = std::move(cpu_callback)]( + void** outputs, void** inputs) { + return cpu_callback->PrepareAndCall(outputs, inputs); + }; + return tsl::RCReference( + tsl::MakeRef( + ifrt_client, std::move(host_callback), callable, operand_shapes, + result_shapes, send_channel_ids, recv_channel_ids, + std::move(serializer))); +} + +PyHostSendAndRecvLoadedHostCallback::PyHostSendAndRecvLoadedHostCallback( + ifrt::Client* ifrt_client, + std::unique_ptr xla_host_callback, nb::callable callable, + absl::Span operand_shapes, + absl::Span result_shapes, + absl::Span send_channel_ids, + absl::Span recv_channel_ids, nb::callable serializer) + : llvm::RTTIExtends( + ifrt_client, std::move(xla_host_callback)), + callable_(std::move(callable)), + operand_shapes_(operand_shapes.begin(), operand_shapes.end()), + result_shapes_(result_shapes.begin(), result_shapes.end()), + send_channel_ids_(send_channel_ids.begin(), send_channel_ids.end()), + recv_channel_ids_(recv_channel_ids.begin(), recv_channel_ids.end()), + serializer_(serializer) {} + +PyHostSendAndRecvLoadedHostCallback::~PyHostSendAndRecvLoadedHostCallback() { + GlobalPyRefManager()->AddGarbage( + absl::MakeSpan(static_cast(&callable_), 1)); + GlobalPyRefManager()->AddGarbage( + absl::MakeSpan(static_cast(&serializer_), 1)); +} + +absl::StatusOr PyHostSendAndRecvLoadedHostCallback::Serialize() + const { + if (serializer_.is_none()) { + return InvalidArgument( + "Host callback cannot be serialized because serializer was not " + "provided by JAX"); + } + ifrt::XlaHostCallbackProto xla_host_callback_proto; + + TF_RET_CHECK(operand_shapes_.size() == send_channel_ids_.size()); + for (int i = 0; i < operand_shapes_.size(); ++i) { + ifrt::XlaHostCallbackProto::ArgInfo* const operand = + xla_host_callback_proto.add_operands(); + operand->set_channel_id(send_channel_ids_[i]); + *operand->mutable_shape() = operand_shapes_[i].ToProto(); + } + + TF_RET_CHECK(result_shapes_.size() == recv_channel_ids_.size()); + for (int i = 0; i < result_shapes_.size(); ++i) { + ifrt::XlaHostCallbackProto::ArgInfo* const result = + xla_host_callback_proto.add_results(); + result->set_channel_id(recv_channel_ids_[i]); + *result->mutable_shape() = result_shapes_[i].ToProto(); + } + + std::string callable; + { + nb::gil_scoped_acquire gil_acquire; + try { + nb::bytes bytes = nb::cast(serializer_(callable_)); + callable = std::string(bytes.c_str(), bytes.size()); + } catch (const nb::python_error& e) { + return absl::InternalError(absl::StrCat( + "Unable to pickle the host_callback callable: ", e.what())); + } catch (const std::exception& e) { + std::exception_ptr p = std::current_exception(); + return absl::InternalError(absl::StrCat( + "Exception while pickling the host_callback callable: ", e.what())); + } catch (...) { + // Ensure to avoid leaking any exception because this method could have + // been called outside of a Python context where C++ exceptions are not + // necessarily enabled. + return absl::InternalError( + "Unknown exception while pickling the host_callback callable."); + } + } + PyHostCallbackProto py_host_callback_proto; + py_host_callback_proto.set_callable(std::move(callable)); + if (!xla_host_callback_proto.mutable_serialized_callback()->PackFrom( + py_host_callback_proto)) { + return absl::InternalError("Could not serialize a Python host callback"); + } + xla_host_callback_proto.set_use_major_to_minor_data_layout_for_callbacks( + true); + return xla_host_callback_proto.SerializeAsString(); +} + +} // namespace xla diff --git a/jaxlib/xla/py_host_callback.h b/jaxlib/xla/py_host_callback.h new file mode 100644 index 000000000000..da504d0c12ca --- /dev/null +++ b/jaxlib/xla/py_host_callback.h @@ -0,0 +1,170 @@ +/* Copyright 2023 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_PY_HOST_CALLBACK_H_ +#define JAXLIB_XLA_PY_HOST_CALLBACK_H_ + +#include +#include +#include +#include +#include + +#include "absl/base/casts.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "llvm/Support/ExtensibleRTTI.h" +#include "nanobind/nanobind.h" +#include "jaxlib/xla/callback.h" +#include "xla/pjrt/host_callback.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/host_callback.h" +#include "xla/python/pjrt_ifrt/pjrt_host_callback.h" +#include "xla/shape.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/util.h" + +namespace xla { + +using PyLoadedHostCallback = ::xla::ifrt::LoadedHostCallback; + +class PyFfiLoadedHostCallback final + : public llvm::RTTIExtends { + public: + PyFfiLoadedHostCallback(ifrt::Client* ifrt_client, + nanobind::callable callable) + : llvm::RTTIExtends(ifrt_client, + callable.ptr()), + callable_(std::move(callable)) {} + ~PyFfiLoadedHostCallback() override; + + ifrt::Client* client() const override { return ifrt_client_; } + absl::StatusOr Serialize() const override { + return Unimplemented( + "PyCpuLoadedHostCallback::callback_data() is not supported"); + }; + + static char ID; // NOLINT + + private: + ifrt::Client* ifrt_client_; + nanobind::callable callable_; +}; + +// `PyCpuLoadedHostCallback` implements a Python host callback that uses a +// descriptor (a raw pointer to JAX `CpuCallback`). The descriptor should be +// passed into a 'xla_python_cpu_callback' or 'xla_python_gpu_callback' +// CustomCall as its first argument. +// +// Serialization is not supported. Once the descriptor is embedded in +// CustomCall in an XLA computation, the computation will not be serializable. +class PyCpuLoadedHostCallback final + : public llvm::RTTIExtends { + public: + static absl::StatusOr> Create( + ifrt::Client* ifrt_client, nanobind::callable callable, + absl::Span operand_shapes, + absl::Span result_shapes); + + // Returns the descriptor of `CpuCallback`. + uint64_t descriptor() const { + return absl::bit_cast(cpu_callback_.get()); + } + + CpuCallback* cpu_callback() { return cpu_callback_.get(); } + + // LoadedHostCallback implementation. + + ~PyCpuLoadedHostCallback() override = default; + + ifrt::Client* client() const override { return ifrt_client_; } + + absl::StatusOr Serialize() const override; + + static char ID; // NOLINT + + private: + PyCpuLoadedHostCallback(ifrt::Client* ifrt_client, + std::unique_ptr cpu_callback) + : llvm::RTTIExtends( + ifrt_client, cpu_callback->callback()), + cpu_callback_(std::move(cpu_callback)) {} + + template + friend tsl::RCReference tsl::MakeRef(Args&&... args); + + ifrt::Client* ifrt_client_; + std::unique_ptr cpu_callback_; +}; + +// `PyHostSendAndRecvLoadedHostCallback` implements a Python host callback that +// uses XLA host send and recv. This object should be passed to the compiler +// when creating `xla::ifrt::LoadedExecutable`. +// +// Serialization is supported if the Python host callback using the +// `cloudpickle` third-party library. +// +// TODO(hyeontaek): Update the comment ("compiler" to "client") after splitting +// compilation and loading. +class PyHostSendAndRecvLoadedHostCallback final + : public llvm::RTTIExtends { + public: + static absl::StatusOr> + Create(ifrt::Client* ifrt_client, nanobind::callable callable, + absl::Span operand_shapes, + absl::Span result_shapes, + absl::Span send_channel_ids, + absl::Span recv_channel_ids, + nanobind::callable serializer); + + // PjRtLoadedHostCallback implementation. + + ~PyHostSendAndRecvLoadedHostCallback() override; + + absl::StatusOr Serialize() const override; + + static char ID; // NOLINT + + private: + PyHostSendAndRecvLoadedHostCallback( + ifrt::Client* ifrt_client, + std::unique_ptr xla_host_callback, + nanobind::callable callable, absl::Span operand_shapes, + absl::Span result_shapes, + absl::Span send_channel_ids, + absl::Span recv_channel_ids, + nanobind::callable serializer); + + template + friend tsl::RCReference tsl::MakeRef(Args&&... args); + + // Retained arguments for host callback serialization. + nanobind::callable callable_; + std::vector operand_shapes_; + std::vector result_shapes_; + std::vector send_channel_ids_; + std::vector recv_channel_ids_; + nanobind::callable serializer_; +}; + +} // namespace xla + +#endif // JAXLIB_XLA_PY_HOST_CALLBACK_H_ diff --git a/jaxlib/xla/py_host_callback.proto b/jaxlib/xla/py_host_callback.proto new file mode 100644 index 000000000000..997fc7fe450c --- /dev/null +++ b/jaxlib/xla/py_host_callback.proto @@ -0,0 +1,25 @@ +/* Copyright 2023 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla; + +// Represents a JAX host callback that is serialized using the 'cloudpickle' +// Python library. Typically used for +// `xla.ifrt.XlaHostCallbackProto.serialized_callback`. +message PyHostCallbackProto { + bytes callable = 1; +} diff --git a/jaxlib/xla/util.cc b/jaxlib/xla/util.cc new file mode 100644 index 000000000000..ef0fb2ac3afd --- /dev/null +++ b/jaxlib/xla/util.cc @@ -0,0 +1,60 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/util.h" + +#include + +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt/value.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/util.h" + +namespace xla { + +absl::Status AwaitBuffersReady(absl::Span ifrt_arrays) { + if (ifrt_arrays.empty()) { + return absl::OkStatus(); + } + + ifrt::Future<> future; + if (ifrt_arrays.size() == 1) { + future = ifrt_arrays[0]->GetReadyFuture(); + } else { + std::vector> values; + values.reserve(ifrt_arrays.size()); + for (ifrt::Array* const ifrt_array : ifrt_arrays) { + values.push_back(tsl::FormRef(ifrt_array)); + } + ifrt::Client* const client = ifrt_arrays.front()->client(); + future = client->GetReadyFuture(values); + } + + absl::Status s = future.Await(); + if (!s.ok()) { + // Fix up error string because some clients rely on it. + if (s.message() == "GetReadyFuture() called on deleted or donated buffer") { + s = InvalidArgument( + "BlockHostUntilReady() called on deleted or donated buffer"); + } + } + return s; +} + +} // namespace xla diff --git a/jaxlib/xla/util.h b/jaxlib/xla/util.h new file mode 100644 index 000000000000..ef5fc735fc33 --- /dev/null +++ b/jaxlib/xla/util.h @@ -0,0 +1,31 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_UTIL_H_ +#define JAXLIB_XLA_UTIL_H_ + +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "xla/python/ifrt/array.h" + +namespace xla { + +// Requests if given buffers are ready, awaits for results and returns OK if +// all of the buffers are ready or the last non-ok status. +absl::Status AwaitBuffersReady(absl::Span ifrt_arrays); + +} // namespace xla + +#endif // JAXLIB_XLA_UTIL_H_ diff --git a/jaxlib/xla/xla.cc b/jaxlib/xla/xla.cc index 0e1ba031670f..a0508013910b 100644 --- a/jaxlib/xla/xla.cc +++ b/jaxlib/xla/xla.cc @@ -87,6 +87,7 @@ limitations under the License. #include "jaxlib/xla/config.h" #include "jaxlib/xla/custom_call_sharding.h" #include "jaxlib/xla/dlpack.h" +#include "jaxlib/xla/guard_lib.h" #include "jaxlib/xla/jax_jit.h" #include "jaxlib/xla/mlir.h" #include "jaxlib/xla/pjit.h" @@ -109,7 +110,6 @@ limitations under the License. #include "xla/pjrt/pjrt_common.h" #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/pjrt_layout.h" -#include "xla/python/guard_lib.h" #include "xla/python/logging.h" // IWYU pragma: keep #include "xla/python/nb_absl_flat_hash_map.h" // IWYU pragma: keep #include "xla/python/nb_absl_span.h" // IWYU pragma: keep From 087a38988c5608c300119ff16dce01132a931951 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 25 Mar 2025 16:41:37 -0700 Subject: [PATCH 164/483] [sharding_in_types] Add `out_sharding` to `jax.random.uniform`. Drop into `Auto` mode inside for implementation. Co-authored-by: Roy Frostig PiperOrigin-RevId: 740529498 --- jax/_src/random.py | 16 +++++++++++++--- tests/pjit_test.py | 19 +++++++++++++++++++ 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/jax/_src/random.py b/jax/_src/random.py index c0663dc67f80..2d315ed0cc8b 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -38,6 +38,8 @@ from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir +from jax._src.sharding_impls import canonicalize_sharding +from jax._src.pjit import auto_axes from jax._src.lax import lax as lax_internal from jax._src.numpy.lax_numpy import _convert_and_clip_integer from jax._src.numpy.util import _arraylike, check_arraylike, promote_dtypes_inexact @@ -379,7 +381,8 @@ def uniform(key: ArrayLike, shape: Shape = (), dtype: DTypeLikeFloat = float, minval: RealArray = 0., - maxval: RealArray = 1.) -> Array: + maxval: RealArray = 1., + out_sharding=None) -> Array: """Sample uniform random values in [minval, maxval) with given shape/dtype. Args: @@ -397,14 +400,21 @@ def uniform(key: ArrayLike, key, _ = _check_prng_key("uniform", key) dtypes.check_user_dtype_supported(dtype) shape = core.canonicalize_shape(shape) + out_sharding = canonicalize_sharding(out_sharding, "uniform") if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `uniform` must be a float dtype, " f"got {dtype}") dtype = dtypes.canonicalize_dtype(dtype) - return _uniform(key, shape, dtype, minval, maxval) + return _uniform_auto(key, shape, dtype, minval, maxval, out_sharding) + +@partial(jit, static_argnums=(1, 2, 5)) +def _uniform_auto(key, shape, dtype, minval, maxval, out_sharding) -> Array: + if out_sharding is None: + return _uniform(key, shape, dtype, minval, maxval) + def f(key, minval, maxval): return _uniform(key, shape, dtype, minval, maxval) + return auto_axes(f, out_shardings=out_sharding)(key, minval, maxval) -@partial(jit, static_argnums=(1, 2)) def _uniform(key, shape, dtype, minval, maxval) -> Array: _check_shape("uniform", shape) if not jnp.issubdtype(dtype, np.floating): diff --git a/tests/pjit_test.py b/tests/pjit_test.py index ed1f9e9b62d8..528384358351 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7226,6 +7226,25 @@ def f(key): out = f(key) self.assertEqual(out.sharding, NamedSharding(mesh, P())) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_random_uniform(self, mesh): + @jax.jit + def f(key): + out = jax.random.uniform(key, shape=(8, 12), out_sharding=P('x', 'y')) + self.assertEqual(out.aval.sharding.spec, P('x', 'y')) + return out + + key = jax.random.key(1) + out = f(key) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + + lowered_text = f.lower(key).as_text() + if config.use_shardy_partitioner.value: + self.assertIn('sdy.sharding_constraint', lowered_text) + self.assertIn('<@mesh, [{"x"}, {"y"}]>', lowered_text) + else: + self.assertIn('mhlo.sharding = "{devices=[2,2]<=[4]}"}', lowered_text) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): From f1a92411872dcb43c1f709701f1163bbf23299be Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 25 Mar 2025 17:02:45 -0700 Subject: [PATCH 165/483] Add standard_insert_broadcasts to all traceables in lax.py and checks in abstract_eval rules of those primitives. PiperOrigin-RevId: 740536031 --- jax/_src/lax/lax.py | 182 ++++++++++++++++++++++++++++++++++------- jax/_src/lax/linalg.py | 14 +++- jax/_src/lax/utils.py | 10 ++- 3 files changed, 172 insertions(+), 34 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index dd6e7399321b..655ef763f1ef 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -296,6 +296,7 @@ def neg(x: ArrayLike) -> Array: .. _stablehlo.negate: https://openxla.org/stablehlo/spec#negate """ + x, = core.standard_insert_pbroadcast(x) return neg_p.bind(x) @export @@ -339,6 +340,7 @@ def sign(x: ArrayLike) -> Array: .. _stablehlo.sign: https://openxla.org/stablehlo/spec#sign """ + x, = core.standard_insert_pbroadcast(x) return sign_p.bind(x) @export @@ -369,6 +371,7 @@ def nextafter(x1: ArrayLike, x2: ArrayLike) -> Array: For the smallest usable (i.e. normal) float, use ``tiny`` of ``jnp.finfo``. """ + x1, x2 = core.standard_insert_pbroadcast(x1, x2) return nextafter_p.bind(x1, x2) @export @@ -390,6 +393,7 @@ def floor(x: ArrayLike) -> Array: .. _stablehlo.floor: https://openxla.org/stablehlo/spec#floor """ + x, = core.standard_insert_pbroadcast(x) return floor_p.bind(x) @export @@ -411,6 +415,7 @@ def ceil(x: ArrayLike) -> Array: .. _stablehlo.ceil: https://openxla.org/stablehlo/spec#ceil """ + x, = core.standard_insert_pbroadcast(x) return ceil_p.bind(x) class RoundingMethod(enum.IntEnum): @@ -460,6 +465,7 @@ def round(x: ArrayLike, .. _stablehlo.round: https://openxla.org/stablehlo/spec#round """ rounding_method = RoundingMethod(rounding_method) + x, = core.standard_insert_pbroadcast(x) return round_p.bind(x, rounding_method=rounding_method) @export @@ -481,6 +487,7 @@ def is_finite(x: ArrayLike) -> Array: .. _stablehlo.is_finite: https://openxla.org/stablehlo/spec#is_finite """ + x, = core.standard_insert_pbroadcast(x) return is_finite_p.bind(x) @export @@ -502,6 +509,7 @@ def exp(x: ArrayLike) -> Array: .. _stablehlo.exponential: https://openxla.org/stablehlo/spec#exponential """ + x, = core.standard_insert_pbroadcast(x) return exp_p.bind(x) @export @@ -525,6 +533,7 @@ def exp2(x: ArrayLike) -> Array: .. _stablehlo.exponential: https://openxla.org/stablehlo/spec#exponential .. _stablehlo.multiply: https://openxla.org/stablehlo/spec#multiply """ + x, = core.standard_insert_pbroadcast(x) return exp2_p.bind(x) @export @@ -548,6 +557,7 @@ def expm1(x: ArrayLike) -> Array: .. _stablehlo.exponential_minus_one: https://openxla.org/stablehlo/spec#exponential_minus_one """ + x, = core.standard_insert_pbroadcast(x) return expm1_p.bind(x) @export @@ -568,6 +578,7 @@ def log(x: ArrayLike) -> Array: .. _stablehlo.log: https://openxla.org/stablehlo/spec#log """ + x, = core.standard_insert_pbroadcast(x) return log_p.bind(x) @export @@ -591,6 +602,7 @@ def log1p(x: ArrayLike) -> Array: .. _stablehlo.log_plus_one: https://openxla.org/stablehlo/spec#log_plus_one """ + x, = core.standard_insert_pbroadcast(x) return log1p_p.bind(x) @export @@ -613,6 +625,7 @@ def tanh(x: ArrayLike) -> Array: .. _stablehlo.tanh: https://openxla.org/stablehlo/spec#tanh """ + x, = core.standard_insert_pbroadcast(x) return tanh_p.bind(x) @export @@ -632,6 +645,7 @@ def logistic(x: ArrayLike) -> Array: See also: - :func:`jax.nn.sigmoid`: an alternative API for this functionality. """ + x, = core.standard_insert_pbroadcast(x) return logistic_p.bind(x) @export @@ -656,6 +670,7 @@ def sin(x: ArrayLike) -> Array: .. _stablehlo.sine: https://openxla.org/stablehlo/spec#sine """ + x, = core.standard_insert_pbroadcast(x) return sin_p.bind(x) @export @@ -680,6 +695,7 @@ def cos(x: ArrayLike) -> Array: .. _stablehlo.cosine: https://openxla.org/stablehlo/spec#cosine """ + x, = core.standard_insert_pbroadcast(x) return cos_p.bind(x) @export @@ -704,6 +720,7 @@ def atan2(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.atan2: https://openxla.org/stablehlo/spec#atan2 """ + x, y = core.standard_insert_pbroadcast(x, y) return atan2_p.bind(x, y) @export @@ -726,6 +743,7 @@ def real(x: ArrayLike) -> Array: .. _stablehlo.real: https://openxla.org/stablehlo/spec#real """ + x, = core.standard_insert_pbroadcast(x) return real_p.bind(x) @export @@ -748,6 +766,7 @@ def imag(x: ArrayLike) -> Array: .. _stablehlo.imag: https://openxla.org/stablehlo/spec#imag """ + x, = core.standard_insert_pbroadcast(x) return imag_p.bind(x) @export @@ -773,6 +792,7 @@ def complex(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.complex: https://openxla.org/stablehlo/spec#complex """ + x, y = core.standard_insert_pbroadcast(x, y) return complex_p.bind(x, y) @export @@ -799,6 +819,7 @@ def conj(x: ArrayLike) -> Array: .. _stablehlo.complex: https://openxla.org/stablehlo/spec#complex """ # TODO(mattjj): remove input_dtype, not needed anymore + x, = core.standard_insert_pbroadcast(x) return conj_p.bind(x, input_dtype=_dtype(x)) @export @@ -819,6 +840,7 @@ def abs(x: ArrayLike) -> Array: .. _stablehlo.abs: https://openxla.org/stablehlo/spec#abs """ + x, = core.standard_insert_pbroadcast(x) return abs_p.bind(x) @export @@ -844,6 +866,7 @@ def pow(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.convert: https://openxla.org/stablehlo/spec#convert .. _stablehlo.pow: https://openxla.org/stablehlo/spec#pow """ + x, y = core.standard_insert_pbroadcast(x, y) return pow_p.bind(x, y) @export @@ -865,6 +888,7 @@ def integer_pow(x: ArrayLike, y: int) -> Array: .. _stablehlo.multiply: https://openxla.org/stablehlo/spec#multiply """ + x, = core.standard_insert_pbroadcast(x) return integer_pow_p.bind(x, y=y) @export @@ -886,6 +910,7 @@ def sqrt(x: ArrayLike) -> Array: .. _stablehlo.sqrt: https://openxla.org/stablehlo/spec#sqrt """ + x, = core.standard_insert_pbroadcast(x) return sqrt_p.bind(x) @export @@ -908,6 +933,7 @@ def rsqrt(x: ArrayLike) -> Array: .. _stablehlo.rsqrt: https://openxla.org/stablehlo/spec#rsqrt """ + x, = core.standard_insert_pbroadcast(x) return rsqrt_p.bind(x) @export @@ -929,6 +955,7 @@ def cbrt(x: ArrayLike) -> Array: .. _stablehlo.cbrt: https://openxla.org/stablehlo/spec#cbrt """ + x, = core.standard_insert_pbroadcast(x) return cbrt_p.bind(x) @export @@ -953,6 +980,7 @@ def bitwise_not(x: ArrayLike) -> Array: .. _stablehlo.not: https://openxla.org/stablehlo/spec#not """ + x, = core.standard_insert_pbroadcast(x) return not_p.bind(x) @export @@ -979,6 +1007,7 @@ def bitwise_and(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.and: https://openxla.org/stablehlo/spec#and """ + x, y = core.standard_insert_pbroadcast(x, y) return and_p.bind(x, y) @export @@ -1005,6 +1034,7 @@ def bitwise_or(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.or: https://openxla.org/stablehlo/spec#or """ + x, y = core.standard_insert_pbroadcast(x, y) return or_p.bind(x, y) @export @@ -1031,6 +1061,7 @@ def bitwise_xor(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.xor: https://openxla.org/stablehlo/spec#xor """ + x, y = core.standard_insert_pbroadcast(x, y) return xor_p.bind(x, y) @export @@ -1052,6 +1083,7 @@ def population_count(x: ArrayLike) -> Array: .. _stablehlo.popcnt: https://openxla.org/stablehlo/spec#popcnt """ + x, = core.standard_insert_pbroadcast(x) return population_count_p.bind(x) @export @@ -1072,6 +1104,7 @@ def clz(x: ArrayLike) -> Array: .. _stablehlo.count_leading_zeros: https://openxla.org/stablehlo/spec#count_leading_zeros """ + x, = core.standard_insert_pbroadcast(x) return clz_p.bind(x) @export @@ -1095,6 +1128,7 @@ def add(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.add: https://openxla.org/stablehlo/spec#add """ + x, y = core.standard_insert_pbroadcast(x, y) return add_p.bind(x, y) @export @@ -1118,6 +1152,7 @@ def sub(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.subtract: https://openxla.org/stablehlo/spec#subtract """ + x, y = core.standard_insert_pbroadcast(x, y) return sub_p.bind(x, y) @export @@ -1171,6 +1206,7 @@ def div(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.divide: https://openxla.org/stablehlo/spec#divide """ + x, y = core.standard_insert_pbroadcast(x, y) return div_p.bind(x, y) @export @@ -1198,6 +1234,7 @@ def rem(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.remainder: https://openxla.org/stablehlo/spec#remainder """ + x, y = core.standard_insert_pbroadcast(x, y) return rem_p.bind(x, y) @export @@ -1223,6 +1260,7 @@ def max(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.maximum: https://openxla.org/stablehlo/spec#maximum """ + x, y = core.standard_insert_pbroadcast(x, y) return max_p.bind(x, y) @export @@ -1248,6 +1286,7 @@ def min(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.minimum: https://openxla.org/stablehlo/spec#minimum """ + x, y = core.standard_insert_pbroadcast(x, y) return min_p.bind(x, y) @export @@ -1273,6 +1312,7 @@ def shift_left(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.shift_left: https://openxla.org/stablehlo/spec#shift_left """ + x, y = core.standard_insert_pbroadcast(x, y) return shift_left_p.bind(x, y) @export @@ -1299,6 +1339,7 @@ def shift_right_arithmetic(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.shift_right_arithmetic: https://openxla.org/stablehlo/spec#shift_right_arithmetic """ + x, y = core.standard_insert_pbroadcast(x, y) return shift_right_arithmetic_p.bind(x, y) @export @@ -1325,6 +1366,7 @@ def shift_right_logical(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.shift_right_logical: https://openxla.org/stablehlo/spec#shift_right_logical """ + x, y = core.standard_insert_pbroadcast(x, y) return shift_right_logical_p.bind(x, y) @export @@ -1355,6 +1397,7 @@ def eq(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare """ + x, y = core.standard_insert_pbroadcast(x, y) return eq_p.bind(x, y) @export @@ -1385,6 +1428,7 @@ def ne(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare """ + x, y = core.standard_insert_pbroadcast(x, y) return ne_p.bind(x, y) @export @@ -1415,6 +1459,7 @@ def ge(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare """ + x, y = core.standard_insert_pbroadcast(x, y) return ge_p.bind(x, y) @export @@ -1445,6 +1490,7 @@ def gt(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare """ + x, y = core.standard_insert_pbroadcast(x, y) return gt_p.bind(x, y) @export @@ -1475,6 +1521,7 @@ def le(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare """ + x, y = core.standard_insert_pbroadcast(x, y) return le_p.bind(x, y) @export @@ -1505,6 +1552,7 @@ def lt(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare """ + x, y = core.standard_insert_pbroadcast(x, y) return lt_p.bind(x, y) @export @@ -1574,6 +1622,8 @@ def _convert_element_type( "Instead, convert to and from their representation dtypes, e.g.:\n" f"{dtype_to_string(old_dtype)} -> {dtype_to_string(old_rep_dtype)} " f"-> {dtype_to_string(new_rep_dtype)} -> {dtype_to_string(new_dtype)}") + + operand, = core.standard_insert_pbroadcast(operand) if isinstance(new_dtype, dtypes.ExtendedDType): return to_edtype_p.bind(operand, edtype=new_dtype) return from_edtype_p.bind(operand, dtype=np.dtype(new_dtype)) @@ -1649,6 +1699,7 @@ def bitcast_convert_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array: .. _stablehlo.bitcast_convert: https://openxla.org/stablehlo/spec#bitcast_convert """ new_dtype = dtypes.canonicalize_dtype(new_dtype) + operand, = core.standard_insert_pbroadcast(operand) return bitcast_convert_type_p.bind(operand, new_dtype=new_dtype) def clamp(min: ArrayLike, x: ArrayLike, max: ArrayLike) -> Array: @@ -1660,6 +1711,7 @@ def clamp(min: ArrayLike, x: ArrayLike, max: ArrayLike) -> Array: x & \text{otherwise} \end{cases}`. """ + min, x, max = core.standard_insert_pbroadcast(min, x, max) return clamp_p.bind(min, x, max) @@ -1766,6 +1818,7 @@ def _decorator(*args, **kwargs): closed_jaxpr, out_tree = _trace_composite_to_jaxpr( partial(decomposition, **kwargs), in_tree, in_avals, name, debug_info ) + flat_args = core.standard_insert_pbroadcast(*flat_args) out_flat = composite_p.bind( *flat_args, name=name, @@ -1883,6 +1936,7 @@ def concatenate(operands: Array | Sequence[ArrayLike], dimension: int) -> Array: op, = operands if isinstance(op, Array): return op + operands = core.standard_insert_pbroadcast(*operands) return concatenate_p.bind(*operands, dimension=dimension) @@ -1902,6 +1956,7 @@ def split(operand: ArrayLike, sizes: Sequence[int], taken along ``axis``. """ operand = asarray(operand) + operand, = core.standard_insert_pbroadcast(operand) return split_p.bind(operand, sizes=tuple(sizes), axis=canonicalize_axis(axis, operand.ndim)) @@ -2408,6 +2463,7 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN preferred_element_type = ( None if preferred_element_type is None else dtypes.canonicalize_dtype(np.dtype(preferred_element_type))) + lhs, rhs = core.standard_insert_pbroadcast(lhs, rhs) return dot_general_p.bind(lhs, rhs, dimension_numbers=(cdims, bdims), precision=canonicalize_precision(precision), @@ -2543,6 +2599,7 @@ def ragged_dot_general( extra leading dimension of size `g` in the case where the lhs ragged dimension is a contracting dimension. """ + lhs, rhs, group_sizes = core.standard_insert_pbroadcast(lhs, rhs, group_sizes) return ragged_dot_general_p.bind( lhs, rhs, @@ -2605,6 +2662,7 @@ def broadcast_in_dim(operand: ArrayLike, shape: Shape, dyn_shape, static_shape = _extract_tracers_dyn_shape(shape) else: dyn_shape, static_shape = [], shape # type: ignore + operand, = core.standard_insert_pbroadcast(operand) return broadcast_in_dim_p.bind( operand, *dyn_shape, shape=tuple(static_shape), broadcast_dimensions=tuple(broadcast_dimensions), @@ -2671,6 +2729,7 @@ def reshape(operand: ArrayLike, new_sizes: Shape, else: dyn_shape, static_new_sizes = _extract_tracers_dyn_shape(new_sizes) out_sharding = canonicalize_sharding(out_sharding, 'reshape') + operand, = core.standard_insert_pbroadcast(operand) return reshape_p.bind( operand, *dyn_shape, new_sizes=tuple(static_new_sizes), dimensions=None if dims is None or same_dims else dims, @@ -2726,6 +2785,7 @@ def pad(operand: ArrayLike, padding_value: ArrayLike, [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], dtype=int32) """ + operand, padding_value = core.standard_insert_pbroadcast(operand, padding_value) return pad_p.bind(operand, padding_value, padding_config=tuple(padding_config)) def rev(operand: ArrayLike, dimensions: Sequence[int]) -> Array: @@ -2733,6 +2793,7 @@ def rev(operand: ArrayLike, dimensions: Sequence[int]) -> Array: `_ operator. """ + operand, = core.standard_insert_pbroadcast(operand) return rev_p.bind(operand, dimensions=tuple(dimensions)) def select(pred: ArrayLike, on_true: ArrayLike, on_false: ArrayLike) -> Array: @@ -2758,6 +2819,8 @@ def select(pred: ArrayLike, on_true: ArrayLike, on_false: ArrayLike) -> Array: """ # Caution! The select_n_p primitive has the *opposite* order of arguments to # select(). This is because it implements `select_n`. + pred, on_false, on_true = core.standard_insert_pbroadcast( + pred, on_false, on_true) return select_n_p.bind(pred, on_false, on_true) def select_n(which: ArrayLike, *cases: ArrayLike) -> Array: @@ -2783,6 +2846,7 @@ def select_n(which: ArrayLike, *cases: ArrayLike) -> Array: """ if len(cases) == 0: raise ValueError("select_n() must have at least one case") + which, *cases = core.standard_insert_pbroadcast(which, *cases) return select_n_p.bind(which, *cases) @@ -2796,17 +2860,20 @@ def transpose(operand: ArrayLike, if permutation == tuple(range(np.ndim(operand))) and isinstance(operand, Array): return operand else: + operand, = core.standard_insert_pbroadcast(operand) return transpose_p.bind(operand, permutation=permutation) def argmin(operand: ArrayLike, axis: int, index_dtype: DTypeLike) -> Array: """Computes the index of the minimum element along ``axis``.""" + operand, = core.standard_insert_pbroadcast(operand) return argmin_p.bind(operand, axes=(axis,), index_dtype=dtypes.canonicalize_dtype(index_dtype)) def argmax(operand: ArrayLike, axis: int, index_dtype: DTypeLike) -> Array: """Computes the index of the maximum element along ``axis``.""" + operand, = core.standard_insert_pbroadcast(operand) return argmax_p.bind(operand, axes=(axis,), index_dtype=dtypes.canonicalize_dtype(index_dtype)) @@ -2972,6 +3039,7 @@ def reduce_sum(operand: ArrayLike, axes: Sequence[int]) -> Array: :func:`jax.lax.reduce_prod`, :func:`jax.lax.reduce_max`, :func:`jax.lax.reduce_min`, :func:`jax.lax.reduce_and`, :func:`jax.lax.reduce_or`, :func:`jax.lax.reduce_xor`. """ + operand, = core.standard_insert_pbroadcast(operand) return reduce_sum_p.bind(operand, axes=tuple(axes)) def reduce_prod(operand: ArrayLike, axes: Sequence[int]) -> Array: @@ -2998,6 +3066,7 @@ def reduce_prod(operand: ArrayLike, axes: Sequence[int]) -> Array: :func:`jax.lax.reduce_sum`, :func:`jax.lax.reduce_max`, :func:`jax.lax.reduce_min`, :func:`jax.lax.reduce_and`, :func:`jax.lax.reduce_or`, :func:`jax.lax.reduce_xor`. """ + operand, = core.standard_insert_pbroadcast(operand) return reduce_prod_p.bind(operand, axes=tuple(axes)) def reduce_max(operand: ArrayLike, axes: Sequence[int]) -> Array: @@ -3019,6 +3088,7 @@ def reduce_max(operand: ArrayLike, axes: Sequence[int]) -> Array: :func:`jax.lax.reduce_sum`, :func:`jax.lax.reduce_prod`, :func:`jax.lax.reduce_min`, :func:`jax.lax.reduce_and`, :func:`jax.lax.reduce_or`, :func:`jax.lax.reduce_xor`. """ + operand, = core.standard_insert_pbroadcast(operand) return reduce_max_p.bind(operand, axes=tuple(axes)) def reduce_min(operand: ArrayLike, axes: Sequence[int]) -> Array: @@ -3040,6 +3110,7 @@ def reduce_min(operand: ArrayLike, axes: Sequence[int]) -> Array: :func:`jax.lax.reduce_sum`, :func:`jax.lax.reduce_prod`, :func:`jax.lax.reduce_max`, :func:`jax.lax.reduce_and`, :func:`jax.lax.reduce_or`, :func:`jax.lax.reduce_xor`. """ + operand, = core.standard_insert_pbroadcast(operand) return reduce_min_p.bind(operand, axes=tuple(axes)) def reduce_or(operand: ArrayLike, axes: Sequence[int]) -> Array: @@ -3062,6 +3133,7 @@ def reduce_or(operand: ArrayLike, axes: Sequence[int]) -> Array: :func:`jax.lax.reduce_sum`, :func:`jax.lax.reduce_prod`, :func:`jax.lax.reduce_max`, :func:`jax.lax.reduce_min`, :func:`jax.lax.reduce_and`, :func:`jax.lax.reduce_xor`. """ + operand, = core.standard_insert_pbroadcast(operand) return reduce_or_p.bind(operand, axes=tuple(axes)) def reduce_and(operand: ArrayLike, axes: Sequence[int]) -> Array: @@ -3084,6 +3156,7 @@ def reduce_and(operand: ArrayLike, axes: Sequence[int]) -> Array: :func:`jax.lax.reduce_sum`, :func:`jax.lax.reduce_prod`, :func:`jax.lax.reduce_max`, :func:`jax.lax.reduce_min`, :func:`jax.lax.reduce_or`, :func:`jax.lax.reduce_xor`. """ + operand, = core.standard_insert_pbroadcast(operand) return reduce_and_p.bind(operand, axes=tuple(axes)) def reduce_xor(operand: ArrayLike, axes: Sequence[int]) -> Array: @@ -3106,6 +3179,7 @@ def reduce_xor(operand: ArrayLike, axes: Sequence[int]) -> Array: :func:`jax.lax.reduce_sum`, :func:`jax.lax.reduce_prod`, :func:`jax.lax.reduce_max`, :func:`jax.lax.reduce_min`, :func:`jax.lax.reduce_and`, :func:`jax.lax.reduce_or`. """ + operand, = core.standard_insert_pbroadcast(operand) return reduce_xor_p.bind(operand, axes=tuple(axes)) @overload @@ -3143,6 +3217,7 @@ def sort(operand: Array | Sequence[Array], dimension: int = -1, if not (1 <= num_keys <= len(operand)): raise ValueError(f"{num_keys=} must be between 1 and {len(operand)=}") dimension = canonicalize_axis(dimension, len(operand[0].shape)) + operand = core.standard_insert_pbroadcast(*operand) return tuple(sort_p.bind(*operand, dimension=dimension, is_stable=is_stable, num_keys=num_keys)) @@ -3190,6 +3265,7 @@ def top_k(operand: ArrayLike, k: int) -> tuple[Array, Array]: k = int(k) if k < 0: raise ValueError(f"k argument to top_k must be nonnegative, got {k}") + operand, = core.standard_insert_pbroadcast(operand) return top_k_p.bind(operand, k=k) def tie_in(x: Any, y: T) -> T: @@ -3375,7 +3451,9 @@ def reduce_precision(operand: float | ArrayLike, operator.index, exponent_bits, "exponent_bits argument of lax.reduce_precision") mantissa_bits = core.concrete_or_error( operator.index, mantissa_bits, "mantissa_bits argument of lax.reduce_precision") - return reduce_precision_p.bind(operand, exponent_bits=exponent_bits, mantissa_bits=mantissa_bits) + operand, = core.standard_insert_pbroadcast(operand) + return reduce_precision_p.bind(operand, exponent_bits=exponent_bits, + mantissa_bits=mantissa_bits) def squeeze(array: ArrayLike, dimensions: Sequence[int]) -> Array: """Squeeze any number of size 1 dimensions from an array.""" @@ -3383,6 +3461,7 @@ def squeeze(array: ArrayLike, dimensions: Sequence[int]) -> Array: dimensions = tuple(sorted(canonicalize_axis(i, ndim) for i in dimensions)) if not dimensions and isinstance(array, Array): return array + array, = core.standard_insert_pbroadcast(array) return squeeze_p.bind(array, dimensions=dimensions) def expand_dims(array: ArrayLike, dimensions: Sequence[int]) -> Array: @@ -3503,6 +3582,7 @@ def batch_matmul(lhs: Array, rhs: Array, def square(x: ArrayLike) -> Array: r"""Elementwise square: :math:`x^2`.""" + x, = core.standard_insert_pbroadcast(x) return square_p.bind(x) def reciprocal(x: ArrayLike) -> Array: @@ -3530,6 +3610,7 @@ def tan(x: ArrayLike) -> Array: .. _stablehlo.tangent: https://openxla.org/stablehlo/spec#tangent """ + x, = core.standard_insert_pbroadcast(x) return tan_p.bind(x) @export @@ -3550,6 +3631,7 @@ def asin(x: ArrayLike) -> Array: - :func:`jax.lax.acos`: elementwise arc cosine. - :func:`jax.lax.atan`: elementwise arc tangent. """ + x, = core.standard_insert_pbroadcast(x) return asin_p.bind(x) @export @@ -3570,6 +3652,7 @@ def acos(x: ArrayLike) -> Array: - :func:`jax.lax.asin`: elementwise arc sine. - :func:`jax.lax.atan`: elementwise arc tangent. """ + x, = core.standard_insert_pbroadcast(x) return acos_p.bind(x) @export @@ -3591,6 +3674,7 @@ def atan(x: ArrayLike) -> Array: - :func:`jax.lax.asin`: elementwise arc sine. - :func:`jax.lax.atan2`: elementwise 2-term arc tangent. """ + x, = core.standard_insert_pbroadcast(x) return atan_p.bind(x) @export @@ -3611,6 +3695,7 @@ def sinh(x: ArrayLike) -> Array: - :func:`jax.lax.cosh`: elementwise hyperbolic cosine. - :func:`jax.lax.tanh`: elementwise hyperbolic tangent. """ + x, = core.standard_insert_pbroadcast(x) return sinh_p.bind(x) @export @@ -3631,6 +3716,7 @@ def cosh(x: ArrayLike) -> Array: - :func:`jax.lax.sinh`: elementwise hyperbolic sine. - :func:`jax.lax.tanh`: elementwise hyperbolic tangent. """ + x, = core.standard_insert_pbroadcast(x) return cosh_p.bind(x) @export @@ -3651,6 +3737,7 @@ def asinh(x: ArrayLike) -> Array: - :func:`jax.lax.atanh`: elementwise inverse hyperbolic tangent. - :func:`jax.lax.sinh`: elementwise hyperbolic sine. """ + x, = core.standard_insert_pbroadcast(x) return asinh_p.bind(x) @export @@ -3671,6 +3758,7 @@ def acosh(x: ArrayLike) -> Array: - :func:`jax.lax.atanh`: elementwise inverse hyperbolic tangent. - :func:`jax.lax.cosh`: elementwise hyperbolic cosine. """ + x, = core.standard_insert_pbroadcast(x) return acosh_p.bind(x) @export @@ -3691,6 +3779,7 @@ def atanh(x: ArrayLike) -> Array: - :func:`jax.lax.asinh`: elementwise inverse hyperbolic sine. - :func:`jax.lax.tanh`: elementwise hyperbolic tangent. """ + x, = core.standard_insert_pbroadcast(x) return atanh_p.bind(x) @@ -3759,7 +3848,8 @@ def unop_dtype_rule(result_dtype, accepted_dtypes, name, aval, **kwargs): def unop(result_dtype, accepted_dtypes, name): dtype_rule = partial(unop_dtype_rule, result_dtype, accepted_dtypes, name) prim = standard_primitive(_attrgetter('shape'), dtype_rule, name, - sharding_rule=_attrgetter('sharding')) + sharding_rule=_attrgetter('sharding'), + vma_rule=_attrgetter('vma')) batching.defvectorized(prim) pe.def_trivial_padding(prim) return prim @@ -4314,7 +4404,7 @@ def _integer_pow_jvp(g, x, *, y): integer_pow_p = standard_primitive( _attrgetter('shape'), _integer_pow_dtype_rule, 'integer_pow', - sharding_rule=_attrgetter('sharding')) + sharding_rule=_attrgetter('sharding'), vma_rule=_attrgetter('vma')) batching.defvectorized(integer_pow_p) ad.defjvp(integer_pow_p, _integer_pow_jvp) pe.def_trivial_padding(integer_pow_p) @@ -4883,7 +4973,8 @@ def _bitcast_convert_type_dtype_rule(operand, *, new_dtype): bitcast_convert_type_p = standard_primitive( _bitcast_convert_type_shape_rule, _bitcast_convert_type_dtype_rule, 'bitcast_convert_type', weak_type_rule=_strip_weak_type, - sharding_rule=_bitcast_convert_type_sharding_rule) + sharding_rule=_bitcast_convert_type_sharding_rule, + vma_rule=partial(standard_vma_rule, 'bitcast_convert_type')) ad.defjvp_zero(bitcast_convert_type_p) batching.defvectorized(bitcast_convert_type_p) @@ -5352,6 +5443,7 @@ def _dot_general_ragged_prop_rule(eqn_params, invar_raggedness, outvars): _dot_general_dtype_rule, 'dot_general', sharding_rule=_dot_general_sharding_rule, + vma_rule=partial(standard_vma_rule, 'dot_general') ) @@ -6352,7 +6444,10 @@ def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions, new_sharding = _broadcast_in_dim_sharding_rule( x, shape=shape, broadcast_dimensions=broadcast_dimensions, sharding=sharding) - return core.ShapedArray(shape, x.dtype, x.weak_type, sharding=new_sharding) + new_vma = (standard_vma_rule('broadcast_in_dim', x) + if config.varying_axes_in_types.value else frozenset()) + return core.ShapedArray(shape, x.dtype, x.weak_type, sharding=new_sharding, + vma=new_vma) # If any BInts in shape, or Tracers in dyn_shape, produce a DShapedArray # (even if x is a ShapedArray) # TODO(mattjj): unify DShapedArray with ShapedArray, and remove this code @@ -6436,7 +6531,8 @@ def _clamp_batch_rule(batched_args, batch_dims, **params): return clamp_p.bind(min, x, max), 0 clamp_p = standard_primitive(_clamp_shape_rule, _clamp_dtype_rule, 'clamp', - sharding_rule=_clamp_sharding_rule) + sharding_rule=_clamp_sharding_rule, + vma_rule=partial(standard_vma_rule, 'clamp')) ad.defjvp(clamp_p, lambda g, min, operand, max: select(bitwise_and(gt(min, operand), lt(min, max)), @@ -6523,7 +6619,8 @@ def _concatenate_pad_rule(in_avals, out_avals, *operands, dimension): concatenate_p = standard_primitive( _concatenate_shape_rule, _concatenate_dtype_rule, 'concatenate', - sharding_rule=_concatenate_sharding_rule) + sharding_rule=_concatenate_sharding_rule, + vma_rule=partial(standard_vma_rule, 'concatenate')) ad.deflinear2(concatenate_p, _concatenate_transpose_rule) ad.primitive_transposes[concatenate_p] = _concatenate_transpose_rule batching.primitive_batchers[concatenate_p] = _concatenate_batch_rule @@ -6595,11 +6692,17 @@ def _split_sharding_rule(operand, *, sizes, axis): return [slicing._get_sharding_for_varying_out_shape(out_sh, operand, 'split') for out_sh in out_shapes] +def _split_vma_rule(operand, *, sizes, axis): + out_vma = standard_vma_rule('split', operand) + out_shapes = _split_shape_rule(operand, sizes=sizes, axis=axis) + return [out_vma] * len(out_shapes) + split_p = core.Primitive('split') split_p.multiple_results = True split_p.def_abstract_eval( partial(standard_multi_result_abstract_eval, split_p, _split_shape_rule, - _split_dtype_rule, _split_weak_type_rule, _split_sharding_rule)) + _split_dtype_rule, _split_weak_type_rule, _split_sharding_rule, + _split_vma_rule)) split_p.def_impl(partial(dispatch.apply_primitive, split_p)) ad.deflinear2(split_p, _split_transpose_rule) batching.primitive_batchers[split_p] = _split_batch_rule @@ -6681,7 +6784,8 @@ def _pad_batch_rule(batched_args, batch_dims, *, padding_config): return select(mask, x, broadcasted_padding), operand_bdim pad_p = standard_primitive(_pad_shape_rule, _pad_dtype_rule, 'pad', - sharding_rule=_pad_sharding_rule) + sharding_rule=_pad_sharding_rule, + vma_rule=partial(standard_vma_rule, 'pad')) ad.deflinear2(pad_p, _pad_transpose) batching.primitive_batchers[pad_p] = _pad_batch_rule @@ -6745,7 +6849,8 @@ def _squeeze_batch_rule(batched_args, batch_dims, *, dimensions): return squeeze(operand, dimensions=dimensions), bdim_out squeeze_p = standard_primitive(_squeeze_shape_rule, _squeeze_dtype_rule, - 'squeeze', sharding_rule=_squeeze_sharding_rule) + 'squeeze', sharding_rule=_squeeze_sharding_rule, + vma_rule=partial(standard_vma_rule, 'squeeze')) ad.deflinear2(squeeze_p, _squeeze_transpose_rule) batching.primitive_batchers[squeeze_p] = _squeeze_batch_rule pe.def_trivial_padding(squeeze_p) @@ -6979,7 +7084,8 @@ def _reshape_staging_rule( return _dyn_shape_staging_rule(trace, reshape_p, av, x, *dyn, **params) reshape_p = standard_primitive(_reshape_shape_rule, _reshape_dtype_rule, - 'reshape', sharding_rule=_reshape_sharding_rule) + 'reshape', sharding_rule=_reshape_sharding_rule, + vma_rule=partial(standard_vma_rule, 'reshape')) ad.deflinear2(reshape_p, _reshape_transpose_rule) batching.fancy_primitive_batchers[reshape_p] = _reshape_batch_rule batching.skippable_batchers[reshape_p] = lambda _: () @@ -7011,7 +7117,8 @@ def _rev_batch_rule(batched_args, batch_dims, *, dimensions): return rev(operand, new_dimensions), bdim rev_p = standard_primitive(_rev_shape_rule, _input_dtype, 'rev', - sharding_rule=_rev_sharding_rule) + sharding_rule=_rev_sharding_rule, + vma_rule=partial(standard_vma_rule, 'rev')) ad.deflinear2(rev_p, lambda t, _, dimensions: [rev(t, dimensions)]) batching.primitive_batchers[rev_p] = _rev_batch_rule @@ -7059,7 +7166,8 @@ def _transpose_lower(ctx, x, *, permutation): transpose_p = standard_primitive( _transpose_shape_rule, _input_dtype, 'transpose', - sharding_rule=_transpose_sharding_rule) + sharding_rule=_transpose_sharding_rule, + vma_rule=partial(standard_vma_rule, 'transpose')) ad.deflinear2(transpose_p, lambda t, _, permutation: [transpose(t, np.argsort(permutation))]) batching.primitive_batchers[transpose_p] = _transpose_batch_rule @@ -7235,7 +7343,8 @@ def _select(offset, cases): select_n_p = standard_primitive( _select_shape_rule, _select_dtype_rule, 'select_n', - weak_type_rule=_select_weak_type_rule, sharding_rule=_select_sharding_rule) + weak_type_rule=_select_weak_type_rule, sharding_rule=_select_sharding_rule, + vma_rule=partial(standard_vma_rule, 'select_n')) ad.primitive_jvps[select_n_p] = _select_jvp ad.primitive_transposes[select_n_p] = _select_transpose_rule batching.primitive_batchers[select_n_p] = _select_batch_rule @@ -7341,7 +7450,8 @@ def _reduce_jvp_rule(primals, tangents, *, computation, jaxpr, dimensions): reduce_p.def_impl(partial(dispatch.apply_primitive, reduce_p)) reduce_p.def_abstract_eval( partial(standard_multi_result_abstract_eval, reduce_p, _reduce_shape_rule, - _reduce_dtype_rule, _reduce_weak_type_rule, _reduce_sharding_rule)) + _reduce_dtype_rule, _reduce_weak_type_rule, _reduce_sharding_rule, + None)) batching.primitive_batchers[reduce_p] = _reduce_batch_rule ad.primitive_jvps[reduce_p] = _reduce_jvp_rule @@ -7415,7 +7525,8 @@ def _reduce_op_sharding_rule(operand, *, axes): reduce_sum_p = standard_primitive( _reduce_op_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_sum'), - 'reduce_sum', sharding_rule=_reduce_op_sharding_rule) + 'reduce_sum', sharding_rule=_reduce_op_sharding_rule, + vma_rule=partial(standard_vma_rule, 'reduce_sum')) ad.deflinear2(reduce_sum_p, _reduce_sum_transpose_rule) batching.defreducer(reduce_sum_p, _get_sum_identity) pe.padding_rules[reduce_sum_p] = partial(_reducer_padding, reduce_sum, @@ -7430,7 +7541,8 @@ def _reduce_prod_jvp_rule(primals, tangents, *, axes): reduce_prod_p = standard_primitive( _reduce_op_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_prod'), - 'reduce_prod', sharding_rule=_reduce_op_sharding_rule) + 'reduce_prod', sharding_rule=_reduce_op_sharding_rule, + vma_rule=partial(standard_vma_rule, 'reduce_prod')) ad.primitive_jvps[reduce_prod_p] = _reduce_prod_jvp_rule batching.defreducer(reduce_prod_p, _get_prod_identity) pe.padding_rules[reduce_prod_p] = partial(_reducer_padding, reduce_prod, @@ -7450,7 +7562,8 @@ def _reduce_chooser_jvp_rule(g, ans, operand, *, axes): reduce_max_p = standard_primitive( _reduce_op_shape_rule, _input_dtype, 'reduce_max', - sharding_rule=_reduce_op_sharding_rule) + sharding_rule=_reduce_op_sharding_rule, + vma_rule=partial(standard_vma_rule, 'reduce_max')) ad.defjvp2(reduce_max_p, _reduce_chooser_jvp_rule) batching.defreducer(reduce_max_p, _get_max_identity) pe.padding_rules[reduce_max_p] = partial(_reducer_padding, reduce_max, @@ -7460,7 +7573,8 @@ def _reduce_chooser_jvp_rule(g, ans, operand, *, axes): reduce_min_p = standard_primitive( _reduce_op_shape_rule, _input_dtype, 'reduce_min', - sharding_rule=_reduce_op_sharding_rule) + sharding_rule=_reduce_op_sharding_rule, + vma_rule=partial(standard_vma_rule, 'reduce_min')) ad.defjvp2(reduce_min_p, _reduce_chooser_jvp_rule) batching.defreducer(reduce_min_p, _get_min_identity) pe.padding_rules[reduce_min_p] = partial(_reducer_padding, reduce_min, @@ -7527,13 +7641,15 @@ def _compute_argminmax(value_comparator, get_identity, argmin_p = standard_primitive(_argminmax_shape_rule, _argminmax_dtype_rule, 'argmin', weak_type_rule=_strip_weak_type, - sharding_rule=_argminmax_sharding_rule) + sharding_rule=_argminmax_sharding_rule, + vma_rule=partial(standard_vma_rule, 'argmin')) batching.defreducer(argmin_p, _get_min_identity) ad.defjvp_zero(argmin_p) argmax_p = standard_primitive(_argminmax_shape_rule, _argminmax_dtype_rule, 'argmax', weak_type_rule=_strip_weak_type, - sharding_rule=_argminmax_sharding_rule) + sharding_rule=_argminmax_sharding_rule, + vma_rule=partial(standard_vma_rule, 'argmax')) batching.defreducer(argmax_p, _get_max_identity) ad.defjvp_zero(argmax_p) @@ -7556,20 +7672,23 @@ def _reduce_logical_sharding_rule(operand, *, axes): reduce_or_p = standard_primitive( _reduce_logical_shape_rule, _input_dtype, 'reduce_or', - weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule) + weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule, + vma_rule=partial(standard_vma_rule, 'reduce_or')) batching.defreducer(reduce_or_p, _get_bitwise_or_identity) reduce_and_p = standard_primitive( _reduce_logical_shape_rule, _input_dtype, 'reduce_and', - weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule) + weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule, + vma_rule=partial(standard_vma_rule, 'reduce_and')) batching.defreducer(reduce_and_p, _get_bitwise_and_identity) batching.ragged_prop_rules[reduce_and_p] = batching.ragged_mask_elementwise_rule reduce_xor_p = standard_primitive( _reduce_logical_shape_rule, _input_dtype, 'reduce_xor', - weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule) + weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule, + vma_rule=partial(standard_vma_rule, 'reduce_xor')) batching.defreducer(reduce_xor_p, _get_bitwise_or_identity) @@ -7616,7 +7735,8 @@ def _reduce_precision_sharding_rule(operand, *, exponent_bits, mantissa_bits): reduce_precision_p = standard_primitive( _reduce_precision_shape_rule, partial(unop_dtype_rule, _identity, _float, 'reduce_precision'), - name='reduce_precision', sharding_rule=_reduce_precision_sharding_rule) + name='reduce_precision', sharding_rule=_reduce_precision_sharding_rule, + vma_rule=partial(standard_vma_rule, 'reduce_precision')) ad.deflinear(reduce_precision_p, lambda t, **kwargs: [reduce_precision_p.bind(t, **kwargs)]) batching.defvectorized(reduce_precision_p) @@ -7893,6 +8013,7 @@ def after_all(*operands): """Merges one or more XLA token values. Experimental. Wraps the XLA AfterAll operator.""" + operands = core.standard_insert_pbroadcast(*operands) return after_all_p.bind(*operands) def _after_all_abstract_eval(*operands): @@ -8027,6 +8148,7 @@ def rng_uniform(a, b, shape): This API may be removed at any time. """ + a, b = core.standard_insert_pbroadcast(a, b) return rng_uniform_p.bind(a, b, shape=tuple(shape)) def _rng_uniform_abstract_eval(a, b, *, shape): @@ -8161,7 +8283,8 @@ def _rng_bit_generator_lowering( rng_bit_generator_p.def_abstract_eval( partial(standard_multi_result_abstract_eval, rng_bit_generator_p, _rng_bit_generator_shape_rule, _rng_bit_generator_dtype_rule, - _rng_bit_generator_weak_type_rule, _rng_bit_generator_sharding_rule)) + _rng_bit_generator_weak_type_rule, _rng_bit_generator_sharding_rule, + None)) mlir.register_lowering(rng_bit_generator_p, _rng_bit_generator_lowering) @@ -8245,6 +8368,7 @@ def rng_bit_generator(key, shape, dtype=np.uint32, if np.dtype(dtype) not in {np.dtype('uint8'), np.dtype('uint16'), np.dtype('uint32'), np.dtype('uint64')}: raise TypeError(f'rng_bit_generator: unsupported dtype {dtype}') + key, = core.standard_insert_pbroadcast(key) return tuple( rng_bit_generator_p.bind( key, shape=shape, dtype=dtype, algorithm=algorithm, @@ -8703,8 +8827,10 @@ def optimization_barrier(operand, /): Array(0., dtype=float32, weak_type=True) """ flat_args, treedef = tree_util.tree_flatten(operand) - return tree_util.tree_unflatten( - treedef, optimization_barrier_p.bind(*flat_args)) + # TODO(yashkatariya): Enable this + # flat_args = core.standard_insert_pbroadcast(flat_args) + out = optimization_barrier_p.bind(*flat_args) + return tree_util.tree_unflatten(treedef, out) def _optimization_barrier_abstract_eval(*args): diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index b22a4cf56062..b455257e107c 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -740,6 +740,14 @@ def linalg_sharding_rule( ndim = len(output_shapes) - len(batch_spec) return sharding.with_spec(P(*(tuple(batch_spec) + (None,) * ndim))) +def linalg_vma_rule(multiple_results, shape_rule, name, *avals, **kwargs): + output_shapes = shape_rule(*avals, **kwargs) + out_vma = lax_internal.standard_vma_rule(name, *avals) + if multiple_results: + return [out_vma] * len(output_shapes) + else: + return out_vma + def linalg_primitive(result_dtype, accepted_dtypes, ranks, result_shape, name, multiple_results=False, supports_batching=True, require_same=True): @@ -754,6 +762,7 @@ def linalg_primitive(result_dtype, accepted_dtypes, ranks, result_shape, name, linalg_sharding_rule, multiple_results, shape_rule, ranks, name) else: sharding_rule = None + vma_rule = partial(linalg_vma_rule, multiple_results, shape_rule, name) prim = core.Primitive(name) prim.multiple_results = multiple_results prim.def_impl(partial(dispatch.apply_primitive, prim)) @@ -761,11 +770,12 @@ def linalg_primitive(result_dtype, accepted_dtypes, ranks, result_shape, name, prim.def_abstract_eval( partial(lax_utils.standard_multi_result_abstract_eval, prim, shape_rule, dtype_rule, lax_utils._standard_weak_type_rule, - sharding_rule)) + sharding_rule, vma_rule)) else: prim.def_abstract_eval( partial(lax_utils.standard_abstract_eval, prim, shape_rule, dtype_rule, - lax_utils._standard_weak_type_rule, sharding_rule, None)) + lax_utils._standard_weak_type_rule, sharding_rule, + partial(lax_internal.standard_vma_rule, name))) if supports_batching: batching.primitive_batchers[prim] = partial( batching.expand_dims_batcher, prim) diff --git a/jax/_src/lax/utils.py b/jax/_src/lax/utils.py index 63088d665afd..0a641c122064 100644 --- a/jax/_src/lax/utils.py +++ b/jax/_src/lax/utils.py @@ -131,7 +131,7 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, raise TypeError(avals, least_specialized) def standard_multi_result_abstract_eval( - prim, shape_rule, dtype_rule, weak_type_rule, sharding_rule, + prim, shape_rule, dtype_rule, weak_type_rule, sharding_rule, vma_rule, *avals, **kwargs): assert prim.multiple_results assert all(isinstance(aval, core.UnshapedArray) for aval in avals), avals @@ -141,11 +141,13 @@ def standard_multi_result_abstract_eval( core.check_avals_context_mesh(avals, prim.name) out_shapes, out_dtypes, out_shardings = call_shape_dtype_sharding_rule( prim, shape_rule, dtype_rule, sharding_rule, True, *avals, **kwargs) + out_vmas = (vma_rule(*avals, **kwargs) if config.varying_axes_in_types.value + else [frozenset()] * len(out_shapes)) if isinstance(weak_types, bool): weak_types = (weak_types,) * len(out_shapes) - out_avals = [core.ShapedArray(s, d, weak_type=weak_type, sharding=sh) - for s, d, weak_type, sh in zip(out_shapes, out_dtypes, - weak_types, out_shardings)] + out_avals = [core.ShapedArray(s, d, weak_type=weak_type, sharding=sh, vma=vma) + for s, d, weak_type, sh, vma in zip( + out_shapes, out_dtypes, weak_types, out_shardings, out_vmas)] core.check_avals_context_mesh(out_avals, prim.name) return out_avals elif least_specialized is core.UnshapedArray: From cc5141201976cbc1ce823cf24a5b6dd26412d888 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 25 Mar 2025 17:11:44 -0700 Subject: [PATCH 166/483] [sharding_in_types] Add out_sharding to `jax.random.normal`. Drop into `Auto` mode inside for implementation. Co-authored-by: Roy Frostig PiperOrigin-RevId: 740538785 --- jax/_src/random.py | 9 +++++++-- tests/pjit_test.py | 19 +++++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/jax/_src/random.py b/jax/_src/random.py index 2d315ed0cc8b..7277ed5aa966 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -690,7 +690,8 @@ def choice(key: ArrayLike, def normal(key: ArrayLike, shape: Shape = (), - dtype: DTypeLikeFloat = float) -> Array: + dtype: DTypeLikeFloat = float, + out_sharding=None) -> Array: r"""Sample standard normal random values with given shape and float dtype. The values are returned according to the probability density function: @@ -712,12 +713,16 @@ def normal(key: ArrayLike, """ key, _ = _check_prng_key("normal", key) shape = core.canonicalize_shape(shape) + out_sharding = canonicalize_sharding(out_sharding, 'normal') dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.inexact): raise ValueError(f"dtype argument to `normal` must be a float or complex dtype, " f"got {dtype}") dtype = dtypes.canonicalize_dtype(dtype) - return _normal(key, shape, dtype) + if out_sharding is None: + return _normal(key, shape, dtype) + return auto_axes(partial(_normal, shape=shape, dtype=dtype), + out_shardings=out_sharding)(key) @partial(jit, static_argnums=(1, 2)) def _normal(key, shape, dtype) -> Array: diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 528384358351..d6673c6b6d5a 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7245,6 +7245,25 @@ def f(key): else: self.assertIn('mhlo.sharding = "{devices=[2,2]<=[4]}"}', lowered_text) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_random_normal(self, mesh): + @jax.jit + def f(key): + out = jax.random.normal(key, shape=(8, 12), out_sharding=P('x', 'y')) + self.assertEqual(out.aval.sharding.spec, P('x', 'y')) + return out + + key = jax.random.key(1) + out = f(key) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + + lowered_text = f.lower(key).as_text() + if config.use_shardy_partitioner.value: + self.assertIn('sdy.sharding_constraint', lowered_text) + self.assertIn('<@mesh, [{"x"}, {"y"}]>', lowered_text) + else: + self.assertIn('mhlo.sharding = "{devices=[2,2]<=[4]}"}', lowered_text) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): From 3a593219d413a081247cae309872617bf5d2819f Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Tue, 25 Mar 2025 17:40:47 -0700 Subject: [PATCH 167/483] [jaxlib:cpu] Cleaning up after callback FFI refactor. PiperOrigin-RevId: 740547947 --- CHANGELOG.md | 2 ++ jax/_src/callback.py | 1 + jaxlib/xla/callback.cc | 13 --------- jaxlib/xla/callback.h | 4 --- jaxlib/xla/py_host_callback.cc | 31 -------------------- jaxlib/xla/py_host_callback.h | 53 +--------------------------------- 6 files changed, 4 insertions(+), 100 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1acb2b48eab6..93bbe81b5e63 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * {func}`jax.tree_util.build_tree` is deprecated. Use {func}`jax.tree.unflatten` instead. + * Implemented host callback handlers for CPU and GPU devices using XLA's FFI + and removed existing CPU/GPU handlers using XLA's custom call. * Several previously-deprecated APIs have been removed, including: * From `jax.lib.xla_client`: `FftType`, `PaddingType`, `dtype_to_etype`, and `shape_from_pyval`. diff --git a/jax/_src/callback.py b/jax/_src/callback.py index 92c275e7e924..dc60bfb94356 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -826,6 +826,7 @@ def _wrapped_callback(*args): for result_aval in result_avals] return outputs, token, None + # TODO(dsuo): Remove this once we bump minimum_jaxlib_version to "0.5.4". if xla_extension_version <= 320: result_types = mlir.flatten_ir_types([mlir.aval_to_ir_type(aval) for aval in result_avals]) if token: diff --git a/jaxlib/xla/callback.cc b/jaxlib/xla/callback.cc index 4eab8290c7bb..2df1715d099f 100644 --- a/jaxlib/xla/callback.cc +++ b/jaxlib/xla/callback.cc @@ -26,7 +26,6 @@ limitations under the License. #include #include -#include "absl/base/casts.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" @@ -40,7 +39,6 @@ limitations under the License. #include "xla/primitive_util.h" #include "xla/python/nb_numpy.h" #include "xla/python/python_ref_manager.h" -#include "xla/service/custom_call_status.h" #include "xla/tsl/platform/statusor.h" namespace nb = nanobind; @@ -170,15 +168,4 @@ absl::StatusOr CpuCallback::Call(nb::tuple args) { return result_tuple; } -void XlaPythonCpuCallback(void* output, void** inputs, - XlaCustomCallStatus* status) { - CpuCallback* callback = - absl::bit_cast(*static_cast(inputs[0])); - auto s = callback->PrepareAndCall(output, inputs + 1); - if (!s.ok()) { - auto msg = s.message(); - XlaCustomCallStatusSetFailure(status, msg.data(), msg.length()); - } -} - } // namespace xla diff --git a/jaxlib/xla/callback.h b/jaxlib/xla/callback.h index b63025efe120..ebd0aaca4e6d 100644 --- a/jaxlib/xla/callback.h +++ b/jaxlib/xla/callback.h @@ -28,7 +28,6 @@ limitations under the License. #include "nanobind/nanobind.h" #include "xla/pjrt/transpose.h" #include "xla/python/nb_numpy.h" -#include "xla/service/custom_call_status.h" #include "xla/xla_data.pb.h" namespace xla { @@ -84,9 +83,6 @@ class CpuCallback { xla::TransposePlanCache transpose_cache_; }; -void XlaPythonCpuCallback(void* output, void** inputs, - XlaCustomCallStatus* status); - } // namespace xla #endif // JAXLIB_XLA_CALLBACK_H_ diff --git a/jaxlib/xla/py_host_callback.cc b/jaxlib/xla/py_host_callback.cc index 9d759cc6b77c..833079335a36 100644 --- a/jaxlib/xla/py_host_callback.cc +++ b/jaxlib/xla/py_host_callback.cc @@ -34,7 +34,6 @@ limitations under the License. #include "jaxlib/xla/py_host_callback.pb.h" #include "xla/layout_util.h" #include "xla/pjrt/host_callback.h" -#include "xla/pjrt/pjrt_compiler.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/host_callback.h" #include "xla/python/pjrt_ifrt/pjrt_host_callback.h" @@ -54,7 +53,6 @@ namespace nb = nanobind; namespace xla { char PyFfiLoadedHostCallback::ID = 0; -char PyCpuLoadedHostCallback::ID = 0; char PyHostSendAndRecvLoadedHostCallback::ID = 0; namespace { @@ -128,35 +126,6 @@ PyFfiLoadedHostCallback::~PyFfiLoadedHostCallback() { GlobalPyRefManager()->AddGarbage(absl::MakeSpan(objects)); } -absl::StatusOr> -PyCpuLoadedHostCallback::Create(ifrt::Client* ifrt_client, - nb::callable callable, - absl::Span operand_shapes, - absl::Span result_shapes) { - ifrt::PlatformId platform_id = ifrt_client->platform_id(); - if (platform_id != CpuId() && platform_id != CudaId() && - platform_id != RocmId() && platform_id != SyclId()) { - return Unimplemented("CpuCallback supports CPU and GPU only"); - } - - TF_ASSIGN_OR_RETURN(auto callback_args, CreateCallbackArgs(operand_shapes)); - TF_ASSIGN_OR_RETURN(auto callback_results, - CreateCallbackResults(result_shapes)); - - // `callable` will be destroyed safely with `PythonRefManager` when - // `CpuCallback` is destroyed. - auto cpu_callback = std::make_unique( - std::move(callable), callback_args, callback_results); - return tsl::RCReference( - tsl::MakeRef(ifrt_client, - std::move(cpu_callback))); -} - -absl::StatusOr PyCpuLoadedHostCallback::Serialize() const { - return Unimplemented( - "PyCpuLoadedHostCallback serialization is not supported"); -} - absl::StatusOr> PyHostSendAndRecvLoadedHostCallback::Create( ifrt::Client* ifrt_client, nb::callable callable, diff --git a/jaxlib/xla/py_host_callback.h b/jaxlib/xla/py_host_callback.h index da504d0c12ca..1a1402a4eee2 100644 --- a/jaxlib/xla/py_host_callback.h +++ b/jaxlib/xla/py_host_callback.h @@ -22,13 +22,10 @@ limitations under the License. #include #include -#include "absl/base/casts.h" -#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" #include "llvm/Support/ExtensibleRTTI.h" #include "nanobind/nanobind.h" -#include "jaxlib/xla/callback.h" #include "xla/pjrt/host_callback.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/host_callback.h" @@ -56,7 +53,7 @@ class PyFfiLoadedHostCallback final ifrt::Client* client() const override { return ifrt_client_; } absl::StatusOr Serialize() const override { return Unimplemented( - "PyCpuLoadedHostCallback::callback_data() is not supported"); + "PyFfiLoadedHostCallback::Serialize() is not supported"); }; static char ID; // NOLINT @@ -66,54 +63,6 @@ class PyFfiLoadedHostCallback final nanobind::callable callable_; }; -// `PyCpuLoadedHostCallback` implements a Python host callback that uses a -// descriptor (a raw pointer to JAX `CpuCallback`). The descriptor should be -// passed into a 'xla_python_cpu_callback' or 'xla_python_gpu_callback' -// CustomCall as its first argument. -// -// Serialization is not supported. Once the descriptor is embedded in -// CustomCall in an XLA computation, the computation will not be serializable. -class PyCpuLoadedHostCallback final - : public llvm::RTTIExtends { - public: - static absl::StatusOr> Create( - ifrt::Client* ifrt_client, nanobind::callable callable, - absl::Span operand_shapes, - absl::Span result_shapes); - - // Returns the descriptor of `CpuCallback`. - uint64_t descriptor() const { - return absl::bit_cast(cpu_callback_.get()); - } - - CpuCallback* cpu_callback() { return cpu_callback_.get(); } - - // LoadedHostCallback implementation. - - ~PyCpuLoadedHostCallback() override = default; - - ifrt::Client* client() const override { return ifrt_client_; } - - absl::StatusOr Serialize() const override; - - static char ID; // NOLINT - - private: - PyCpuLoadedHostCallback(ifrt::Client* ifrt_client, - std::unique_ptr cpu_callback) - : llvm::RTTIExtends( - ifrt_client, cpu_callback->callback()), - cpu_callback_(std::move(cpu_callback)) {} - - template - friend tsl::RCReference tsl::MakeRef(Args&&... args); - - ifrt::Client* ifrt_client_; - std::unique_ptr cpu_callback_; -}; - // `PyHostSendAndRecvLoadedHostCallback` implements a Python host callback that // uses XLA host send and recv. This object should be passed to the compiler // when creating `xla::ifrt::LoadedExecutable`. From fd5c1dc8a6855eed485df81a03c75e96f569399b Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Tue, 25 Mar 2025 19:04:48 -0700 Subject: [PATCH 168/483] [jaxlib:cpu] Return an error if we try to use subbyte types in CPU callbacks instead of failing silently. We will be adding subbyte type support in subsequence changes. PiperOrigin-RevId: 740569954 --- jaxlib/xla/BUILD | 1 + jaxlib/xla/py_client_cpu.cc | 15 +++++++++++++++ 2 files changed, 16 insertions(+) diff --git a/jaxlib/xla/BUILD b/jaxlib/xla/BUILD index e4db73d6c86d..e10977d526ed 100644 --- a/jaxlib/xla/BUILD +++ b/jaxlib/xla/BUILD @@ -664,6 +664,7 @@ cc_library( "@nanobind", "@xla//third_party/python_runtime:headers", # buildcleaner: keep "@xla//xla:shape_util", + "@xla//xla:xla_data_proto_cc", "@xla//xla/ffi:ffi_api", "@xla//xla/ffi/api:ffi", "@xla//xla/pjrt:host_callback", diff --git a/jaxlib/xla/py_client_cpu.cc b/jaxlib/xla/py_client_cpu.cc index 936a89aa3b42..ac4e7bee5680 100644 --- a/jaxlib/xla/py_client_cpu.cc +++ b/jaxlib/xla/py_client_cpu.cc @@ -41,6 +41,7 @@ limitations under the License. #include "xla/python/nb_numpy.h" #include "xla/python/types.h" #include "xla/shape_util.h" +#include "xla/xla_data.pb.h" namespace nb = nanobind; @@ -77,6 +78,13 @@ ffi::Error XlaFfiPythonCpuCallback(FfiLoadedHostCallbacks* callbacks, for (size_t i = 0; i < args.size(); ++i) { auto arg = args.get(i); auto ptype = static_cast(arg->element_type()); + // TODO(b/395428868): Remove this check once we support subbyte types. + if (ptype == S1 || ptype == S2 || ptype == S4 || ptype == U1 || + ptype == U2 || ptype == U4) { + return ffi::Error(ffi::ErrorCode::kUnimplemented, + absl::StrFormat("Unsupported primitive type: %s", + PrimitiveType_Name(ptype))); + } if (ptype == TOKEN) { PyTuple_SET_ITEM(nb_args.ptr(), i, nb::none().release().ptr()); continue; @@ -111,6 +119,13 @@ ffi::Error XlaFfiPythonCpuCallback(FfiLoadedHostCallbacks* callbacks, for (size_t i = 0; i < rets.size(); ++i) { auto ret = rets.get(i).value(); auto ptype = static_cast(ret->element_type()); + // TODO(b/395428868): Remove this check once we support subbyte types. + if (ptype == S1 || ptype == S2 || ptype == S4 || ptype == U1 || + ptype == U2 || ptype == U4) { + return ffi::Error(ffi::ErrorCode::kUnimplemented, + absl::StrFormat("Unsupported primitive type: %s", + PrimitiveType_Name(ptype))); + } if (ptype == TOKEN) continue; nb::object output = nb::borrow(PyTuple_GetItem(result_tuple.ptr(), i)); From 81abbac53675f9d7fa71af822ef394413f2f86d5 Mon Sep 17 00:00:00 2001 From: Matt Bahr Date: Tue, 25 Mar 2025 06:36:28 +0000 Subject: [PATCH 169/483] add pascal matrix --- docs/jax.scipy.rst | 1 + jax/_src/scipy/linalg.py | 61 ++++++++++++++++++++++++++++++++++++++++ jax/scipy/linalg.py | 1 + tests/linalg_test.py | 16 +++++++++++ 4 files changed, 79 insertions(+) diff --git a/docs/jax.scipy.rst b/docs/jax.scipy.rst index dcbb673997ad..3c436697e1be 100644 --- a/docs/jax.scipy.rst +++ b/docs/jax.scipy.rst @@ -69,6 +69,7 @@ jax.scipy.linalg lu lu_factor lu_solve + pascal polar qr rsf2csf diff --git a/jax/_src/scipy/linalg.py b/jax/_src/scipy/linalg.py index 9917cbaa0b12..55961607b252 100644 --- a/jax/_src/scipy/linalg.py +++ b/jax/_src/scipy/linalg.py @@ -2182,3 +2182,64 @@ def hilbert(n: int) -> Array: """ a = lax.broadcasted_iota(jnp.float64, (n, 1), 0) return 1/(a + a.T + 1) + +@partial(jit, static_argnames=("n", "kind",)) +def pascal(n: int, kind: str | None = None) -> Array: + r"""Create a Pascal matrix approximation of order n. + + JAX implementation of :func:`scipy.linalg.pascal`. + + The elements of the Pascal matrix approximate the binomial coefficents. This + implementation is not exact as JAX does not support exact factorials. + + Args: + n: the size of the matrix to create. + kind: (optional) must be one of ``lower``, ``upper``, or ``symmetric`` (default). + + Returns: + A Pascal matrix of shape ``(n, n)`` + + Examples: + >>> with jnp.printoptions(precision=3): + ... print(jax.scipy.linalg.pascal(3, kind="lower")) + ... print(jax.scipy.linalg.pascal(4, kind="upper")) + ... print(jax.scipy.linalg.pascal(5)) + [[1. 0. 0.] + [1. 1. 0.] + [1. 2. 1.]] + [[1. 1. 1. 1.] + [0. 1. 2. 3.] + [0. 0. 1. 3.] + [0. 0. 0. 1.]] + [[ 1. 1. 1. 1. 1.] + [ 1. 2. 3. 4. 5.] + [ 1. 3. 6. 10. 15.] + [ 1. 4. 10. 20. 35.] + [ 1. 5. 15. 35. 70.]] + """ + if kind is None: + kind = "symmetric" + + valid_kind = ["symmetric", "lower", "upper"] + + if kind not in valid_kind: + raise ValueError(f"Expected kind to be on of: {valid_kind}; got {kind}") + + a = jnp.arange(n, dtype=jnp.float32) + + L_n = _binom(a[:, None], a[None, :]) + + if kind == "lower": + return L_n + + if kind == "upper": + return L_n.T + + return jnp.dot(L_n, L_n.T) + +@jit +def _binom(n, k): + a = lax.lgamma(n + 1.0) + b = lax.lgamma(n - k + 1.0) + c = lax.lgamma(k + 1.0) + return lax.exp(a - b - c) diff --git a/jax/scipy/linalg.py b/jax/scipy/linalg.py index 64bc0544000b..c8a2d5f81957 100644 --- a/jax/scipy/linalg.py +++ b/jax/scipy/linalg.py @@ -31,6 +31,7 @@ lu as lu, lu_factor as lu_factor, lu_solve as lu_solve, + pascal as pascal, polar as polar, qr as qr, rsf2csf as rsf2csf, diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 60e507d84782..20c998d6a685 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -2329,6 +2329,22 @@ def testSymmetricProduct(self, shape, dtype, symmetrize_output): self.assertAllClose( new_product_with_batching, old_product, atol=atol) + @jtu.sample_product( + n=[0, 1, 5, 10, 20], + kind=["symmetric", "lower", "upper"], + ) + @jax.default_matmul_precision("float32") + def testPascal(self, n, kind): + args_maker = lambda: [] + osp_fun = partial(osp.linalg.pascal, n=n, kind=kind, exact=False) + jsp_fun = partial(jsp.linalg.pascal, n=n, kind=kind) + self._CheckAgainstNumpy(osp_fun, + jsp_fun, args_maker, + atol=1e-3, + rtol=1e-2 if jtu.test_device_matches(['tpu']) else 1e-3, + check_dtypes=False) + self._CompileAndCheck(jsp_fun, args_maker) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From fd7775856e1de1ee161c821e123f0fc2cd21fc9d Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Tue, 25 Mar 2025 19:35:13 -0700 Subject: [PATCH 170/483] [jaxlib:gpu] Return an error if we try to use subbyte types in GPU callbacks instead of failing silently. We will be adding subbyte type support in subsequence changes. PiperOrigin-RevId: 740577676 --- jaxlib/cuda/BUILD | 1 + jaxlib/gpu/py_client_gpu.cc | 17 ++++++++++++++++- jaxlib/rocm/BUILD | 1 + 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index 48441632fba9..c47bc3c8126f 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -682,6 +682,7 @@ cc_library( "@xla//third_party/python_runtime:headers", # buildcleaner: keep "@xla//xla:comparison_util", "@xla//xla:shape_util", + "@xla//xla:xla_data_proto_cc", "@xla//xla/ffi:ffi_api", "@xla//xla/ffi/api:ffi", "@xla//xla/pjrt:host_callback", diff --git a/jaxlib/gpu/py_client_gpu.cc b/jaxlib/gpu/py_client_gpu.cc index c39d5201f223..59cc385825a0 100644 --- a/jaxlib/gpu/py_client_gpu.cc +++ b/jaxlib/gpu/py_client_gpu.cc @@ -43,6 +43,7 @@ limitations under the License. #include "xla/python/nb_numpy.h" #include "xla/python/types.h" #include "xla/shape_util.h" +#include "xla/xla_data.pb.h" namespace nb = nanobind; @@ -79,6 +80,13 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, for (size_t i = 0; i < arity; ++i) { auto arg = args.get(i); auto ptype = static_cast(arg->element_type()); + // TODO(b/395428868): Remove this check once we support subbyte types. + if (ptype == xla::S1 || ptype == xla::S2 || ptype == xla::S4 || + ptype == xla::U1 || ptype == xla::U2 || ptype == xla::U4) { + return xla::ffi::Error(xla::ffi::ErrorCode::kUnimplemented, + absl::StrFormat("Unsupported primitive type: %s", + PrimitiveType_Name(ptype))); + } if (ptype == xla::TOKEN) { host_input_buffers[i] = nullptr; continue; @@ -115,7 +123,7 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, auto dims = absl::Span(arg->dimensions().begin(), arg->dimensions().size()); auto array = xla::nb_numpy_ndarray(dtype, dims, std::nullopt, - host_input_buffers[i], base); + host_input_buffers[i], base); array.attr("flags").attr("writeable") = nb::bool_(false); PyTuple_SET_ITEM(host_input_arrays.ptr(), i, array.inc_ref().ptr()); } @@ -137,6 +145,13 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, for (size_t i = 0; i < rets.size(); ++i) { auto ret = rets.get(i).value(); auto ptype = static_cast(ret->element_type()); + // TODO(b/395428868): Remove this check once we support subbyte types. + if (ptype == xla::S1 || ptype == xla::S2 || ptype == xla::S4 || + ptype == xla::U1 || ptype == xla::U2 || ptype == xla::U4) { + return xla::ffi::Error(xla::ffi::ErrorCode::kUnimplemented, + absl::StrFormat("Unsupported primitive type: %s", + PrimitiveType_Name(ptype))); + } if (ptype == xla::TOKEN) continue; nb::object output = nb::borrow(PyTuple_GetItem(result_tuple.ptr(), i)); diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index 258556be8b1e..2c13228d3c51 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -580,6 +580,7 @@ cc_library( "@xla//third_party/python_runtime:headers", # buildcleaner: keep "@xla//xla:comparison_util", "@xla//xla:shape_util", + "@xla//xla:xla_data_proto_cc", "@xla//xla/ffi:ffi_api", "@xla//xla/ffi/api:ffi", "@xla//xla/pjrt:host_callback", From 83989f6fc674a35a599a5d3cbfbdb5aa8a23fd2a Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Wed, 26 Mar 2025 01:29:10 -0700 Subject: [PATCH 171/483] [Pallas/Mosaic GPU] Add a test tracking primitives warpgroup lowering rules. The goal is to use this to figure out when we can enable warpgroup lowering by default. PiperOrigin-RevId: 740670338 --- tests/pallas/mosaic_gpu_test.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 7f8cfa21e980..32851797fff5 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -26,9 +26,13 @@ from absl.testing import parameterized import jax from jax import lax +from jax._src import pjit from jax._src import test_util as jtu from jax._src.pallas import pallas_call +from jax._src.pallas.mosaic_gpu import lowering as mgpu_lowering from jax._src.pallas.mosaic_gpu import pipeline as mgpu_pipeline +from jax._src.pallas.mosaic_gpu import primitives as mgpu_primitives +from jax._src.state import discharge from jax.experimental import pallas as pl from jax.experimental.pallas import mosaic_gpu as plgpu import jax.numpy as jnp @@ -1351,6 +1355,29 @@ class PallasCallWGTest( ): ... + def test_missing_primitive_lowerings_are_tracked(self): + # This test is a way to keep track of which primitives need to be adapted + # to using warpgroup semantics. Once the set is empty, we should be able to + # enable warpgroup semantics by default (assuming we haven't overspecialized + # lowerings). + rules = mgpu_lowering.mosaic_lowering_rules + wg_lowered_primitives = set(rules[plgpu.ThreadSemantics.Warpgroup]) + lane_lowered_primitives = set(rules[plgpu.ThreadSemantics.Lane]) + + actual_missing_primitives = lane_lowered_primitives - wg_lowered_primitives + expected_missing_primitives = { + lax.optimization_barrier_p, + mgpu_primitives.broadcasted_iota_p, + lax.exp2_p, + mgpu_primitives.layout_cast_p, + mgpu_primitives.load_p, + pjit.mesh_cast_p, + lax.slice_p, + discharge.run_state_p, + } + + self.assertSetEqual(actual_missing_primitives, expected_missing_primitives) + class PallasCallSm90ATest(PallasSm90ATest): From 660f536300e23680c14b26373d111c07477f669a Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Wed, 26 Mar 2025 02:22:01 -0700 Subject: [PATCH 172/483] [Pallas/Mosaic GPU] Add a lowering rule for `lax.optimization_barrier_p` with warpgroup semantics. PiperOrigin-RevId: 740684030 --- jax/_src/pallas/mosaic_gpu/lowering.py | 12 ++++++++++ tests/pallas/mosaic_gpu_test.py | 31 +++++++++++++++++++++++++- 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 493d8c07b941..c5436c818e1d 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -2272,6 +2272,18 @@ def _optimization_barrier_lowering(ctx: LoweringRuleContext, *args): return mgpu.optimization_barrier(*args) +@register_lowering_rule( + lax.optimization_barrier_p, mgpu.ThreadSemantics.Warpgroup +) +def _optimization_barrier_lowering_wg(ctx: LoweringRuleContext, *args): + args = [ + _ensure_ir_value(arg, aval.dtype) for arg, aval in zip(args, ctx.avals_in) + ] + result = mgpu.dialect.optimization_barrier(args) + + return (result,) if len(args) == 1 else result + + def _bcast( x: ir.Value, y: ir.Value, diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 32851797fff5..73440ebf5fa5 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1349,6 +1349,36 @@ def convert(x_ref, y_ref): convert(x), jax.lax.bitcast_convert_type(x, out_dtype) ) + def test_optimization_barrier(self): + if self.THREAD_SEMANTICS == plgpu.ThreadSemantics.Lane: + self.skipTest("This test crashes with lane semantics") + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((128,), jnp.float32), + ) + def kernel(x_ref, o_ref): + o_ref[...] = lax.optimization_barrier(x_ref[...]) + + x = jax.lax.iota(jnp.float32, 128) + np.testing.assert_array_equal(kernel(x), x) + + def test_optimization_barrier_multiple_inputs(self): + if self.THREAD_SEMANTICS == plgpu.ThreadSemantics.Lane: + self.skipTest("This test crashes with lane semantics") + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((128,), jnp.float32), + ) + def kernel(x_ref, y_ref, o_ref): + x, y = lax.optimization_barrier([x_ref[...], y_ref[...]]) + o_ref[...] = x + y + + x = jax.lax.iota(jnp.float32, 128) + y = jax.lax.iota(jnp.float32, 128) * 3 + np.testing.assert_array_equal(kernel(x, y), x + y) + class PallasCallWGTest( PallasCallTest, thread_semantics=plgpu.ThreadSemantics.Warpgroup @@ -1366,7 +1396,6 @@ def test_missing_primitive_lowerings_are_tracked(self): actual_missing_primitives = lane_lowered_primitives - wg_lowered_primitives expected_missing_primitives = { - lax.optimization_barrier_p, mgpu_primitives.broadcasted_iota_p, lax.exp2_p, mgpu_primitives.layout_cast_p, From 3f3081d46ed77f8fd37a7497013c26df8abaa62c Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Wed, 26 Mar 2025 04:45:34 -0700 Subject: [PATCH 173/483] [Pallas/Mosaic GPU] Add a lowering rule for `pjit.mesh_cast_p` for warpgroup semantics. PiperOrigin-RevId: 740719326 --- jax/_src/pallas/mosaic_gpu/lowering.py | 3 +++ tests/pallas/mosaic_gpu_test.py | 2 -- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index c5436c818e1d..c67633125fc0 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1202,9 +1202,12 @@ def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **kwargs): ) @register_lowering_rule(pjit.mesh_cast_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(pjit.mesh_cast_p, mgpu.ThreadSemantics.Warpgroup) def _mesh_cast_lowering_rule(ctx, x, dst_sharding): + del ctx, dst_sharding # Unused. return x + @register_lowering_rule(lax.slice_p, mgpu.ThreadSemantics.Lane) def _slice_lowering_rule( ctx: LoweringRuleContext, x, limit_indices, start_indices, strides diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 73440ebf5fa5..4531bd568913 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -26,7 +26,6 @@ from absl.testing import parameterized import jax from jax import lax -from jax._src import pjit from jax._src import test_util as jtu from jax._src.pallas import pallas_call from jax._src.pallas.mosaic_gpu import lowering as mgpu_lowering @@ -1400,7 +1399,6 @@ def test_missing_primitive_lowerings_are_tracked(self): lax.exp2_p, mgpu_primitives.layout_cast_p, mgpu_primitives.load_p, - pjit.mesh_cast_p, lax.slice_p, discharge.run_state_p, } From 9ff08909557c1a322740f15cada3e6514152a321 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Wed, 26 Mar 2025 04:49:46 -0700 Subject: [PATCH 174/483] [jax:callbacks] Add a test for callbacks with subbyte types. Today, we have TPU support for subbyte types, but not on CPU/GPU. Explicitly raise an error for now with a TODO for when we implement CPU/GPU support. PiperOrigin-RevId: 740720316 --- jaxlib/xla/xla_client.py | 2 +- tests/python_callback_test.py | 51 +++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 1 deletion(-) diff --git a/jaxlib/xla/xla_client.py b/jaxlib/xla/xla_client.py index 0e4eebdfb26f..a9b1109c3bd3 100644 --- a/jaxlib/xla/xla_client.py +++ b/jaxlib/xla/xla_client.py @@ -50,7 +50,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 321 +_version = 322 # Version number for MLIR:Python components. mlir_api_version = 58 diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index 05b4c8d7c0ff..5650a2d4f48b 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -28,6 +28,7 @@ from jax._src import dispatch from jax._src import test_util as jtu from jax._src import util +from jax._src.lib import xla_extension_version from jax.experimental import io_callback from jax.experimental import pjit from jax.experimental.shard_map import shard_map @@ -585,6 +586,56 @@ def fun(x): self.assertAllClose(2 * x, fun(x)) self.assertEqual(count(), 1) + @parameterized.parameters("int2", "int4", "uint2", "uint4") + def test_subbyte_operands(self, dtype: str): + if xla_extension_version <= 321: + self.skipTest("Requires xla_extension_version >= 322.") + def get(x): + return x + def f(x): + y = jax.pure_callback( + get, + jax.ShapeDtypeStruct((8,), dtype=dtype), + x, + ) + return y + x = np.arange(8, dtype=dtype) + # TODO(b/395428868): Remove this check once we support subbyte types. + if jtu.test_device_matches(["tpu"]): + if "2" in dtype: + self.skipTest("TODO(dsuo): TPU callbacks send SIGABRT for int2/uint2.") + np.testing.assert_array_equal(jax.jit(f)(x), np.arange(8, dtype=dtype)) + else: + with self.assertRaisesRegex( + Exception, "Unsupported primitive type" + ): + _ = jax.jit(f)(x) + + @parameterized.parameters("int2", "int4", "uint2", "uint4") + def test_subbyte_results(self, dtype: str): + if xla_extension_version <= 321: + self.skipTest("Requires xla_extension_version >= 322.") + def get(): + return np.arange(8, dtype=dtype) + + def f(): + y = jax.pure_callback( + get, + jax.ShapeDtypeStruct((8,), dtype) + ) + return y + + # TODO(b/395428868): Remove this check once we support subbyte types. + if jtu.test_device_matches(["tpu"]): + if "2" in dtype: + self.skipTest("TODO(dsuo): TPU callbacks send SIGABRT for int2/uint2.") + np.testing.assert_array_equal(jax.jit(f)(), np.arange(8, dtype=dtype)) + else: + with self.assertRaisesRegex( + Exception, "Unsupported primitive type" + ): + _ = jax.jit(f)() + class PureCallbackTest(jtu.JaxTestCase): From 07ebcb2d63c3fa38d84c8c7eceef23c9d980bab8 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Wed, 26 Mar 2025 05:16:44 -0700 Subject: [PATCH 175/483] [Mosaic] Use large 2nd minor tiling for x2. To avoid relayout from (16, 128) to (128, 128) because we always use native tiling for ext/trunc. PiperOrigin-RevId: 740726621 --- .../mosaic/dialect/tpu/transforms/infer_memref_layout.cc | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc index cdf48632784b..fdfd04949bce 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc @@ -62,18 +62,21 @@ int getTilingFactor(const int src_sublane, const int hardware_generation, const int max_normal_tiling = tiling_sublane; int large_tiling = [&] { + if (bitwidth == 2) { + return target_sublane_count * 16; + } if (bitwidth == 4 && tpu_tiling_flags.use_x4_large_second_minor) { - return tiling_sublane * 8; + return target_sublane_count * 8; } if (bitwidth == 8 && tpu_tiling_flags.use_x8_large_second_minor) { - return tiling_sublane * 4; + return target_sublane_count * 4; } // 16-bit values are generally always possible to relayout on the fly in v6, // so we allow large 2nd minor tiling whenever possible. We can't do this // for kernel arguments, because the layout of those is controlled by XLA. if (bitwidth == 16 && (tpu_tiling_flags.use_x16_large_second_minor || (!is_kernel_argument && hardware_generation >= 6))) { - return tiling_sublane * 2; + return target_sublane_count * 2; } return tiling_sublane; }(); From 5e3330cf8cf93744f4a6ce512443ca1ee936bc3b Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 26 Mar 2025 05:16:55 -0700 Subject: [PATCH 176/483] Update XLA dependency to use revision http://github.com/openxla/xla/commit/d0b25f9cd8222a348c9728f88e909c4e2c30991b. PiperOrigin-RevId: 740726667 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 8fcda2281ea7..359048ffacbb 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "d505fef9c5eb6cc1bf282fdf62139783d7fe4ec5" -XLA_SHA256 = "4fe51bd389428ce65415b08693f966b142fe8218ced771becab9033503a70a3d" +XLA_COMMIT = "d0b25f9cd8222a348c9728f88e909c4e2c30991b" +XLA_SHA256 = "8cd70a67a56a8b18087fc4849908f52c95c6413eb7edc9f800fdff6304804fa4" def repo(): tf_http_archive( From 7a42e3d39d9beff823469ba0c87722248e6ace29 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 26 Mar 2025 07:06:29 -0700 Subject: [PATCH 177/483] [pallas:mosaic_gpu] `thread_semantics=` should still default to lane-level PiperOrigin-RevId: 740753009 --- jax/_src/pallas/mosaic_gpu/pallas_call_registration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py index 5399727878a6..40b12215c003 100644 --- a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py @@ -57,7 +57,7 @@ def pallas_call_lowering( print(grid_mapping) thread_semantics = compiler_params.get("mosaic_gpu", {}).get( - "thread_semantics", mgpu.ThreadSemantics.Warpgroup + "thread_semantics", mgpu.ThreadSemantics.Lane ) if thread_semantics == mgpu.ThreadSemantics.Warpgroup: mgpu.dialect.register_dialect(ctx.module_context.context) # pytype: disable=attribute-error From c15921243936d8027de28678b3ad199f9ac498d5 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 25 Mar 2025 22:33:21 +0000 Subject: [PATCH 178/483] Some codebase fixes required for python 3.14 - Fix for "SyntaxWarning: 'return' in a 'finally' block" - Fix for "AttributeError: 'typing.Union' object attribute '__doc__' is read-only" --- jax/_src/basearray.py | 13 +++++++++++-- jax/_src/util.py | 5 +++-- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/jax/_src/basearray.py b/jax/_src/basearray.py index a89d4a2949be..fbd14d157e78 100644 --- a/jax/_src/basearray.py +++ b/jax/_src/basearray.py @@ -17,6 +17,7 @@ from __future__ import annotations import abc +import sys import numpy as np from typing import Any, Union from collections.abc import Sequence @@ -175,7 +176,11 @@ def copy_to_host_async(self): np.bool_, np.number, # NumPy scalar types bool, int, float, complex, # Python scalar types ] -StaticScalar.__doc__ = "Type annotation for JAX-compatible static scalars." + +if sys.version_info[:2] < (3, 14): + # Python 3.14 raises + # AttributeError: 'typing.Union' object attribute '__doc__' is read-only + StaticScalar.__doc__ = "Type annotation for JAX-compatible static scalars." # ArrayLike is a Union of all objects that can be implicitly converted to a @@ -187,4 +192,8 @@ def copy_to_host_async(self): np.ndarray, # NumPy array type StaticScalar, # valid scalars ] -ArrayLike.__doc__ = "Type annotation for JAX array-like objects." + +if sys.version_info[:2] < (3, 14): + # Python 3.14 raises + # AttributeError: 'typing.Union' object attribute '__doc__' is read-only + ArrayLike.__doc__ = "Type annotation for JAX array-like objects." diff --git a/jax/_src/util.py b/jax/_src/util.py index d558954e881c..b3f7becee7eb 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -417,8 +417,9 @@ def wrapper(fun: T) -> T: else docstr.format(fun=name, doc=doc, **kwargs)) fun.__qualname__ = getattr(wrapped, "__qualname__", fun.__name__) fun.__wrapped__ = wrapped - finally: - return fun + except Exception: + pass + return fun return wrapper From 9f40440d476aee980b519fd82911e2ec5102466c Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 26 Mar 2025 07:56:27 -0700 Subject: [PATCH 179/483] Add missing `jax` wheel dependencies. PiperOrigin-RevId: 740767116 --- BUILD.bazel | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index ebf852a60924..e7cf6de66cad 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -47,6 +47,7 @@ transitive_py_deps( "//jax:sparse_test_util", "//jax:test_util", "//jax/_src/lib", + "//jax/_src/pallas/fuser", "//jax/_src/pallas/mosaic_gpu", "//jax/experimental/array_serialization:serialization", "//jax/experimental/jax2tf", @@ -105,14 +106,19 @@ jax_source_package( ) genrule( - name = "internal_test_util_sources", + name = "wheel_additives", srcs = [ "//jax:internal_export_back_compat_test_util", "//jax:internal_test_harnesses", "//jax:internal_test_util", "//jax:internal_export_back_compat_test_data", + "//jax:experimental/pallas/ops/tpu/random/philox.py", + "//jax:experimental/pallas/ops/tpu/random/prng_utils.py", + "//jax:experimental/pallas/ops/tpu/random/threefry.py", + "//jax/experimental/mosaic/gpu/examples:flash_attention.py", + "//jax/experimental/mosaic/gpu/examples:matmul.py", ], - outs = ["internal_test_util_sources.zip"], + outs = ["wheel_additives.zip"], cmd = "$(location @bazel_tools//tools/zip:zipper) c $@ $(SRCS)", tools = ["@bazel_tools//tools/zip:zipper"], ) @@ -131,15 +137,16 @@ COMMON_DEPS = py_deps([ py_import( name = "jax_py_import", wheel = ":jax_wheel", - wheel_deps = [":internal_test_util_sources"], + wheel_deps = [":wheel_additives"], deps = COMMON_DEPS, ) -# This target is used to add internal test util sources to the jax wheel. -# This is needed for the tests that depend on jax and use internal test util sources. +# This target is used to add more sources to the jax wheel. +# This is needed for the tests that depend on jax and use modules that are not part of +# the jax wheel, but share the same package paths as the modules in the jax wheel. py_import( name = "jax_wheel_with_internal_test_util", wheel = "@pypi_jax//:whl", - wheel_deps = [":internal_test_util_sources"], + wheel_deps = [":wheel_additives"], deps = COMMON_DEPS, ) From dfa2f46968f07797ba9c21b2570f651f2c123c69 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Wed, 26 Mar 2025 08:51:37 -0700 Subject: [PATCH 180/483] [Pallas/Mosaic GPU] Delete `mesh_cast_p` lowering rules. They don't seem to be used. PiperOrigin-RevId: 740785108 --- jax/_src/pallas/mosaic_gpu/lowering.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index c67633125fc0..fc3bdaac7aed 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1201,12 +1201,6 @@ def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **kwargs): ctx.module_ctx, ctx.launch_ctx, jaxpr.jaxpr, args, ) -@register_lowering_rule(pjit.mesh_cast_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(pjit.mesh_cast_p, mgpu.ThreadSemantics.Warpgroup) -def _mesh_cast_lowering_rule(ctx, x, dst_sharding): - del ctx, dst_sharding # Unused. - return x - @register_lowering_rule(lax.slice_p, mgpu.ThreadSemantics.Lane) def _slice_lowering_rule( From 9d768c475454f484769d15ebe177b8c63f3620bb Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Wed, 26 Mar 2025 09:09:20 -0700 Subject: [PATCH 181/483] [pallas:mgpu] Use the ExitStack context to manage smem allocations. PiperOrigin-RevId: 740790684 --- jax/_src/pallas/mosaic_gpu/lowering.py | 134 ++++++++++++------------- 1 file changed, 67 insertions(+), 67 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index fc3bdaac7aed..e2c4ce322b1a 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1838,76 +1838,76 @@ def _run_scoped_lowering_rule( ): input_refs = [] should_discharge = [] - alloc_stack = contextlib.ExitStack() - for v in jaxpr.invars: - aval = v.aval - if isinstance(aval, gpu_core.WGMMAAbstractAccumulatorRef): - dtype = mlir.dtype_to_ir_type(aval.dtype) - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: - input_refs.append(mgpu.WGMMAAccumulator.zero(*aval.shape, dtype)) + with contextlib.ExitStack() as alloc_stack: + for v in jaxpr.invars: + aval = v.aval + if isinstance(aval, gpu_core.WGMMAAbstractAccumulatorRef): + dtype = mlir.dtype_to_ir_type(aval.dtype) + if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + input_refs.append(mgpu.WGMMAAccumulator.zero(*aval.shape, dtype)) + else: + zero = arith_dialect.constant(dtype, ir.FloatAttr.get(dtype, 0.0)) + acc = vector_dialect.splat(ir.VectorType.get(aval.shape, dtype), zero) + acc = mgpu.dialect.optimization_barrier([acc]) + nvvm_dialect.wgmma_fence_aligned() + input_refs.append(acc) + should_discharge.append(True) + elif isinstance(aval.dtype, gpu_core.BarrierType): + input_refs.append( + ctx.module_ctx.reserve_barrier( + mgpu.Barrier( + aval.dtype.num_arrivals + * ctx.estimator_ctx.arrival_multiplier, + *aval.shape, + ) + ) + ) + should_discharge.append(False) + elif aval.memory_space == gpu_core.SMEM: + [input_ref] = alloc_stack.enter_context( + ctx.module_ctx.scratch_view( + [jax.ShapeDtypeStruct(shape=aval.shape, dtype=aval.dtype)] + ) + ) + input_refs.append(input_ref) + should_discharge.append(False) + elif aval.memory_space == gpu_core.TMEM: + input_ref = alloc_stack.enter_context( + ctx.module_ctx.alloc_tmem( + jax.ShapeDtypeStruct(shape=aval.shape, dtype=aval.dtype), + ) + ) + input_refs.append(input_ref) + should_discharge.append(False) else: - zero = arith_dialect.constant(dtype, ir.FloatAttr.get(dtype, 0.0)) - acc = vector_dialect.splat(ir.VectorType.get(aval.shape, dtype), zero) - acc = mgpu.dialect.optimization_barrier([acc]) - nvvm_dialect.wgmma_fence_aligned() - input_refs.append(acc) - should_discharge.append(True) - elif isinstance(aval.dtype, gpu_core.BarrierType): - input_refs.append( - ctx.module_ctx.reserve_barrier( - mgpu.Barrier( - aval.dtype.num_arrivals - * ctx.estimator_ctx.arrival_multiplier, - *aval.shape, - ) - ) - ) - should_discharge.append(False) - elif aval.memory_space == gpu_core.SMEM: - [input_ref] = alloc_stack.enter_context( - ctx.module_ctx.scratch_view( - [jax.ShapeDtypeStruct(shape=aval.shape, dtype=aval.dtype)] - ) - ) - input_refs.append(input_ref) - should_discharge.append(False) - elif aval.memory_space == gpu_core.TMEM: - input_ref = alloc_stack.enter_context( - ctx.module_ctx.alloc_tmem( - jax.ShapeDtypeStruct(shape=aval.shape, dtype=aval.dtype), - ) + raise ValueError(f"Can't convert to ref: {aval}") + + if any(should_discharge): + # We convert consts to args, because we only have ir.Values and + # not JAX values during lowering. discharge_state() produces JAX + # valiues for the aguments but expects them to be provided for the + # consts. We also don't want to wrap the values in refs. + no_const_jaxpr = pe.convert_constvars_jaxpr(jaxpr) + should_discharge = [False] * len(consts) + should_discharge + discharged_jaxpr, _ = discharge.discharge_state(no_const_jaxpr, (), should_discharge=should_discharge) + new_input_vals = consts + tuple(input_refs) + outs = lower_jaxpr_to_mosaic_gpu( + ctx.module_ctx, + ctx.launch_ctx, + discharged_jaxpr, + new_input_vals, + (), ) - input_refs.append(input_ref) - should_discharge.append(False) + # Discharge appends to the output the refs that got discharged. + outs = outs[:-sum(should_discharge)] else: - raise ValueError(f"Can't convert to ref: {aval}") - - if any(should_discharge): - # We convert consts to args, because we only have ir.Values and - # not JAX values during lowering. discharge_state() produces JAX - # valiues for the aguments but expects them to be provided for the - # consts. We also don't want to wrap the values in refs. - no_const_jaxpr = pe.convert_constvars_jaxpr(jaxpr) - should_discharge = [False] * len(consts) + should_discharge - discharged_jaxpr, _ = discharge.discharge_state(no_const_jaxpr, (), should_discharge=should_discharge) - new_input_vals = consts + tuple(input_refs) - outs = lower_jaxpr_to_mosaic_gpu( - ctx.module_ctx, - ctx.launch_ctx, - discharged_jaxpr, - new_input_vals, - (), - ) - # Discharge appends to the output the refs that got discharged. - outs = outs[:-sum(should_discharge)] - else: - outs = lower_jaxpr_to_mosaic_gpu( - ctx.module_ctx, - ctx.launch_ctx, - jaxpr, - input_refs, - consts, - ) + outs = lower_jaxpr_to_mosaic_gpu( + ctx.module_ctx, + ctx.launch_ctx, + jaxpr, + input_refs, + consts, + ) assert len(outs) == len(jaxpr.outvars), (jaxpr, outs) return outs From 6851d6a1c81dab8437c3d7a7bc94c4df66ac9af6 Mon Sep 17 00:00:00 2001 From: Michael Whittaker Date: Wed, 26 Mar 2025 09:11:36 -0700 Subject: [PATCH 182/483] Skip some array_extensibility_tests on TPU. PiperOrigin-RevId: 740791514 --- tests/array_extensibility_test.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/array_extensibility_test.py b/tests/array_extensibility_test.py index 3e84f6668b8d..7c0ec07e6a05 100644 --- a/tests/array_extensibility_test.py +++ b/tests/array_extensibility_test.py @@ -53,6 +53,7 @@ class NumPyAPI(NamedTuple): fun: Callable[..., Any] args: list[jax.ShapeDtypeStruct] kwargs: dict[str, Any] + skip_on_devices: list[str] | None def name(self): return self.fun.__name__ @@ -61,9 +62,12 @@ def make_args(self, rng): rng = jtu.rand_default(rng) return jax.tree.map(lambda arg: rng(arg.shape, arg.dtype), self.args) + def with_skip_on_devices(self, disabled_devices: list[str]) -> 'NumPyAPI': + return self._replace(skip_on_devices=disabled_devices) + @classmethod def sig(cls, fun: Callable[..., Any], *args: Any, **kwargs: Any) -> 'NumPyAPI': - return cls(fun, args, kwargs) + return cls(fun, args, kwargs, None) class ShapeDtype: @@ -444,7 +448,7 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: NumPyAPI.sig(jnp.rint, Float[5]), NumPyAPI.sig(jnp.roll, Float[5], Int[1]), NumPyAPI.sig(jnp.rollaxis, Float[5, 4], axis=1), - NumPyAPI.sig(jnp.roots, Float[5]), + NumPyAPI.sig(jnp.roots, Float[5]).with_skip_on_devices(['tpu']), NumPyAPI.sig(jnp.rot90, Float[5, 3]), NumPyAPI.sig(jnp.round, Float[5]), NumPyAPI.sig(jnp.searchsorted, Float[5], Float[5]), @@ -512,6 +516,8 @@ class JaxArrayTests(jtu.JaxTestCase): @parameterized.named_parameters( {'testcase_name': api.name(), 'api': api} for api in NUMPY_APIS) def test_numpy_api_supports_jax_array(self, api): + if api.skip_on_devices and jtu.test_device_matches(api.skip_on_devices): + self.skipTest(f'{api.name()} not supported on {api.skip_on_devices}') fun = api.fun args = api.make_args(self.rng()) wrapped_args = jax.tree.map(JaxArrayWrapper, args) From 6386efe369dfa5c234c36c205dda6b270a1a91eb Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 26 Mar 2025 09:46:26 -0700 Subject: [PATCH 183/483] [pallas:mosaic_gpu] `plgpu.kernel` now accepts scratch shapes This frees the caller from another level of indirection via `pl.run_scoped`. PiperOrigin-RevId: 740802977 --- jax/_src/pallas/mosaic_gpu/core.py | 18 +++-- tests/pallas/mosaic_gpu_test.py | 102 ++++++++++++++--------------- 2 files changed, 63 insertions(+), 57 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 99a84962ae50..19007b6850fd 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -18,7 +18,7 @@ import abc import collections -from collections.abc import Iterable, Sequence +from collections.abc import Callable, Iterable, Sequence import dataclasses import enum import itertools as it @@ -31,6 +31,7 @@ from jax._src import tree_util from jax._src.lib.mlir.dialects import arith as arith_dialect from jax._src.pallas import core as pallas_core +from jax._src.pallas import primitives as pallas_primitives from jax._src.state import discharge as state_discharge from jax._src.state import indexing from jax._src.state import types as state_types @@ -114,20 +115,29 @@ def __call__( shape: tuple[int, ...], dtype: jnp.dtype, transforms: Sequence[MemoryRefTransform] = (), - ) -> pallas_core.MemoryRef: # A convenience function for constructing MemoryRef types. return GPUMemoryRef(shape, dtype, memory_space=self, transforms=transforms) -def kernel(body, out_shape, *, compiler_params=None, **mesh_kwargs): +def kernel( + body: Callable[..., None], + out_shape: object, + *, + scratch_shapes: Sequence[pallas_core.ScratchShape] = (), + compiler_params: object | None = None, + **mesh_kwargs: object, +): if unwrap_out := not isinstance(out_shape, (tuple, list)): out_shape = (out_shape,) def wrapper(*operands): def stateful(operand_and_out_refs): operand_refs, out_refs = operand_and_out_refs def cmap_body(): - body(*operand_refs, *out_refs) + pallas_primitives.run_scoped( + lambda *scratch_refs: body(*operand_refs, *out_refs, *scratch_refs), + *scratch_shapes, + ) pallas_core.core_map( GPUMesh(**mesh_kwargs), compiler_params=compiler_params )(cmap_body) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 4531bd568913..c10f06f8bb5d 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1681,22 +1681,19 @@ def test_tmem_alloc(self): @functools.partial( plgpu.kernel, out_shape=jnp.zeros((128, 128), jnp.float32), + scratch_shapes=[ + plgpu.TMEM((128, 128), jnp.float32), + plgpu.SMEM((128, 128), jnp.float32), + ], num_threads=1, axis_names=("x",), ) - def kernel(y_ref): - def scope(tmem_ref, smem_ref): - # Issue a write so the TMEM load is not DCE'd. - smem_ref[...] = tmem_ref[...] - plgpu.commit_smem() - plgpu.copy_smem_to_gmem(smem_ref, y_ref) - plgpu.wait_smem_to_gmem(0) - - pl.run_scoped( - scope, - plgpu.TMEM((128, 128), jnp.float32), - plgpu.SMEM((128, 128), jnp.float32), - ) + def kernel(y_ref, tmem_ref, smem_ref): + # Issue a write so the TMEM load is not DCE'd. + smem_ref[...] = tmem_ref[...] + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(smem_ref, y_ref) + plgpu.wait_smem_to_gmem(0) # Test that this runs without errors. jax.block_until_ready(kernel()) @@ -2122,7 +2119,18 @@ def tiled_add_kernel(x_smem, y_smem, o_smem): def test_carry_accumulate(self, m=256, n=256, num_compute_wgs=2): blk_m = blk_n = 64 - def _scoped(acc_smem, x_gmem, acc_gmem): + @functools.partial( + plgpu.kernel, + out_shape=jax.ShapeDtypeStruct((blk_m, blk_n), jnp.float32), + scratch_shapes=[ + plgpu.SMEM((blk_m, blk_n), jnp.float32), + ], + compiler_params=plgpu.GPUCompilerParams(approx_math=True), + grid=(1,), + num_threads=num_compute_wgs + 1, + axis_names=("_", "wg"), + ) + def kernel(x_gmem, acc_gmem, acc_smem): def _compute_thread(): # Cast the init value to the same layout as x_smem, so the pipeline loop # carry has a constant signature. @@ -2162,20 +2170,6 @@ def tiled_acc_kernel(x_smem, carry): ) pipeline(x_gmem) - @functools.partial( - plgpu.kernel, - out_shape=jax.ShapeDtypeStruct((blk_m, blk_n), jnp.float32), - compiler_params=plgpu.GPUCompilerParams(approx_math=True), - grid=(1,), - num_threads=num_compute_wgs + 1, - axis_names=("_", "wg"), - ) - def kernel(x_ref, acc_ref): - pl.run_scoped( - functools.partial(_scoped, x_gmem=x_ref, acc_gmem=acc_ref), - plgpu.SMEM((blk_m, blk_n), jnp.float32), - ) - x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float32) ref = jnp.sum(jnp.stack(np.split(x, m // blk_m, axis=0)), axis=0) ref = jnp.sum(jnp.stack(np.split(ref, n // blk_n, axis=1)), axis=0) @@ -2259,18 +2253,16 @@ def test_cross_wg_barrier(self): @functools.partial( plgpu.kernel, out_shape=jnp.zeros((2, 128), np.int32), + # Each warpgroup is a single logical thread! + scratch_shapes=[plgpu.Barrier(num_arrivals=2)], num_threads=2, axis_names=("wg",), ) - def kernel(y_ref): - def scoped(barrier): - plgpu.barrier_arrive(barrier) - plgpu.barrier_wait(barrier) - wg_idx = jax.lax.axis_index("wg") - y_ref[wg_idx] = jnp.broadcast_to(wg_idx, (128,)) - - # Each warpgroup is a single logical thread! - pl.run_scoped(scoped, plgpu.Barrier(num_arrivals=2)) + def kernel(o_ref, barrier): + plgpu.barrier_arrive(barrier) + plgpu.barrier_wait(barrier) + wg_idx = jax.lax.axis_index("wg") + o_ref[wg_idx] = jnp.broadcast_to(wg_idx, (128,)) np.testing.assert_array_equal( kernel(), np.repeat([0, 1], 128).reshape(2, 128) @@ -2329,25 +2321,29 @@ def body(l_ref, r_ref, o_ref): # Async copies def test_stage3(self): row_block, col_block = 64, 128 - def body(l_ref, r_ref, o_ref): + + @functools.partial( + plgpu.kernel, + out_shape=jax.ShapeDtypeStruct((128, 128), jnp.float16), + scratch_shapes=[ + *([plgpu.SMEM((row_block, col_block), jnp.float16)] * 3), + plgpu.Barrier(num_arrivals=2), + ], + grid=(2,), + axis_names=("rows",), + ) + def kernel(l_ref, r_ref, o_ref, l_smem, r_smem, o_smem, barrier): my_slice = pl.ds(lax.axis_index("rows") * row_block, row_block) - def scoped(l_smem, r_smem, o_smem, barrier): - plgpu.copy_gmem_to_smem(l_ref.at[my_slice], l_smem, barrier) - plgpu.copy_gmem_to_smem(r_ref.at[my_slice], r_smem, barrier) - plgpu.barrier_wait(barrier) - o_smem[...] = l_smem[...] + r_smem[...] - plgpu.commit_smem() - plgpu.copy_smem_to_gmem(o_smem, o_ref.at[my_slice]) - plgpu.wait_smem_to_gmem(0) - pl.run_scoped( - scoped, - *([plgpu.SMEM((row_block, col_block), jnp.float16)] * 3), - plgpu.Barrier(num_arrivals=2), - ) + plgpu.copy_gmem_to_smem(l_ref.at[my_slice], l_smem, barrier) + plgpu.copy_gmem_to_smem(r_ref.at[my_slice], r_smem, barrier) + plgpu.barrier_wait(barrier) + o_smem[...] = l_smem[...] + r_smem[...] + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(o_smem, o_ref.at[my_slice]) + plgpu.wait_smem_to_gmem(0) x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) - out = plgpu.kernel(body, out_shape=x, grid=(2,), axis_names=("rows",))(x, x) - np.testing.assert_allclose(out, x + x) + np.testing.assert_allclose(kernel(x, x), x + x) # Pipelining def test_stage4(self): From 2057df13ba70996324b3617436ffd04639f89dd7 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Wed, 26 Mar 2025 09:49:37 -0700 Subject: [PATCH 184/483] [Pallas/Mosaic GPU] Fix `copy_smem_to_gmem` lowering to not use a `single_thread_predicate` when using warpgroup semantics. Also avoid generating the predicate at all when using warpgroup semantics. PiperOrigin-RevId: 740803927 --- jax/_src/pallas/mosaic_gpu/lowering.py | 10 ++++++++-- jax/_src/pallas/mosaic_gpu/primitives.py | 17 +++++++++++++---- jaxlib/mosaic/dialect/gpu/mosaic_gpu.td | 12 +++++------- 3 files changed, 26 insertions(+), 13 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index e2c4ce322b1a..a41a657ba738 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -292,7 +292,7 @@ class ModuleContext: axis_names: _AxisNames | None program_ids: Sequence[ir.Value] | None approx_math: bool - single_wg_lane_predicate: ir.Value + single_wg_lane_predicate: ir.Value | None smem_requested_bytes: int smem_used_bytes: int tmem_requested_cols: int @@ -703,12 +703,18 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): tmem_cols = math.prod(runtime_tmem.shape) // tcgen05.TMEM_ROWS else: tmem_cols = 0 + + if thread_semantics == mgpu.ThreadSemantics.Lane: + single_lane_predicate = mgpu.single_thread_predicate(per_block=False) + else: # Warpgroup semantics do not have a single lane predicate. + single_lane_predicate = None + module_ctx = ModuleContext( mlir.sanitize_name(debug_info.func_name), axis_names, [_program_id(axis, squashed_dims) for axis in range(len(grid))], approx_math, - mgpu.single_thread_predicate(per_block=False), + single_lane_predicate, smem_requested_bytes=math.prod(ir.MemRefType(runtime_smem.type).shape), smem_used_bytes=0, tmem_requested_cols=tmem_cols, diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 8eafa0ac8e6d..9dc65c1bef88 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -186,12 +186,21 @@ def _copy_smem_to_gmem_lowering( has_user_predicate, commit_group, ): - predicate = ctx.module_ctx.single_wg_lane_predicate if has_user_predicate: flat_args, user_predicate = flat_args[:-1], flat_args[-1] - predicate = arith_dialect.andi( - predicate, lowering._ensure_ir_value(user_predicate, jnp.bool) - ) + predicate = lowering._ensure_ir_value(user_predicate, jnp.bool) + else: + predicate = None + + if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if predicate is not None: + assert ctx.module_ctx.single_wg_lane_predicate is not None + predicate = arith_dialect.andi( + predicate, ctx.module_ctx.single_wg_lane_predicate + ) + else: + predicate = ctx.module_ctx.single_wg_lane_predicate + flat_src_transforms, flat_dst_transforms = util.split_list( flat_args, [src_transforms_treedef.num_leaves], diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td index f0a37084b759..85929080faec 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td @@ -281,12 +281,11 @@ def MosaicGPU_AsyncLoadOp : Op Date: Wed, 26 Mar 2025 09:51:13 -0700 Subject: [PATCH 185/483] [AutoPGLE] Prevent an AutoPGLE to run if user launched an external profiler. Reverts d4745b9bd81b49e2a7a8938ea98516296d54635f PiperOrigin-RevId: 740804528 --- jax/_src/profiler.py | 5 +++++ jaxlib/xla/xla_extension/profiler.pyi | 1 + tests/pgle_test.py | 14 ++++++++++++-- 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/jax/_src/profiler.py b/jax/_src/profiler.py index f06933f57e22..96e742f33904 100644 --- a/jax/_src/profiler.py +++ b/jax/_src/profiler.py @@ -33,6 +33,7 @@ from jax._src import xla_bridge from jax._src.lib import xla_client +from jax._src.lib import version as jaxlib_version _profiler_server: xla_client.profiler.ProfilerServer | None = None @@ -426,6 +427,10 @@ def trace(cls, runner: PGLEProfiler | None): else: options = xla_client.profiler.ProfileOptions() options.enable_hlo_proto = True + + # ToDo(patrios): Remove when jaxlib version is updated to 0.5.4. + if jaxlib_version > (0, 5, 3): + options.raise_error_on_start_failure = True runner.current_session = xla_client.profiler.ProfilerSession(options) try: diff --git a/jaxlib/xla/xla_extension/profiler.pyi b/jaxlib/xla/xla_extension/profiler.pyi index 7610ce1000bf..95749f61978a 100644 --- a/jaxlib/xla/xla_extension/profiler.pyi +++ b/jaxlib/xla/xla_extension/profiler.pyi @@ -42,6 +42,7 @@ class ProfileOptions: start_timestamp_ns: int duration_ms: int repository_path: str + raise_error_on_start_failure: bool def aggregate_profiled_instructions(profiles: List[bytes], percentile: int) -> str: ... diff --git a/tests/pgle_test.py b/tests/pgle_test.py index 7f9ea598d51b..7dabd809d95e 100644 --- a/tests/pgle_test.py +++ b/tests/pgle_test.py @@ -65,7 +65,11 @@ def testPGLEProfilerGetFDOProfile(self): jax.jit, in_shardings=NamedSharding(mesh, PartitionSpec('x')), out_shardings=NamedSharding(mesh, PartitionSpec('x')), - compiler_options={'xla_gpu_enable_latency_hiding_scheduler': 'True'}, + compiler_options={ + 'xla_gpu_enable_latency_hiding_scheduler': 'True', + # Make sure that matmul is not emitted as Triton GEMM. + 'xla_gpu_enable_triton_gemm': 'False', + }, ) def f(x, y): return x @ y @@ -93,6 +97,8 @@ def testPGLEProfilerGetFDOProfileLarge(self): compiler_options = { 'xla_gpu_enable_latency_hiding_scheduler': 'True', + # Make sure that matmul is not emitted as Triton GEMM. + 'xla_gpu_enable_triton_gemm': 'False', } # TODO(b/37664749): Remove this flag once the bug is fixed. compiler_options['xla_gpu_enable_command_buffer'] = '' @@ -321,7 +327,11 @@ def testPassingFDOProfile(self): jax.jit, in_shardings=NamedSharding(mesh, PartitionSpec('x')), out_shardings=NamedSharding(mesh, PartitionSpec('x')), - compiler_options={'xla_gpu_enable_latency_hiding_scheduler': 'True'}, + compiler_options={ + 'xla_gpu_enable_latency_hiding_scheduler': 'True', + # Make sure that matmul is not emitted as Triton GEMM. + 'xla_gpu_enable_triton_gemm': 'False', + }, ) def f(x, y): return x @ y From 91a07ea2e8911a5b6fab7b989d28402cc0176352 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 25 Mar 2025 08:40:43 -0700 Subject: [PATCH 186/483] Clean up a number of finalized deprecations --- jax/__init__.py | 5 ----- jax/_src/numpy/lax_numpy.py | 13 +------------ jax/core.py | 23 ----------------------- jax/interpreters/xla.py | 13 ------------- jax/lib/xla_bridge.py | 9 --------- jax/lib/xla_client.py | 23 ----------------------- jax/numpy/__init__.py | 16 ---------------- jax/numpy/__init__.pyi | 3 +-- jax/sharding.py | 15 --------------- tests/lax_numpy_test.py | 5 ----- 10 files changed, 2 insertions(+), 123 deletions(-) diff --git a/jax/__init__.py b/jax/__init__.py index ae3bac4ad3fa..988c224e4772 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -220,11 +220,6 @@ "or jax.tree_util.tree_map (any JAX version).", _deprecated_tree_map ), - # Finalized Nov 12 2024; remove after Feb 12 2025 - "clear_backends": ( - "jax.clear_backends was removed in JAX v0.4.36", - None - ), } import typing as _typing diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 16355695792d..fd6209ab22c4 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -1944,8 +1944,7 @@ def isrealobj(x: Any) -> bool: @export def reshape( - a: ArrayLike, shape: DimSize | Shape | None = None, order: str = "C", *, - newshape: DimSize | Shape | DeprecatedArg = DeprecatedArg(), + a: ArrayLike, shape: DimSize | Shape, order: str = "C", *, copy: bool | None = None) -> Array: """Return a reshaped copy of an array. @@ -1962,8 +1961,6 @@ def reshape( JAX does not support ``order="A"``. copy: unused by JAX; JAX always returns a copy, though under JIT the compiler may optimize such copies away. - newshape: deprecated alias of the ``shape`` argument. Will result in a - :class:`DeprecationWarning` if used. Returns: reshaped copy of input array with the specified shape. @@ -2021,14 +2018,6 @@ def reshape( __tracebackhide__ = True util.check_arraylike("reshape", a) - # TODO(jakevdp): finalized 2024-12-2; remove argument after JAX v0.4.40. - if not isinstance(newshape, DeprecatedArg): - raise TypeError("The newshape argument to jnp.reshape was removed in JAX v0.4.36." - " Use shape instead.") - if shape is None: - raise TypeError( - "jnp.shape requires passing a `shape` argument, but none was given." - ) try: # forward to method for ndarrays return a.reshape(shape, order=order) # type: ignore[call-overload,union-attr] diff --git a/jax/core.py b/jax/core.py index 3fd7af440d4a..b404e66c2691 100644 --- a/jax/core.py +++ b/jax/core.py @@ -160,29 +160,6 @@ _src_core.lattice_join), "raise_to_shaped": ("jax.core.raise_to_shaped is deprecated. It is a no-op as of JAX v0.4.36.", _src_core.raise_to_shaped), - # Finalized 2024-12-11; remove after 2025-3-11 - "check_eqn": ("jax.core.check_eqn was removed in JAX v0.4.38.", None), - "check_type": ("jax.core.check_type was removed in JAX v0.4.38.", None), - "check_valid_jaxtype": ( - ("jax.core.check_valid_jaxtype was removed in JAX v0.4.38. Instead, you can manually" - " raise an error if core.valid_jaxtype() returns False."), - None), - "non_negative_dim": ( - "jax.core.non_negative_dim was removed in JAX v0.4.38. Use max_dim(..., 0).", None, - ), - # Finalized 2024-09-25; remove after 2024-12-25 - "pp_aval": ("jax.core.pp_aval was removed in JAX v0.4.34.", None), - "pp_eqn": ("jax.core.pp_eqn was removed in JAX v0.4.34.", None), - "pp_eqn_rules": ("jax.core.pp_eqn_rules was removed in JAX v0.4.34.", None), - "pp_eqns": ("jax.core.pp_eqns was removed in JAX v0.4.34.", None), - "pp_jaxpr": ("jax.core.pp_jaxpr was removed in JAX v0.4.34.", None), - "pp_jaxpr_eqn_range": ("jax.core.pp_jaxpr_eqn_range was removed in JAX v0.4.34.", None), - "pp_jaxpr_skeleton": ("jax.core.pp_jaxpr_skeleton was removed in JAX v0.4.34.", None), - "pp_jaxprs": ("jax.core.pp_jaxprs was removed in JAX v0.4.34.", None), - "pp_kv_pair": ("jax.core.pp_kv_pair was removed in JAX v0.4.34.", None), - "pp_kv_pairs": ("jax.core.pp_kv_pairs was removed in JAX v0.4.34.", None), - "pp_var": ("jax.core.pp_var was removed in JAX v0.4.34.", None), - "pp_vars": ("jax.core.pp_vars was removed in JAX v0.4.34.", None), } import typing diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index bd3b83e37d24..2f8417ade1f8 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -38,19 +38,6 @@ "jax.interpreters.xla.pytype_aval_mappings is deprecated.", _src_core.pytype_aval_mappings ), - # Finalized 2024-10-24; remove after 2025-01-24 - "xb": ( - ("jax.interpreters.xla.xb was removed in JAX v0.4.36. " - "Use jax.lib.xla_bridge instead."), None - ), - "xc": ( - ("jax.interpreters.xla.xc was removed in JAX v0.4.36. " - "Use jax.lib.xla_client instead."), None - ), - "xe": ( - ("jax.interpreters.xla.xe was removed in JAX v0.4.36. " - "Use jax.lib.xla_extension instead."), None - ), } import typing as _typing diff --git a/jax/lib/xla_bridge.py b/jax/lib/xla_bridge.py index b158d9b1ff51..95598c447262 100644 --- a/jax/lib/xla_bridge.py +++ b/jax/lib/xla_bridge.py @@ -27,15 +27,6 @@ "jax.lib.xla_bridge.get_backend is deprecated; use jax.extend.backend.get_backend.", _deprecated_get_backend ), - # Finalized 2024-12-11; remove after 2025-3-11 - "xla_client": ( - "jax.lib.xla_bridge.xla_client was removed in JAX v0.4.38; use jax.lib.xla_client directly.", - None - ), - "default_backend": ( - "jax.lib.xla_bridge.default_backend was removed in JAX v0.4.38; use jax.default_backend.", - None - ), } import typing as _typing diff --git a/jax/lib/xla_client.py b/jax/lib/xla_client.py index 07c6914a1f59..314788bfa5e7 100644 --- a/jax/lib/xla_client.py +++ b/jax/lib/xla_client.py @@ -26,27 +26,6 @@ Traceback = _xc.Traceback _deprecations = { - # Finalized 2024-12-11; remove after 2025-3-11 - "_xla": ( - "jax.lib.xla_client._xla was removed in JAX v0.4.38; use jax.lib.xla_extension.", - None, - ), - "bfloat16": ( - "jax.lib.xla_client.bfloat16 was removed in JAX v0.4.38; use ml_dtypes.bfloat16.", - None, - ), - # Finalized 2024-12-23; remove after 2024-03-23 - "Device": ( - "jax.lib.xla_client.Device is deprecated; use jax.Device instead.", - None, - ), - "XlaRuntimeError": ( - ( - "jax.lib.xla_client.XlaRuntimeError is deprecated; use" - " jax.errors.JaxRuntimeError." - ), - None, - ), # Finalized 2025-03-25; remove after 2025-06-25 "FftType": ( "jax.lib.xla_client.FftType was removed in JAX v0.6.0; use jax.lax.FftType.", @@ -106,12 +85,10 @@ ops = _xc.ops register_custom_call_target = _xc.register_custom_call_target ArrayImpl = _xc.ArrayImpl - Device = _xc.Device PrimitiveType = _xc.PrimitiveType Shape = _xc.Shape XlaBuilder = _xc.XlaBuilder XlaComputation = _xc.XlaComputation - XlaRuntimeError = _xc.XlaRuntimeError else: from jax._src.deprecations import deprecation_getattr as _deprecation_getattr diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index cb291bdca79a..31cca3578916 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -506,19 +506,3 @@ from jax._src.numpy.array_methods import register_jax_array_methods register_jax_array_methods() del register_jax_array_methods - - -_deprecations = { - # Finalized 2024-12-13; remove after 2024-3-13 - "round_": ( - "jnp.round_ was deprecated in JAX 0.4.38; use jnp.round instead.", - None - ), -} - -import typing -if not typing.TYPE_CHECKING: - from jax._src.deprecations import deprecation_getattr as _deprecation_getattr - __getattr__ = _deprecation_getattr(__name__, _deprecations) - del _deprecation_getattr -del typing diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index b73a3b95b9a5..640e9de7eac3 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -808,8 +808,7 @@ def remainder(x: ArrayLike, y: ArrayLike, /) -> Array: ... def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = ..., *, total_repeat_length: int | None = ...) -> Array: ... def reshape( - a: ArrayLike, shape: DimSize | Shape = ..., - newshape: DimSize | Shape | None = ..., order: str = ... + a: ArrayLike, shape: DimSize | Shape, order: str = ..., *, copy: bool | None = ..., ) -> Array: ... def resize(a: ArrayLike, new_shape: Shape) -> Array: ... diff --git a/jax/sharding.py b/jax/sharding.py index 55ff0f6aea0b..bacf848f07ed 100644 --- a/jax/sharding.py +++ b/jax/sharding.py @@ -34,18 +34,3 @@ AxisType as AxisType, get_abstract_mesh as get_abstract_mesh, ) - -_deprecations = { - # Finalized 2024-10-01; remove after 2025-01-01. - "XLACompatibleSharding": ( - ( - "jax.sharding.XLACompatibleSharding was removed in JAX v0.4.34. " - "Use jax.sharding.Sharding instead." - ), - None, - ) -} - -from jax._src.deprecations import deprecation_getattr as _deprecation_getattr -__getattr__ = _deprecation_getattr(__name__, _deprecations) -del _deprecation_getattr diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 98f10d9c02b3..c0650441edd7 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -3496,11 +3496,6 @@ def testReshape(self, arg_shape, out_shape, dtype, order): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) - def testReshapeDeprecatedArgs(self): - msg = "The newshape argument to jnp.reshape was removed in JAX v0.4.36." - with self.assertRaisesRegex(TypeError, msg): - jnp.reshape(jnp.arange(4), newshape=(2, 2)) - @jtu.sample_product( [dict(arg_shape=arg_shape, out_shape=out_shape) for arg_shape, out_shape in [ From b1b281a427c7182d5345f5f1a88b83feb13a46f2 Mon Sep 17 00:00:00 2001 From: Ayaka Date: Wed, 26 Mar 2025 10:36:35 -0700 Subject: [PATCH 187/483] Prototype of adding error checking to jax.numpy functions PiperOrigin-RevId: 740822504 --- jax/_src/error_check.py | 149 ++++++---------------------------- jax/_src/numpy/error.py | 131 ++++++++++++++++++++++++++++++ jax/_src/numpy/ufuncs.py | 9 +- tests/BUILD | 5 ++ tests/error_check_test.py | 86 -------------------- tests/jax_numpy_error_test.py | 130 +++++++++++++++++++++++++++++ 6 files changed, 296 insertions(+), 214 deletions(-) create mode 100644 jax/_src/numpy/error.py create mode 100644 tests/jax_numpy_error_test.py diff --git a/jax/_src/error_check.py b/jax/_src/error_check.py index 9d493c1f351b..e78b9bc82115 100644 --- a/jax/_src/error_check.py +++ b/jax/_src/error_check.py @@ -14,18 +14,15 @@ from __future__ import annotations -import contextlib import dataclasses from functools import partial import json import threading import traceback as tb_lib from types import TracebackType -from typing import Literal import warnings import jax -from jax._src import config from jax._src import core from jax._src import source_info_util from jax._src import traceback_util @@ -118,56 +115,39 @@ def __exit__(self, exc_type, exc_value, traceback): _error_storage.ref = self.old_ref -# TODO(ayx): Move all category-related logic into the jax.numpy integration once -# it is ready. This logic is specific to how jax.numpy decides when to call -# `set_error_if`, and doesn't belong in the core error-checking library itself. -# The responsibility for deciding whether to predicate an error should lie with -# the user or the higher-level library (like jax.numpy), not with -# `set_error_if`. -Category = Literal["nan", "divide", "oob"] - - -def _is_category_disabled( - category: Category | None, -) -> bool: - """Check if the error checking behavior for the given category is disabled.""" - if category is None: - return False - if category == "nan": - return config.error_checking_behavior_nan.value == "ignore" - if category == "divide": - return config.error_checking_behavior_divide.value == "ignore" - if category == "oob": - return config.error_checking_behavior_oob.value == "ignore" - raise ValueError(f"Invalid category: {category}") - - -def _set_error_if_with_category( - pred: jax.Array, - /, - msg: str, - category: Category | None = None, -) -> None: +def set_error_if(pred: jax.Array, /, msg: str) -> None: """Set the internal error state if any element of `pred` is `True`. - This function is similar to :func:`set_error_if`, but it also takes a category - argument. The category can be "nan", "divide", or "oob". The error checking - behavior for each category can be configured using - :func:`set_error_checking_behavior`. If not provided, there will be no - category. + This function is used inside JAX computations to detect runtime errors without + immediately halting execution. When this function is traced (e.g., inside + :func:`jax.jit`), the corresponding error message and its traceback are + recorded. At execution time, if `pred` contains any `True` values, the error + state is set, but execution continues without interruption. The recorded error + can later be raised using :func:`raise_if_error`. + + If the error state has already been set, subsequent errors are ignored and + will not override the existing error. + + For multi-device environments, in explicit mode, users must call + :func:`error_checking_context` to initialize a new error tracking state that + matches the device mesh. In auto mode, implicit cross-device communication may + occur inside this function, which could impact performance. A warning is + issued in such cases. - This function is intended for use in JAX internal APIs (e.g., `jax.numpy`) - to perform category-specific runtime checks tied to the operation being - performed. + When exporting a function with `jax.export`, error checking must be explicitly + wrapped using :func:`wrap_for_export` before export and + :func:`unwrap_from_import` after import. + + Args: + pred: A JAX boolean array. If any element of `pred` is `True`, the internal + error state will be set. + msg: The corresponding error message to be raised later. """ if _error_storage.ref is None: with core.eval_context(): _initialize_error_code_ref() assert _error_storage.ref is not None - if _is_category_disabled(category): - return - # Get the traceback. traceback = source_info_util.current().traceback assert traceback is not None @@ -219,37 +199,6 @@ def _set_error_if_with_category( _error_storage.ref[...] = error_code -def set_error_if(pred: jax.Array, /, msg: str) -> None: - """Set the internal error state if any element of `pred` is `True`. - - This function is used inside JAX computations to detect runtime errors without - immediately halting execution. When this function is traced (e.g., inside - :func:`jax.jit`), the corresponding error message and its traceback are - recorded. At execution time, if `pred` contains any `True` values, the error - state is set, but execution continues without interruption. The recorded error - can later be raised using :func:`raise_if_error`. - - If the error state has already been set, subsequent errors are ignored and - will not override the existing error. - - For multi-device environments, in explicit mode, users must call - :func:`error_checking_context` to initialize a new error tracking state that - matches the device mesh. In auto mode, implicit cross-device communication may - occur inside this function, which could impact performance. A warning is - issued in such cases. - - When exporting a function with `jax.export`, error checking must be explicitly - wrapped using :func:`wrap_for_export` before export and - :func:`unwrap_from_import` after import. - - Args: - pred: A JAX boolean array. If any element of `pred` is `True`, the internal - error state will be set. - msg: The corresponding error message to be raised later. - """ - _set_error_if_with_category(pred, msg) - - def raise_if_error() -> None: """Raise an exception if the internal error state is set. @@ -406,53 +355,3 @@ def inner(*args, **kwargs): return out return inner - - -Behavior = Literal["ignore", "raise"] - - -class error_checking_behavior: - """A context manager to set the error checking behavior. - - If both `all` and a category are provided, the category will override the - `all` setting. - - When the error checking behavior is set to "ignore", all errors will be - ignored. When set to "raise", errors will be detected and recorded, but an - exception will not be raised immediately. Users must call - :func:`raise_if_error` to at the end of the computation to raise the - exception. - """ - - def __init__( - self, - *, - all: Behavior | None = None, - nan: Behavior | None = None, - divide: Behavior | None = None, - oob: Behavior | None = None, - ) -> None: - new_settings = {} - if all is not None: - new_settings["nan"] = new_settings["divide"] = new_settings["oob"] = all - if nan is not None: - new_settings["nan"] = nan - if divide is not None: - new_settings["divide"] = divide - if oob is not None: - new_settings["oob"] = oob - self.new_settings = new_settings - self.stack = contextlib.ExitStack() - - def __enter__(self): - config_flags = { - "nan": config.error_checking_behavior_nan, - "divide": config.error_checking_behavior_divide, - "oob": config.error_checking_behavior_oob, - } - for key, value in self.new_settings.items(): - self.stack.enter_context(config_flags[key](value)) - return self - - def __exit__(self, exc_type, exc_value, traceback): - self.stack.close() diff --git a/jax/_src/numpy/error.py b/jax/_src/numpy/error.py new file mode 100644 index 000000000000..52b996a0b050 --- /dev/null +++ b/jax/_src/numpy/error.py @@ -0,0 +1,131 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +from typing import Literal + +import jax +from jax._src import config + +Category = Literal["nan", "divide", "oob"] + + +def _is_category_disabled( + category: Category | None, +) -> bool: + """Check if the error checking behavior for the given category is disabled.""" + if category is None: + return False + if category == "nan": + raise ValueError("nan is deprecated. Use `_set_error_if_nan` instead.") + if category == "divide": + return config.error_checking_behavior_divide.value == "ignore" + if category == "oob": + return config.error_checking_behavior_oob.value == "ignore" + raise ValueError(f"Invalid category: {category}") + + +def _set_error_if_with_category( + pred: jax.Array, + /, + msg: str, + category: Category | None = None, +) -> None: + """Set the internal error state if any element of `pred` is `True`. + + This function is similar to :func:`set_error_if`, but it also takes a category + argument. The category can be "nan", "divide", or "oob". The error checking + behavior for each category can be configured using + :func:`set_error_checking_behavior`. If not provided, there will be no + category. + + This function is intended for use in JAX internal APIs (e.g., `jax.numpy`) + to perform category-specific runtime checks tied to the operation being + performed. + """ + if _is_category_disabled(category): + return + + # TODO(mattjj): fix the circular import issue. + from jax._src import error_check as error_check_lib + error_check_lib.set_error_if(pred, msg) + + +def _set_error_if_nan(pred: jax.Array, /): + """Set the internal error state if any element of `pred` is `NaN`. + + This function is disabled if the `jax_error_checking_behavior_nan` flag is + set to "ignore". + """ + if config.error_checking_behavior_nan.value == "ignore": + return + + # TODO(mattjj): fix the circular import issue. + import jax.numpy as jnp + if not jnp.issubdtype(pred.dtype, jnp.floating): # only check floats + return + + # TODO(mattjj): fix the circular import issue. + from jax._src import error_check as error_check_lib + error_check_lib.set_error_if(jnp.isnan(pred), "NaN encountered") + + +Behavior = Literal["ignore", "raise"] + + +class error_checking_behavior: + """A context manager to set the error checking behavior. + + If both `all` and a category are provided, the category will override the + `all` setting. + + When the error checking behavior is set to "ignore", all errors will be + ignored. When set to "raise", errors will be detected and recorded, but an + exception will not be raised immediately. Users must call + :func:`raise_if_error` to at the end of the computation to raise the + exception. + """ + + def __init__( + self, + *, + all: Behavior | None = None, + nan: Behavior | None = None, + divide: Behavior | None = None, + oob: Behavior | None = None, + ) -> None: + new_settings = {} + if all is not None: + new_settings["nan"] = new_settings["divide"] = new_settings["oob"] = all + if nan is not None: + new_settings["nan"] = nan + if divide is not None: + new_settings["divide"] = divide + if oob is not None: + new_settings["oob"] = oob + self.new_settings = new_settings + self.stack = contextlib.ExitStack() + + def __enter__(self): + config_flags = { + "nan": config.error_checking_behavior_nan, + "divide": config.error_checking_behavior_divide, + "oob": config.error_checking_behavior_oob, + } + for key, value in self.new_settings.items(): + self.stack.enter_context(config_flags[key](value)) + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.stack.close() diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 91191d24a12e..1df973039213 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -32,12 +32,13 @@ from jax._src.lax import lax from jax._src.lax import other as lax_other from jax._src.typing import Array, ArrayLike +from jax._src.numpy import error as jnp_error +from jax._src.numpy import reductions +from jax._src.numpy.ufunc_api import ufunc from jax._src.numpy.util import ( check_arraylike, promote_args, promote_args_inexact, promote_args_numeric, promote_dtypes_inexact, promote_dtypes_numeric, promote_shapes, _where, check_no_float0s) -from jax._src.numpy.ufunc_api import ufunc -from jax._src.numpy import reductions from jax._src.util import set_module @@ -486,7 +487,9 @@ def log(x: ArrayLike, /) -> Array: >>> jnp.allclose(jnp.log(x1*x2), jnp.log(x1)+jnp.log(x2)) Array(True, dtype=bool) """ - return lax.log(*promote_args_inexact('log', x)) + out = lax.log(*promote_args_inexact('log', x)) + jnp_error._set_error_if_nan(out) + return out @export diff --git a/tests/BUILD b/tests/BUILD index 2e03f331744c..1baeb4f83af7 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1165,6 +1165,11 @@ jax_multiplatform_test( srcs = ["error_check_test.py"], ) +jax_multiplatform_test( + name = "jax_numpy_error_test", + srcs = ["jax_numpy_error_test.py"], +) + jax_multiplatform_test( name = "stax_test", srcs = ["stax_test.py"], diff --git a/tests/error_check_test.py b/tests/error_check_test.py index af3f35c7ab62..0c77989b8a43 100644 --- a/tests/error_check_test.py +++ b/tests/error_check_test.py @@ -372,92 +372,6 @@ def run_import(serialized): ): error_check.raise_if_error() - @parameterized.product(jit=[True, False]) - def test_error_category_nan_check(self, jit): - def f(x): - error_check._set_error_if_with_category( - jnp.isnan(x), "x is NaN", category="nan" - ) - return x - - if jit: - f = jax.jit(f) - - x = jnp.full((4,), jnp.nan, dtype=jnp.float32) - - with error_check.error_checking_behavior(nan="ignore"): - _ = f(x) - error_check.raise_if_error() # should not raise error - - with error_check.error_checking_behavior(nan="raise"): - _ = f(x) - with self.assertRaisesRegex(JaxValueError, "x is NaN"): - error_check.raise_if_error() - - @parameterized.product(jit=[True, False]) - def test_error_category_divide_check(self, jit): - def f(x, y): - error_check._set_error_if_with_category( - y == 0.0, "division by zero", category="divide" - ) - return x / y - - if jit: - f = jax.jit(f) - - x = jnp.arange(4, dtype=jnp.float32) + 1 - y = jnp.arange(4, dtype=jnp.float32) - - with error_check.error_checking_behavior(divide="ignore"): - _ = f(x, y) - error_check.raise_if_error() # should not raise error - - with error_check.error_checking_behavior(divide="raise"): - _ = f(x, y) - with self.assertRaisesRegex(JaxValueError, "division by zero"): - error_check.raise_if_error() - - @parameterized.product(jit=[True, False]) - def test_error_category_oob_check(self, jit): - def f(x, start_indices, slice_sizes): - error_check._set_error_if_with_category( - jnp.logical_or( - start_indices < 0, - start_indices + jnp.array(slice_sizes, dtype=jnp.int32) - >= jnp.array(x.shape, dtype=jnp.int32), - ), - "Out of bounds in dynamic_slice", - category="oob", - ) - y = jax.lax.dynamic_slice( - x, start_indices, slice_sizes, allow_negative_indices=False - ) - return y - - if jit: - f = jax.jit(f, static_argnums=(2,)) - - x = jnp.arange(12).reshape(3, 4) - start_indices = jnp.array([0, -1], dtype=jnp.int32) - slice_sizes = (3, 4) - - with error_check.error_checking_behavior(oob="ignore"): - _ = f(x, start_indices, slice_sizes) - error_check.raise_if_error() # should not raise error - - with error_check.error_checking_behavior(oob="raise"): - _ = f(x, start_indices, slice_sizes) - with self.assertRaisesRegex( - JaxValueError, "Out of bounds in dynamic_slice", - ): - error_check.raise_if_error() - - def test_error_category_invalid_category(self): - with self.assertRaisesRegex(ValueError, "Invalid category"): - error_check._set_error_if_with_category( - jnp.isnan(jnp.float32(1.0)), "x is NaN", category="invalid" - ) - if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/jax_numpy_error_test.py b/tests/jax_numpy_error_test.py new file mode 100644 index 000000000000..c2883f2005e0 --- /dev/null +++ b/tests/jax_numpy_error_test.py @@ -0,0 +1,130 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax._src import config +from jax._src import error_check +from jax._src import test_util as jtu +from jax._src.numpy import error as jnp_error +import jax.numpy as jnp + +config.parse_flags_with_absl() + + +JaxValueError = error_check.JaxValueError + + +class JaxNumpyErrorTests(jtu.JaxTestCase): + @parameterized.product(jit=[True, False]) + def test_set_error_if_nan(self, jit): + def f(x): + jnp_error._set_error_if_nan(x) + return x + + if jit: + f = jax.jit(f) + + x = jnp.full((4,), jnp.nan, dtype=jnp.float32) + + with jnp_error.error_checking_behavior(nan="ignore"): + _ = f(x) + error_check.raise_if_error() # should not raise error + + with jnp_error.error_checking_behavior(nan="raise"): + _ = f(x) + with self.assertRaisesRegex(JaxValueError, "NaN"): + error_check.raise_if_error() + + @parameterized.product(jit=[True, False]) + def test_error_category_divide_check(self, jit): + def f(x, y): + jnp_error._set_error_if_with_category( + y == 0.0, "division by zero", category="divide" + ) + return x / y + + if jit: + f = jax.jit(f) + + x = jnp.arange(4, dtype=jnp.float32) + 1 + y = jnp.arange(4, dtype=jnp.float32) + + with jnp_error.error_checking_behavior(divide="ignore"): + _ = f(x, y) + error_check.raise_if_error() # should not raise error + + with jnp_error.error_checking_behavior(divide="raise"): + _ = f(x, y) + with self.assertRaisesRegex(JaxValueError, "division by zero"): + error_check.raise_if_error() + + @parameterized.product(jit=[True, False]) + def test_error_category_oob_check(self, jit): + def f(x, start_indices, slice_sizes): + jnp_error._set_error_if_with_category( + jnp.logical_or( + start_indices < 0, + start_indices + jnp.array(slice_sizes, dtype=jnp.int32) + >= jnp.array(x.shape, dtype=jnp.int32), + ), + "Out of bounds in dynamic_slice", + category="oob", + ) + y = jax.lax.dynamic_slice( + x, start_indices, slice_sizes, allow_negative_indices=False + ) + return y + + if jit: + f = jax.jit(f, static_argnums=(2,)) + + x = jnp.arange(12).reshape(3, 4) + start_indices = jnp.array([0, -1], dtype=jnp.int32) + slice_sizes = (3, 4) + + with jnp_error.error_checking_behavior(oob="ignore"): + _ = f(x, start_indices, slice_sizes) + error_check.raise_if_error() # should not raise error + + with jnp_error.error_checking_behavior(oob="raise"): + _ = f(x, start_indices, slice_sizes) + with self.assertRaisesRegex( + JaxValueError, "Out of bounds in dynamic_slice", + ): + error_check.raise_if_error() + + def test_error_category_invalid_category(self): + with self.assertRaisesRegex(ValueError, "Invalid category"): + jnp_error._set_error_if_with_category( + jnp.isnan(jnp.float32(1.0)), "x is NaN", category="invalid" + ) + + @parameterized.product(jit=[True, False]) + def test_can_raise_nan_error(self, jit): + x = jnp.arange(4, dtype=jnp.float32) - 1 + + f = jnp.log + if jit: + f = jax.jit(f) + + with jnp_error.error_checking_behavior(nan="raise"): + f(x) + with self.assertRaisesRegex(JaxValueError, "NaN"): + error_check.raise_if_error() + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) From 55318d582424ba78dbbf7c0a7c9a33a60a43ada8 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 26 Mar 2025 10:55:14 -0700 Subject: [PATCH 188/483] `build/build.py` changes: copy the wheels created by the new build wheel targets into the path specified by `--output_path`. PiperOrigin-RevId: 740829299 --- build/build.py | 31 +++++++++++++++++++++++++++++++ build/tools/utils.py | 26 ++++++++++++++++++++++++++ ci/build_artifacts.sh | 11 ++--------- 3 files changed, 59 insertions(+), 9 deletions(-) diff --git a/build/build.py b/build/build.py index cdb568171b66..7f70f7b2ffef 100755 --- a/build/build.py +++ b/build/build.py @@ -76,6 +76,8 @@ "jax-rocm-pjrt": "//jaxlib/tools:jax_rocm_pjrt_wheel", } +_JAX_CUDA_VERSION = "12" + def add_global_arguments(parser: argparse.ArgumentParser): """Adds all the global arguments that applies to all the CLI subcommands.""" parser.add_argument( @@ -695,6 +697,35 @@ async def main(): if result.return_code != 0: raise RuntimeError(f"Command failed with return code {result.return_code}") + if args.use_new_wheel_build_rule: + output_path = args.output_path + jax_bazel_dir = os.path.join("bazel-bin", "dist") + jaxlib_and_plugins_bazel_dir = os.path.join( + "bazel-bin", "jaxlib", "tools", "dist" + ) + for wheel in args.wheels.split(","): + if wheel == "jax": + bazel_dir = jax_bazel_dir + else: + bazel_dir = jaxlib_and_plugins_bazel_dir + if "cuda" in wheel: + wheel_dir = wheel.replace("cuda", f"cuda{_JAX_CUDA_VERSION}").replace( + "-", "_" + ) + else: + wheel_dir = wheel + + if args.editable: + src_dir = os.path.join(bazel_dir, wheel_dir) + dst_dir = os.path.join(output_path, wheel_dir) + utils.copy_dir_recursively(src_dir, dst_dir) + else: + utils.copy_individual_files(bazel_dir, output_path, f"{wheel_dir}*.whl") + if wheel == "jax": + utils.copy_individual_files( + bazel_dir, output_path, f"{wheel_dir}*.tar.gz" + ) + # Exit with success if all wheels in the list were built successfully. sys.exit(0) diff --git a/build/tools/utils.py b/build/tools/utils.py index 7e375169827b..8b8dc80d1e0f 100644 --- a/build/tools/utils.py +++ b/build/tools/utils.py @@ -14,6 +14,7 @@ # ============================================================================== # Helper script for tools/utilities used by the JAX build CLI. import collections +import glob import hashlib import logging import os @@ -256,3 +257,28 @@ def _parse_string_as_bool(s): return False else: raise ValueError(f"Expected either 'true' or 'false'; got {s}") + + +def copy_dir_recursively(src, dst): + if os.path.exists(dst): + shutil.rmtree(dst) + os.makedirs(dst, exist_ok=True) + for root, dirs, files in os.walk(src): + relative_path = os.path.relpath(root, src) + dst_dir = os.path.join(dst, relative_path) + os.makedirs(dst_dir, exist_ok=True) + for f in files: + src_file = os.path.join(root, f) + dst_file = os.path.join(dst_dir, f) + shutil.copy2(src_file, dst_file) + logging.info("Editable wheel path: %s" % dst) + + +def copy_individual_files(src, dst, regex): + os.makedirs(dst, exist_ok=True) + for f in glob.glob(os.path.join(src, regex)): + dst_file = os.path.join(dst, os.path.basename(f)) + if os.path.exists(dst_file): + os.remove(dst_file) + shutil.copy2(f, dst_file) + logging.info("Distribution path: %s" % dst_file) diff --git a/ci/build_artifacts.sh b/ci/build_artifacts.sh index 84b8d35a2a50..d7ffe82eb699 100755 --- a/ci/build_artifacts.sh +++ b/ci/build_artifacts.sh @@ -96,6 +96,7 @@ if [[ "${allowed_artifacts[@]}" =~ "${artifact}" ]]; then --bazel_options=--config="$bazelrc_config" $bazel_remote_cache \ --python_version=$JAXCI_HERMETIC_PYTHON_VERSION \ --verbose --detailed_timestamped_log --use_new_wheel_build_rule \ + --output_path="$JAXCI_OUTPUT_DIR" \ $artifact_tag_flags # If building release artifacts, we also build a release candidate ("rc") @@ -105,18 +106,10 @@ if [[ "${allowed_artifacts[@]}" =~ "${artifact}" ]]; then --bazel_options=--config="$bazelrc_config" $bazel_remote_cache \ --python_version=$JAXCI_HERMETIC_PYTHON_VERSION \ --verbose --detailed_timestamped_log --use_new_wheel_build_rule \ + --output_path="$JAXCI_OUTPUT_DIR" \ $artifact_tag_flags --bazel_options=--repo_env=ML_WHEEL_VERSION_SUFFIX="$JAXCI_WHEEL_RC_VERSION" fi - # Move the built artifacts from the Bazel cache directory to the output - # directory. - if [[ "$artifact" == "jax" ]]; then - mv bazel-bin/dist/*.whl "$JAXCI_OUTPUT_DIR" - mv bazel-bin/dist/*.tar.gz "$JAXCI_OUTPUT_DIR" - else - mv bazel-bin/jaxlib/tools/dist/*.whl "$JAXCI_OUTPUT_DIR" - fi - # If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we # run `auditwheel show` to verify manylinux compliance. if [[ "$os" == "linux" ]] && [[ "$artifact" != "jax" ]]; then From 2518e187f3fa63f0bc2e116b0592b18e6584f2c0 Mon Sep 17 00:00:00 2001 From: Gleb Pobudzey Date: Wed, 26 Mar 2025 11:10:42 -0700 Subject: [PATCH 189/483] [Mosaic GPU] Support more layouts in the `swap` lowering. PiperOrigin-RevId: 740835345 --- jax/_src/pallas/mosaic_gpu/lowering.py | 24 +++++++++++++---- tests/pallas/mosaic_gpu_test.py | 36 +++++++++++++++++++++----- 2 files changed, 48 insertions(+), 12 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index a41a657ba738..286fedfa44d5 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1156,11 +1156,25 @@ def _swap_lowering_rule( value.store_tiled(x_smem, swizzle=swizzle) return old_value case (): - old_value = mgpu.FragmentedArray.load_strided( - x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype) - ) - value.store_untiled(x_smem) - return old_value + match value.layout: + case mgpu.WGMMARowFragLayout(): + old_value = mgpu.FragmentedArray.load_wgmma_row( + x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype) + ) + value.store_untiled(x_smem) + return old_value + case mgpu.WGMMAColFragLayout(): + old_value = mgpu.FragmentedArray.load_wgmma_col( + x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype) + ) + value.store_untiled(x_smem) + return old_value + case _: + old_value = mgpu.FragmentedArray.load_strided( + x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype) + ) + value.store_untiled(x_smem) + return old_value case _: raise NotImplementedError(f"Unsupported transforms: {transforms}") diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index c10f06f8bb5d..aea49b645ec6 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -657,14 +657,10 @@ def kernel(x_ref, o_ref, barrier_ref): @parameterized.product( src_memory_space=[plgpu.SMEM, plgpu.GMEM], - layout=[ - plgpu.Layout.WGMMA_ROW, - plgpu.Layout.WGMMA_COL, - plgpu.Layout.WG_STRIDED((128,), vec_size=1), - None, - ], + layout=[plgpu.Layout.WG_STRIDED((128,), vec_size=1), None, + ] ) - def test_load_to_layout_with_indexing(self, src_memory_space, layout): + def test_load_to_strided_layout_with_indexing(self, src_memory_space, layout): self.skip_if_wg_semantics() @functools.partial( @@ -685,6 +681,32 @@ def kernel(x_ref, o_ref): x = jnp.arange(2 * 128, dtype=jnp.float32).reshape(2, 128) np.testing.assert_array_equal(kernel(x), x) + @parameterized.product( + src_memory_space=[plgpu.SMEM, plgpu.GMEM], + layout=[plgpu.Layout.WGMMA_ROW, plgpu.Layout.WGMMA_COL], + m=[64, 128, 192], + ) + def test_load_to_wgmma_row_col_layout_with_indexing(self, src_memory_space, layout, m): + self.skip_if_wg_semantics() + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct([2, m], jnp.float32), + in_specs=[pl.BlockSpec(memory_space=src_memory_space)], + out_specs=plgpu.GPUBlockSpec( + (2, m), + lambda: (0, 0), + memory_space=plgpu.SMEM, + ), + ) + def kernel(x_ref, o_ref): + for i in range(2): + x = plgpu.load(x_ref, (i,), layout=layout) + o_ref[i, ...] = x + + x = jnp.arange(2 * m, dtype=jnp.float32).reshape(2, m) + np.testing.assert_array_equal(kernel(x), x) + @parameterized.product( src_memory_space=[plgpu.SMEM], layout=[plgpu.Layout.WGMMA_ROW, plgpu.Layout.WGMMA_COL], From feed69c56192ae5883082ecf4155bb2f69d1658b Mon Sep 17 00:00:00 2001 From: Ayaka Date: Wed, 26 Mar 2025 11:18:08 -0700 Subject: [PATCH 190/483] Add nan checking to jax.numpy functions PiperOrigin-RevId: 740838221 --- jax/_src/numpy/ufuncs.py | 71 ++++++++++++++++++++++++++--------- tests/jax_numpy_error_test.py | 64 ++++++++++++++++++++++++++++--- 2 files changed, 113 insertions(+), 22 deletions(-) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 1df973039213..0ea2992c9955 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -575,7 +575,9 @@ def log1p(x: ArrayLike, /) -> Array: >>> jnp.expm1(jnp.log(x1+1)) # doctest: +SKIP Array([1.000166e-04, 9.536743e-07, 0.000000e+00], dtype=float32) """ - return lax.log1p(*promote_args_inexact('log1p', x)) + out = lax.log1p(*promote_args_inexact('log1p', x)) + jnp_error._set_error_if_nan(out) + return out @export @@ -607,7 +609,9 @@ def sin(x: ArrayLike, /) -> Array: ... print(jnp.sin(x)) [ 0.707 1. 0.707 -0. ] """ - return lax.sin(*promote_args_inexact('sin', x)) + out = lax.sin(*promote_args_inexact('sin', x)) + jnp_error._set_error_if_nan(out) + return out @export @@ -638,7 +642,9 @@ def cos(x: ArrayLike, /) -> Array: ... print(jnp.cos(x)) [ 0.707 -0. -0.707 -0.866] """ - return lax.cos(*promote_args_inexact('cos', x)) + out = lax.cos(*promote_args_inexact('cos', x)) + jnp_error._set_error_if_nan(out) + return out @export @@ -669,7 +675,9 @@ def tan(x: ArrayLike, /) -> Array: ... print(jnp.tan(x)) [ 0. 0.577 1. -1. -0.577] """ - return lax.tan(*promote_args_inexact('tan', x)) + out = lax.tan(*promote_args_inexact('tan', x)) + jnp_error._set_error_if_nan(out) + return out @export @@ -711,7 +719,9 @@ def arcsin(x: ArrayLike, /) -> Array: ... jnp.arcsin(3+4j) Array(0.634+2.306j, dtype=complex64, weak_type=True) """ - return lax.asin(*promote_args_inexact('arcsin', x)) + out = lax.asin(*promote_args_inexact('arcsin', x)) + jnp_error._set_error_if_nan(out) + return out @export @@ -754,7 +764,9 @@ def arccos(x: ArrayLike, /) -> Array: ... jnp.arccos(4-1j) Array(0.252+2.097j, dtype=complex64, weak_type=True) """ - return lax.acos(*promote_args_inexact('arccos', x)) + out = lax.acos(*promote_args_inexact('arccos', x)) + jnp_error._set_error_if_nan(out) + return out @export @@ -1008,6 +1020,7 @@ def arccosh(x: ArrayLike, /) -> Array: # Note: arccosh is multi-valued for complex input, and lax.acosh # uses a different convention than np.arccosh. result = lax.acosh(*promote_args_inexact("arccosh", x)) + jnp_error._set_error_if_nan(result) if dtypes.issubdtype(result.dtype, np.complexfloating): result = _where(real(result) < 0, lax.neg(result), result) return result @@ -1113,7 +1126,9 @@ def arctanh(x: ArrayLike, /) -> Array: ... jnp.arctanh(x1) Array([-0.549+1.571j, 0.347+1.571j, 0.239-1.509j], dtype=complex64) """ - return lax.atanh(*promote_args_inexact('arctanh', x)) + out = lax.atanh(*promote_args_inexact('arctanh', x)) + jnp_error._set_error_if_nan(out) + return out @export @@ -1146,7 +1161,9 @@ def sqrt(x: ArrayLike, /) -> Array: >>> jnp.sqrt(-1) Array(nan, dtype=float32, weak_type=True) """ - return lax.sqrt(*promote_args_inexact('sqrt', x)) + out = lax.sqrt(*promote_args_inexact('sqrt', x)) + jnp_error._set_error_if_nan(out) + return out @export @@ -1215,7 +1232,11 @@ def add(x: ArrayLike, y: ArrayLike, /) -> Array: Array([10, 11, 12, 13], dtype=int32) """ x, y = promote_args("add", x, y) - return lax.add(x, y) if x.dtype != bool else lax.bitwise_or(x, y) + if x.dtype == bool: + return lax.bitwise_or(x, y) + out = lax.add(x, y) + jnp_error._set_error_if_nan(out) + return out def _multiply_at(a: Array, indices: Any, b: ArrayLike) -> Array: @@ -1544,7 +1565,9 @@ def subtract(x: ArrayLike, y: ArrayLike, /) -> Array: >>> x - 10 Array([-10, -9, -8, -7], dtype=int32) """ - return lax.sub(*promote_args("subtract", x, y)) + out = lax.sub(*promote_args("subtract", x, y)) + jnp_error._set_error_if_nan(out) + return out @export @@ -1768,7 +1791,9 @@ def float_power(x: ArrayLike, y: ArrayLike, /) -> Array: >>> jnp.float_power(-3, 1.7) Array(nan, dtype=float32, weak_type=True) """ - return lax.pow(*promote_args_inexact("float_power", x, y)) + out = lax.pow(*promote_args_inexact("float_power", x, y)) + jnp_error._set_error_if_nan(out) + return out @export @@ -2446,7 +2471,9 @@ def true_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: :func:`jax.numpy.floor_divide` for integer division """ x1, x2 = promote_args_inexact("true_divide", x1, x2) - return lax.div(x1, x2) + out = lax.div(x1, x2) + jnp_error._set_error_if_nan(out) + return out @export @@ -2648,7 +2675,9 @@ def power(x1: ArrayLike, x2: ArrayLike, /) -> Array: return lax.integer_pow(x1, x2) # Handle cases #2 and #3 under a jit: - return _power(x1, x2) + out = _power(x1, x2) + jnp_error._set_error_if_nan(out) + return out @export def pow(x1: ArrayLike, x2: ArrayLike, /) -> Array: @@ -2774,7 +2803,9 @@ def log2(x: ArrayLike, /) -> Array: im = lax.imag(r) ln2 = lax.log(_constant_like(re, 2)) return lax.complex(lax.div(re, ln2), lax.div(im, ln2)) - return lax.div(lax.log(x), lax.log(_constant_like(x, 2))) + out = lax.div(lax.log(x), lax.log(_constant_like(x, 2))) + jnp_error._set_error_if_nan(out) + return out @export @@ -2804,7 +2835,9 @@ def log10(x: ArrayLike, /) -> Array: im = lax.imag(r) ln10 = lax.log(_constant_like(re, 10)) return lax.complex(lax.div(re, ln10), lax.div(im, ln10)) - return lax.div(lax.log(x), lax.log(_constant_like(x, 10))) + out = lax.div(lax.log(x), lax.log(_constant_like(x, 10))) + jnp_error._set_error_if_nan(out) + return out @export @@ -3064,7 +3097,9 @@ def remainder(x1: ArrayLike, x2: ArrayLike, /) -> Array: trunc_mod_not_zero = lax.ne(trunc_mod, zero) do_plus = lax.bitwise_and( lax.ne(lax.lt(trunc_mod, zero), lax.lt(x2, zero)), trunc_mod_not_zero) - return lax.select(do_plus, lax.add(trunc_mod, x2), trunc_mod) + out = lax.select(do_plus, lax.add(trunc_mod, x2), trunc_mod) + jnp_error._set_error_if_nan(out) + return out @export @@ -3112,7 +3147,9 @@ def fmod(x1: ArrayLike, x2: ArrayLike, /) -> Array: check_arraylike("fmod", x1, x2) if dtypes.issubdtype(dtypes.result_type(x1, x2), np.integer): x2 = _where(x2 == 0, lax._ones(x2), x2) - return lax.rem(*promote_args_numeric("fmod", x1, x2)) + out = lax.rem(*promote_args_numeric("fmod", x1, x2)) + jnp_error._set_error_if_nan(out) + return out @export diff --git a/tests/jax_numpy_error_test.py b/tests/jax_numpy_error_test.py index c2883f2005e0..08917aeed364 100644 --- a/tests/jax_numpy_error_test.py +++ b/tests/jax_numpy_error_test.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import operator + from absl.testing import absltest from absl.testing import parameterized import jax @@ -112,16 +114,68 @@ def test_error_category_invalid_category(self): jnp.isnan(jnp.float32(1.0)), "x is NaN", category="invalid" ) - @parameterized.product(jit=[True, False]) - def test_can_raise_nan_error(self, jit): - x = jnp.arange(4, dtype=jnp.float32) - 1 + @staticmethod + def op_cases(cases): + for jit in (True, False): + for func, operands in cases: + if not isinstance(operands, tuple): + operands = (operands,) + + jit_str = "jit" if jit else "nojit" + func_str = f"{func.__module__}.{func.__name__}" + name = f"_{jit_str}_{func_str}" + + yield name, jit, func, operands + + @parameterized.named_parameters( + op_cases(( + # list of all NaN-producing jax.numpy functions + # go/keep-sorted start + (jnp.acos, 2.0), + (jnp.acosh, 0.5), + (jnp.add, (jnp.inf, -jnp.inf)), + (jnp.arccos, 2.0), + (jnp.arccosh, 0.5), + (jnp.arcsin, -2.0), + (jnp.arctanh, -2.0), + (jnp.asin, -2.0), + (jnp.atanh, -2.0), + (jnp.cos, jnp.inf), + (jnp.divide, (0.0, 0.0)), + (jnp.divmod, (1.0, 0.0)), + (jnp.float_power, (-1.0, 0.5)), + (jnp.fmod, (1.0, 0.0)), + (jnp.log, -1.0), + (jnp.log10, -1.0), + (jnp.log1p, -1.5), + (jnp.log2, -1.0), + (jnp.mod, (1.0, 0.0)), + (jnp.pow, (-1.0, 0.5)), + (jnp.power, (-1.0, 0.5)), + (jnp.remainder, (1.0, 0.0)), + (jnp.sin, jnp.inf), + # TODO(https://github.com/jax-ml/jax/issues/27470): Not yet supported. + # (jnp.sinc, jnp.inf), + (jnp.sqrt, -4.0), + (jnp.subtract, (jnp.inf, jnp.inf)), + (jnp.tan, jnp.inf), + (jnp.true_divide, (0.0, 0.0)), + (operator.add, (jnp.inf, -jnp.inf)), + (operator.mod, (1.0, 0.0)), + (operator.pow, (-1.0, 0.5)), + (operator.sub, (jnp.inf, jnp.inf)), + (operator.truediv, (0.0, 0.0)), + # go/keep-sorted end + )) + ) + def test_can_raise_nan_error(self, jit, f, operands): + operands = [jnp.float32(x) for x in operands] - f = jnp.log if jit: f = jax.jit(f) with jnp_error.error_checking_behavior(nan="raise"): - f(x) + f(*operands) with self.assertRaisesRegex(JaxValueError, "NaN"): error_check.raise_if_error() From 1b7c8e8d08d1308c438d47334387c6339f0456f8 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 26 Mar 2025 11:25:04 -0700 Subject: [PATCH 191/483] Add editable `jax` wheel target. The set of editable wheels (`jaxlib`, `jax-cuda-plugin` and `jax-cuda-pjrt`) was used as dependencies in `requirements.in` file together with `:build_jaxlib=false` flag. After [adding `jax` wheel dependencies](https://github.com/jax-ml/jax/commit/f5a4d1a85c41a42ed8fb389259a241513970ff9a) to the tests when `:build_jaxlib=false` is used, we need an editable `jax` wheel target as well to get the tests passing. PiperOrigin-RevId: 740840736 --- BUILD.bazel | 42 ++++++++++++++++++++++-------------------- build/build.py | 11 +++++++++-- build_wheel.py | 31 ++++++++++++++++++++++--------- 3 files changed, 53 insertions(+), 31 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index e7cf6de66cad..2c10f0d9a748 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -72,35 +72,37 @@ py_binary( ], ) +WHEEL_SOURCE_FILES = [ + ":transitive_py_data", + ":transitive_py_deps", + "//jax:py.typed", + "AUTHORS", + "LICENSE", + "README.md", + "pyproject.toml", + "setup.py", +] + jax_wheel( name = "jax_wheel", platform_independent = True, - source_files = [ - ":transitive_py_data", - ":transitive_py_deps", - "//jax:py.typed", - "AUTHORS", - "LICENSE", - "README.md", - "pyproject.toml", - "setup.py", - ], + source_files = WHEEL_SOURCE_FILES, + wheel_binary = ":build_wheel", + wheel_name = "jax", +) + +jax_wheel( + name = "jax_wheel_editable", + editable = True, + platform_independent = True, + source_files = WHEEL_SOURCE_FILES, wheel_binary = ":build_wheel", wheel_name = "jax", ) jax_source_package( name = "jax_source_package", - source_files = [ - ":transitive_py_data", - ":transitive_py_deps", - "//jax:py.typed", - "AUTHORS", - "LICENSE", - "README.md", - "pyproject.toml", - "setup.py", - ], + source_files = WHEEL_SOURCE_FILES, source_package_binary = ":build_wheel", source_package_name = "jax", ) diff --git a/build/build.py b/build/build.py index 7f70f7b2ffef..4d16851f837c 100755 --- a/build/build.py +++ b/build/build.py @@ -68,10 +68,14 @@ # rule as the default. WHEEL_BUILD_TARGET_DICT_NEW = { "jax": "//:jax_wheel", + "jax_editable": "//:jax_wheel_editable", "jax_source_package": "//:jax_source_package", "jaxlib": "//jaxlib/tools:jaxlib_wheel", + "jaxlib_editable": "//jaxlib/tools:jaxlib_wheel_editable", "jax-cuda-plugin": "//jaxlib/tools:jax_cuda_plugin_wheel", + "jax-cuda-plugin_editable": "//jaxlib/tools:jax_cuda_plugin_wheel_editable", "jax-cuda-pjrt": "//jaxlib/tools:jax_cuda_pjrt_wheel", + "jax-cuda-pjrt_editable": "//jaxlib/tools:jax_cuda_pjrt_wheel_editable", "jax-rocm-plugin": "//jaxlib/tools:jax_rocm_plugin_wheel", "jax-rocm-pjrt": "//jaxlib/tools:jax_rocm_pjrt_wheel", } @@ -662,9 +666,12 @@ async def main(): ) # Append the build target to the Bazel command. - build_target = wheel_build_targets[wheel] + if args.use_new_wheel_build_rule and args.editable: + build_target = wheel_build_targets[wheel + "_editable"] + else: + build_target = wheel_build_targets[wheel] wheel_build_command.append(build_target) - if args.use_new_wheel_build_rule and wheel == "jax": + if args.use_new_wheel_build_rule and wheel == "jax" and not args.editable: wheel_build_command.append(wheel_build_targets["jax_source_package"]) if not args.use_new_wheel_build_rule: diff --git a/build_wheel.py b/build_wheel.py index b4db96773527..793523e8e3b2 100644 --- a/build_wheel.py +++ b/build_wheel.py @@ -61,6 +61,11 @@ "Whether to build the source package only. Optional." ), ) +parser.add_argument( + "--editable", + action="store_true", + help="Create an 'editable' jax build instead of a wheel.", +) args = parser.parse_args() @@ -90,7 +95,11 @@ def prepare_srcs(deps: list[str], srcs_dir: str) -> None: """ for file in deps: - if not (file.startswith("bazel-out") or file.startswith("external")): + if not ( + file.startswith("bazel-out") + or file.startswith("external") + or file.startswith("jaxlib") + ): copy_file(file, srcs_dir) @@ -103,14 +112,18 @@ def prepare_srcs(deps: list[str], srcs_dir: str) -> None: try: os.makedirs(args.output_path, exist_ok=True) prepare_srcs(args.srcs, pathlib.Path(sources_path)) - build_utils.build_wheel( - sources_path, - args.output_path, - package_name="jax", - git_hash=args.jaxlib_git_hash, - build_wheel_only=args.build_wheel_only, - build_source_package_only=args.build_source_package_only, - ) + package_name = "jax" + if args.editable: + build_utils.build_editable(sources_path, args.output_path, package_name) + else: + build_utils.build_wheel( + sources_path, + args.output_path, + package_name, + git_hash=args.jaxlib_git_hash, + build_wheel_only=args.build_wheel_only, + build_source_package_only=args.build_source_package_only, + ) finally: if tmpdir: tmpdir.cleanup() From e364abe961ed251915b1a1c7374a0ebd9974c201 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Wed, 12 Mar 2025 05:15:58 +0000 Subject: [PATCH 192/483] Prune passthrough outputs in lax.switch. --- jax/_src/lax/control_flow/conditionals.py | 18 +++++++++++++-- tests/lax_control_flow_test.py | 28 +++++++++++++++++++++++ 2 files changed, 44 insertions(+), 2 deletions(-) diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 63896cc2a0bf..1e9372254ca1 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -151,12 +151,27 @@ def switch(index, branches, *operands): out_trees[0], jaxprs[0].out_avals, f"branch {i + 1} output", out_tree, jaxpr.out_avals) + # prune passthrough outputs + fwds = [pe._jaxpr_forwarding(jaxpr.jaxpr) for jaxpr in jaxprs] + in_fwd = [xs[0] if len(set(xs)) == 1 else None for xs in zip(*fwds)] + keep = [f is None for f in in_fwd] + jaxprs = [pe.prune_closed_jaxpr_outputs(jaxpr, keep) for jaxpr in jaxprs] + joined_effects = core.join_effects(*(jaxpr.effects for jaxpr in jaxprs)) disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(joined_effects) if disallowed_effects: raise NotImplementedError( f'Effects not supported in `switch`: {disallowed_effects}') + jaxprs = [replace_jaxpr_effects(jaxpr, joined_effects) for jaxpr in jaxprs] out = cond_p.bind(index, *consts, *ops, branches=tuple(jaxprs)) + out_ = iter(out) + + all_inputs = [*consts, *ops] + out = [ + next(out_) if fwd is None else lax.asarray(all_inputs[fwd]) + for fwd in in_fwd + ] + assert next(out_, None) is None return tree_unflatten(out_trees[0], out) @@ -259,7 +274,7 @@ def cond(pred, true_fun, false_fun, *operands): out_tree, true_jaxpr.out_avals, "false_fun output", false_out_tree, false_jaxpr.out_avals) - # prune passhtrough outputs + # prune passthrough outputs true_fwds = pe._jaxpr_forwarding(true_jaxpr.jaxpr) false_fwds = pe._jaxpr_forwarding(false_jaxpr.jaxpr) in_fwd = [i if i == j else None for i, j in zip(true_fwds, false_fwds)] @@ -278,7 +293,6 @@ def cond(pred, true_fun, false_fun, *operands): true_jaxpr = replace_jaxpr_effects(true_jaxpr, joined_effects) out = cond_p.bind(index, *consts, *ops, branches=(false_jaxpr, true_jaxpr)) - num_consts = len(consts) out_ = iter(out) all_inputs = [*consts, *ops] diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 3871a87a7a3e..9ac4e8c6da80 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -1309,6 +1309,34 @@ def f(x): self.assertAllClose(ans, expected, check_dtypes=False) jtu.check_grads(f, (x,), order=2, modes=["fwd", "rev"]) + @parameterized.parameters(itertools.product(range(4), repeat=3)) + @jtu.run_on_devices("cpu") + def testSwitchGradWithForwarding(self, seed, num_input_fwd, num_output_fwd): + num_args = 3 + num_branches = 4 + rng = np.random.RandomState(seed) + in_perm = rng.permutation(num_args) + out_perm = rng.permutation(num_args) + + def branch(s, inputs): + inputs = [inputs[i] for i in in_perm] + outputs = inputs[:num_input_fwd] + [ + s * jnp.exp(inputs[i]) if i < num_output_fwd else jnp.sin(inputs[i]) + for i in range(num_args - num_input_fwd)] + return [outputs[i] for i in out_perm] + + branches = [partial(branch, i) for i in range(num_branches)] + + @jax.jit + def f_(idx, inputs): + idx = lax.convert_element_type(idx // 1, np.int32) + return lax.switch(idx, branches, inputs) + + for idx in range(num_branches): + f = partial(f_, idx) + jtu.check_grads(f, (jnp.arange(float(num_args)),), + order=1, modes=['fwd', 'rev'], atol=1e-2, rtol=1e-2) + def testSwitchGradWithWeakTypeMismatch(self): # issue #4696, PR #4896 dtype = dtypes.canonicalize_dtype(np.float64) dtype = jnp.float32 if dtype == jnp.float32 else jnp.float64 From ec2f0f5913a3376bb940e17cc0151090f5d07d2d Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 26 Mar 2025 11:56:05 -0700 Subject: [PATCH 193/483] [sharding_in_types] Enable auto_axes to work without any mesh context manager. We extract the mesh from `out_shardings` given. This allows APIs like `random.uniform` to accept NamedSharding in `out_sharding` argument and continue to work without a mesh context. PiperOrigin-RevId: 740852542 --- jax/_src/pjit.py | 46 ++++++++++++++++++++++++++++++++++------------ tests/pjit_test.py | 42 +++++++++++++++++++++++++++++++++++++++--- 2 files changed, 73 insertions(+), 15 deletions(-) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index bcdbe6b1bdb7..054b55e32918 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2821,29 +2821,50 @@ def _reshard_batcher(axis_data, vals_in, dims_in, dst_sharding): # -------------------- auto and user mode ------------------------- def _get_new_mesh(axes: str | tuple[str, ...] | None, - axis_type: mesh_lib.AxisType, name: str, - error_on_manual_to_auto_explict=False): + axis_type: mesh_lib.AxisType, name: str, shardings=None, + error_on_manual_to_auto_explicit=False): cur_mesh = mesh_lib.get_abstract_mesh() - # TODO(yashkatariya): Maybe allow fetching mesh from the args to enable - # computation follows data? - if cur_mesh.empty: + flat_shardings, _ = tree_flatten(shardings) + sharding_mesh = mesh_lib.empty_abstract_mesh + for i in flat_shardings: + if isinstance(i, NamedSharding): + if not sharding_mesh.empty and sharding_mesh != i.mesh: + raise ValueError( + f'Shardings passed to {name} should have the same mesh. Got one' + f' mesh {sharding_mesh} and another {i.mesh}') + sharding_mesh = i.mesh.abstract_mesh + + if sharding_mesh.empty and cur_mesh.empty: raise ValueError( f'Context mesh {cur_mesh} cannot be empty. Please use' ' `jax.sharding.use_mesh` API to enter into a mesh context when using' f' `{name}` API.') + if not sharding_mesh.empty and not cur_mesh.empty: + if sharding_mesh != cur_mesh: + raise ValueError( + f'Context mesh {cur_mesh} must match the mesh passed to shardings' + f' {sharding_mesh}. Recommended approach is to use' + ' `jax.sharding.use_mesh` context manager.') + mesh_to_use = cur_mesh + elif sharding_mesh.empty and not cur_mesh.empty: + mesh_to_use = cur_mesh + else: + assert not sharding_mesh.empty and cur_mesh.empty + mesh_to_use = sharding_mesh + if axes is None: - axes = cur_mesh.axis_names + axes = mesh_to_use.axis_names if not isinstance(axes, tuple): axes = (axes,) for a in axes: - if (error_on_manual_to_auto_explict and - cur_mesh._name_to_type[a] == mesh_lib.AxisType.Manual and + if (error_on_manual_to_auto_explicit and + mesh_to_use._name_to_type[a] == mesh_lib.AxisType.Manual and axis_type in {mesh_lib.AxisType.Auto, mesh_lib.AxisType.Explicit}): raise NotImplementedError( 'Going from `Manual` AxisType to `Auto` or `Explicit` AxisType is not' ' allowed. Please file a bug at https://github.com/jax-ml/jax/issues' ' with your use case') - return cur_mesh.update_axis_types({a: axis_type for a in axes}) + return mesh_to_use.update_axis_types({a: axis_type for a in axes}) def auto_axes(fun, *, axes: str | tuple[str, ...] | None = None, out_shardings=None): @@ -2855,8 +2876,9 @@ def decorator(*args, **kwargs): raise TypeError("Missing required keyword argument: 'out_shardings'") else: _out_shardings = out_shardings - new_mesh = _get_new_mesh(axes, mesh_lib.AxisType.Auto, 'auto_axes', - error_on_manual_to_auto_explict=True) + new_mesh = _get_new_mesh( + axes, mesh_lib.AxisType.Auto, 'auto_axes', shardings=_out_shardings, + error_on_manual_to_auto_explicit=True) with mesh_lib.use_abstract_mesh(new_mesh): in_specs = tree_map(lambda a: core.modify_spec_for_auto_manual( core.get_aval(a).sharding.spec, new_mesh), args) @@ -2883,7 +2905,7 @@ def decorator(*args, **kwargs): else: _in_shardings = in_shardings new_mesh = _get_new_mesh(axes, mesh_lib.AxisType.Explicit, 'explicit_axes', - error_on_manual_to_auto_explict=True) + error_on_manual_to_auto_explicit=True) with mesh_lib.use_abstract_mesh(new_mesh): args = mesh_cast(args, _in_shardings) out = fun(*args, **kwargs) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index d6673c6b6d5a..5cd16e1e6925 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7117,7 +7117,7 @@ def g(x): out = jax.grad(g)(arr) self.assertEqual(out.sharding, arr.sharding) - def test_auto_axes_computation_follows_data_error(self): + def test_auto_axes_computation_follows_data(self): mesh = jtu.create_mesh((2,), ('x',), axis_types=(AxisType.Explicit,)) s = NamedSharding(mesh, P('x')) arr = jax.device_put(np.arange(8), s) @@ -7126,8 +7126,9 @@ def test_auto_axes_computation_follows_data_error(self): def f(x): return x * 2 - with self.assertRaisesRegex(ValueError, "Context mesh.*cannot be empty"): - auto_axes(f, out_shardings=s)(arr) + out = auto_axes(f, out_shardings=s)(arr) + self.assertEqual(out.sharding, s) + self.assertArraysEqual(out, arr * 2) def test_divisbility_aval_error(self): abstract_mesh = mesh_lib.AbstractMesh( @@ -7264,6 +7265,41 @@ def f(key): else: self.assertIn('mhlo.sharding = "{devices=[2,2]<=[4]}"}', lowered_text) + def test_random_normal_wo_mesh_context(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y'), + axis_types=(AxisType.Explicit,) * 2) + s = NamedSharding(mesh, P('x', 'y')) + + @jax.jit + def f(arr, key): + out = jax.random.normal(key, shape=(8, 12), out_sharding=s) + self.assertEqual(out.aval.sharding.spec, P('x', 'y')) + return arr + out + + key = jax.random.key(1) + out = f(jax.device_put(np.arange(8 * 12.).reshape(8, 12), s), key) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + + def test_auto_axes_no_context_mesh(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y'), axis_types=(AxisType.Explicit,) * 2) + np_inp = np.arange(16.).reshape(8, 2) + s = NamedSharding(mesh, P('x', 'y')) + arr = jax.device_put(np_inp, s) + + @partial(auto_axes, axes='x', + out_shardings=NamedSharding(mesh, P('x', 'y'))) + def h(y): + self.assertEqual(y.aval.sharding.spec, P(None, 'y')) + z = jnp.sin(y) + self.assertEqual(z.aval.sharding.spec, P(None, 'y')) + return z + + out = jax.jit(h)(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + + out = h(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): From aa160937cf3a7aa4dd953c18c1bc1ef83ddc0546 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 26 Mar 2025 12:05:37 -0700 Subject: [PATCH 194/483] [JAX] [XLA:Python] Migrate more modules to JAX. PiperOrigin-RevId: 740855958 --- jaxlib/xla/BUILD | 102 ++++++-- jaxlib/xla/callback.cc | 2 +- jaxlib/xla/config.cc | 2 +- jaxlib/xla/dlpack.cc | 6 +- jaxlib/xla/dlpack.h | 2 +- jaxlib/xla/ifrt_proxy.cc | 2 +- jaxlib/xla/jax_jit.h | 2 +- jaxlib/xla/nb_class_ptr.h | 59 +++++ jaxlib/xla/pjit.cc | 6 +- jaxlib/xla/pmap_lib.cc | 6 +- jaxlib/xla/py_array.cc | 6 +- jaxlib/xla/py_array.h | 4 +- jaxlib/xla/py_client.cc | 6 +- jaxlib/xla/py_client.h | 2 +- jaxlib/xla/py_compile_only_client.cc | 2 +- jaxlib/xla/py_compile_only_client.h | 2 +- jaxlib/xla/py_device.cc | 4 +- jaxlib/xla/py_device.h | 2 +- jaxlib/xla/py_device_list.cc | 4 +- jaxlib/xla/py_device_list.h | 2 +- jaxlib/xla/py_executable.cc | 4 +- jaxlib/xla/py_executable.h | 4 +- jaxlib/xla/py_host_callback.cc | 2 +- jaxlib/xla/py_memory_space.cc | 2 +- jaxlib/xla/py_memory_space.h | 2 +- jaxlib/xla/py_program.cc | 4 +- jaxlib/xla/py_socket_transfer.cc | 4 +- jaxlib/xla/py_values.cc | 2 +- jaxlib/xla/python_ref_manager.cc | 104 ++++++++ jaxlib/xla/python_ref_manager.h | 108 ++++++++ jaxlib/xla/pytree.cc | 2 +- jaxlib/xla/pytree.h | 2 +- jaxlib/xla/sharding.cc | 2 +- jaxlib/xla/sharding.h | 2 +- jaxlib/xla/to_ifrt_sharding.cc | 2 +- jaxlib/xla/traceback.cc | 357 +++++++++++++++++++++++++++ jaxlib/xla/traceback.h | 108 ++++++++ jaxlib/xla/xla.cc | 6 +- 38 files changed, 866 insertions(+), 74 deletions(-) create mode 100644 jaxlib/xla/nb_class_ptr.h create mode 100644 jaxlib/xla/python_ref_manager.cc create mode 100644 jaxlib/xla/python_ref_manager.h create mode 100644 jaxlib/xla/traceback.cc create mode 100644 jaxlib/xla/traceback.h diff --git a/jaxlib/xla/BUILD b/jaxlib/xla/BUILD index e10977d526ed..512eeb867618 100644 --- a/jaxlib/xla/BUILD +++ b/jaxlib/xla/BUILD @@ -53,11 +53,14 @@ nanobind_extension( ":ifrt_proxy", ":jax_jit", ":mlir", + ":nb_class_ptr", ":pjit", ":pmap_lib", ":py_client", + ":python_ref_manager", ":pytree", ":sdy", + ":traceback", ":util", ":weakref_lru_cache", ":xla_compiler", @@ -104,13 +107,10 @@ nanobind_extension( "@xla//xla/python:logging", "@xla//xla/python:nb_absl_flat_hash_map", "@xla//xla/python:nb_absl_span", - "@xla//xla/python:nb_class_ptr", "@xla//xla/python:ops", "@xla//xla/python:pprof_profile_builder", "@xla//xla/python:profiler", - "@xla//xla/python:python_ref_manager", "@xla//xla/python:refine_polymorphic_shapes", - "@xla//xla/python:traceback", "@xla//xla/python:types", "@xla//xla/python/ifrt", "@xla//xla/python/ifrt:plugin_program", @@ -162,6 +162,7 @@ cc_library( ], features = ["-use_header_modules"], deps = [ + ":python_ref_manager", "@com_google_absl//absl/base", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/status", @@ -176,7 +177,6 @@ cc_library( "@xla//xla/pjrt:host_callback", "@xla//xla/pjrt:transpose", "@xla//xla/python:nb_numpy", - "@xla//xla/python:python_ref_manager", "@xla//xla/service:custom_call_status", "@xla//xla/tsl/platform:statusor", ], @@ -193,13 +193,13 @@ cc_library( ], features = ["-use_header_modules"], deps = [ + ":python_ref_manager", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@nanobind", "@xla//third_party/python_runtime:headers", # buildcleaner: keep - "@xla//xla/python:python_ref_manager", "@xla//xla/tsl/platform:logging", ], ) @@ -246,7 +246,10 @@ cc_library( ], features = ["-use_header_modules"], deps = [ + ":nb_class_ptr", ":py_client", + ":python_ref_manager", + ":traceback", ":util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", @@ -265,9 +268,6 @@ cc_library( "@xla//xla/pjrt:pjrt_common", "@xla//xla/pjrt:pjrt_compiler", "@xla//xla/pjrt:pjrt_layout", - "@xla//xla/python:nb_class_ptr", - "@xla//xla/python:python_ref_manager", - "@xla//xla/python:traceback", "@xla//xla/python:types", "@xla//xla/python/ifrt", "@xla//xla/python/pjrt_ifrt", @@ -308,6 +308,7 @@ cc_library( ], features = ["-use_header_modules"], deps = [ + ":nb_class_ptr", ":py_client", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -320,7 +321,6 @@ cc_library( "@tsl//tsl/platform:env", "@tsl//tsl/platform:statusor", "@xla//xla/pjrt:status_casters", - "@xla//xla/python:nb_class_ptr", "@xla//xla/python/ifrt", "@xla//xla/python/ifrt:attribute_map", "@xla//xla/python/ifrt_proxy/client:grpc_client", @@ -340,6 +340,7 @@ cc_library( features = ["-use_header_modules"], deps = [ ":py_client", + ":python_ref_manager", ":pytree", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", @@ -359,7 +360,6 @@ cc_library( "@xla//xla/python:nb_absl_inlined_vector", "@xla//xla/python:nb_absl_span", "@xla//xla/python:nb_helpers", - "@xla//xla/python:python_ref_manager", "@xla//xla/python:types", "@xla//xla/tsl/platform:logging", ], @@ -404,6 +404,15 @@ cc_library( ], ) +cc_library( + name = "nb_class_ptr", + hdrs = ["nb_class_ptr.h"], + copts = ["-fexceptions"], + features = ["-use_header_modules"], + visibility = jax_visibility("jaxlib/xla/nb_class_ptr"), + deps = ["@nanobind"], +) + cc_library( name = "pjit", srcs = ["pjit.cc"], @@ -418,8 +427,11 @@ cc_library( ":config", ":guard_lib", ":jax_jit", + ":nb_class_ptr", ":py_client", + ":python_ref_manager", ":pytree", + ":traceback", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", @@ -438,11 +450,8 @@ cc_library( "@xla//xla:util", "@xla//xla/pjrt:exceptions", "@xla//xla/pjrt:lru_cache", - "@xla//xla/python:nb_class_ptr", "@xla//xla/python:nb_helpers", "@xla//xla/python:nb_numpy", - "@xla//xla/python:python_ref_manager", - "@xla//xla/python:traceback", "@xla//xla/python/ifrt", "@xla//xla/tsl/concurrency:ref_count", "@xla//xla/tsl/platform:env", @@ -465,8 +474,11 @@ cc_library( deps = [ ":config", ":jax_jit", + ":nb_class_ptr", ":py_client", + ":python_ref_manager", ":pytree", + ":traceback", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/hash", @@ -485,11 +497,8 @@ cc_library( "@xla//xla:xla_data_proto_cc", "@xla//xla/pjrt:exceptions", "@xla//xla/pjrt:status_casters", - "@xla//xla/python:nb_class_ptr", "@xla//xla/python:nb_helpers", "@xla//xla/python:nb_numpy", - "@xla//xla/python:python_ref_manager", - "@xla//xla/python:traceback", "@xla//xla/python:types", "@xla//xla/python/ifrt", "@xla//xla/tsl/concurrency:ref_count", @@ -539,9 +548,12 @@ cc_library( deps = [ ":callback", ":guard_lib", + ":nb_class_ptr", ":py_client_cpu", ":py_host_callback", ":py_host_callback_cc_proto", + ":python_ref_manager", + ":traceback", ":util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", @@ -603,12 +615,9 @@ cc_library( "@xla//xla/pjrt/distributed:client", "@xla//xla/python:aggregate_profile", "@xla//xla/python:nb_absl_span", - "@xla//xla/python:nb_class_ptr", "@xla//xla/python:nb_helpers", "@xla//xla/python:nb_numpy", "@xla//xla/python:pprof_profile_builder", - "@xla//xla/python:python_ref_manager", - "@xla//xla/python:traceback", "@xla//xla/python:types", "@xla//xla/python:xplane_to_profile_instructions", "@xla//xla/python/compile_only_ifrt:client", @@ -688,6 +697,7 @@ cc_library( deps = [ ":callback", ":py_host_callback_cc_proto", + ":python_ref_manager", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", "@com_google_absl//absl/log:check", @@ -703,7 +713,6 @@ cc_library( "@xla//xla:xla_data_proto_cc", "@xla//xla/pjrt:host_callback", "@xla//xla/pjrt:pjrt_compiler", - "@xla//xla/python:python_ref_manager", "@xla//xla/python:types", "@xla//xla/python/ifrt", "@xla//xla/python/pjrt_ifrt", @@ -734,7 +743,9 @@ cc_library( ], features = ["-use_header_modules"], deps = [ + ":nb_class_ptr", ":py_client", + ":traceback", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -744,9 +755,7 @@ cc_library( "@xla//xla:util", "@xla//xla/pjrt:pjrt_client", "@xla//xla/pjrt:status_casters", - "@xla//xla/python:nb_class_ptr", "@xla//xla/python:nb_numpy", - "@xla//xla/python:traceback", "@xla//xla/python:types", "@xla//xla/python/ifrt", "@xla//xla/python/pjrt_ifrt", @@ -761,6 +770,26 @@ cc_library( ], ) +cc_library( + name = "python_ref_manager", + srcs = ["python_ref_manager.cc"], + hdrs = ["python_ref_manager.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + visibility = jax_visibility("jaxlib/xla/python_ref_manager"), + deps = [ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@nanobind", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + ], +) + proto_library( name = "pytree_proto", srcs = ["pytree.proto"], @@ -783,6 +812,7 @@ cc_library( features = ["-use_header_modules"], visibility = jax_visibility("jaxlib/xla/pytree"), deps = [ + ":nb_class_ptr", ":pytree_cc_proto", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -794,7 +824,6 @@ cc_library( "@nanobind", "@xla//third_party/python_runtime:headers", # buildcleaner: keep "@xla//xla/pjrt:exceptions", - "@xla//xla/python:nb_class_ptr", "@xla//xla/tsl/platform:logging", ], ) @@ -833,6 +862,33 @@ cc_library( ], ) +cc_library( + name = "traceback", + srcs = ["traceback.cc"], + hdrs = ["traceback.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + visibility = jax_visibility("jaxlib/xla/traceback"), + deps = [ + ":nb_class_ptr", + "@com_google_absl//absl/base", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@nanobind", + "@tsl//tsl/platform", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla/pjrt:exceptions", + "@xla//xla/tsl/platform:logging", + ], +) + cc_library( name = "util", srcs = ["util.cc"], diff --git a/jaxlib/xla/callback.cc b/jaxlib/xla/callback.cc index 2df1715d099f..bb238e6991ec 100644 --- a/jaxlib/xla/callback.cc +++ b/jaxlib/xla/callback.cc @@ -34,11 +34,11 @@ limitations under the License. #include "absl/types/span.h" #include "nanobind/nanobind.h" #include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "jaxlib/xla/python_ref_manager.h" #include "xla/pjrt/host_callback.h" #include "xla/pjrt/transpose.h" #include "xla/primitive_util.h" #include "xla/python/nb_numpy.h" -#include "xla/python/python_ref_manager.h" #include "xla/tsl/platform/statusor.h" namespace nb = nanobind; diff --git a/jaxlib/xla/config.cc b/jaxlib/xla/config.cc index b5bc5830acbf..82f0bd0b0f5a 100644 --- a/jaxlib/xla/config.cc +++ b/jaxlib/xla/config.cc @@ -26,7 +26,7 @@ limitations under the License. #include "absl/types/span.h" #include "nanobind/nanobind.h" #include "nanobind/stl/optional.h" // IWYU pragma: keep -#include "xla/python/python_ref_manager.h" +#include "jaxlib/xla/python_ref_manager.h" #include "xla/tsl/platform/logging.h" namespace jax { diff --git a/jaxlib/xla/dlpack.cc b/jaxlib/xla/dlpack.cc index 94d57e07c34a..6c4c24bfe10e 100644 --- a/jaxlib/xla/dlpack.cc +++ b/jaxlib/xla/dlpack.cc @@ -34,8 +34,11 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "nanobind/nanobind.h" #include "nanobind/ndarray.h" +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_array.h" #include "jaxlib/xla/py_client.h" +#include "jaxlib/xla/python_ref_manager.h" +#include "jaxlib/xla/traceback.h" #include "jaxlib/xla/util.h" #include "xla/layout.h" #include "xla/pjrt/exceptions.h" @@ -45,12 +48,9 @@ limitations under the License. #include "xla/pjrt/pjrt_layout.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/device.h" -#include "xla/python/nb_class_ptr.h" #include "xla/python/pjrt_ifrt/pjrt_array.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" #include "xla/python/pjrt_ifrt/pjrt_device.h" -#include "xla/python/python_ref_manager.h" -#include "xla/python/traceback.h" #include "xla/python/types.h" #include "xla/shape_util.h" #include "xla/status_macros.h" diff --git a/jaxlib/xla/dlpack.h b/jaxlib/xla/dlpack.h index e73c477b1495..46b0954105f7 100644 --- a/jaxlib/xla/dlpack.h +++ b/jaxlib/xla/dlpack.h @@ -22,9 +22,9 @@ limitations under the License. #include "absl/status/statusor.h" #include "nanobind/nanobind.h" #include "nanobind/ndarray.h" +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_client.h" #include "xla/python/ifrt/device.h" -#include "xla/python/nb_class_ptr.h" namespace xla { diff --git a/jaxlib/xla/ifrt_proxy.cc b/jaxlib/xla/ifrt_proxy.cc index e03fde194d49..eda57be86ba5 100644 --- a/jaxlib/xla/ifrt_proxy.cc +++ b/jaxlib/xla/ifrt_proxy.cc @@ -36,12 +36,12 @@ #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/unordered_map.h" // IWYU pragma: keep #include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_client.h" #include "xla/pjrt/status_casters.h" #include "xla/python/ifrt/attribute_map.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt_proxy/client/registry.h" -#include "xla/python/nb_class_ptr.h" #include "tsl/platform/env.h" #include "tsl/platform/statusor.h" diff --git a/jaxlib/xla/jax_jit.h b/jaxlib/xla/jax_jit.h index e2c186c5d3ff..a2e6d725f3b0 100644 --- a/jaxlib/xla/jax_jit.h +++ b/jaxlib/xla/jax_jit.h @@ -36,11 +36,11 @@ limitations under the License. #include "nanobind/nanobind.h" #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "jaxlib/xla/py_values.h" +#include "jaxlib/xla/python_ref_manager.h" #include "jaxlib/xla/pytree.h" #include "jaxlib/xla/sharding.h" #include "xla/pjrt/pjrt_client.h" #include "xla/python/nb_helpers.h" -#include "xla/python/python_ref_manager.h" #include "xla/tsl/platform/logging.h" namespace jax { diff --git a/jaxlib/xla/nb_class_ptr.h b/jaxlib/xla/nb_class_ptr.h new file mode 100644 index 000000000000..e468860dc661 --- /dev/null +++ b/jaxlib/xla/nb_class_ptr.h @@ -0,0 +1,59 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_NB_CLASS_PTR_H_ +#define JAXLIB_XLA_NB_CLASS_PTR_H_ + +#include "nanobind/nanobind.h" + +namespace xla { + +// A reference-counting smart pointer to a nanobind-wrapped class on the Python +// heap. Type T must be a class known to nanobind via a nanobind::class_ +// declaration. nb_class_ptr is useful for managing C++ classes that may be +// allocated inline in Python objects on the Python heap. +template +class nb_class_ptr : public nanobind::object { + public: + inline nb_class_ptr() : nanobind::object() {} + inline nb_class_ptr(nanobind::handle h, ::nanobind::detail::borrow_t) + : nanobind::object(h, ::nanobind::detail::borrow_t{}) {} + inline nb_class_ptr(nanobind::handle h, ::nanobind::detail::steal_t) + : nanobind::object(h, ::nanobind::detail::steal_t{}) {} + inline static bool check_(nanobind::handle h) { + nanobind::handle type = nanobind::type(); + return h.type().is(type); + }; + + T* operator->() const { return nanobind::inst_ptr(ptr()); } + T& operator*() const { return *nanobind::inst_ptr(ptr()); } + T* get() const { return ptr() ? nanobind::inst_ptr(ptr()) : nullptr; } +}; + +// This function is analogous to std::make_unique(...), but instead it +// allocates the object on the Python heap +template +nb_class_ptr make_nb_class(Args&&... args) { + nanobind::handle type = nanobind::type(); + nanobind::object instance = nanobind::inst_alloc(type); + T* ptr = nanobind::inst_ptr(instance); + new (ptr) T(std::forward(args)...); + nanobind::inst_mark_ready(instance); + return nb_class_ptr(instance.release(), ::nanobind::detail::steal_t{}); +} + +} // namespace xla + +#endif // JAXLIB_XLA_NB_CLASS_PTR_H_ diff --git a/jaxlib/xla/pjit.cc b/jaxlib/xla/pjit.cc index 0409397c82de..508bf79f9ec0 100644 --- a/jaxlib/xla/pjit.cc +++ b/jaxlib/xla/pjit.cc @@ -53,11 +53,14 @@ limitations under the License. #include "jaxlib/xla/config.h" #include "jaxlib/xla/guard_lib.h" #include "jaxlib/xla/jax_jit.h" +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_array.h" #include "jaxlib/xla/py_executable.h" #include "jaxlib/xla/py_values.h" +#include "jaxlib/xla/python_ref_manager.h" #include "jaxlib/xla/pytree.h" #include "jaxlib/xla/sharding.h" +#include "jaxlib/xla/traceback.h" #include "xla/layout.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/lru_cache.h" @@ -67,11 +70,8 @@ limitations under the License. #include "xla/python/ifrt/executable.h" #include "xla/python/ifrt/memory.h" #include "xla/python/ifrt/sharding.h" -#include "xla/python/nb_class_ptr.h" #include "xla/python/nb_helpers.h" #include "xla/python/nb_numpy.h" -#include "xla/python/python_ref_manager.h" -#include "xla/python/traceback.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/tsl/platform/env.h" #include "xla/tsl/platform/errors.h" diff --git a/jaxlib/xla/pmap_lib.cc b/jaxlib/xla/pmap_lib.cc index 3dbd736076db..295ac8bfccfb 100644 --- a/jaxlib/xla/pmap_lib.cc +++ b/jaxlib/xla/pmap_lib.cc @@ -46,15 +46,18 @@ limitations under the License. #include "nanobind/stl/vector.h" // IWYU pragma: keep #include "jaxlib/xla/config.h" #include "jaxlib/xla/jax_jit.h" +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_array.h" #include "jaxlib/xla/py_client.h" #include "jaxlib/xla/py_device.h" #include "jaxlib/xla/py_executable.h" #include "jaxlib/xla/py_values.h" +#include "jaxlib/xla/python_ref_manager.h" #include "jaxlib/xla/pytree.h" #include "jaxlib/xla/sharded_device_array.h" #include "jaxlib/xla/sharding.h" #include "jaxlib/xla/to_ifrt_sharding.h" +#include "jaxlib/xla/traceback.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/status_casters.h" #include "xla/python/ifrt/array.h" @@ -64,11 +67,8 @@ limitations under the License. #include "xla/python/ifrt/memory.h" #include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/sharding.h" -#include "xla/python/nb_class_ptr.h" #include "xla/python/nb_helpers.h" #include "xla/python/nb_numpy.h" -#include "xla/python/python_ref_manager.h" -#include "xla/python/traceback.h" #include "xla/python/types.h" #include "xla/status_macros.h" #include "xla/tsl/concurrency/ref_count.h" diff --git a/jaxlib/xla/py_array.cc b/jaxlib/xla/py_array.cc index a348b47454e7..1325f0cbd2bc 100644 --- a/jaxlib/xla/py_array.cc +++ b/jaxlib/xla/py_array.cc @@ -58,12 +58,15 @@ limitations under the License. #include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep #include "jaxlib/xla/guard_lib.h" +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_client.h" #include "jaxlib/xla/py_device.h" #include "jaxlib/xla/py_device_list.h" #include "jaxlib/xla/py_values.h" +#include "jaxlib/xla/python_ref_manager.h" #include "jaxlib/xla/sharding.h" #include "jaxlib/xla/to_ifrt_sharding.h" +#include "jaxlib/xla/traceback.h" #include "jaxlib/xla/util.h" #include "xla/layout.h" #include "xla/layout_util.h" @@ -86,15 +89,12 @@ limitations under the License. #include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/sharding.h" #include "xla/python/nb_absl_span.h" // IWYU pragma: keep -#include "xla/python/nb_class_ptr.h" #include "xla/python/nb_helpers.h" #include "xla/python/nb_numpy.h" #include "xla/python/pjrt_ifrt/pjrt_array.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" #include "xla/python/pjrt_ifrt/pjrt_device.h" #include "xla/python/pjrt_ifrt/pjrt_dtype.h" -#include "xla/python/python_ref_manager.h" -#include "xla/python/traceback.h" #include "xla/python/types.h" #include "xla/shape.h" #include "xla/shape_util.h" diff --git a/jaxlib/xla/py_array.h b/jaxlib/xla/py_array.h index f914639e383f..645f51096c1d 100644 --- a/jaxlib/xla/py_array.h +++ b/jaxlib/xla/py_array.h @@ -33,7 +33,9 @@ limitations under the License. #include "absl/types/span.h" #include "llvm/Support/Casting.h" #include "nanobind/nanobind.h" +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_client.h" +#include "jaxlib/xla/traceback.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_future.h" @@ -41,10 +43,8 @@ limitations under the License. #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/device_list.h" #include "xla/python/ifrt/future.h" -#include "xla/python/nb_class_ptr.h" #include "xla/python/nb_numpy.h" #include "xla/python/pjrt_ifrt/pjrt_array.h" -#include "xla/python/traceback.h" #include "xla/shape.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/util.h" diff --git a/jaxlib/xla/py_client.cc b/jaxlib/xla/py_client.cc index b74c37f28863..795a4fee29fa 100644 --- a/jaxlib/xla/py_client.cc +++ b/jaxlib/xla/py_client.cc @@ -50,12 +50,15 @@ limitations under the License. #include "nanobind/stl/vector.h" // IWYU pragma: keep #include "jaxlib/xla/callback.h" #include "jaxlib/xla/guard_lib.h" +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_array.h" #include "jaxlib/xla/py_device.h" #include "jaxlib/xla/py_executable.h" #include "jaxlib/xla/py_host_callback.h" #include "jaxlib/xla/py_memory_space.h" #include "jaxlib/xla/py_values.h" +#include "jaxlib/xla/python_ref_manager.h" +#include "jaxlib/xla/traceback.h" #include "xla/literal.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/mlir_to_hlo.h" @@ -74,14 +77,11 @@ limitations under the License. #include "xla/python/ifrt/memory.h" #include "xla/python/ifrt/program.h" #include "xla/python/nb_absl_span.h" // IWYU pragma: keep -#include "xla/python/nb_class_ptr.h" #include "xla/python/nb_numpy.h" #include "xla/python/pjrt_ifrt/pjrt_array.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" #include "xla/python/pjrt_ifrt/xla_compiler.h" #include "xla/python/pprof_profile_builder.h" -#include "xla/python/python_ref_manager.h" -#include "xla/python/traceback.h" #include "xla/python/types.h" #include "xla/service/custom_call_target_registry.h" #include "xla/service/platform_util.h" // IWYU pragma: keep diff --git a/jaxlib/xla/py_client.h b/jaxlib/xla/py_client.h index 8f50c6451627..898a40141307 100644 --- a/jaxlib/xla/py_client.h +++ b/jaxlib/xla/py_client.h @@ -34,6 +34,7 @@ limitations under the License. #include "absl/types/span.h" #include "llvm/Support/Casting.h" #include "nanobind/nanobind.h" +#include "jaxlib/xla/nb_class_ptr.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_executable.h" @@ -42,7 +43,6 @@ limitations under the License. #include "xla/python/ifrt/compiler.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/program.h" -#include "xla/python/nb_class_ptr.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" #include "xla/shape.h" diff --git a/jaxlib/xla/py_compile_only_client.cc b/jaxlib/xla/py_compile_only_client.cc index 6319c70f91b0..673dfc214346 100644 --- a/jaxlib/xla/py_compile_only_client.cc +++ b/jaxlib/xla/py_compile_only_client.cc @@ -30,6 +30,7 @@ limitations under the License. #include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_client.h" #include "xla/pjrt/mlir_to_hlo.h" #include "xla/pjrt/pjrt_compiler.h" @@ -37,7 +38,6 @@ limitations under the License. #include "xla/pjrt/status_casters.h" #include "xla/python/compile_only_ifrt/client.h" #include "xla/python/ifrt/executable.h" -#include "xla/python/nb_class_ptr.h" #include "xla/python/pjrt_ifrt/pjrt_executable.h" #include "xla/python/pjrt_ifrt/pjrt_topology.h" #include "xla/python/pjrt_ifrt/xla_compiler.h" diff --git a/jaxlib/xla/py_compile_only_client.h b/jaxlib/xla/py_compile_only_client.h index 721830d6f52e..6cc700e1d3a9 100644 --- a/jaxlib/xla/py_compile_only_client.h +++ b/jaxlib/xla/py_compile_only_client.h @@ -20,8 +20,8 @@ limitations under the License. // placeholder for index annotation headers #include "nanobind/nanobind.h" +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_client.h" -#include "xla/python/nb_class_ptr.h" #include "xla/python/pjrt_ifrt/pjrt_topology.h" namespace xla { diff --git a/jaxlib/xla/py_device.cc b/jaxlib/xla/py_device.cc index 20c257bb7d1a..253bfd439a9b 100644 --- a/jaxlib/xla/py_device.cc +++ b/jaxlib/xla/py_device.cc @@ -36,17 +36,17 @@ limitations under the License. #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "nanobind/stl/variant.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_client.h" #include "jaxlib/xla/py_memory_space.h" +#include "jaxlib/xla/python_ref_manager.h" #include "xla/layout_util.h" #include "xla/literal.h" #include "xla/pjrt/status_casters.h" #include "xla/python/ifrt/device.h" -#include "xla/python/nb_class_ptr.h" #include "xla/python/nb_helpers.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" #include "xla/python/pjrt_ifrt/pjrt_device.h" -#include "xla/python/python_ref_manager.h" #include "xla/python/types.h" #include "xla/shape.h" #include "xla/shape_util.h" diff --git a/jaxlib/xla/py_device.h b/jaxlib/xla/py_device.h index 6d2b3893dea8..6071ede52305 100644 --- a/jaxlib/xla/py_device.h +++ b/jaxlib/xla/py_device.h @@ -24,10 +24,10 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "nanobind/nanobind.h" +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_client.h" #include "xla/literal.h" #include "xla/python/ifrt/device.h" -#include "xla/python/nb_class_ptr.h" #include "xla/shape.h" namespace xla { diff --git a/jaxlib/xla/py_device_list.cc b/jaxlib/xla/py_device_list.cc index 593a86ccbe42..300e477dbbd0 100644 --- a/jaxlib/xla/py_device_list.cc +++ b/jaxlib/xla/py_device_list.cc @@ -32,13 +32,13 @@ limitations under the License. #include "nanobind/nanobind.h" #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_client.h" #include "jaxlib/xla/py_device.h" +#include "jaxlib/xla/python_ref_manager.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/device_list.h" -#include "xla/python/nb_class_ptr.h" #include "xla/python/nb_helpers.h" -#include "xla/python/python_ref_manager.h" #include "xla/python/types.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/util.h" diff --git a/jaxlib/xla/py_device_list.h b/jaxlib/xla/py_device_list.h index ea574c5dc5a2..1d0f64003f8c 100644 --- a/jaxlib/xla/py_device_list.h +++ b/jaxlib/xla/py_device_list.h @@ -23,9 +23,9 @@ limitations under the License. #include "absl/status/statusor.h" #include "nanobind/nanobind.h" +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_client.h" #include "xla/python/ifrt/device_list.h" -#include "xla/python/nb_class_ptr.h" #include "xla/tsl/concurrency/ref_count.h" namespace jax { diff --git a/jaxlib/xla/py_executable.cc b/jaxlib/xla/py_executable.cc index 5a02a8f6dd20..71e6cfbdba7f 100644 --- a/jaxlib/xla/py_executable.cc +++ b/jaxlib/xla/py_executable.cc @@ -35,9 +35,11 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "nanobind/nanobind.h" +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_array.h" #include "jaxlib/xla/py_client.h" #include "jaxlib/xla/py_device.h" +#include "jaxlib/xla/traceback.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/pjrt/pjrt_future.h" #include "xla/pjrt/pjrt_layout.h" @@ -48,8 +50,6 @@ limitations under the License. #include "xla/python/ifrt/future.h" #include "xla/python/ifrt/memory.h" #include "xla/python/ifrt/sharding.h" -#include "xla/python/nb_class_ptr.h" -#include "xla/python/traceback.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/tsl/platform/env.h" #include "xla/tsl/platform/logging.h" diff --git a/jaxlib/xla/py_executable.h b/jaxlib/xla/py_executable.h index 214431f9472e..688eb779df8d 100644 --- a/jaxlib/xla/py_executable.h +++ b/jaxlib/xla/py_executable.h @@ -32,8 +32,10 @@ limitations under the License. #include "absl/types/span.h" #include "llvm/Support/Casting.h" #include "nanobind/nanobind.h" +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_array.h" #include "jaxlib/xla/py_client.h" +#include "jaxlib/xla/traceback.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/layout.h" #include "xla/pjrt/exceptions.h" @@ -45,9 +47,7 @@ limitations under the License. #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/attribute_map.h" #include "xla/python/ifrt/executable.h" -#include "xla/python/nb_class_ptr.h" #include "xla/python/pjrt_ifrt/pjrt_executable.h" -#include "xla/python/traceback.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/tsl/platform/status.h" #include "xla/xla_data.pb.h" diff --git a/jaxlib/xla/py_host_callback.cc b/jaxlib/xla/py_host_callback.cc index 833079335a36..fdb40c04b517 100644 --- a/jaxlib/xla/py_host_callback.cc +++ b/jaxlib/xla/py_host_callback.cc @@ -32,13 +32,13 @@ limitations under the License. #include "nanobind/nanobind.h" #include "jaxlib/xla/callback.h" #include "jaxlib/xla/py_host_callback.pb.h" +#include "jaxlib/xla/python_ref_manager.h" #include "xla/layout_util.h" #include "xla/pjrt/host_callback.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/host_callback.h" #include "xla/python/pjrt_ifrt/pjrt_host_callback.h" #include "xla/python/pjrt_ifrt/xla_host_callback.pb.h" -#include "xla/python/python_ref_manager.h" #include "xla/python/types.h" #include "xla/shape.h" #include "xla/shape_util.h" diff --git a/jaxlib/xla/py_memory_space.cc b/jaxlib/xla/py_memory_space.cc index f365dd25dfb6..0409861dd3b9 100644 --- a/jaxlib/xla/py_memory_space.cc +++ b/jaxlib/xla/py_memory_space.cc @@ -22,9 +22,9 @@ limitations under the License. #include "absl/strings/string_view.h" #include "nanobind/nanobind.h" #include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_client.h" #include "xla/python/ifrt/device.h" -#include "xla/python/nb_class_ptr.h" namespace nb = ::nanobind; diff --git a/jaxlib/xla/py_memory_space.h b/jaxlib/xla/py_memory_space.h index 4ad7b852f416..f111263497fb 100644 --- a/jaxlib/xla/py_memory_space.h +++ b/jaxlib/xla/py_memory_space.h @@ -19,9 +19,9 @@ limitations under the License. #include #include "nanobind/nanobind.h" +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_client.h" #include "xla/python/ifrt/memory.h" -#include "xla/python/nb_class_ptr.h" namespace xla { diff --git a/jaxlib/xla/py_program.cc b/jaxlib/xla/py_program.cc index ec82292a50cd..b3828f5372d9 100644 --- a/jaxlib/xla/py_program.cc +++ b/jaxlib/xla/py_program.cc @@ -34,8 +34,10 @@ limitations under the License. #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_device.h" #include "jaxlib/xla/py_device_list.h" +#include "jaxlib/xla/python_ref_manager.h" #include "jaxlib/xla/sharding.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/pjrt/mlir_to_hlo.h" @@ -53,11 +55,9 @@ limitations under the License. #include "xla/python/ifrt/program.h" #include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/sharding.h" -#include "xla/python/nb_class_ptr.h" #include "xla/python/nb_numpy.h" #include "xla/python/pjrt_ifrt/xla_compiler.h" #include "xla/python/pjrt_ifrt/xla_sharding.h" -#include "xla/python/python_ref_manager.h" #include "xla/python/types.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/tsl/platform/statusor.h" diff --git a/jaxlib/xla/py_socket_transfer.cc b/jaxlib/xla/py_socket_transfer.cc index b1c4fbcc541f..05397cdf116f 100644 --- a/jaxlib/xla/py_socket_transfer.cc +++ b/jaxlib/xla/py_socket_transfer.cc @@ -34,9 +34,11 @@ limitations under the License. #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_array.h" #include "jaxlib/xla/py_client.h" #include "jaxlib/xla/to_ifrt_sharding.h" +#include "jaxlib/xla/traceback.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/status_casters.h" #include "xla/python/ifrt/array.h" @@ -45,13 +47,11 @@ limitations under the License. #include "xla/python/ifrt/memory.h" #include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/sharding.h" -#include "xla/python/nb_class_ptr.h" #include "xla/python/nb_numpy.h" #include "xla/python/pjrt_ifrt/pjrt_array.h" #include "xla/python/pjrt_ifrt/pjrt_device.h" #include "xla/python/pjrt_ifrt/pjrt_dtype.h" #include "xla/python/pjrt_ifrt/pjrt_memory.h" -#include "xla/python/traceback.h" #include "xla/python/transfer/event_loop.h" #include "xla/python/transfer/socket-server.h" #include "xla/python/transfer/socket_bulk_transport.h" diff --git a/jaxlib/xla/py_values.cc b/jaxlib/xla/py_values.cc index 9375dd5440c6..1c7db0bec13a 100644 --- a/jaxlib/xla/py_values.cc +++ b/jaxlib/xla/py_values.cc @@ -42,6 +42,7 @@ limitations under the License. #include "nanobind/stl/complex.h" // IWYU pragma: keep #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "jaxlib/xla/py_array.h" +#include "jaxlib/xla/python_ref_manager.h" #include "jaxlib/xla/sharding.h" #include "xla/primitive_util.h" #include "xla/python/ifrt/array.h" @@ -53,7 +54,6 @@ limitations under the License. #include "xla/python/ifrt/sharding.h" #include "xla/python/nb_numpy.h" #include "xla/python/pjrt_ifrt/pjrt_dtype.h" -#include "xla/python/python_ref_manager.h" #include "xla/python/types.h" #include "xla/shape.h" #include "xla/tsl/concurrency/ref_count.h" diff --git a/jaxlib/xla/python_ref_manager.cc b/jaxlib/xla/python_ref_manager.cc new file mode 100644 index 000000000000..a19622d94244 --- /dev/null +++ b/jaxlib/xla/python_ref_manager.cc @@ -0,0 +1,104 @@ +/* Copyright 2019 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/python_ref_manager.h" + +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" + +namespace xla { + +namespace nb = nanobind; + +PythonRefManager::ManagedPyObjects::ManagedPyObjects( + PythonRefManager* manager, absl::Span objects) + : manager_(manager) { + objects_.reserve(objects.size()); + for (nb::object& object : objects) { + objects_.push_back(std::move(object)); + } +} + +PythonRefManager::ManagedPyObjects::~ManagedPyObjects() { + if (manager_ && !objects_.empty()) { + manager_->AddGarbage(absl::MakeSpan(objects_)); + } +} + +std::shared_ptr +PythonRefManager::ManageReference(nb::object object) { + return std::make_shared(this, + absl::Span(&object, 1)); +} + +std::shared_ptr +PythonRefManager::ManageReferences(absl::Span objects) { + return std::make_shared(this, objects); +} + +void PythonRefManager::AddGarbage(nb::object garbage) { + absl::MutexLock lock(&mu_); + // We want to collect arbitrary python garbage (e.g., buffers) aggressively. + garbage_count_.fetch_add(100, std::memory_order_relaxed); + python_garbage_.push_back(std::move(garbage)); +} + +void PythonRefManager::AddGarbage(absl::Span garbage) { + absl::MutexLock lock(&mu_); + // We want to collect arbitrary python garbage (e.g., buffers) aggressively. + garbage_count_.fetch_add(100, std::memory_order_relaxed); + for (nb::object& o : garbage) { + python_garbage_.push_back(std::move(o)); + } +} + +void PythonRefManager::AddGarbage( + absl::Span const> garbage) { + absl::MutexLock lock(&mu_); + // We don't care about collecting stack frame objects often. We grab a lot of + // tracebacks and the code objects are most likely live for the entire + // process. + garbage_count_.fetch_add(1, std::memory_order_relaxed); + for (const auto& o : garbage) { + python_garbage_.push_back(nb::steal(reinterpret_cast(o.first))); + } +} + +void PythonRefManager::CollectGarbage() { + // TODO(phawkins): we should CHECK(PyGILState_Check()); + std::deque garbage; + { + absl::MutexLock lock(&mu_); + garbage_count_ = 0; + garbage.swap(python_garbage_); + } + // We defer deleting garbage until the lock is released. It's possible that + // deleting garbage will lead to more Python garbage being added; if we held + // the lock we would deadlock because absl::Mutex is not reentrant. +} + +PythonRefManager* GlobalPyRefManager() { + static PythonRefManager* static_ref_manager = new PythonRefManager(); + return static_ref_manager; +} + +} // namespace xla diff --git a/jaxlib/xla/python_ref_manager.h b/jaxlib/xla/python_ref_manager.h new file mode 100644 index 000000000000..c0630da2ebd5 --- /dev/null +++ b/jaxlib/xla/python_ref_manager.h @@ -0,0 +1,108 @@ +/* Copyright 2019 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_PYTHON_REF_MANAGER_H_ +#define JAXLIB_XLA_PYTHON_REF_MANAGER_H_ + +#include + +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/inlined_vector.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" + +namespace xla { + +// Class that manages destruction of Python objects. +// +// We must not destroy Python objects without holding the GIL. However, we +// frequently want to hold references to Python objects for the duration of +// an asynchronous transfer on a Stream, and release our reference when the +// transfer completes. +// +// This class holds references to Python objects outside a GIL scope, that can +// be collected later when the GIL is held by calling CollectGarbage(). +class PythonRefManager { + public: + PythonRefManager() = default; + + // Holds references to a set of nanobind::objects, adding the references to + // the PythonRefManager on destruction. + class ManagedPyObjects { + public: + ManagedPyObjects() = default; + ManagedPyObjects(PythonRefManager* manager, + absl::Span objects); + + ~ManagedPyObjects(); + + ManagedPyObjects(const ManagedPyObjects& other) = delete; + ManagedPyObjects(ManagedPyObjects&& other) = default; + ManagedPyObjects& operator=(const ManagedPyObjects& other) = delete; + ManagedPyObjects& operator=(ManagedPyObjects&& other) noexcept = default; + + private: + PythonRefManager* manager_ = nullptr; + absl::InlinedVector objects_; + }; + + // Creates a managed std::shared_ptr to an object. When the shared_ptr is + // destroyed, the reference to 'object' will be added to python_garbage_, + // and collected next time CollectGarbage() is called. + std::shared_ptr ManageReference(nanobind::object object); + std::shared_ptr ManageReferences( + absl::Span objects); + + // Adds garbage objects to the manager. + void AddGarbage(nanobind::object garbage); + void AddGarbage(absl::Span garbage); + void AddGarbage(absl::Span const> garbage); + + // Releases the contents of python_garbage_. Requires that the GIL is held. + // The client calls this method during API entry points where the GIL is held + // to free any garbage that has accumulated. + void CollectGarbage(); + + // Cheaper version of CollectGarbage() with relaxed consistency and frequency. + // The purpose of this function is to amortize lock acquisition costs over + // a larger number of API calls. + void MaybeCollectGarbage() { + if (garbage_count_.load(std::memory_order_relaxed) >= 100) { + CollectGarbage(); + } + } + + private: + absl::Mutex mu_; + std::deque python_garbage_ ABSL_GUARDED_BY(mu_); + + // Writes to garbage_count_ are protected by mu_, reads are not protected. + std::atomic garbage_count_{0}; +}; + +// A global PythonRefManager. Unless `CollectGarbage()` is called before +// shutdown, this container will hold on to Python objects and thus cause a +// leak. This behavior is similar to `tensorflow::ClearDecRefCache()`. +PythonRefManager* GlobalPyRefManager(); + +} // namespace xla + +#endif // JAXLIB_XLA_PYTHON_REF_MANAGER_H_ diff --git a/jaxlib/xla/pytree.cc b/jaxlib/xla/pytree.cc index 7d1f7676bada..175e753515d0 100644 --- a/jaxlib/xla/pytree.cc +++ b/jaxlib/xla/pytree.cc @@ -49,9 +49,9 @@ limitations under the License. #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "nanobind/stl/tuple.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/pytree.pb.h" #include "xla/pjrt/exceptions.h" -#include "xla/python/nb_class_ptr.h" #include "xla/tsl/platform/logging.h" namespace xla { diff --git a/jaxlib/xla/pytree.h b/jaxlib/xla/pytree.h index 471d25af89bc..9c4aaff0bfae 100644 --- a/jaxlib/xla/pytree.h +++ b/jaxlib/xla/pytree.h @@ -34,8 +34,8 @@ limitations under the License. #include "absl/hash/hash.h" #include "absl/types/span.h" #include "nanobind/nanobind.h" +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/pytree.pb.h" -#include "xla/python/nb_class_ptr.h" namespace xla { diff --git a/jaxlib/xla/sharding.cc b/jaxlib/xla/sharding.cc index 9952c31bd393..409dddb62268 100644 --- a/jaxlib/xla/sharding.cc +++ b/jaxlib/xla/sharding.cc @@ -30,13 +30,13 @@ limitations under the License. #include "nanobind/nanobind.h" #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_client.h" #include "jaxlib/xla/py_device_list.h" #include "jaxlib/xla/sharded_device_array.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/pjrt/status_casters.h" #include "xla/python/ifrt/device_list.h" -#include "xla/python/nb_class_ptr.h" #include "xla/python/nb_numpy.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/tsl/platform/logging.h" diff --git a/jaxlib/xla/sharding.h b/jaxlib/xla/sharding.h index 572a6cd3c86e..dac18a4160b5 100644 --- a/jaxlib/xla/sharding.h +++ b/jaxlib/xla/sharding.h @@ -24,13 +24,13 @@ limitations under the License. #include "absl/hash/hash.h" #include "absl/status/statusor.h" #include "nanobind/nanobind.h" +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_client.h" #include "jaxlib/xla/py_device_list.h" #include "jaxlib/xla/sharded_device_array.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/pjrt/status_casters.h" #include "xla/python/ifrt/device.h" -#include "xla/python/nb_class_ptr.h" #include "xla/python/nb_numpy.h" #include "xla/xla_data.pb.h" diff --git a/jaxlib/xla/to_ifrt_sharding.cc b/jaxlib/xla/to_ifrt_sharding.cc index 96ec9c77071d..116ead49ad23 100644 --- a/jaxlib/xla/to_ifrt_sharding.cc +++ b/jaxlib/xla/to_ifrt_sharding.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "nanobind/nanobind.h" #include "nanobind/stl/string.h" // IWYU pragma: keep +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_device_list.h" #include "jaxlib/xla/sharding.h" #include "xla/hlo/ir/hlo_sharding.h" @@ -32,7 +33,6 @@ limitations under the License. #include "xla/python/ifrt/memory.h" #include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/sharding.h" -#include "xla/python/nb_class_ptr.h" #include "xla/python/pjrt_ifrt/pjrt_dtype.h" #include "xla/python/pjrt_ifrt/xla_sharding.h" #include "xla/shape.h" diff --git a/jaxlib/xla/traceback.cc b/jaxlib/xla/traceback.cc new file mode 100644 index 000000000000..35085b3e32fa --- /dev/null +++ b/jaxlib/xla/traceback.cc @@ -0,0 +1,357 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla/traceback.h" + +#include + +#include +#include +#include +#include +#include + +#include "absl/base/casts.h" +#include "absl/hash/hash.h" +#include "absl/log/check.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/xla/nb_class_ptr.h" +#include "xla/pjrt/exceptions.h" +#include "tsl/platform/platform.h" + +#ifdef PLATFORM_GOOGLE +#define Py_BUILD_CORE +#include "internal/pycore_frame.h" +#undef Py_BUILD_CORE +#endif // PLATFORM_GOOGLE + +namespace xla { + +namespace nb = nanobind; + +bool Traceback::enabled_ = true; + +Traceback::Traceback() { + DCHECK(PyGILState_Check()); + PyThreadState* thread_state = PyThreadState_GET(); + +#if PY_VERSION_HEX < 0x030b0000 + // The representation of frame->f_lasti changed from bytes to words in Python + // 3.10, see https://docs.python.org/3/whatsnew/3.10.html#changes-in-the-c-api + // This should match sizeof(_Py_CODEUNIT) which is unfortunately private. + constexpr int kLastiWordBytes = 2; + + for (PyFrameObject* py_frame = thread_state->frame; py_frame != nullptr; + py_frame = py_frame->f_back) { + Py_INCREF(py_frame->f_code); + frames_.emplace_back(py_frame->f_code, py_frame->f_lasti * kLastiWordBytes); + } +#else // PY_VERSION_HEX < 0x030b0000 + +#ifdef PLATFORM_GOOGLE + // This code is equivalent to the version using public APIs, but it saves us + // an allocation of one object per stack frame. However, this is definitely + // violating the API contract of CPython, so we only use this where we can be + // confident we know exactly which CPython we are using (internal to Google). + // Feel free to turn this on if you like, but it might break at any time! + for (_PyInterpreterFrame* f = thread_state->cframe->current_frame; + f != nullptr; f = f->previous) { + if (_PyFrame_IsIncomplete(f)) continue; + Py_INCREF(f->f_code); + frames_.emplace_back(f->f_code, + _PyInterpreterFrame_LASTI(f) * sizeof(_Py_CODEUNIT)); + } +#else // PLATFORM_GOOGLE + PyFrameObject* next; + for (PyFrameObject* py_frame = PyThreadState_GetFrame(thread_state); + py_frame != nullptr; py_frame = next) { + frames_.emplace_back(PyFrame_GetCode(py_frame), PyFrame_GetLasti(py_frame)); + next = PyFrame_GetBack(py_frame); + Py_XDECREF(py_frame); + } +#endif // PLATFORM_GOOGLE + +#endif // PY_VERSION_HEX < 0x030b0000 +} + +Traceback::~Traceback() { + for (auto& frame : frames_) { + DCHECK(PyGILState_Check()); + Py_DECREF(frame.first); + } +} + +Traceback::Traceback(Traceback&& other) noexcept + : frames_(std::move(other.frames_)) { + // absl::InlinedVector does not always clear itself if moved. Since we rely on + // its empty() method to destroy Traceback differently, we explicitly clear + // here. + other.frames_.clear(); +} + +std::string Traceback::Frame::ToString() const { + return absl::StrFormat("%s:%d (%s)", nb::cast(file_name), + line_num, nb::cast(function_name)); +} + +std::string Traceback::ToString() const { + std::vector frame_strs; + frame_strs.reserve(frames_.size()); + for (const Frame& frame : Frames()) { + frame_strs.push_back(frame.ToString()); + } + return absl::StrJoin(frame_strs, "\n"); +} + +std::vector Traceback::Frames() const { + // We require the GIL because we manipulate Python strings. + CHECK(PyGILState_Check()); + std::vector frames; + frames.reserve(frames_.size()); + for (const auto& frame : frames_) { + frames.push_back(Frame{nb::borrow(frame.first->co_filename), + nb::borrow(frame.first->co_name), + frame.first->co_firstlineno, + PyCode_Addr2Line(frame.first, frame.second)}); + } + return frames; +} + +std::optional> Traceback::Get() { + DCHECK(PyGILState_Check()); + if (!enabled_) { + return std::nullopt; + } + return make_nb_class(); +} + +void Traceback::SetEnabled(bool enabled) { enabled_ = enabled; } + +nb::object Traceback::AsPythonTraceback() const { + nb::object traceback = nb::none(); + nb::dict globals; + nb::handle traceback_type(reinterpret_cast(&PyTraceBack_Type)); + for (const std::pair& frame : frames_) { + int lineno = PyCode_Addr2Line(frame.first, frame.second); + // Under Python 3.11 we observed crashes when using a fake PyFrameObject + // with a real PyCodeObject (https://github.com/google/jax/issues/16027). + // because the frame does not have fields necessary to compute the locals, + // notably the closure object, leading to crashes in CPython in + // _PyFrame_FastToLocalsWithError + // https://github.com/python/cpython/blob/deaf509e8fc6e0363bd6f26d52ad42f976ec42f2/Objects/frameobject.c#LL1116C2-L1116C2 + // We therefore always build a fake code object to go along with our fake + // frame. + PyCodeObject* py_code = + PyCode_NewEmpty(PyUnicode_AsUTF8(frame.first->co_filename), + PyUnicode_AsUTF8(frame.first->co_name), lineno); + PyFrameObject* py_frame = PyFrame_New(PyThreadState_Get(), py_code, + globals.ptr(), /*locals=*/nullptr); + Py_DECREF(py_code); + + traceback = traceback_type( + /*tb_next=*/std::move(traceback), + /*tb_frame=*/ + nb::steal(reinterpret_cast(py_frame)), + /*tb_lasti=*/0, + /*tb_lineno=*/ + PyCode_Addr2Line(frame.first, frame.second)); + } + return traceback; +} + +namespace { + +Py_hash_t traceback_tp_hash(PyObject* o) { + Traceback* tb; + if (!nb::try_cast(nb::handle(o), tb)) { + PyErr_SetString(PyExc_TypeError, "Expected a Traceback object"); + return -1; + } + size_t h = absl::HashOf(*tb); + Py_hash_t s = absl::bit_cast(h); // Python hashes are signed. + return s == -1 ? -2 : s; // -1 must not be used as a Python hash value. +} + +PyObject* traceback_tp_richcompare(PyObject* self, PyObject* other, int op) { + if (op != Py_EQ && op != Py_NE) { + return Py_NewRef(Py_NotImplemented); + } + + Traceback* x; + if (!nb::try_cast(nb::handle(self), x)) { + PyErr_SetString(PyExc_TypeError, "Expected a Traceback object"); + return nullptr; + } + + bool result; + Traceback* y; + if (nb::try_cast(nb::handle(other), y)) { + result = ((*x == *y) == (op == Py_EQ)); + } else { + result = (op == Py_NE); + } + return Py_NewRef(result ? Py_True : Py_False); +} + +// It turns out to be slightly faster to define a tp_hash slot rather than +// defining __hash__ and __eq__ on the class. +PyType_Slot traceback_slots_[] = { + {Py_tp_hash, (void*)traceback_tp_hash}, + {Py_tp_richcompare, (void*)traceback_tp_richcompare}, + {0, nullptr}, +}; + +} // namespace + +void BuildTracebackSubmodule(nb::module_& m) { + nb::class_(m, "Frame") + .def(nb::init()) + .def_ro("file_name", &Traceback::Frame::file_name) + .def_ro("function_name", &Traceback::Frame::function_name) + .def_ro("function_start_line", &Traceback::Frame::function_start_line) + .def_ro("line_num", &Traceback::Frame::line_num) + .def("__repr__", [](const Traceback::Frame& frame) { + return absl::StrFormat( + "%s;%s:%d", nb::cast(frame.function_name), + nb::cast(frame.file_name), frame.line_num); + }); + + nb::class_ traceback(m, "Traceback", + nb::type_slots(traceback_slots_), + "Represents a Python stack trace."); + traceback.def_prop_rw_static( + "enabled", [](nb::object /* cls */) { return Traceback::enabled(); }, + [](nb::object /* cls */, bool enabled) { + return Traceback::SetEnabled(enabled); + }); + traceback.def_static( + "get_traceback", []() { return Traceback::Get(); }, + R"doc( + Returns a :class:`Traceback` for the current thread. + + If ``Traceback.enabled`` is ``True``, returns a :class:`Traceback` object + that describes the Python stack of the calling thread. Stack trace + collection has a small overhead, so it is disabled by default. If traceback + collection is disabled, returns ``None``. + )doc"); + traceback.def_prop_ro("frames", &Traceback::Frames); + traceback.def("raw_frames", [](const Traceback& tb) -> nb::tuple { + // We return a tuple of lists, rather than a list of tuples, because it + // is cheaper to allocate only three Python objects for everything rather + // than one per frame. + nb::list out_code = nb::steal(PyList_New(tb.raw_frames().size())); + nb::list out_lasti = + nb::steal(PyList_New(tb.raw_frames().size())); + for (size_t i = 0; i < tb.raw_frames().size(); ++i) { + const auto& frame = tb.raw_frames()[i]; + PyObject* code = reinterpret_cast(frame.first); + Py_INCREF(code); + PyList_SET_ITEM(out_code.ptr(), i, code); + PyList_SET_ITEM(out_lasti.ptr(), i, + nb::int_(frame.second).release().ptr()); + } + return nb::make_tuple(out_code, out_lasti); + }); + traceback.def("__str__", &Traceback::ToString); + traceback.def("as_python_traceback", &Traceback::AsPythonTraceback); + + traceback.def_static( + "traceback_from_frames", + [](std::vector frames) { + nb::object traceback = nb::none(); + nb::dict globals; + nb::handle traceback_type( + reinterpret_cast(&PyTraceBack_Type)); + for (const Traceback::Frame& frame : frames) { + PyCodeObject* py_code = + PyCode_NewEmpty(frame.file_name.c_str(), + frame.function_name.c_str(), frame.line_num); + PyFrameObject* py_frame = PyFrame_New(PyThreadState_Get(), py_code, + globals.ptr(), /*locals=*/ + nullptr); + Py_DECREF(py_code); + traceback = traceback_type( + /*tb_next=*/std::move(traceback), + /*tb_frame=*/ + nb::steal(reinterpret_cast(py_frame)), + /*tb_lasti=*/0, + /*tb_lineno=*/ + frame.line_num); + } + return traceback; + }, + "Creates a traceback from a list of frames."); + + traceback.def_static( + "code_addr2line", + [](nb::handle code, int lasti) { + if (!PyCode_Check(code.ptr())) { + throw xla::XlaRuntimeError("code argument must be a code object"); + } + return PyCode_Addr2Line(reinterpret_cast(code.ptr()), + lasti); + }, + "Python wrapper around the Python C API function PyCode_Addr2Line"); + +#if PY_VERSION_HEX >= 0x030b0000 + traceback.def_static( + "code_addr2location", + [](nb::handle code, int lasti) { + if (!PyCode_Check(code.ptr())) { + throw xla::XlaRuntimeError("code argument must be a code object"); + } + int start_line, start_column, end_line, end_column; + if (!PyCode_Addr2Location(reinterpret_cast(code.ptr()), + lasti, &start_line, &start_column, &end_line, + &end_column)) { + throw nb::python_error(); + } + return nb::make_tuple(start_line, start_column, end_line, end_column); + }, + "Python wrapper around the Python C API function PyCode_Addr2Location"); +#endif // PY_VERSION_HEX >= 0x030b0000 + +#if PY_VERSION_HEX < 0x030b0000 + // This function replaces the exception traceback associated with the current + // Python thread. + m.def( + "replace_thread_exc_traceback", + [](nb::object tb) { + if (!tb.is_none() && !PyTraceBack_Check(tb.ptr())) { + throw xla::XlaRuntimeError( + "argument must be a traceback object or None"); + } + PyThreadState* thread_state = PyThreadState_Get(); + if (!thread_state->exc_info->exc_traceback) { + throw xla::XlaRuntimeError( + "Current thread does not have an active " + "exception traceback"); + } + PyObject* old_exc_traceback = thread_state->exc_info->exc_traceback; + PyObject* new_tb = tb.is_none() ? nullptr : tb.release().ptr(); + thread_state->exc_info->exc_traceback = new_tb; + Py_XDECREF(old_exc_traceback); + }, + nb::arg("traceback").none()); +#endif // PY_VERSION_HEX < 0x30b0000 +} +} // namespace xla diff --git a/jaxlib/xla/traceback.h b/jaxlib/xla/traceback.h new file mode 100644 index 000000000000..953d626439c4 --- /dev/null +++ b/jaxlib/xla/traceback.h @@ -0,0 +1,108 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_TRACEBACK_H_ +#define JAXLIB_XLA_TRACEBACK_H_ + +#include +#include +#include +#include +#include + +// placeholder for index annotation headers +#include "absl/container/inlined_vector.h" +#include "nanobind/nanobind.h" +#include "jaxlib/xla/nb_class_ptr.h" + +namespace xla { + +// Represents a Python traceback. This object is designed to be allocated on +// the Python heap; creating or destroying a traceback requires the GIL. +class Traceback { + public: + // Requires GIL. Creates a Traceback object that requires destructor to be + // invoked with GIL held as well. + static std::optional> Get(); + + // Requires GIL. + static bool enabled() { return enabled_; } + // Requires GIL. + static void SetEnabled(bool enabled); + + // Requires GIL. Don't call this directly, you're looking for Get(). + Traceback(); + // Requires GIL. + ~Traceback(); + + Traceback(const Traceback&) = delete; + Traceback(Traceback&& other) noexcept; + Traceback& operator=(const Traceback&) = delete; + Traceback& operator=(Traceback&&) = delete; + + // Requires the GIL be held. + std::string ToString() const; + + struct Frame { + nanobind::str file_name; + nanobind::str function_name; + int function_start_line; + int line_num; + + std::string ToString() const; + }; + std::vector Frames() const; + + const absl::InlinedVector, 32>& raw_frames() + const { + return frames_; + } + + // Returns the traceback as a fake Python Traceback object, suitable for + // using as an exception traceback. + nanobind::object AsPythonTraceback() const; + + bool operator==(const Traceback& other) const { + return frames_ == other.frames_; + } + bool operator!=(const Traceback& other) const { + return frames_ != other.frames_; + } + + private: + // Each frame is a pair of a code object and a "lasti" instruction location + // in bytes. The size of _Py_CODEUNIT has changed across different Python + // versions; the lasti value here has already been multiplied by + // sizeof(_Py_CODEUNIT) if needed and is suitable for passing to functions + // like PyCode_Addr2Line(). + absl::InlinedVector, 32> frames_; + + // Protected by GIL. + static bool enabled_; +}; + +using nb_traceback = nb_class_ptr; + +template +H AbslHashValue(H h, const Traceback& traceback) { + h = H::combine(std::move(h), traceback.raw_frames()); + return h; +} + +void BuildTracebackSubmodule(nanobind::module_& m); + +} // namespace xla + +#endif // JAXLIB_XLA_TRACEBACK_H_ diff --git a/jaxlib/xla/xla.cc b/jaxlib/xla/xla.cc index a0508013910b..6e47be15fc68 100644 --- a/jaxlib/xla/xla.cc +++ b/jaxlib/xla/xla.cc @@ -90,6 +90,7 @@ limitations under the License. #include "jaxlib/xla/guard_lib.h" #include "jaxlib/xla/jax_jit.h" #include "jaxlib/xla/mlir.h" +#include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/pjit.h" #include "jaxlib/xla/pmap_lib.h" #include "jaxlib/xla/py_array.h" @@ -98,8 +99,10 @@ limitations under the License. #include "jaxlib/xla/py_device_list.h" #include "jaxlib/xla/py_executable.h" #include "jaxlib/xla/py_memory_space.h" +#include "jaxlib/xla/python_ref_manager.h" #include "jaxlib/xla/pytree.h" #include "jaxlib/xla/sharding.h" +#include "jaxlib/xla/traceback.h" #include "jaxlib/xla/weakref_lru_cache.h" #include "jaxlib/xla/xla_compiler.h" #include "xla/pjrt/distributed/key_value_store_interface.h" @@ -113,15 +116,12 @@ limitations under the License. #include "xla/python/logging.h" // IWYU pragma: keep #include "xla/python/nb_absl_flat_hash_map.h" // IWYU pragma: keep #include "xla/python/nb_absl_span.h" // IWYU pragma: keep -#include "xla/python/nb_class_ptr.h" #include "xla/python/ops.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" #include "xla/python/pjrt_ifrt/pjrt_executable.h" #include "xla/python/pjrt_ifrt/pjrt_topology.h" #include "xla/python/pprof_profile_builder.h" #include "xla/python/profiler.h" -#include "xla/python/python_ref_manager.h" -#include "xla/python/traceback.h" #include "xla/tsl/distributed_runtime/preemption/preemption_sync_manager.h" #include "xla/tsl/platform/status.h" #include "tsl/platform/platform.h" From 096810a72150df391d4986004795baeb19e7e1db Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 26 Mar 2025 12:11:47 -0700 Subject: [PATCH 195/483] [array API] make capabilities more accurate --- jax/_src/numpy/array_api_metadata.py | 5 +++-- tests/array_api_test.py | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/jax/_src/numpy/array_api_metadata.py b/jax/_src/numpy/array_api_metadata.py index 4a01f579a67e..d8d2c2d1a2a4 100644 --- a/jax/_src/numpy/array_api_metadata.py +++ b/jax/_src/numpy/array_api_metadata.py @@ -51,8 +51,9 @@ class ArrayNamespaceInfo: .. _Python array API: https://data-apis.org/array-api/ """ _capabilities = { - "boolean indexing": True, - "data-dependent shapes": False, + "boolean indexing": False, # within transformations + "data-dependent shapes": False, # within transformations + "max dimensions": 64, # XLA limitation } def _build_dtype_dict(self): diff --git a/tests/array_api_test.py b/tests/array_api_test.py index 250eeb810872..d509fe78c35f 100644 --- a/tests/array_api_test.py +++ b/tests/array_api_test.py @@ -275,8 +275,9 @@ def build_dtype_dict(self, dtypes): def test_capabilities_info(self): capabilities = self.info.capabilities() - assert capabilities["boolean indexing"] + assert not capabilities["boolean indexing"] assert not capabilities["data-dependent shapes"] + assert capabilities["max dimensions"] == 64 def test_default_device_info(self): assert self.info.default_device() is None From 4644b2ba67fd8a8144602df014595a564104b8ae Mon Sep 17 00:00:00 2001 From: Ayaka Date: Wed, 26 Mar 2025 12:55:20 -0700 Subject: [PATCH 196/483] Add tests to ensure nan checks do not produce false positives in jax.numpy functions PiperOrigin-RevId: 740872313 --- tests/jax_numpy_error_test.py | 90 +++++++++++++++++++---------------- 1 file changed, 49 insertions(+), 41 deletions(-) diff --git a/tests/jax_numpy_error_test.py b/tests/jax_numpy_error_test.py index 08917aeed364..566e0b1ba209 100644 --- a/tests/jax_numpy_error_test.py +++ b/tests/jax_numpy_error_test.py @@ -117,65 +117,73 @@ def test_error_category_invalid_category(self): @staticmethod def op_cases(cases): for jit in (True, False): - for func, operands in cases: - if not isinstance(operands, tuple): - operands = (operands,) + for func, ops_error, ops_no_err in cases: + if not isinstance(ops_error, tuple): + ops_error = (ops_error,) + if not isinstance(ops_no_err, tuple): + ops_no_err = (ops_no_err,) jit_str = "jit" if jit else "nojit" func_str = f"{func.__module__}.{func.__name__}" name = f"_{jit_str}_{func_str}" - yield name, jit, func, operands + yield name, jit, func, ops_error, ops_no_err @parameterized.named_parameters( op_cases(( - # list of all NaN-producing jax.numpy functions + # List of all NaN-producing jax.numpy functions. + # The first group of numbers is the input that will produce a NaN, and + # the second group is the input that will not produce a NaN. # go/keep-sorted start - (jnp.acos, 2.0), - (jnp.acosh, 0.5), - (jnp.add, (jnp.inf, -jnp.inf)), - (jnp.arccos, 2.0), - (jnp.arccosh, 0.5), - (jnp.arcsin, -2.0), - (jnp.arctanh, -2.0), - (jnp.asin, -2.0), - (jnp.atanh, -2.0), - (jnp.cos, jnp.inf), - (jnp.divide, (0.0, 0.0)), - (jnp.divmod, (1.0, 0.0)), - (jnp.float_power, (-1.0, 0.5)), - (jnp.fmod, (1.0, 0.0)), - (jnp.log, -1.0), - (jnp.log10, -1.0), - (jnp.log1p, -1.5), - (jnp.log2, -1.0), - (jnp.mod, (1.0, 0.0)), - (jnp.pow, (-1.0, 0.5)), - (jnp.power, (-1.0, 0.5)), - (jnp.remainder, (1.0, 0.0)), - (jnp.sin, jnp.inf), + (jnp.acos, 2.0, 0.5), + (jnp.acosh, 0.5, 2.0), + (jnp.add, (jnp.inf, -jnp.inf), (0.0, 0.0)), + (jnp.arccos, 2.0, 0.5), + (jnp.arccosh, 0.5, 2.0), + (jnp.arcsin, -2.0, 0.5), + (jnp.arctanh, -2.0, 0.5), + (jnp.asin, -2.0, 0.5), + (jnp.atanh, -2.0, 0.5), + (jnp.cos, jnp.inf, 1.0), + (jnp.divide, (0.0, 0.0), (1.0, 1.0)), + (jnp.divmod, (1.0, 0.0), (1.0, 1.0)), + (jnp.float_power, (-1.0, 0.5), (1.0, 1.0)), + (jnp.fmod, (1.0, 0.0), (1.0, 1.0)), + (jnp.log, -1.0, 1.0), + (jnp.log10, -1.0, 1.0), + (jnp.log1p, -1.5, 1.0), + (jnp.log2, -1.0, 1.0), + (jnp.mod, (1.0, 0.0), (1.0, 1.0)), + (jnp.pow, (-1.0, 0.5), (1.0, 1.0)), + (jnp.power, (-1.0, 0.5), (1.0, 1.0)), + (jnp.remainder, (1.0, 0.0), (1.0, 1.0)), + (jnp.sin, jnp.inf, 1.0), # TODO(https://github.com/jax-ml/jax/issues/27470): Not yet supported. - # (jnp.sinc, jnp.inf), - (jnp.sqrt, -4.0), - (jnp.subtract, (jnp.inf, jnp.inf)), - (jnp.tan, jnp.inf), - (jnp.true_divide, (0.0, 0.0)), - (operator.add, (jnp.inf, -jnp.inf)), - (operator.mod, (1.0, 0.0)), - (operator.pow, (-1.0, 0.5)), - (operator.sub, (jnp.inf, jnp.inf)), - (operator.truediv, (0.0, 0.0)), + # (jnp.sinc, jnp.inf, 1.0), + (jnp.sqrt, -4.0, 4.0), + (jnp.subtract, (jnp.inf, jnp.inf), (0.0, 0.0)), + (jnp.tan, jnp.inf, 1.0), + (jnp.true_divide, (0.0, 0.0), (1.0, 1.0)), + (operator.add, (jnp.inf, -jnp.inf), (0.0, 0.0)), + (operator.mod, (1.0, 0.0), (1.0, 1.0)), + (operator.pow, (-1.0, 0.5), (1.0, 1.0)), + (operator.sub, (jnp.inf, jnp.inf), (0.0, 0.0)), + (operator.truediv, (0.0, 0.0), (1.0, 1.0)), # go/keep-sorted end )) ) - def test_can_raise_nan_error(self, jit, f, operands): - operands = [jnp.float32(x) for x in operands] + def test_can_raise_nan_error(self, jit, f, ops_err, ops_no_err): + ops_err = [jnp.float32(x) for x in ops_err] + ops_no_err = [jnp.float32(x) for x in ops_no_err] if jit: f = jax.jit(f) with jnp_error.error_checking_behavior(nan="raise"): - f(*operands) + f(*ops_no_err) + error_check.raise_if_error() # should not raise error + + f(*ops_err) with self.assertRaisesRegex(JaxValueError, "NaN"): error_check.raise_if_error() From c9bc5f094d7b839d458c7191bbfc2c9defc02235 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Longeri?= Date: Wed, 26 Mar 2025 13:22:41 -0700 Subject: [PATCH 197/483] [Mosaic:TPU] 32-bit sublane broadcast for non-native tilings PiperOrigin-RevId: 740881404 --- .../tpu/transforms/apply_vector_layout.cc | 25 +++++++++++-------- .../tpu/transforms/infer_vector_layout.cc | 6 ++--- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 7755738a4fc7..71924739595c 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -3416,20 +3416,25 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op, if (tiling[1] != ctx.target_shape[1]) { return op.emitOpError("Not implemented: unsupported tiling"); } - int64_t num_tiles = layout_in.tilesPerVreg(ctx.target_shape); + const int64_t num_tiles = layout_in.tilesPerVreg(ctx.target_shape); + const int64_t sublanes_per_tile = + layout_in.sublanesPerTile(ctx.target_shape); if (needs_physical_broadcast == std::array{true, false}) { // Sublane broadcast const int packing = layout_in.packing(); - if (num_tiles != 1) { - return op.emitOpError( - "Not implemented: Only native tiling supported"); - } TPU_ASSERT_EQ_OP(*(src_tiles.dimensions().end() - 2), 1); TPU_ASSERT_OP(offsets_in[0].has_value()); const int64_t sublane_offset = *offsets_in[0] / packing; const int64_t subelement_offset = *offsets_in[0] % packing; - const DenseI32ArrayAttr indices = builder.getDenseI32ArrayAttr( - SmallVector(ctx.target_shape[0], sublane_offset)); + SmallVector pattern; + pattern.reserve(ctx.target_shape[0]); + for (int32_t t = 0; t < num_tiles; ++t) { + for (int32_t i = 0; i < sublanes_per_tile; ++i) { + pattern.push_back(sublanes_per_tile * t + sublane_offset); + } + } + const DenseI32ArrayAttr sublane_pattern = + builder.getDenseI32ArrayAttr(pattern); const absl::Status status = src_tiles.EachStatus([&](const absl::Span src_idx, Value *const src_vreg) { @@ -3446,8 +3451,8 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op, return absl::InternalError(""); } } - dst_vreg = builder.create(dst_vreg.getType(), - dst_vreg, indices, 0); + dst_vreg = builder.create( + dst_vreg.getType(), dst_vreg, sublane_pattern, 0); SmallVector dst_starts(dst_tiles_implicit_shape.size()); SmallVector dst_limits(dst_tiles_implicit_shape.size()); for (int64_t i = 0; i < dst_tiles.num_dimensions(); ++i) { @@ -3469,8 +3474,6 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op, std::array{false, true}) { // Lane broadcast TPU_ASSERT_EQ_OP(*(src_tiles.dimensions().end() - 1), 1); TPU_ASSERT_OP(offsets_in[1].has_value()); - const int64_t sublanes_per_tile = - layout_in.sublanesPerTile(ctx.target_shape); const int64_t offset = *offsets_in[1]; const int64_t lane_offset = offset % ctx.target_shape[1]; const int64_t tile_offset = offset / ctx.target_shape[1]; diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index 00e53314e588..c1a642b48f04 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -1092,13 +1092,11 @@ class VectorLayoutInferer { } auto src_tiled_ishape = layout.getImplicitTiledDims(src_ty.getShape(), 1); auto dst_tiled_ishape = layout.getImplicitTiledDims(res_ty.getShape(), 1); - // Since we can only do sublane broadcasts in the (8, 128) tiling, we - // should always use that when sublane broadcasting is required. if (src_tiled_ishape[0] != dst_tiled_ishape[0] && layout.offsets()[0] != std::nullopt) { + // TODO(tlongeri): Remove this. We support non-native tiling now, but + // things may still break downstream due to missing relayouts. LayoutOffsets offsets = layout.offsets(); - // At the moment relayout can only produce replicated sublanes when - // converting to (8, 128) if the input was in (1, 128) tiling if (layout.tiling()[0] == 1 && layout.bitwidth() == kNativeBitwidth) { offsets[0] = std::nullopt; } From b92b9b0e26700202cea26ebbbc8e5f9ab42997d1 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 26 Mar 2025 13:36:39 -0700 Subject: [PATCH 198/483] Raise an informative error when the length of device_assignment doesn't match the mesh.size of out_avals. This happens when (1) we can't extract the device_assignment from the arguments and (2) there is no concrete mesh in context. For example: ``` def test_random_normal_wo_mesh_context_error(self): mesh = jtu.create_mesh((2, 2), ('x', 'y'), axis_types=(AxisType.Explicit,) * 2) s = NamedSharding(mesh, P('x', 'y')) @jax.jit def f(key): out = jax.random.normal(key, shape=(8, 12), out_sharding=s) self.assertEqual(out.aval.sharding.spec, P('x', 'y')) self.assertEqual(out.aval.sharding.mesh, mesh.abstract_mesh) return out key = jax.random.key(1) with self.assertRaisesRegex( ValueError, 'Length of device assignment.*is not equal to the size of the mesh'): f(key) ``` PiperOrigin-RevId: 740886114 --- jax/_src/interpreters/pxla.py | 10 ++++++++++ tests/pjit_test.py | 18 ++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 387f0661ae9d..6f95b1b72281 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2281,6 +2281,16 @@ def lower_sharding_computation( devices_from_context) unique_intermediate_shardings = [js for js, _ in unique_intermediate_shardings] + for a in global_out_avals: + if (a is not core.abstract_token and not a.sharding.mesh.empty and + a.sharding.mesh._are_all_axes_explicit and + len(device_assignment) != a.sharding.mesh.size): + raise ValueError( + f"Length of device assignment {len(device_assignment)} is not equal" + f" to the size of the mesh {a.sharding.mesh.size} of aval" + f" {a.str_short(True, True)}. Please enter your `jit` into a mesh" + " context via `jax.sharding.use_mesh`.") + # TODO(parkers): One _raw_platform has been unified with platform, # change this back to just read platform. platforms = lowering_platforms or ( diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 5cd16e1e6925..277f24bd703f 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7265,6 +7265,24 @@ def f(key): else: self.assertIn('mhlo.sharding = "{devices=[2,2]<=[4]}"}', lowered_text) + def test_random_normal_wo_mesh_context_error(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y'), + axis_types=(AxisType.Explicit,) * 2) + s = NamedSharding(mesh, P('x', 'y')) + + @jax.jit + def f(key): + out = jax.random.normal(key, shape=(8, 12), out_sharding=s) + self.assertEqual(out.aval.sharding.spec, P('x', 'y')) + self.assertEqual(out.aval.sharding.mesh, mesh.abstract_mesh) + return out + + key = jax.random.key(1) + with self.assertRaisesRegex( + ValueError, + 'Length of device assignment.*is not equal to the size of the mesh'): + f(key) + def test_random_normal_wo_mesh_context(self): mesh = jtu.create_mesh((2, 2), ('x', 'y'), axis_types=(AxisType.Explicit,) * 2) From d9a6cd1a5ec1aeba2aa479f8f79f37c2504e4b77 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 26 Mar 2025 13:40:50 -0700 Subject: [PATCH 199/483] Remove xla_client.make_gpu_client. Cleanup; this code is not used any more because we use C API plugins instead. PiperOrigin-RevId: 740887556 --- jax/_src/xla_bridge.py | 75 ++++------------------------------------ jaxlib/xla/xla_client.py | 45 ------------------------ 2 files changed, 6 insertions(+), 114 deletions(-) diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 72d88b9735b7..227359dc4676 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -89,13 +89,13 @@ 'Restricts the set of ROCM devices that JAX will use. Either "all", or a ' 'comma-separate list of integer device IDs.') -_MOCK_NUM_GPU_PROCESSES = config.int_flag( +MOCK_NUM_GPU_PROCESSES = config.int_flag( name="mock_num_gpu_processes", default=0, help="Mock number of JAX processes in GPU client. Value zero turns " "off mocking.", ) -_MOCK_GPU_TOPOLOGY = config.string_flag( +MOCK_GPU_TOPOLOGY = config.string_flag( name="jax_mock_gpu_topology", default="", help='Mock multi-host GPU topology in GPU client. The value should ' @@ -432,7 +432,7 @@ def _version_check(name: str, f'following issues with CUDA components:\n' f'{join_str.join(errors)}') -def _get_num_nodes_from_gpu_topology(topology: str) -> int: +def get_num_nodes_from_gpu_topology(topology: str) -> int: try: slices_str, hosts_per_slice_str, _ = topology.split("x", 2) return int(slices_str) * int(hosts_per_slice_str) @@ -441,69 +441,6 @@ def _get_num_nodes_from_gpu_topology(topology: str) -> int: '" x x ' '".') -def make_gpu_client( - *, platform_name: str, visible_devices_flag: config.Flag[str] -) -> xla_client.Client: - visible_devices = visible_devices_flag.value - allowed_devices = None - if visible_devices != "all": - allowed_devices = {int(x) for x in visible_devices.split(",")} - - mock_gpu_topology = _MOCK_GPU_TOPOLOGY.value or None - mock_num_gpu_processes = (_get_num_nodes_from_gpu_topology(mock_gpu_topology) if - mock_gpu_topology else _MOCK_NUM_GPU_PROCESSES.value) - - use_mock_gpu_client = mock_num_gpu_processes > 0 - num_nodes = (mock_num_gpu_processes if use_mock_gpu_client - else distributed.global_state.num_processes) - - if platform_name == "cuda": - if not os.getenv("JAX_SKIP_CUDA_CONSTRAINTS_CHECK"): - _check_cuda_versions() - else: - print('Skipped CUDA versions constraints check due to the ' - 'JAX_SKIP_CUDA_CONSTRAINTS_CHECK env var being set.') - - devices_to_check = ( - allowed_devices - if allowed_devices - else range(cuda_versions.cuda_device_count()) - ) - _check_cuda_compute_capability(devices_to_check) - - return xla_client.make_gpu_client( - distributed_client=distributed.global_state.client, - node_id=distributed.global_state.process_id, - num_nodes=num_nodes, - platform_name=platform_name, - allowed_devices=allowed_devices, - mock=use_mock_gpu_client, - ) - - -if hasattr(xla_client, "make_gpu_client"): - register_backend_factory( - "cuda", - partial( - make_gpu_client, - platform_name="cuda", - visible_devices_flag=CUDA_VISIBLE_DEVICES, - ), - priority=200, - fail_quietly=True, - ) - register_backend_factory( - "rocm", - partial( - make_gpu_client, - platform_name="rocm", - visible_devices_flag=_ROCM_VISIBLE_DEVICES, - ), - priority=200, - fail_quietly=True, - ) - - if hasattr(xla_client, "make_tpu_client"): # TODO(phawkins,skyewm): switch TPU plugin to use the PJRT plugin mechanism, # and then fail loudly on initialization failure. @@ -652,9 +589,9 @@ def _options_from_jax_configs(plugin_name): else _ROCM_VISIBLE_DEVICES.value) if visible_devices != 'all': options['visible_devices'] = [int(x) for x in visible_devices.split(',')] - mock_gpu_topology = _MOCK_GPU_TOPOLOGY.value or None - mock_num_processes = (_get_num_nodes_from_gpu_topology(mock_gpu_topology) if - mock_gpu_topology else _MOCK_NUM_GPU_PROCESSES.value) + mock_gpu_topology = MOCK_GPU_TOPOLOGY.value or None + mock_num_processes = (get_num_nodes_from_gpu_topology(mock_gpu_topology) if + mock_gpu_topology else MOCK_NUM_GPU_PROCESSES.value) options['enable_mock_nccl'] = mock_num_processes > 0 if mock_num_processes > 0: options['num_nodes'] = mock_num_processes diff --git a/jaxlib/xla/xla_client.py b/jaxlib/xla/xla_client.py index a9b1109c3bd3..ce881bee17c0 100644 --- a/jaxlib/xla/xla_client.py +++ b/jaxlib/xla/xla_client.py @@ -85,51 +85,6 @@ def make_cpu_client( ) -def make_gpu_client( - distributed_client=None, - node_id=0, - num_nodes=1, - platform_name=None, - allowed_devices=None, - mock=False, - mock_gpu_topology=None, -): - """Returns a GPU client. BFC allocator is used by default.""" - options = generate_pjrt_gpu_plugin_options() - allocator = options['allocator'] - config = _xla.GpuAllocatorConfig() - if allocator == 'default': - config.kind = _xla.GpuAllocatorConfig.Kind.DEFAULT - if allocator == 'platform': - config.kind = _xla.GpuAllocatorConfig.Kind.PLATFORM - if allocator == 'bfc': - config.kind = _xla.GpuAllocatorConfig.Kind.BFC - if allocator == 'cuda_async': - config.kind = _xla.GpuAllocatorConfig.Kind.CUDA_ASYNC - if 'memory_fraction' in options: - config.memory_fraction = options['memory_fraction'] - if 'preallocate' in options: - config.preallocate = options['preallocate'] - if 'collective_memory_size' in options: - config.collective_memory_size = options['collective_memory_size'] - register_custom_call_handler('CUDA', _xla.register_custom_call_target) - register_custom_call_handler('ROCM', _xla.register_custom_call_target) - register_custom_type_id_handler('CUDA', _xla.register_custom_type_id) - register_custom_type_id_handler('ROCM', _xla.register_custom_type_id) - - return _xla.get_gpu_client( - asynchronous=True, - allocator_config=config, - distributed_client=distributed_client, - node_id=node_id, - num_nodes=num_nodes, - platform_name=platform_name, - allowed_devices=allowed_devices, - mock=mock, - mock_gpu_topology=mock_gpu_topology, - ) - - def make_tfrt_tpu_c_api_client(options: _NameValueMapping | None = None): assert pjrt_plugin_loaded('tpu') if not pjrt_plugin_initialized('tpu'): From 66908372af2a21832eb44fb9c652dda317397b4c Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 26 Mar 2025 14:06:26 -0700 Subject: [PATCH 200/483] jnp.tri*_indices: support __jax_array__ inputs --- jax/_src/numpy/lax_numpy.py | 18 +++++++++++++----- jax/numpy/__init__.pyi | 4 ++-- tests/array_extensibility_test.py | 4 ++-- 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index fd6209ab22c4..83d5e9ee3e80 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -60,7 +60,7 @@ from jax._src.numpy.vectorize import vectorize from jax._src.sharding_impls import SingleDeviceSharding from jax._src.typing import ( - Array, ArrayLike, DType, DTypeLike, DeprecatedArg, DimSize, Shape + Array, ArrayLike, DType, DTypeLike, DeprecatedArg, DimSize, DuckTypedArray, Shape ) from jax._src.util import ( NumpyComplexWarning, canonicalize_axis as _canonicalize_axis, @@ -7557,7 +7557,7 @@ def tril_indices(n: int, k: int = 0, m: int | None = None) -> tuple[Array, Array @export -def triu_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]: +def triu_indices_from(arr: ArrayLike | DuckTypedArray, k: int = 0) -> tuple[Array, Array]: """Return the indices of upper triangle of a given array. JAX implementation of :func:`numpy.triu_indices_from`. @@ -7608,14 +7608,18 @@ def triu_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]: >>> jnp.triu_indices_from(arr, k=-1) (Array([0, 0, 0, 1, 1, 1, 2, 2], dtype=int32), Array([0, 1, 2, 0, 1, 2, 1, 2], dtype=int32)) """ - arr_shape = np.shape(arr) + if hasattr(arr, "shape"): + arr_shape = arr.shape + else: + arr = util.ensure_arraylike("triu_indices_from", arr) + arr_shape = arr.shape if len(arr_shape) != 2: raise ValueError("Only 2-D inputs are accepted") return triu_indices(arr_shape[0], k=k, m=arr_shape[1]) @export -def tril_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]: +def tril_indices_from(arr: ArrayLike | DuckTypedArray, k: int = 0) -> tuple[Array, Array]: """Return the indices of lower triangle of a given array. JAX implementation of :func:`numpy.tril_indices_from`. @@ -7666,7 +7670,11 @@ def tril_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]: >>> jnp.tril_indices_from(arr, k=-1) (Array([1, 2, 2], dtype=int32), Array([0, 0, 1], dtype=int32)) """ - arr_shape = np.shape(arr) + if hasattr(arr, "shape"): + arr_shape = arr.shape + else: + arr = util.ensure_arraylike("tril_indices_from", arr) + arr_shape = arr.shape if len(arr_shape) != 2: raise ValueError("Only 2-D inputs are accepted") return tril_indices(arr_shape[0], k=k, m=arr_shape[1]) diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index 640e9de7eac3..fb679969fe31 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -929,14 +929,14 @@ def tril(m: ArrayLike, k: int = ...) -> Array: ... def tril_indices( n: int, k: int = ..., m: int | None = ... ) -> tuple[Array, Array]: ... -def tril_indices_from(arr: ArrayLike, k: int = ...) -> tuple[Array, Array]: ... +def tril_indices_from(arr: ArrayLike | DuckTypedArray, k: int = ...) -> tuple[Array, Array]: ... def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap: builtins.bool = ..., *, inplace: builtins.bool = ...) -> Array: ... def trim_zeros(filt: ArrayLike, trim: str = ...) -> Array: ... def triu(m: ArrayLike, k: int = ...) -> Array: ... def triu_indices( n: int, k: int = ..., m: int | None = ... ) -> tuple[Array, Array]: ... -def triu_indices_from(arr: ArrayLike, k: int = ...) -> tuple[Array, Array]: ... +def triu_indices_from(arr: ArrayLike | DuckTypedArray, k: int = ...) -> tuple[Array, Array]: ... def true_divide(x: ArrayLike, y: ArrayLike, /) -> Array: ... def trunc(x: ArrayLike, /) -> Array: ... uint: Any diff --git a/tests/array_extensibility_test.py b/tests/array_extensibility_test.py index 7c0ec07e6a05..551f6d45dc41 100644 --- a/tests/array_extensibility_test.py +++ b/tests/array_extensibility_test.py @@ -484,10 +484,10 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: # NumPyAPI.sig(jnp.transpose, Float[5, 6]), NumPyAPI.sig(jnp.trapezoid, Float[5]), NumPyAPI.sig(jnp.tril, Float[5, 6]), - # NumPyAPI.sig(jnp.tril_indices_from, Float[5, 6]), + NumPyAPI.sig(jnp.tril_indices_from, Float[5, 6]), NumPyAPI.sig(jnp.trim_zeros, Float[5]), NumPyAPI.sig(jnp.triu, Float[5, 6]), - # NumPyAPI.sig(jnp.triu_indices_from, Float[5, 6]), + NumPyAPI.sig(jnp.triu_indices_from, Float[5, 6]), NumPyAPI.sig(jnp.true_divide, Float[5], Float[5]), NumPyAPI.sig(jnp.trunc, Float[5]), NumPyAPI.sig(jnp.union1d, Int[5], Int[5]), From c450b69dd7cb3f4ddc1700866ce7ab5dc9c4c459 Mon Sep 17 00:00:00 2001 From: Ayaka Date: Wed, 26 Mar 2025 14:26:54 -0700 Subject: [PATCH 201/483] Add missing `__len__` to MutableArray Fixes https://github.com/jax-ml/jax/issues/27476 PiperOrigin-RevId: 740903637 --- jax/_src/core.py | 4 ++-- jax/_src/state/types.py | 6 ++++++ tests/mutable_array_test.py | 12 ++++++++++++ 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index ef90341f5cf7..14f3d9cc18e2 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1940,8 +1940,7 @@ def update(self, shape=None, dtype=None, weak_type=None, **kwargs): if 'sharding' not in kwargs: kwargs['sharding'] = self.sharding if 'vma' not in kwargs: - kwargs['vma'] = getattr(self, 'vma', - frozenset()) + kwargs['vma'] = getattr(self, 'vma', frozenset()) return ShapedArray(shape, dtype, weak_type, **kwargs) ndim = property(lambda self: len(self.shape)) @@ -2170,6 +2169,7 @@ def __init__(self, aval, buf): def __getitem__(self, idx): return self._aval._getitem(self, idx) def __setitem__(self, idx, x): return self._aval._setitem(self, idx, x) def __repr__(self) -> str: return 'Mutable' + repr(self[...]) + def __len__(self) -> int: return self._aval._len(self) pytype_aval_mappings[MutableArray] = lambda x: x._aval def mutable_array(init_val): diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py index fa9d0cb9fb16..e926e3a35f80 100644 --- a/jax/_src/state/types.py +++ b/jax/_src/state/types.py @@ -333,6 +333,12 @@ def update(self, inner_aval=None): ndim = property(lambda self: len(self.shape)) size = property(lambda self: math.prod(self.shape)) + def _len(self, ignored_tracer) -> int: + try: + return self.shape[0] + except IndexError as err: + raise TypeError("len() of unsized object") from err # same as numpy error + @property def shape(self): try: diff --git a/tests/mutable_array_test.py b/tests/mutable_array_test.py index e962653ed32d..950bddf544d7 100644 --- a/tests/mutable_array_test.py +++ b/tests/mutable_array_test.py @@ -116,6 +116,18 @@ def f(y_mut, z): check_dtypes=False) self.assertAllClose(w, 10, check_dtypes=False) + @parameterized.parameters([True, False]) + def test_len_mutable_array(self, jit): + x_mut = core.mutable_array(jnp.zeros(3)) + + def f(): + return jnp.int32(len(x_mut)) + + if jit: + f = jax.jit(f) + + self.assertEqual(f(), 3) + @parameterized.parameters([True, False]) def test_internal_mutarray_basic(self, jit): def f(): From ce3941c635b9994b4e27ee3cd377d2bd568d5ea7 Mon Sep 17 00:00:00 2001 From: Ayaka Date: Wed, 26 Mar 2025 14:35:10 -0700 Subject: [PATCH 202/483] Add division-by-zero checks to jax.numpy functions PiperOrigin-RevId: 740906595 --- jax/_src/numpy/error.py | 23 +++++++++- jax/_src/numpy/ufuncs.py | 4 ++ tests/jax_numpy_error_test.py | 80 +++++++++++++++++++++++++++-------- 3 files changed, 88 insertions(+), 19 deletions(-) diff --git a/jax/_src/numpy/error.py b/jax/_src/numpy/error.py index 52b996a0b050..20dab289d779 100644 --- a/jax/_src/numpy/error.py +++ b/jax/_src/numpy/error.py @@ -30,7 +30,9 @@ def _is_category_disabled( if category == "nan": raise ValueError("nan is deprecated. Use `_set_error_if_nan` instead.") if category == "divide": - return config.error_checking_behavior_divide.value == "ignore" + raise ValueError( + "divide is deprecated. Use `_set_error_if_divide_by_zero` instead." + ) if category == "oob": return config.error_checking_behavior_oob.value == "ignore" raise ValueError(f"Invalid category: {category}") @@ -81,6 +83,25 @@ def _set_error_if_nan(pred: jax.Array, /): error_check_lib.set_error_if(jnp.isnan(pred), "NaN encountered") +def _set_error_if_divide_by_zero(pred: jax.Array, /): + """Set the internal error state if any element of `pred` is zero. + + This function is intended for checking if the denominator of a division is + zero. + + This function is disabled if the `jax_error_checking_behavior_divide` flag is + set to "ignore". + """ + if config.error_checking_behavior_divide.value == "ignore": + return + + # TODO(ayx): fix the circular import issue. + from jax._src import error_check as error_check_lib + import jax.numpy as jnp + zero = jnp.zeros_like(pred, shape=()) + error_check_lib.set_error_if(pred == zero, "Division by zero encountered") + + Behavior = Literal["ignore", "raise"] diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 0ea2992c9955..e561b7ae71b6 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -2471,6 +2471,7 @@ def true_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: :func:`jax.numpy.floor_divide` for integer division """ x1, x2 = promote_args_inexact("true_divide", x1, x2) + jnp_error._set_error_if_divide_by_zero(x2) out = lax.div(x1, x2) jnp_error._set_error_if_nan(out) return out @@ -2523,6 +2524,7 @@ def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: Array([3., 2., 2.], dtype=float32) """ x1, x2 = promote_args_numeric("floor_divide", x1, x2) + jnp_error._set_error_if_divide_by_zero(x2) dtype = dtypes.dtype(x1) if dtypes.issubdtype(dtype, np.unsignedinteger): return lax.div(x1, x2) @@ -2577,6 +2579,7 @@ def divmod(x1: ArrayLike, x2: ArrayLike, /) -> tuple[Array, Array]: if dtypes.issubdtype(dtypes.dtype(x1), np.integer): return floor_divide(x1, x2), remainder(x1, x2) else: + jnp_error._set_error_if_divide_by_zero(x2) return _float_divmod(x1, x2) @@ -3090,6 +3093,7 @@ def remainder(x1: ArrayLike, x2: ArrayLike, /) -> Array: [ 0., 2., -2.]], dtype=float32) """ x1, x2 = promote_args_numeric("remainder", x1, x2) + jnp_error._set_error_if_divide_by_zero(x2) zero = _constant_like(x1, 0) if dtypes.issubdtype(x2.dtype, np.integer): x2 = _where(x2 == 0, lax._ones(x2), x2) diff --git a/tests/jax_numpy_error_test.py b/tests/jax_numpy_error_test.py index 566e0b1ba209..f2262d8b5dc0 100644 --- a/tests/jax_numpy_error_test.py +++ b/tests/jax_numpy_error_test.py @@ -51,11 +51,9 @@ def f(x): error_check.raise_if_error() @parameterized.product(jit=[True, False]) - def test_error_category_divide_check(self, jit): + def test_set_error_if_divide_by_zero(self, jit): def f(x, y): - jnp_error._set_error_if_with_category( - y == 0.0, "division by zero", category="divide" - ) + jnp_error._set_error_if_divide_by_zero(y) return x / y if jit: @@ -70,7 +68,7 @@ def f(x, y): with jnp_error.error_checking_behavior(divide="raise"): _ = f(x, y) - with self.assertRaisesRegex(JaxValueError, "division by zero"): + with self.assertRaisesRegex(JaxValueError, "Division by zero"): error_check.raise_if_error() @parameterized.product(jit=[True, False]) @@ -115,22 +113,22 @@ def test_error_category_invalid_category(self): ) @staticmethod - def op_cases(cases): + def nan_cases(cases): for jit in (True, False): - for func, ops_error, ops_no_err in cases: - if not isinstance(ops_error, tuple): - ops_error = (ops_error,) - if not isinstance(ops_no_err, tuple): - ops_no_err = (ops_no_err,) + for func, args_error, args_no_err in cases: + if not isinstance(args_error, tuple): + args_error = (args_error,) + if not isinstance(args_no_err, tuple): + args_no_err = (args_no_err,) jit_str = "jit" if jit else "nojit" func_str = f"{func.__module__}.{func.__name__}" name = f"_{jit_str}_{func_str}" - yield name, jit, func, ops_error, ops_no_err + yield name, jit, func, args_error, args_no_err @parameterized.named_parameters( - op_cases(( + nan_cases(( # List of all NaN-producing jax.numpy functions. # The first group of numbers is the input that will produce a NaN, and # the second group is the input that will not produce a NaN. @@ -172,21 +170,67 @@ def op_cases(cases): # go/keep-sorted end )) ) - def test_can_raise_nan_error(self, jit, f, ops_err, ops_no_err): - ops_err = [jnp.float32(x) for x in ops_err] - ops_no_err = [jnp.float32(x) for x in ops_no_err] + def test_can_raise_nan_error(self, jit, f, args_err, args_no_err): + args_err = [jnp.float32(x) for x in args_err] + args_no_err = [jnp.float32(x) for x in args_no_err] if jit: f = jax.jit(f) with jnp_error.error_checking_behavior(nan="raise"): - f(*ops_no_err) + f(*args_no_err) error_check.raise_if_error() # should not raise error - f(*ops_err) + f(*args_err) with self.assertRaisesRegex(JaxValueError, "NaN"): error_check.raise_if_error() + INT_TYPES = (jnp.int32, jnp.uint32, jnp.int64, jnp.uint64, jnp.int16, + jnp.uint16, jnp.int8, jnp.uint8) + FLOAT_TYPES = (jnp.float32, jnp.float64, jnp.float16, jnp.bfloat16) + + @staticmethod + def divide_cases(cases): + for jit in (True, False): + for func, dtypes in cases: + for dtype in dtypes: + jit_str = "jit" if jit else "nojit" + func_str = f"{func.__module__}.{func.__name__}" + dtype_str = dtype.__name__ + name = f"_{jit_str}_{func_str}_{dtype_str}" + yield name, jit, func, dtype + + @parameterized.named_parameters( + divide_cases(( + # go/keep-sorted start + (jnp.divmod, FLOAT_TYPES + INT_TYPES), + (jnp.floor_divide, INT_TYPES), + (jnp.mod, FLOAT_TYPES + INT_TYPES), + (jnp.remainder, FLOAT_TYPES + INT_TYPES), + (jnp.true_divide, FLOAT_TYPES), + (operator.mod, FLOAT_TYPES + INT_TYPES), + (operator.truediv, FLOAT_TYPES), + # go/keep-sorted end + )) + ) + def test_can_raise_divide_by_zero_error(self, jit, div_func, dtype): + if not jax.config.x64_enabled and jnp.dtype(dtype).itemsize == 8: + self.skipTest("64-bit types require x64_enabled") + + args_err = (dtype(1), dtype(0)) + args_no_err = (dtype(1), dtype(1)) + + if jit: + div_func = jax.jit(div_func) + + with jnp_error.error_checking_behavior(divide="raise"): + div_func(*args_no_err) + error_check.raise_if_error() # should not raise error + + div_func(*args_err) + with self.assertRaisesRegex(JaxValueError, "Division by zero"): + error_check.raise_if_error() + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From af25dc47196e1237b6e92d8bb399b3439f22eebd Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 26 Mar 2025 15:19:24 -0700 Subject: [PATCH 203/483] Update the Windows docker image to ltsc2022 PiperOrigin-RevId: 740921613 --- ci/envs/docker.env | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/envs/docker.env b/ci/envs/docker.env index 82a76d33350c..a0f558520d45 100644 --- a/ci/envs/docker.env +++ b/ci/envs/docker.env @@ -41,5 +41,5 @@ fi # Windows image for building JAX artifacts, running Pytests CPU tests, and Bazel # tests if [[ $os =~ "msys_nt" ]]; then - export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/tf-test-windows@sha256:6e2b299f12418d70ea522646b3dd618042a102f2ac2e4f8b1e423638549ea801" + export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/tf-test-windows:latest" fi \ No newline at end of file From 667c4a0ee0da4cb96795624f7b91a9deacdeca14 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 26 Mar 2025 15:27:25 -0700 Subject: [PATCH 204/483] Support __jax_array__ for jnp.shape/jnp.size/jnp.ndim --- jax/_src/numpy/lax_numpy.py | 6 ++--- jax/_src/numpy/util.py | 38 +++++++++++++++++++++++-------- jax/_src/typing.py | 14 +++++++++++- jax/numpy/__init__.pyi | 12 +++++----- tests/array_extensibility_test.py | 6 ++--- 5 files changed, 53 insertions(+), 23 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 83d5e9ee3e80..63edaed0adeb 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -60,7 +60,7 @@ from jax._src.numpy.vectorize import vectorize from jax._src.sharding_impls import SingleDeviceSharding from jax._src.typing import ( - Array, ArrayLike, DType, DTypeLike, DeprecatedArg, DimSize, DuckTypedArray, Shape + Array, ArrayLike, DType, DTypeLike, DeprecatedArg, DimSize, Shape, SupportsShape ) from jax._src.util import ( NumpyComplexWarning, canonicalize_axis as _canonicalize_axis, @@ -7557,7 +7557,7 @@ def tril_indices(n: int, k: int = 0, m: int | None = None) -> tuple[Array, Array @export -def triu_indices_from(arr: ArrayLike | DuckTypedArray, k: int = 0) -> tuple[Array, Array]: +def triu_indices_from(arr: ArrayLike | SupportsShape, k: int = 0) -> tuple[Array, Array]: """Return the indices of upper triangle of a given array. JAX implementation of :func:`numpy.triu_indices_from`. @@ -7619,7 +7619,7 @@ def triu_indices_from(arr: ArrayLike | DuckTypedArray, k: int = 0) -> tuple[Arra @export -def tril_indices_from(arr: ArrayLike | DuckTypedArray, k: int = 0) -> tuple[Array, Array]: +def tril_indices_from(arr: ArrayLike | SupportsShape, k: int = 0) -> tuple[Array, Array]: """Return the indices of lower triangle of a given array. JAX implementation of :func:`numpy.tril_indices_from`. diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index e281c63ae654..e0e20d443e02 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -27,7 +27,8 @@ from jax._src.lib import xla_client as xc from jax._src.sharding_impls import SingleDeviceSharding from jax._src.util import safe_zip, safe_map, set_module -from jax._src.typing import Array, ArrayLike, DimSize, DType, DTypeLike, Shape +from jax._src.typing import (Array, ArrayLike, DimSize, DType, DTypeLike, + Shape, SupportsNdim, SupportsShape, SupportsSize) from jax.sharding import Sharding import numpy as np @@ -313,7 +314,7 @@ def normalize_device_to_sharding(device: xc.Device | Sharding | None) -> Shardin @export -def ndim(a: ArrayLike) -> int: +def ndim(a: ArrayLike | SupportsNdim) -> int: """Return the number of dimensions of an array. JAX implementation of :func:`numpy.ndim`. Unlike ``np.ndim``, this function @@ -321,7 +322,7 @@ def ndim(a: ArrayLike) -> int: tuple. Args: - a: array-like object. + a: array-like object, or any object with an ``ndim`` attribute. Returns: An integer specifying the number of dimensions of ``a``. @@ -346,13 +347,18 @@ def ndim(a: ArrayLike) -> int: >>> x.ndim 1 """ + if hasattr(a, "ndim"): + return a.ndim # Deprecation warning added 2025-2-20. check_arraylike("ndim", a, emit_warning=True) - return np.ndim(a) # NumPy dispatches to a.ndim if available. + if hasattr(a, "__jax_array__"): + a = a.__jax_array__() + # NumPy dispatches to a.ndim if available. + return np.ndim(a) # type: ignore[arg-type] @export -def shape(a: ArrayLike) -> tuple[int, ...]: +def shape(a: ArrayLike | SupportsShape) -> tuple[int, ...]: """Return the shape an array. JAX implementation of :func:`numpy.shape`. Unlike ``np.shape``, this function @@ -360,7 +366,7 @@ def shape(a: ArrayLike) -> tuple[int, ...]: tuple. Args: - a: array-like object. + a: array-like object, or any object with a ``shape`` attribute. Returns: An tuple of integers representing the shape of ``a``. @@ -385,13 +391,18 @@ def shape(a: ArrayLike) -> tuple[int, ...]: >>> x.shape (10,) """ + if hasattr(a, "shape"): + return a.shape # Deprecation warning added 2025-2-20. check_arraylike("shape", a, emit_warning=True) - return np.shape(a) # NumPy dispatches to a.shape if available. + if hasattr(a, "__jax_array__"): + a = a.__jax_array__() + # NumPy dispatches to a.shape if available. + return np.shape(a) # type: ignore[arg-type] @export -def size(a: ArrayLike, axis: int | None = None) -> int: +def size(a: ArrayLike | SupportsSize | SupportsShape, axis: int | None = None) -> int: """Return number of elements along a given axis. JAX implementation of :func:`numpy.size`. Unlike ``np.size``, this function @@ -399,7 +410,8 @@ def size(a: ArrayLike, axis: int | None = None) -> int: tuple. Args: - a: array-like object + a: array-like object, or any object with a ``size`` attribute when ``axis`` is not + specified, or with a ``shape`` attribute when ``axis`` is specified. axis: optional integer along which to count elements. By default, return the total number of elements. @@ -428,6 +440,12 @@ def size(a: ArrayLike, axis: int | None = None) -> int: >>> y.size 6 """ + if (axis is None and hasattr(a, "size")) or (axis is not None and hasattr(a, "shape")): + # NumPy dispatches to a.size/a.shape if available. + return np.size(a, axis=axis) # type: ignore[arg-type] # Deprecation warning added 2025-2-20. check_arraylike("size", a, emit_warning=True) - return np.size(a, axis=axis) # NumPy dispatches to a.size if available. + if hasattr(a, "__jax_array__"): + a = a.__jax_array__() + # NumPy dispatches to a.size/a.shape if available. + return np.size(a, axis=axis) # type: ignore[arg-type] diff --git a/jax/_src/typing.py b/jax/_src/typing.py index 010841b45dd2..ee2422dd2d73 100644 --- a/jax/_src/typing.py +++ b/jax/_src/typing.py @@ -47,7 +47,19 @@ @typing.runtime_checkable class SupportsDType(Protocol): @property - def dtype(self) -> DType: ... + def dtype(self, /) -> DType: ... + +class SupportsShape(Protocol): + @property + def shape(self, /) -> tuple[int, ...]: ... + +class SupportsSize(Protocol): + @property + def size(self, /) -> int: ... + +class SupportsNdim(Protocol): + @property + def ndim(self, /) -> int: ... # DTypeLike is meant to annotate inputs to np.dtype that return # a valid JAX dtype. It's different than numpy.typing.DTypeLike diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index fb679969fe31..259f6e3ed2ee 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -15,7 +15,7 @@ from jax._src.numpy.index_tricks import _Mgrid, _Ogrid, CClass as _CClass, RClas from jax._src.numpy.array_api_metadata import ArrayNamespaceInfo from jax._src.typing import ( Array, ArrayLike, DType, DTypeLike, DeprecatedArg, - DimSize, DuckTypedArray, Shape, StaticScalar, + DimSize, DuckTypedArray, Shape, StaticScalar, SupportsNdim, SupportsShape, SupportsSize, ) from jax._src.sharding_impls import NamedSharding, PartitionSpec as P from jax.numpy import fft as fft, linalg as linalg @@ -728,7 +728,7 @@ def nanvar(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., ddof: int = 0, keepdims: builtins.bool = False, where: ArrayLike | None = ...) -> Array: ... ndarray = Array -def ndim(a: ArrayLike) -> int: ... +def ndim(a: ArrayLike | SupportsNdim) -> int: ... def negative(x: ArrayLike, /) -> Array: ... newaxis = None def nextafter(x: ArrayLike, y: ArrayLike, /) -> Array: ... @@ -841,7 +841,7 @@ def setdiff1d( fill_value: ArrayLike | None = ..., ) -> Array: ... def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: builtins.bool = ...) -> Array: ... -def shape(a: ArrayLike) -> tuple[int, ...]: ... +def shape(a: ArrayLike | SupportsShape) -> tuple[int, ...]: ... def sign(x: ArrayLike, /) -> Array: ... def signbit(x: ArrayLike, /) -> Array: ... signedinteger = _np.signedinteger @@ -849,7 +849,7 @@ def sin(x: ArrayLike, /) -> Array: ... def sinc(x: ArrayLike, /) -> Array: ... single: Any def sinh(x: ArrayLike, /) -> Array: ... -def size(a: ArrayLike, axis: int | None = None) -> int: ... +def size(a: ArrayLike | SupportsSize, axis: int | None = None) -> int: ... def sort( a: ArrayLike, axis: int | None = ..., @@ -929,14 +929,14 @@ def tril(m: ArrayLike, k: int = ...) -> Array: ... def tril_indices( n: int, k: int = ..., m: int | None = ... ) -> tuple[Array, Array]: ... -def tril_indices_from(arr: ArrayLike | DuckTypedArray, k: int = ...) -> tuple[Array, Array]: ... +def tril_indices_from(arr: ArrayLike | SupportsShape, k: int = ...) -> tuple[Array, Array]: ... def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap: builtins.bool = ..., *, inplace: builtins.bool = ...) -> Array: ... def trim_zeros(filt: ArrayLike, trim: str = ...) -> Array: ... def triu(m: ArrayLike, k: int = ...) -> Array: ... def triu_indices( n: int, k: int = ..., m: int | None = ... ) -> tuple[Array, Array]: ... -def triu_indices_from(arr: ArrayLike | DuckTypedArray, k: int = ...) -> tuple[Array, Array]: ... +def triu_indices_from(arr: ArrayLike | SupportsShape, k: int = ...) -> tuple[Array, Array]: ... def true_divide(x: ArrayLike, y: ArrayLike, /) -> Array: ... def trunc(x: ArrayLike, /) -> Array: ... uint: Any diff --git a/tests/array_extensibility_test.py b/tests/array_extensibility_test.py index 551f6d45dc41..fae9129dd99a 100644 --- a/tests/array_extensibility_test.py +++ b/tests/array_extensibility_test.py @@ -403,7 +403,7 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: NumPyAPI.sig(jnp.nanstd, Float[5]), NumPyAPI.sig(jnp.nansum, Float[5]), NumPyAPI.sig(jnp.nanvar, Float[5]), - # NumPyAPI.sig(jnp.ndim, Float[5]), + NumPyAPI.sig(jnp.ndim, Float[5]), NumPyAPI.sig(jnp.negative, Float[5]), NumPyAPI.sig(jnp.nextafter, Float[5], Float[5]), NumPyAPI.sig(jnp.nonzero, Float[5]), @@ -455,13 +455,13 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: # NumPyAPI.sig(jnp.select, [Bool[5], Bool[5]], [Float[5], Float[5]], Float[5]), NumPyAPI.sig(jnp.setdiff1d, Int[5], Int[5]), NumPyAPI.sig(jnp.setxor1d, Int[5], Int[5]), - # NumPyAPI.sig(jnp.shape, Float[5, 3]), + NumPyAPI.sig(jnp.shape, Float[5, 3]), NumPyAPI.sig(jnp.sign, Float[5]), NumPyAPI.sig(jnp.signbit, Float[5]), NumPyAPI.sig(jnp.sin, Float[5]), NumPyAPI.sig(jnp.sinc, Float[5]), NumPyAPI.sig(jnp.sinh, Float[5]), - # NumPyAPI.sig(jnp.size, Float[5]), + NumPyAPI.sig(jnp.size, Float[5]), NumPyAPI.sig(jnp.sort, Float[5]), NumPyAPI.sig(jnp.sort_complex, Complex[5]), NumPyAPI.sig(jnp.spacing, Float[5]), From 5bc4c57f09c778043b4932dd52cb4ea45c5d7069 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 26 Mar 2025 15:27:20 -0700 Subject: [PATCH 205/483] Inline make_tfrt_tpu_c_api_client into its only caller. PiperOrigin-RevId: 740923936 --- jaxlib/xla/xla_client.py | 16 ++++++---------- jaxlib/xla/xla_client.pyi | 3 --- 2 files changed, 6 insertions(+), 13 deletions(-) diff --git a/jaxlib/xla/xla_client.py b/jaxlib/xla/xla_client.py index ce881bee17c0..eb7a3b5759e5 100644 --- a/jaxlib/xla/xla_client.py +++ b/jaxlib/xla/xla_client.py @@ -85,15 +85,6 @@ def make_cpu_client( ) -def make_tfrt_tpu_c_api_client(options: _NameValueMapping | None = None): - assert pjrt_plugin_loaded('tpu') - if not pjrt_plugin_initialized('tpu'): - initialize_pjrt_plugin('tpu') - if options is None: - options = {} - return _xla.get_c_api_client('tpu', options) - - DeviceTopology = _xla.DeviceTopology get_topology_for_devices = _xla.get_topology_for_devices @@ -169,7 +160,12 @@ def make_tpu_client( if not pjrt_plugin_loaded('tpu'): c_api = load_pjrt_plugin_dynamically('tpu', library_path or 'libtpu.so') profiler.register_plugin_profiler(c_api) - return make_tfrt_tpu_c_api_client(options) + assert pjrt_plugin_loaded('tpu') + if not pjrt_plugin_initialized('tpu'): + initialize_pjrt_plugin('tpu') + if options is None: + options = {} + return _xla.get_c_api_client('tpu', options) def generate_pjrt_gpu_plugin_options() -> _NameValueMapping: diff --git a/jaxlib/xla/xla_client.pyi b/jaxlib/xla/xla_client.pyi index 234af8f7b87d..5ac837ef1d85 100644 --- a/jaxlib/xla/xla_client.pyi +++ b/jaxlib/xla/xla_client.pyi @@ -106,9 +106,6 @@ def make_gpu_client( ) -> Client: ... -def make_tfrt_tpu_c_api_client(options: _NameValueMapping | None = None) -> Client: - ... - def make_tfrt_tpu_c_api_device_topology( topology_name: str | None = None, **kwargs ) -> DeviceTopology: From c88ea23035454a95d7e20e4976d1e595114c8a66 Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Wed, 26 Mar 2025 15:47:55 -0700 Subject: [PATCH 206/483] [JAX] Add caching to `colocated_python.colocated_cpu_devices()` For a deployment with many devices, `colocated_python.colocated_cpu_devices()` can take some time to find colocated devices as it needs to find matching devices one by one in Python. This change adds caching as an optimization to reduce the overall cost of API calls. PiperOrigin-RevId: 740930124 --- jax/experimental/colocated_python/api.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/jax/experimental/colocated_python/api.py b/jax/experimental/colocated_python/api.py index b855bba48abb..e72e04c6ded9 100644 --- a/jax/experimental/colocated_python/api.py +++ b/jax/experimental/colocated_python/api.py @@ -28,6 +28,15 @@ def colocated_cpu_devices( devices: Sequence[jax.Device], ) -> Sequence[jax.Device]: """Finds CPU devices colocated with the given devices.""" + if not isinstance(devices, tuple): + devices = tuple(devices) + return _colocated_cpu_devices_cached(devices) + + +@jax.util.cache(max_size=1024, trace_context_in_key=False) +def _colocated_cpu_devices_cached( + devices: tuple[jax.Device, ...], +) -> Sequence[jax.Device]: cpu_devices_by_colocation_id = collections.defaultdict(list) for device in devices[0].client._get_all_devices(): # pylint: disable=protected-access if device.device_kind == "cpu": From e8038501d0ee0bef99a8a772d2ad9b0f38b018bb Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 26 Mar 2025 16:30:23 -0700 Subject: [PATCH 207/483] Fix a bug where jit was forwarding inputs to outputs even when donation was True for that inputs. This caused the output to be marked as deleted since the input was being forwarded to the output. Since this functionality was added for a dynamic shapes experiment, only enable it when dynamic_shapes config is True. Co-authored-by: Matthew Johnson PiperOrigin-RevId: 740942785 --- jax/_src/checkify.py | 6 ++--- jax/_src/core.py | 3 +-- jax/_src/pjit.py | 25 ++++++++++++--------- tests/api_test.py | 4 ++++ tests/checkify_test.py | 15 ++++++++++++- tests/debug_info_test.py | 25 +++++++++------------ tests/memories_test.py | 1 + tests/pjit_test.py | 48 ++++++++++++++++++++++++++++------------ 8 files changed, 81 insertions(+), 46 deletions(-) diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 1ec8ad50b456..f80a0cbd1d75 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -913,14 +913,14 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr, # Update pjit params to account for extra error values. num_error_vals = len(err_vals) num_out_error_vals = out_tree.num_leaves - len(out_shardings) - sharding = sharding_impls.UNSPECIFIED new_in_shardings = (*[sharding] * num_error_vals, *in_shardings) - new_out_shardings = (*[sharding] * num_out_error_vals, *out_shardings) new_in_layouts = (*[None] * num_error_vals, *in_layouts) - new_out_layouts = (*[None] * num_out_error_vals, *out_layouts) new_donated_invars = (*[False] * num_error_vals, *donated_invars) + new_out_shardings = (*[sharding] * num_out_error_vals, *out_shardings) + new_out_layouts = (*[None] * num_out_error_vals, *out_layouts) + err_and_out = pjit.pjit_p.bind( *new_vals_in, jaxpr=checked_jaxpr, diff --git a/jax/_src/core.py b/jax/_src/core.py index 14f3d9cc18e2..3a1558802682 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1977,8 +1977,7 @@ def to_tangent_aval(self): def str_short(self, short_dtypes=False, mesh_axis_types=False): return str_short_aval( self.shape, self.dtype, self.sharding.mesh, self.sharding.spec, - getattr(self, 'varying_manual_axes', frozenset()), - short_dtypes, mesh_axis_types) + getattr(self, 'vma', frozenset()), short_dtypes, mesh_axis_types) def _len(self, ignored_tracer): try: diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 054b55e32918..03eb6835cb06 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1777,11 +1777,12 @@ def pjit_staging_rule(trace, *args, **params): return pe.inline_jaxpr_into_trace( trace, jaxpr.jaxpr, jaxpr.consts, *args) - jaxpr, in_fwd, out_shardings, out_layouts = _pjit_forwarding( - params['jaxpr'], params['out_shardings'], params['out_layouts']) - params = dict(params, jaxpr=jaxpr, out_shardings=out_shardings, - out_layouts=out_layouts) + jaxpr = params['jaxpr'] if config.dynamic_shapes.value: + jaxpr, in_fwd, out_shardings, out_layouts = _pjit_forwarding( + jaxpr, params['out_shardings'], params['out_layouts']) + params = dict(params, jaxpr=jaxpr, out_shardings=out_shardings, + out_layouts=out_layouts) source_info = source_info_util.current() out_tracers = [] for aval in _out_type(jaxpr): @@ -1795,6 +1796,10 @@ def pjit_staging_rule(trace, *args, **params): map(trace.getvar, args), map(trace.makevar, out_tracers), pjit_p, params, jaxpr.effects, source_info) trace.frame.add_eqn(eqn) + out_tracers_ = iter(out_tracers) + out_tracers = [args[f] if type(f) is int else next(out_tracers_) + for f in in_fwd] + assert next(out_tracers_, None) is None elif any(isinstance(c, core.MutableArray) for c in jaxpr.consts): jaxpr, consts = pxla._move_mutable_consts(jaxpr) consts = map(trace.new_const, consts) @@ -1807,19 +1812,14 @@ def pjit_staging_rule(trace, *args, **params): pjit_p, (*args, *consts), new_params) else: out_tracers = trace.default_process_primitive(pjit_p, args, params) - - out_tracers_ = iter(out_tracers) - out_tracers = [args[f] if type(f) is int else next(out_tracers_) - for f in in_fwd] - assert next(out_tracers_, None) is None return out_tracers pe.custom_staging_rules[pjit_p] = pjit_staging_rule def _pjit_forwarding(jaxpr, out_shardings, out_layouts): in_fwd: list[int | None] = pe._jaxpr_forwarding(jaxpr.jaxpr) - in_fwd = [fwd if isinstance(os, UnspecifiedValue) and ol is None else None for fwd, os, ol - in zip(in_fwd, out_shardings, out_layouts)] + in_fwd = [fwd if isinstance(os, UnspecifiedValue) and ol is None else None + for fwd, os, ol in zip(in_fwd, out_shardings, out_layouts)] keep = [f is None for f in in_fwd] jaxpr = pe.prune_closed_jaxpr_outputs(jaxpr, keep) out_shardings = [o for o, k in zip(out_shardings, keep) if k] @@ -1827,6 +1827,8 @@ def _pjit_forwarding(jaxpr, out_shardings, out_layouts): return jaxpr, in_fwd, out_shardings, out_layouts def pjit_forwarding_rule(eqn): + if not config.dynamic_shapes.value: + return [None] * len(eqn.outvars), eqn jaxpr, in_fwd, out_shardings, out_layouts = _pjit_forwarding( eqn.params['jaxpr'], eqn.params['out_shardings'], eqn.params['out_layouts']) new_outvars = [v for v, f in zip(eqn.outvars, in_fwd) if f is None] @@ -1835,6 +1837,7 @@ def pjit_forwarding_rule(eqn): new_eqn = eqn.replace(params=new_params, outvars=new_outvars) fwd_vars = [eqn.invars[f] if f is not None else None for f in in_fwd] return fwd_vars, new_eqn +# TODO(mattjj): Remove pjit_forwarding_rule and also in staging rule. pe.forwarding_rules[pjit_p] = pjit_forwarding_rule diff --git a/tests/api_test.py b/tests/api_test.py index 82b673fe4b1e..9d80b5fbed74 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -4687,6 +4687,8 @@ def f(inputs): @jtu.run_on_devices("cpu") def test_inner_jit_forwarding_happens(self): + if not config.dynamic_shapes.value: + self.skipTest("Only works for dynamic shapes") jaxpr = jax.make_jaxpr(lambda: jax.jit(lambda x: x)(3))() self.assertLen(jaxpr.jaxpr.outvars, 1) self.assertIsInstance(jaxpr.jaxpr.outvars[0], core.Literal) @@ -4695,6 +4697,8 @@ def test_inner_jit_forwarding_happens(self): @parameterized.parameters(range(8)) @jtu.run_on_devices("cpu") def test_inner_jit_forwarding_correctness(self, num_input_fwd): + if not config.dynamic_shapes.value: + self.skipTest("Only works for dynamic shapes") num_args = 8 rng = np.random.RandomState(0) diff --git a/tests/checkify_test.py b/tests/checkify_test.py index 6a1660b28578..5ea99d20a2ab 100644 --- a/tests/checkify_test.py +++ b/tests/checkify_test.py @@ -24,7 +24,7 @@ from jax.experimental import checkify from jax.experimental import pjit from jax.experimental import shard_map -from jax.sharding import NamedSharding +from jax.sharding import NamedSharding, PartitionSpec as P from jax._src import array from jax._src import config from jax._src import core @@ -475,6 +475,19 @@ def f(init_val): self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "division by zero") + def test_checify_donation_no_forwarding(self): + mesh = jtu.create_mesh((2,), ('x',)) + + @checkify.checkify + @partial(jax.jit, donate_argnums=(0,)) + def f(x: jax.Array) -> jax.Array: + checkify.check(jnp.all(x > 0), "a") + return x + + x = jax.device_put(jnp.zeros(64, dtype="int32"), NamedSharding(mesh, P())) + err, y = f(x) + err, z = f(y) # doesn't crash + @jtu.skip_on_devices("tpu") def test_while_loop_body_and_cond_error(self): def while_cond(val): diff --git a/tests/debug_info_test.py b/tests/debug_info_test.py index a39b53c3ad16..1d2935ea34d7 100644 --- a/tests/debug_info_test.py +++ b/tests/debug_info_test.py @@ -671,7 +671,7 @@ def my_g(b, d=1): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ # TODO(necula): result_paths? - "traced_for=jit, fun=my_f, arg_names=a, result_paths=", + "traced_for=jit, fun=my_f, arg_names=a, result_paths=result", "traced_for=jit, fun=my_g, arg_names=b, result_paths=result", ], expected_tracer_debug_infos=[ @@ -794,7 +794,7 @@ def my_g(u, v): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x,y, result_paths=result", - "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=result['c']" + "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=result['c'],result['d']", ], expected_tracer_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x,y, from x", @@ -1318,17 +1318,15 @@ def the_grad(c, as_): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=the_grad, arg_names=c,as_, result_paths=result[0],result[1]", - # TODO(necula): arg names, bad result paths "traced_for=jit, fun=my_f, arg_names=x,as_, result_paths=,,", - "traced_for=for_loop, fun=f, arg_names=i,refs[0],refs[1],refs[2], result_paths=", "traced_for=for_loop, fun=f, arg_names=,,, result_paths=,", + "traced_for=for_loop, fun=f, arg_names=i,refs[0],refs[1],refs[2], result_paths=", + "traced_for=jit, fun=my_f, arg_names=,,x,as_, result_paths=result[0],result[1]", + "traced_for=checkpoint / remat, fun=to_remat, arg_names=,,, result_paths=,", "traced_for=for_loop, fun=f, arg_names=,,,,,, result_paths=,", - "traced_for=for_loop, fun=f, arg_names=,,,,,,,,,,, result_paths=", + "traced_for=for_loop, fun=f, arg_names=i,refs[0],refs[1],refs[2], result_paths=", "traced_for=for_loop, fun=f, arg_names=,,,,,,,,,,,,,,, result_paths=,", - "traced_for=checkpoint / remat, fun=to_remat, arg_names=,,, result_paths=,", - "traced_for=jit, fun=my_f, arg_names=as_,,, result_paths=" - if config.use_direct_linearize.value else - "traced_for=jit, fun=my_f, arg_names=,,x,as_, result_paths=", + "traced_for=for_loop, fun=f, arg_names=,,,,,,,,,,, result_paths=", ], expected_tracer_debug_infos=[ "traced_for=jit, fun=the_grad, arg_names=c,as_, from c", @@ -1467,7 +1465,7 @@ def my_g(u, v): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x,y, result_paths=result", - "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=result['c']", + "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=result['c'],result['d']", ], expected_tracer_debug_infos=[ # TODO(necula): missing debug info @@ -1611,11 +1609,8 @@ def my_f(x): x, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x, result_paths=result", - # TODO(necula): arg_names and result_paths? - "traced_for=jit, fun=my_f, arg_names=x, result_paths=,,,", - "traced_for=jit, fun=my_f, arg_names=x,, result_paths=," - if config.use_direct_linearize.value else - "traced_for=jit, fun=my_f, arg_names=,x, result_paths=,", + "traced_for=jit, fun=my_f, arg_names=x, result_paths=,", + "traced_for=jit, fun=my_f, arg_names=,x, result_paths=result", ], tracer_spy=tracer_spy, expected_tracer_debug_infos=[ diff --git a/tests/memories_test.py b/tests/memories_test.py index 570b0c375834..64ee2829873d 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -1090,6 +1090,7 @@ def f_bwd(res, tx): self.assertArraysEqual(g(arr), all_true) def test_scan_offload(self): + self.skipTest('b/406586554') np_inp = jnp.arange(4096).reshape(16, 16, 16) @jax.jit diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 277f24bd703f..d72ecc98e771 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -1240,9 +1240,12 @@ def test_pretty_print_pjit_id(self): jaxpr.pretty_print(use_color=False), textwrap.dedent(""" { lambda ; a:f32[1]. let - pjit[name= jaxpr={ lambda ; a:f32[1] b:f32[1]. let in () }] a a - c:f32[1] = add a a - in (c,) } + b:f32[1] = pjit[ + name= + jaxpr={ lambda ; a:f32[1] c:f32[1]. let in (a,) } + ] a a + d:f32[1] = add a b + in (d,) } """).strip(), ) @@ -1289,8 +1292,11 @@ def test_pretty_print_with_literal_outvar(self): jaxpr.pretty_print(use_color=False), textwrap.dedent(""" { lambda ; a:f32[1]. let - b:i32[] = pjit[name= jaxpr={ lambda ; a:f32[1]. let in (2,) }] a - in (b, a) } + b:i32[] c:f32[1] = pjit[ + name= + jaxpr={ lambda ; a:f32[1]. let in (2, a) } + ] a + in (b, c) } """).strip(), ) @@ -1336,19 +1342,19 @@ def f(x): self.assertEqual( jaxpr.pretty_print(use_color=False), textwrap.dedent(""" - let f = { lambda ; a:f32[1]. let in () } in - let f1 = { lambda ; b:f32[2]. let in () } in + let f = { lambda ; a:f32[1]. let in (a,) } in + let f1 = { lambda ; b:f32[2]. let in (b,) } in { lambda ; c:f32[1] d:f32[2]. let e:f32[2] = pjit[ name=g jaxpr={ lambda ; c:f32[1] d:f32[2]. let - pjit[name=f jaxpr=f] c - pjit[name=f jaxpr=f] c - g:f32[1] = mul c c - pjit[name=f jaxpr=f1] d - pjit[name=f jaxpr=f1] d - h:f32[2] = mul d d - e:f32[2] = add g h + g:f32[1] = pjit[name=f jaxpr=f] c + h:f32[1] = pjit[name=f jaxpr=f] c + i:f32[1] = mul g h + j:f32[2] = pjit[name=f jaxpr=f1] d + k:f32[2] = pjit[name=f jaxpr=f1] d + l:f32[2] = mul j k + e:f32[2] = add i l in (e,) } ] c d in (e,) } @@ -2477,6 +2483,20 @@ def test_pjit_committed_array_different_devices_variadic_args(self): r"\[1\].*"): pjit(lambda *x: x)(a, b) + def test_jit_no_forwarding(self): + mesh = jtu.create_mesh((2,), ('x',)) + + @partial(jax.jit, donate_argnums=(0,)) + def f(x): + return x, x * 2 + + x = jax.device_put(jnp.zeros(64, dtype="int32"), NamedSharding(mesh, P())) + jaxpr = jax.make_jaxpr(f)(x) + y = core.jaxpr_as_fun(jaxpr)(x) + self.assertTrue(x.is_deleted()) + self.assertFalse(y[0].is_deleted()) + self.assertFalse(y[1].is_deleted()) + def test_pjit_pytree_inp_device_assignment_mismatch(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) a = jax.device_put(np.array([1, 2, 3]), jax.devices()[0]) From 6033592a9544f8c440df871aa6502a5ffeae6641 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Wed, 26 Mar 2025 16:35:30 -0700 Subject: [PATCH 208/483] Rename xla_extension_version to jaxlib_extension_version to reflect its new scope. PiperOrigin-RevId: 740944270 --- docs/jep/9419-jax-versioning.md | 6 +++--- jax/_src/callback.py | 4 ++-- jax/_src/lib/__init__.py | 2 +- jaxlib/xla/xla_client.py | 2 +- tests/python_callback_test.py | 10 +++++----- 5 files changed, 12 insertions(+), 12 deletions(-) diff --git a/docs/jep/9419-jax-versioning.md b/docs/jep/9419-jax-versioning.md index b964aa2af45d..85b95257ebae 100644 --- a/docs/jep/9419-jax-versioning.md +++ b/docs/jep/9419-jax-versioning.md @@ -167,16 +167,16 @@ We maintain an additional version number (`_version`) in [`xla_client.py` in the XLA repository](https://github.com/openxla/xla/blob/main/xla/python/xla_client.py). The idea is that this version number, is defined in `xla/python` together with the C++ parts of JAX, is also accessible to JAX Python as -`jax._src.lib.xla_extension_version`, and must +`jax._src.lib.jaxlib_extension_version`, and must be incremented every time that a change is made to the XLA/Python code that has backwards compatibility implications for `jax`. The JAX Python code can then use this version number to maintain backwards compatibility, e.g.: ``` -from jax._src.lib import xla_extension_version +from jax._src.lib import jaxlib_extension_version # 123 is the new version number for _version in xla_client.py -if xla_extension_version >= 123: +if jaxlib_extension_version >= 123: # Use new code path ... else: diff --git a/jax/_src/callback.py b/jax/_src/callback.py index dc60bfb94356..683da66638e6 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -33,7 +33,7 @@ from jax._src import tree_util from jax._src import util from jax._src import xla_bridge as xb -from jax._src.lib import xla_extension_version +from jax._src.lib import jaxlib_extension_version from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -827,7 +827,7 @@ def _wrapped_callback(*args): return outputs, token, None # TODO(dsuo): Remove this once we bump minimum_jaxlib_version to "0.5.4". - if xla_extension_version <= 320: + if jaxlib_extension_version <= 320: result_types = mlir.flatten_ir_types([mlir.aval_to_ir_type(aval) for aval in result_avals]) if token: diff --git a/jax/_src/lib/__init__.py b/jax/_src/lib/__init__.py index be551449aa17..fef5d2c26038 100644 --- a/jax/_src/lib/__init__.py +++ b/jax/_src/lib/__init__.py @@ -117,7 +117,7 @@ def _xla_gc_callback(*args): # Only for the internal usage of the JAX developers, we expose a version # number that can be used to perform changes without breaking the main # branch on the Jax github. -xla_extension_version: int = getattr(xla_client, '_version', 0) +jaxlib_extension_version: int = getattr(xla_client, '_version', 0) import jaxlib.gpu_rnn as gpu_rnn # pytype: disable=import-error # noqa: F401 import jaxlib.gpu_triton as gpu_triton # pytype: disable=import-error # noqa: F401 diff --git a/jaxlib/xla/xla_client.py b/jaxlib/xla/xla_client.py index eb7a3b5759e5..80cdeef47387 100644 --- a/jaxlib/xla/xla_client.py +++ b/jaxlib/xla/xla_client.py @@ -49,7 +49,7 @@ profiler = _xla.profiler # Just an internal arbitrary increasing number to help with backward-compatible -# changes. In JAX, reference this via jax._src.lib.xla_extension_version. +# changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version. _version = 322 # Version number for MLIR:Python components. diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index 5650a2d4f48b..a8442b4a1356 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -28,7 +28,7 @@ from jax._src import dispatch from jax._src import test_util as jtu from jax._src import util -from jax._src.lib import xla_extension_version +from jax._src.lib import jaxlib_extension_version from jax.experimental import io_callback from jax.experimental import pjit from jax.experimental.shard_map import shard_map @@ -588,8 +588,8 @@ def fun(x): @parameterized.parameters("int2", "int4", "uint2", "uint4") def test_subbyte_operands(self, dtype: str): - if xla_extension_version <= 321: - self.skipTest("Requires xla_extension_version >= 322.") + if jaxlib_extension_version <= 321: + self.skipTest("Requires jaxlib_extension_version >= 322.") def get(x): return x def f(x): @@ -613,8 +613,8 @@ def f(x): @parameterized.parameters("int2", "int4", "uint2", "uint4") def test_subbyte_results(self, dtype: str): - if xla_extension_version <= 321: - self.skipTest("Requires xla_extension_version >= 322.") + if jaxlib_extension_version <= 321: + self.skipTest("Requires jaxlib_extension_version >= 322.") def get(): return np.arange(8, dtype=dtype) From f949b8b8f62c986849fb2a59d8cac61467dc6eff Mon Sep 17 00:00:00 2001 From: kaixih Date: Wed, 26 Mar 2025 20:57:30 +0000 Subject: [PATCH 209/483] Enable public doc for scaled dot --- docs/jax.nn.rst | 3 + jax/_src/nn/functions.py | 222 ++++++++++++++++++++++++++++----------- jax/nn/__init__.py | 1 + tests/nn_test.py | 23 +--- 4 files changed, 168 insertions(+), 81 deletions(-) diff --git a/docs/jax.nn.rst b/docs/jax.nn.rst index 2e2e9644d50d..339f07f4cdcc 100644 --- a/docs/jax.nn.rst +++ b/docs/jax.nn.rst @@ -54,3 +54,6 @@ Other functions standardize one_hot dot_product_attention + scaled_matmul + get_scaled_dot_general_config + scaled_dot_general diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index ee0643e116f9..cc4a345641dd 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -1210,81 +1210,184 @@ def _check_shape_and_dtype(t: Array | None, shape: Sequence[int], return jnp.reshape(out, output_shape) def scaled_matmul( - lhs: Array, - rhs: Array, - lhs_scales: Array, - rhs_scales: Array, + a: Array, + b: Array, + a_scales: Array, + b_scales: Array, preferred_element_type: DTypeLike = jnp.float32, ) -> Array: - r""" - Performs scaled matrix multiplication between two 3D arrays, with scaling - factors applied to the matrices. - .. math:: - \mathrm{ScaledMatmul}(lhs, rhs, lhs_scales, rhs_scales)=lhs_scales \cdot rhs_scales \cdot \mathrm{dot}(lhs, rhs) + r"""Scaled matrix multiplication function. + + Performs block-scaled matmul of `a` and `b` using `a_scales` and `b_scales`. + The last dim is the contracting dim, and block size is inferred. + + Mathematically, this operation is equivalent to:: + + a_block_size = a.shape[-1] // a_scales.shape[-1] + b_block_size = b.shape[-1] // b_scales.shape[-1] + a_scaled = a * jnp.repeat(a_scales, a_block_size, axis=-1) + b_scaled = b * jnp.repeat(b_scales, b_block_size, axis=-1) + jnp.einsum('BMK,BNK->BMN', a_scaled, b_scaled) + Args: - lhs (Array): A 3D array of shape (B, M, K). - rhs (Array): A 3D array of shape (B, N, K). - lhs_scales (Array): A 3D array of shape (B, M, K_block). - rhs_scales (Array): A 3D array of shape (B, N, K_block). - preferred_element_type (DTypeLike, optional): The preferred data type - for the computation. Defaults to `jnp.float32`. + a (Array): Shape (B, M, K). + b (Array): Shape (B, N, K). + a_scales (Array): Shape (B, M, K_a), where `K % K_a == 0`. + b_scales (Array): Shape (B, N, K_b), where `K % K_b == 0`. + preferred_element_type (DTypeLike, optional): Defaults to `jnp.float32`. + Returns: - Array: A 3D array of shape (B, M, N) representing the scaled matrix - multiplication result. - Raises: - AssertionError: If the number of columns in `lhs` (`lhs_K`) does not - match the number of columns in `rhs` (`rhs_K`). + Array of shape (B, M, N). + Notes: - - The function ensures that the `preferred_element_type` is - danonicalized before passing it to the underlying computation. - - Scaling is applied to the matrices based on the `lhs_scales` and - `rhs_scales` arrays, enabling efficient computations in blocks. + - We currently do not support user-defined `precision` for customizing the + compute data type. It is fixed to `jnp.float32`. + - Block size is inferred as `K // K_a` for `a` and `K // K_b` for `b`. + - To use cuDNN with Nvidia Blackwell GPUs, inputs must match:: + + # mxfp8 + a, b: jnp.float8_e4m3fn | jnp.float8_e5m2 + a_scales, b_scales: jnp.float8_e8m0fnu + block_size: 32 + # nvfp4 + a, b: jnp.float4_e2m1fn + a_scales, b_scales: jnp.float8_e4m3fn + block_size: 16 + + Examples: + + Basic case: + + >>> a = jnp.array([1, 2, 3]).reshape((1, 1, 3)) + >>> b = jnp.array([4, 5, 6]).reshape((1, 1, 3)) + >>> a_scales = jnp.array([0.5]).reshape((1, 1, 1)) + >>> b_scales = jnp.array([0.5]).reshape((1, 1, 1)) + >>> scaled_matmul(a, b, a_scales, b_scales) + Array([[[8.]]], dtype=float32) + + Using fused cuDNN call on Blackwell GPUs: + + >>> a = random.normal(keys[0], (3, 128, 64), dtype=jnp.float8_e4m3fn) + >>> b = random.normal(keys[1], (3, 128, 64), dtype=jnp.float8_e4m3fn) + >>> a_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu) + >>> b_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu) + >>> scaled_matmul(a, b, a_scales, b_scales) """ - B, M, lhs_K = lhs.shape - _, N, rhs_K = rhs.shape - assert lhs_K == rhs_K - _, _, K_block = lhs_scales.shape + assert all(x.ndim == 3 for x in (a, b, a_scales, b_scales)) + B_a, M_a, K_a = a.shape + B_b, N_b, K_b = b.shape + assert K_a == K_b and B_a == B_b + B_as, M_as, K_as = a_scales.shape + B_bs, N_bs, K_bs = b_scales.shape + assert K_as == K_bs and B_as == B_bs + assert M_as == M_a and N_bs == N_b preferred_element_type = dtypes.canonicalize_dtype( np.dtype(preferred_element_type) ) out = cudnn_scaled_matmul( - lhs, - rhs, - lhs_scales, - rhs_scales, + a, + b, + a_scales, + b_scales, preferred_element_type=preferred_element_type, ) return out +def get_scaled_dot_general_config(mode: Literal['nvfp4', 'mxfp8'], + global_scale: Array | None = None): + r"""Get quantization configs for scaled_dot_general. + + Create quantization configs for the `jax.nn.scaled_dot_general`. + + See Also: + - :func:`jax.nn.scaled_dot_general`: Scaled dot general function. + """ + + if mode == 'nvfp4': + one = jnp.ones((1,), dtype=jnp.float32) + return BlockScaleConfig( + mode='nvfp4', + block_size=16, + data_type=jnp.float4_e2m1fn, + scale_type=jnp.float8_e4m3fn, + global_scale=one if global_scale is None else global_scale, + infer_only=False + ) + elif mode == 'mxfp8': + return BlockScaleConfig( + mode='mxfp8', + block_size=32, + data_type=jnp.float8_e4m3fn, + scale_type=jnp.float8_e8m0fnu, + global_scale=None, + infer_only=False + ) + else: + raise ValueError(f"Unsupported mode: {mode}") + def scaled_dot_general( lhs, rhs, dimension_numbers, preferred_element_type=jnp.float32, configs: List[BlockScaleConfig] | None = None, - implementation: Literal['cudnn'] | None = None, ): r"""Scaled dot general operation. - Computes the scaled dot general on lhs, rhs with quanitzation specified by configs: - .. math:: - \widehat{lhs}, s_a=\mathrm{quantize}(lhs) \\ - \widehat{rhs}, s_b=\mathrm{quantize}(rhs) \\ - \mathrm{ScaledDot}(lhs, rhs)=s_a \cdot s_b \cdot \mathrm{dot}(\widehat{lhs}, \widehat{rhs}) + + Performs a generalized dot product with block-scaled quantization on the + lhs and rhs inputs. This operation extends `lax.dot_general` to support + user-defined scaling configurations. + + Essentially, the operation follows:: + + a, a_scales = quantize(lhs, configs[0]) + b, b_scales = quantize(rhs, configs[1]) + c = jax.nn.scaled_matmul(a, b, a_scales, b_scales) + Args: - lhs: Left-hand side input tensor. - rhs: Right-hand side input tensor. - dimension_numbers: A tuple specifying the contraction and batch dimensions - for the dot general operation. Must follow the format: - `((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))`. - preferred_element_type: The preferred output data type. Supported types are - `jnp.float32`, `jnp.bfloat16`, and `jnp.float16`. Defaults to `jnp.float32`. - configs: A list of `BlockScaleConfig` specifying the scaling - configurations for the operation. Defaults to `mxfp8`. - implementation: A string to control which implementation backend to use. - Supported strings are `cudnn` (cuDNN block scaled dot). It defaults - to `None`, which will automatically select the best available backend. + lhs (ArrayLike): Input array. + rhs (ArrayLike): Input array. + dimension_numbers (DotDimensionNumbers): A tuple of two tuples specifying + the contraction and batch dimensions: + `((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))`. + preferred_element_type (DTypeLike, optional): Output data type of the dot + product. Defaults to `jnp.float32`. Other valid types include + `jnp.bfloat16` and `jnp.float16`. + configs (list of BlockScaleConfig, optional): Scaling configurations for + lhs, rhs, and gradients. Users can obtain valid configurations via + `jax.nn.get_scaled_dot_general_config`. Currently, `nvfp4` and `mxfp8` + are supported. If `None`, falls back to `lax.dot_general`. + Returns: - The result of the scaled dot general operation. + Array: The resulting tensor, with batch dimensions first, followed by + non-contracting/non-batch dimensions of lhs, and then those of rhs. + + See Also: + - :func:`jax.nn.scaled_matmul`: Scaled matmul function. + - :func:`jax.lax.dot_general`: General dot product operator. + + Notes: + - Unlike `nn.scaled_matmul`, which assumes quantized low-precision + inputs with explicit scaling factors, this operator takes high-precision + inputs, applies quantization internally, and handles the backward pass. + + Examples: + + Creating config for mxfp8: + + >>> configs = [jax.nn.get_scaled_dot_general_config('mxfp8')] * 3 + + Creating config for nvfp4: + + >>> global_scale = jnp.array([0.5], jnp.float32) + >>> configs = [jax.nn.get_scaled_dot_general_config('nvfp4', global_scale)] * 3 + + Using scaled_dot_general with the configs: + + >>> scaled_dot_general_fn = functools.partial(jax.nn.scaled_dot_general, configs=configs) + >>> lhs = random.normal(keys[0], (3, 128, 64)) + >>> rhs = random.normal(keys[1], (3, 128, 64)) + >>> out = scaled_dot_general_fn(lhs, rhs, (((2,), (2,)), ((0,), (0,)))) """ # Create configs if not provided if configs is None: @@ -1300,17 +1403,10 @@ def scaled_dot_general( ) configs = [mxfp8_config for _ in range(3)] - if implementation is None: - implementation = 'cudnn' - - match implementation: - case 'cudnn': - out = cudnn_scaled_dot_general( - lhs, rhs, dimension_numbers, - preferred_element_type=preferred_element_type, - configs=configs - ) - case _: - raise ValueError(f"Unsupported implementation option: {implementation}") + out = cudnn_scaled_dot_general( + lhs, rhs, dimension_numbers, + preferred_element_type=preferred_element_type, + configs=configs + ) return out diff --git a/jax/nn/__init__.py b/jax/nn/__init__.py index 10f11f829abe..651d9cf4e47f 100644 --- a/jax/nn/__init__.py +++ b/jax/nn/__init__.py @@ -38,6 +38,7 @@ identity as identity, relu6 as relu6, dot_product_attention as dot_product_attention, + get_scaled_dot_general_config as get_scaled_dot_general_config, scaled_dot_general as scaled_dot_general, scaled_matmul as scaled_matmul, selu as selu, diff --git a/tests/nn_test.py b/tests/nn_test.py index e46843186c02..385b216aeb57 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -31,7 +31,6 @@ from jax._src.cudnn.scaled_matmul_stablehlo import ( quantize, shape_normalization, - BlockScaleConfig, ) from jax.test_util import check_grads from jax import nn @@ -110,17 +109,7 @@ def create_mxfp8_configs_if_available(): if _dtypes.float8_e8m0fnu is None: raise unittest.SkipTest("float8_e8m0fnu is not available.") - def _create_mxfp8_config(): - return BlockScaleConfig( - mode='mxfp8', - block_size=32, - data_type=jnp.float8_e4m3fn, - scale_type=jnp.float8_e8m0fnu, - global_scale=None, - infer_only=False - ) - - return [_create_mxfp8_config() for _ in range(3)] + return [nn.get_scaled_dot_general_config("mxfp8") for _ in range(3)] @jtu.with_config(jax_legacy_prng_key="allow", @@ -130,10 +119,9 @@ class NNFunctionsTest(jtu.JaxTestCase): contract=[160, 96], lhs_non_contract=[240, 100], dtype=[jnp.float16, jnp.bfloat16, jnp.float32], - impl=['cudnn',], ) - def testScaledMatmul(self, contract, lhs_non_contract, dtype, impl): - if impl == 'cudnn' and not _is_required_cudnn_version_satisfied("10.0", 90700): + def testScaledMatmul(self, contract, lhs_non_contract, dtype): + if not _is_required_cudnn_version_satisfied("10.0", 90700): raise unittest.SkipTest("CUDA or cuDNN versions are not compatible") # Check if float8_e8m0fnu is available configs = create_mxfp8_configs_if_available() @@ -153,11 +141,10 @@ def testScaledMatmul(self, contract, lhs_non_contract, dtype, impl): @parameterized.product( is_training=[True, False], output_type=[jnp.float16, jnp.bfloat16, jnp.float32], - impl=['cudnn',], ) def testScaledDotGeneral( - self, is_training, output_type, impl): - if impl == 'cudnn' and not _is_required_cudnn_version_satisfied("10.0", 90700): + self, is_training, output_type): + if not _is_required_cudnn_version_satisfied("10.0", 90700): raise unittest.SkipTest("CUDA or cuDNN versions are not compatible") configs = create_mxfp8_configs_if_available() From be1f649b510048e30b8c07bd7e1964987c6e2907 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Wed, 26 Mar 2025 17:30:22 -0700 Subject: [PATCH 210/483] Expose jax._src.lib.ifrt_version which tracks the version of third_party/tensorflow code inside jax. PiperOrigin-RevId: 740957982 --- jax/_src/lib/__init__.py | 1 + jaxlib/xla/BUILD | 1 + jaxlib/xla/xla.cc | 3 +++ jaxlib/xla/xla_client.py | 6 ++++++ jaxlib/xla/xla_client.pyi | 2 ++ 5 files changed, 13 insertions(+) diff --git a/jax/_src/lib/__init__.py b/jax/_src/lib/__init__.py index fef5d2c26038..bb542aa5d61d 100644 --- a/jax/_src/lib/__init__.py +++ b/jax/_src/lib/__init__.py @@ -118,6 +118,7 @@ def _xla_gc_callback(*args): # number that can be used to perform changes without breaking the main # branch on the Jax github. jaxlib_extension_version: int = getattr(xla_client, '_version', 0) +ifrt_version: int = getattr(xla_client, '_ifrt_version', 0) import jaxlib.gpu_rnn as gpu_rnn # pytype: disable=import-error # noqa: F401 import jaxlib.gpu_triton as gpu_triton # pytype: disable=import-error # noqa: F401 diff --git a/jaxlib/xla/BUILD b/jaxlib/xla/BUILD index 512eeb867618..347da6998b57 100644 --- a/jaxlib/xla/BUILD +++ b/jaxlib/xla/BUILD @@ -112,6 +112,7 @@ nanobind_extension( "@xla//xla/python:profiler", "@xla//xla/python:refine_polymorphic_shapes", "@xla//xla/python:types", + "@xla//xla/python:version", "@xla//xla/python/ifrt", "@xla//xla/python/ifrt:plugin_program", "@xla//xla/python/ifrt:plugin_program_serdes", diff --git a/jaxlib/xla/xla.cc b/jaxlib/xla/xla.cc index 6e47be15fc68..668c96869479 100644 --- a/jaxlib/xla/xla.cc +++ b/jaxlib/xla/xla.cc @@ -64,6 +64,7 @@ limitations under the License. #include "xla/python/ifrt/device_list.h" #include "xla/python/ifrt/executable.h" #include "xla/python/ifrt/topology.h" +#include "xla/python/version.h" #include "xla/python/pjrt_ifrt/pjrt_attribute_map_util.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/tsl/python/lib/core/numpy.h" // NOLINT @@ -960,6 +961,8 @@ NB_MODULE(xla_extension, m) { m.def("check_and_canonicalize_memory_kind", &jax::CheckAndCanonicalizeMemoryKind, nb::arg("memory_kind").none(), nb::arg("device_list")); + + m.attr("ifrt_version_number") = JAX_IFRT_VERSION_NUMBER; } // NOLINT(readability/fn_size) } // namespace xla diff --git a/jaxlib/xla/xla_client.py b/jaxlib/xla/xla_client.py index 80cdeef47387..30e8443276c8 100644 --- a/jaxlib/xla/xla_client.py +++ b/jaxlib/xla/xla_client.py @@ -52,6 +52,12 @@ # changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version. _version = 322 +# An internal increasing version number for protecting jaxlib code against +# ifrt changes. +# lives in xla/python/version.h. +# In JAX, reference this via jax._src.lib.ifrt_version. +_ifrt_version = _xla.ifrt_version_number + # Version number for MLIR:Python components. mlir_api_version = 58 diff --git a/jaxlib/xla/xla_client.pyi b/jaxlib/xla/xla_client.pyi index 5ac837ef1d85..b182eb65ba60 100644 --- a/jaxlib/xla/xla_client.pyi +++ b/jaxlib/xla/xla_client.pyi @@ -58,6 +58,8 @@ from jaxlib.xla_extension import XlaOp as XlaOp _version: int +_ifrt_version: int + mlir_api_version: int bfloat16: type[numpy.generic] From 8f25337a9fb7ef7a86452a2ca3a2ccfc6d1aee20 Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Wed, 26 Mar 2025 18:32:39 -0700 Subject: [PATCH 211/483] [ragged-paged-attn] Combine k_pages and v_pages into kv_pages and zip on num_kv_heads. Now we should be able to support sharding num_kv_heads to 1 even dtype is bfloat16 while still having good ragged KV scatter because ragged dim still remains in non-tiling dim. PiperOrigin-RevId: 740971413 --- .../pallas/ops/tpu/ragged_paged_attention.py | 174 +++++++++--------- .../pallas/tpu_ragged_paged_attention_test.py | 24 +-- 2 files changed, 99 insertions(+), 99 deletions(-) diff --git a/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py b/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py index 60ac2e34f610..255670c22e90 100644 --- a/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py +++ b/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py @@ -34,8 +34,8 @@ class MultiPageAsyncCopyDescriptor: def __init__( self, - pages_hbm_ref, # [total_num_pages, page_size, num_kv_heads_per_blk, head_dim] - vmem_buf, # [num_kv_pages_per_blk, page_size, num_kv_heads_per_blk, head_dim] + pages_hbm_ref, # [total_num_pages, page_size, num_combined_kv_heads_per_blk, head_dim] + vmem_buf, # [num_kv_pages_per_blk, page_size, num_combined_kv_heads_per_blk, head_dim] sem, page_indices_ref, # i32[max_num_seqs, pages_per_seq] offset, # [seq_idx, kv_pages_start] @@ -72,8 +72,7 @@ def wait(self): def ref_ragged_paged_attention( queries: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] - k_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] - v_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] + kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] kv_lens: jax.Array, # i32[max_num_seqs] page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] cu_q_lens: jax.Array, # i32[max_num_seqs + 1] @@ -82,9 +81,16 @@ def ref_ragged_paged_attention( sm_scale: float = 1.0, sliding_window: int | None = None, soft_cap: float | None = None, - mask_value: float = DEFAULT_MASK_VALUE, + mask_value: float | None = DEFAULT_MASK_VALUE, ): - _, _, num_kv_heads, head_dim = k_pages.shape + check_inputs_shapes( + queries, kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs + ) + if mask_value is None: + mask_value = DEFAULT_MASK_VALUE + _, _, num_combined_kv_heads, head_dim = kv_pages.shape + assert num_combined_kv_heads % 2 == 0 + num_kv_heads = num_combined_kv_heads // 2 num_q_heads = queries.shape[1] assert num_q_heads % num_kv_heads == 0 num_query_per_kv = num_q_heads // num_kv_heads @@ -96,8 +102,12 @@ def ref_ragged_paged_attention( kv_len = kv_lens[i] indices = page_indices[i] q = queries[q_start:q_end] - k = k_pages[indices, :, :, :].reshape(-1, num_kv_heads, head_dim)[:kv_len] - v = v_pages[indices, :, :, :].reshape(-1, num_kv_heads, head_dim)[:kv_len] + k = kv_pages[indices, :, 0::2, :].reshape(-1, num_kv_heads, head_dim)[ + :kv_len + ] + v = kv_pages[indices, :, 1::2, :].reshape(-1, num_kv_heads, head_dim)[ + :kv_len + ] k = jnp.repeat(k, num_query_per_kv, axis=1) v = jnp.repeat(v, num_query_per_kv, axis=1) attn = jnp.einsum("qhd,khd->hqk", q, k, preferred_element_type=jnp.float32) @@ -122,8 +132,7 @@ def ref_ragged_paged_attention( # Expect to run these checkes during runtime. def validate_inputs_on_runtime( q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] - k_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] - v_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] + kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] kv_lens: jax.Array, # i32[max_num_seqs] page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] cu_q_lens: jax.Array, # i32[max_num_seqs + 1] @@ -131,16 +140,14 @@ def validate_inputs_on_runtime( sliding_window: int | None = None, soft_cap: float | None = None, ): - check_inputs_shapes( - q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens, num_seqs - ) + check_inputs_shapes(q, kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs) max_num_batched_tokens = q.shape[0] - page_size = k_pages.shape[1] + page_size = kv_pages.shape[1] max_num_seqs, pages_per_seq = page_indices.shape if num_seqs[0] > max_num_seqs: raise ValueError(f"{num_seqs[0]=} must be less or equal to {max_num_seqs=}") max_kv_len = jnp.max(kv_lens) - min_pages_per_seq = ceil_div(max_kv_len, page_size) + min_pages_per_seq = cdiv(max_kv_len, page_size) if pages_per_seq < min_pages_per_seq: raise ValueError( f"{pages_per_seq=} must be greater or equal to" @@ -167,22 +174,19 @@ def validate_inputs_on_runtime( # Expect to run these checks during compile time. def check_inputs_shapes( q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] - k_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] - v_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] + kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] kv_lens: jax.Array, # i32[max_num_seqs] page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] cu_q_lens: jax.Array, # i32[max_num_seqs + 1] num_seqs, # i32[1] ): _, num_q_heads, head_dim = q.shape - _, _, num_kv_heads, head_dim_k = k_pages.shape + _, _, num_combined_kv_heads, head_dim_k = kv_pages.shape + assert num_combined_kv_heads % 2 == 0 + num_kv_heads = num_combined_kv_heads // 2 max_num_seqs, _ = page_indices.shape if num_seqs.shape != (1,): raise ValueError(f"{num_seqs.shape=} must be (1,)") - if k_pages.shape != v_pages.shape: - raise ValueError( - f"{k_pages.shape=} and {v_pages.shape=} must have the same shape." - ) if head_dim_k != head_dim: raise ValueError( f"Q head_dim {head_dim} must be the same as that of K/V {head_dim_k}." @@ -221,13 +225,11 @@ def ragged_paged_attention_kernel( num_seqs_ref, # Input q_ref, # [num_q_per_blk, num_q_heads_per_blk, head_dim] - k_pages_hbm_ref, # [total_num_pages, page_size, num_kv_heads, head_dim] - v_pages_hbm_ref, # [total_num_pages, page_size, num_kv_heads, head_dim] + kv_pages_hbm_ref, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] # Output o_ref, # [num_q_per_blk, num_q_heads_per_blk, head_dim] # Scratch - k_bufs, # [2, num_kv_pages_per_blk, page_size, num_kv_heads_per_blk, head_dim] - v_bufs, # [2, num_kv_pages_per_blk, page_size, num_kv_heads_per_blk, head_dim] + kv_bufs, # [2, num_kv_pages_per_blk, page_size, num_combined_kv_heads_per_blk, head_dim] sems, # [2, 2] l_ref, # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128] m_ref, # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128] @@ -235,11 +237,16 @@ def ragged_paged_attention_kernel( sm_scale: float, sliding_window: int | None = None, soft_cap: float | None = None, - mask_value: float = DEFAULT_MASK_VALUE, + mask_value: float | None = DEFAULT_MASK_VALUE, ): + if mask_value is None: + mask_value = DEFAULT_MASK_VALUE num_q_per_blk, num_q_heads_per_blk, head_dim = q_ref.shape num_seqs = num_seqs_ref[0] - _, num_kv_pages_per_blk, page_size, num_kv_heads_per_blk, _ = k_bufs.shape + _, num_kv_pages_per_blk, page_size, num_combined_kv_heads_per_blk, _ = ( + kv_bufs.shape + ) + num_kv_heads_per_blk = num_combined_kv_heads_per_blk // 2 num_kv_per_blk = num_kv_pages_per_blk * page_size num_q_heads_per_kv_head = num_q_heads_per_blk // num_kv_heads_per_blk heads_blk_idx, q_blk_idx = ( @@ -256,22 +263,17 @@ def create_kv_async_copy_descriptors( heads_blk_idx, seq_idx, kv_blk_idx, buf_idx ): offset = (seq_idx, kv_blk_idx * num_kv_pages_per_blk) - heads_start = heads_blk_idx * num_kv_heads_per_blk - async_copy_k = MultiPageAsyncCopyDescriptor( - k_pages_hbm_ref.at[:, :, pl.ds(heads_start, num_kv_heads_per_blk), :], - k_bufs.at[buf_idx], - sems.at[buf_idx, 0], - page_indices_ref, - offset, - ) - async_copy_v = MultiPageAsyncCopyDescriptor( - v_pages_hbm_ref.at[:, :, pl.ds(heads_start, num_kv_heads_per_blk), :], - v_bufs.at[buf_idx], - sems.at[buf_idx, 1], + heads_start = heads_blk_idx * num_combined_kv_heads_per_blk + async_copy_kv = MultiPageAsyncCopyDescriptor( + kv_pages_hbm_ref.at[ + :, :, pl.ds(heads_start, num_combined_kv_heads_per_blk), : + ], + kv_bufs.at[buf_idx], + sems.at[buf_idx], page_indices_ref, offset, ) - return async_copy_k, async_copy_v + return async_copy_kv # TODO(jevinjiang): Add these to Mosaic: # 1. Support arbitrary strided load/store for any dtype. @@ -303,11 +305,10 @@ def fold_on_2nd_minor(vec): @pl.when(heads_blk_idx + q_blk_idx == 0) def prefetch_first_kv_blk(): - async_copy_k, async_copy_v = create_kv_async_copy_descriptors( + async_copy_kv = create_kv_async_copy_descriptors( heads_blk_idx, init_seq_idx, 0, init_buf_idx ) - async_copy_k.start() - async_copy_v.start() + async_copy_kv.start() def is_cur_q_blk_needed(q_states): done, cur_seq_idx, _ = q_states @@ -512,21 +513,18 @@ def prefetch_next_kv_blk(): # TODO(jevinjiang): reuse the same buffer if it is already prefetched! # TODO(jevinjiang): only fetch effective dynamic size to hold kv_len and # DMA to fixed size buffer! - next_async_copy_k, next_async_copy_v = create_kv_async_copy_descriptors( + next_async_copy_kv = create_kv_async_copy_descriptors( next_heads_blk_idx, next_seq_idx, next_kv_blk_idx, next_buf_idx ) - next_async_copy_k.start() - next_async_copy_v.start() + next_async_copy_kv.start() - cur_async_copy_k, cur_async_copy_v = create_kv_async_copy_descriptors( + cur_async_copy_kv = create_kv_async_copy_descriptors( heads_blk_idx, cur_seq_idx, kv_blk_idx, cur_buf_idx ) - kv_to_load_shape = ( - num_kv_pages_per_blk * page_size * num_kv_heads_per_blk, + kv_ref = cur_async_copy_kv.wait().reshape( + num_kv_pages_per_blk * page_size * num_combined_kv_heads_per_blk, head_dim, ) - k_ref = cur_async_copy_k.wait().reshape(kv_to_load_shape) - v_ref = cur_async_copy_v.wait().reshape(kv_to_load_shape) for kv_head_idx in range(num_kv_heads_per_blk): q_head_idx = kv_head_idx * num_q_heads_per_kv_head # TODO(jevinjiang): extra handlig for packed type that can start at @@ -534,8 +532,12 @@ def prefetch_next_kv_blk(): q = fold_on_2nd_minor( q_ref[:, q_head_idx : q_head_idx + num_q_heads_per_kv_head, :] ) - k = strided_load_kv(k_ref, kv_head_idx, num_kv_heads_per_blk) - v = strided_load_kv(v_ref, kv_head_idx, num_kv_heads_per_blk) + k = strided_load_kv( + kv_ref, kv_head_idx * 2, num_combined_kv_heads_per_blk + ) + v = strided_load_kv( + kv_ref, kv_head_idx * 2 + 1, num_combined_kv_heads_per_blk + ) flash_attention( q, k, @@ -566,7 +568,7 @@ def prefetch_next_kv_blk(): seq_buf_idx_ref[1] = buf_idx -def ceil_div(a, b): +def cdiv(a, b): assert b != 0 return (a + b - 1) // b @@ -583,7 +585,9 @@ def get_dtype_packing(dtype): raise ValueError(f"Not implemented: unsupported {dtype=}") -def get_min_heads_per_blk(num_q_heads, num_kv_heads, q_dtype, kv_dtype): +def get_min_heads_per_blk( + num_q_heads, num_combined_kv_heads, q_dtype, kv_dtype +): q_packing = get_dtype_packing(q_dtype) kv_packing = get_dtype_packing(kv_dtype) @@ -594,22 +598,26 @@ def can_be_xla_fully_tiled(x, packing): return x in (1, 2, 4, 8) or x % 8 == 0 # TODO(jevinjiang): support unaligned number of heads! - if not can_be_xla_fully_tiled(num_kv_heads, kv_packing): + if not can_be_xla_fully_tiled(num_combined_kv_heads, kv_packing): raise ValueError( - f"Not implemented: {num_kv_heads=} can not be XLA fully tiled." + f"Not implemented: {num_combined_kv_heads=} can not be XLA fully tiled." ) + assert num_combined_kv_heads % 2 == 0 + num_kv_heads = num_combined_kv_heads // 2 assert num_q_heads % num_kv_heads == 0 ratio = num_q_heads // num_kv_heads # TODO(jevinjiang): we can choose smaller tiling for packed type if large # second minor tiling is not on. - max_kv_tiling = 8 * kv_packing - min_kv_heads = ( - max_kv_tiling if num_kv_heads % max_kv_tiling == 0 else num_kv_heads + max_combined_kv_tiling = 8 * kv_packing + min_combined_kv_heads = ( + max_combined_kv_tiling + if num_combined_kv_heads % max_combined_kv_tiling == 0 + else num_combined_kv_heads ) - min_q_heads = min_kv_heads * ratio + min_q_heads = min_combined_kv_heads // 2 * ratio if can_be_xla_fully_tiled(min_q_heads, q_packing): - return min_q_heads, min_kv_heads - return num_q_heads, num_kv_heads + return min_q_heads, min_combined_kv_heads + return num_q_heads, num_combined_kv_heads @functools.partial( @@ -627,8 +635,7 @@ def can_be_xla_fully_tiled(x, packing): def ragged_paged_attention( q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] # TODO(jevinjiang): create a write_to_kv_cache kernel! - k_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] - v_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] + kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] kv_lens: jax.Array, # i32[max_num_seqs] page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] cu_q_lens: jax.Array, # i32[max_num_seqs + 1] @@ -637,7 +644,7 @@ def ragged_paged_attention( sm_scale: float = 1.0, sliding_window: int | None = None, soft_cap: float | None = None, - mask_value: float = DEFAULT_MASK_VALUE, + mask_value: float | None = DEFAULT_MASK_VALUE, num_kv_pages_per_block: int = 16, num_queries_per_block: int = 128, vmem_limit_bytes: int | None = None, @@ -646,8 +653,7 @@ def ragged_paged_attention( Args: q: concatenated all sequences' queries. - k_pages: paged K cache. Normally in HBM. - v_pages: paged V cache. Normally in HBM. + kv_pages: paged K cache. Normally in HBM. kv_lens: padded kv lengths. Only the first num_seqs values are valid. page_indices: the first index indicates which page to use in the kv cache for each sequence. Only the first num_seqs values are valid. @@ -666,18 +672,22 @@ def ragged_paged_attention( Returns: The output of the attention. """ - check_inputs_shapes( - q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens, num_seqs - ) + check_inputs_shapes(q, kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs) + if mask_value is None: + mask_value = DEFAULT_MASK_VALUE _, num_q_heads, head_dim = q.shape - _, page_size, num_kv_heads, _ = k_pages.shape + _, page_size, num_combined_kv_heads, _ = kv_pages.shape + assert num_combined_kv_heads % 2 == 0 + num_kv_heads = num_combined_kv_heads // 2 num_q_per_blk = num_queries_per_block num_kv_pages_per_blk = num_kv_pages_per_block num_q_heads_per_kv_head = num_q_heads // num_kv_heads - num_q_blks = ceil_div(cu_q_lens[num_seqs[0]], num_q_per_blk) - num_q_heads_per_blk, num_kv_heads_per_blk = get_min_heads_per_blk( - num_q_heads, num_kv_heads, q.dtype, k_pages.dtype + num_q_blks = cdiv(cu_q_lens[num_seqs[0]], num_q_per_blk) + num_q_heads_per_blk, num_combined_kv_heads_per_blk = get_min_heads_per_blk( + num_q_heads, num_combined_kv_heads, q.dtype, kv_pages.dtype ) + assert num_combined_kv_heads_per_blk % 2 == 0 + num_kv_heads_per_blk = num_combined_kv_heads_per_blk // 2 assert num_q_heads_per_blk % num_q_heads_per_kv_head == 0 num_heads_blks = num_q_heads // num_q_heads_per_blk grid = (num_heads_blks, num_q_blks) @@ -692,7 +702,6 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_): in_specs = [ q_block_spec, pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), ] out_specs = q_block_spec lm_scratch = pltpu.VMEM( @@ -706,15 +715,14 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_): 2, # For double buffering during DMA copies. num_kv_pages_per_blk, page_size, - num_kv_heads_per_blk, + num_combined_kv_heads_per_blk, head_dim, ), - k_pages.dtype, + kv_pages.dtype, ) scratch_shapes = [ - double_buf_scratch, # k_bufs - double_buf_scratch, # v_bufs - pltpu.SemaphoreType.DMA((2, 2)), # [double_buffers, k_sem/v_sem] + double_buf_scratch, # kv_bufs + pltpu.SemaphoreType.DMA((2,)), # Semaphores for double buffers. lm_scratch, # l_ref lm_scratch, # m_ref ] @@ -753,4 +761,4 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_): # TODO(jevinjiang): Use f32 acc scratch for output! So we only need # to transfer output with desired dtype back to HBM. - return kernel(*scalar_prefetches, q, k_pages, v_pages).astype(q.dtype) + return kernel(*scalar_prefetches, q, kv_pages).astype(q.dtype) diff --git a/tests/pallas/tpu_ragged_paged_attention_test.py b/tests/pallas/tpu_ragged_paged_attention_test.py index 815c9dc6327f..b76d30bd1dcf 100644 --- a/tests/pallas/tpu_ragged_paged_attention_test.py +++ b/tests/pallas/tpu_ragged_paged_attention_test.py @@ -74,32 +74,26 @@ def _test_ragged_paged_attention( cu_q_lens = jnp.pad(cu_q_lens, (0, max_num_seq + 1 - cu_q_lens.shape[0])) kv_lens = jnp.pad(kv_lens, (0, max_num_seq - kv_lens.shape[0])) prng_key = jax.random.key(1234) - k0, k1, k2, k3 = jax.random.split(prng_key, 4) + k0, k1, k2 = jax.random.split(prng_key, 3) q = jax.random.normal( k0, (max_num_batched_tokens, num_q_heads, head_dim), dtype=dtype, ) - k_pages = jax.random.normal( + kv_pages = jax.random.normal( k1, - (num_pages, page_size, num_kv_heads, head_dim), - dtype=dtype, - ) - v_pages = jax.random.normal( - k2, - (num_pages, page_size, num_kv_heads, head_dim), + (num_pages, page_size, num_kv_heads * 2, head_dim), dtype=dtype, ) page_indices = jax.random.randint( - k3, (max_num_seq, pages_per_seq), 0, num_pages, dtype=jnp.int32 + k2, (max_num_seq, pages_per_seq), 0, num_pages, dtype=jnp.int32 ) num_seqs = jnp.array([len(seq_lens)], dtype=jnp.int32) validate_inputs_on_runtime( q, - k_pages, - v_pages, + kv_pages, kv_lens, page_indices, cu_q_lens, @@ -111,8 +105,7 @@ def _test_ragged_paged_attention( actual_num_q_tokens = cu_q_lens[num_seqs[0]] output = ragged_paged_attention( q, - k_pages, - v_pages, + kv_pages, kv_lens, page_indices, cu_q_lens, @@ -126,8 +119,7 @@ def _test_ragged_paged_attention( expected = ref_ragged_paged_attention( q, - k_pages, - v_pages, + kv_pages, kv_lens, page_indices, cu_q_lens, @@ -272,7 +264,7 @@ def test_ragged_paged_attention_mixed(self, dtype): @parameterized.product( num_seqs=[1, 5, 16], # TODO(jevinjiang): Support more num_heads! - num_heads=[(32, 8), (32, 16), (12, 2), (4, 4)], + num_heads=[(32, 8), (32, 16), (12, 2), (4, 4), (8, 1)], dtype=[jnp.float32, jnp.bfloat16], num_kv_pages_per_block=[4, 8], num_queries_per_block=[32, 64], From c7d04cc75a3aac39a677d318e4b82204a2f096b2 Mon Sep 17 00:00:00 2001 From: shuw Date: Thu, 27 Mar 2025 05:09:25 +0000 Subject: [PATCH 212/483] Improve based on review 2 --- jax/_src/cudnn/scaled_matmul_stablehlo.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/jax/_src/cudnn/scaled_matmul_stablehlo.py b/jax/_src/cudnn/scaled_matmul_stablehlo.py index b1d353e7bcd1..60cdbee7fa20 100644 --- a/jax/_src/cudnn/scaled_matmul_stablehlo.py +++ b/jax/_src/cudnn/scaled_matmul_stablehlo.py @@ -494,12 +494,9 @@ def quantize(x, config): assert config.global_scale.dtype == jnp.float32 SCALE_MAX = jnp.finfo(config.scale_type).max.astype(x.dtype) - prev_amax = config.global_scale * (MAX * SCALE_MAX) - scales_q = jnp.clip( - (amax / prev_amax) * SCALE_MAX, 0, SCALE_MAX - ) - scaled_x = x / scales_q + scales_q = jnp.clip(scales / config.global_scale, 0, SCALE_MAX) scales_q = scales_q.astype(config.scale_type) + scaled_x = x / scales_q.astype(jnp.float32) else: raise ValueError(f"Unrecognized mode: {config.mode}.") @@ -644,6 +641,9 @@ def scaled_dot_bwd(dimension_numbers, preferred_element_type, configs, res, g): grad_lhs = scaled_dot_general_transpose_lhs(*args, **lhs_kw_args) grad_rhs = scaled_dot_general_transpose_rhs(*args, **rhs_kw_args) + # We apply a Straight-Through Estimator (STE) with zero-out behavior: if + # inputs are clipped during quantization in fprop, their corresponding gradients + # are zeroed out; otherwise, they pass through unchanged. if configs[2].mode == "nvfp4": assert rhs.dtype == lhs.dtype MAX = jnp.finfo(configs[0].data_type).max.astype(lhs.dtype) From 0c1f4c155ec49ccd5cde85d3dacd8e0b7c7afb47 Mon Sep 17 00:00:00 2001 From: Mudit Gokhale Date: Wed, 26 Mar 2025 23:12:32 -0700 Subject: [PATCH 213/483] Remove backward compatibility logic for tool naming. PiperOrigin-RevId: 741030788 --- jax/collect_profile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/collect_profile.py b/jax/collect_profile.py index d1309e0c5bca..b355816772a1 100644 --- a/jax/collect_profile.py +++ b/jax/collect_profile.py @@ -91,7 +91,7 @@ def collect_profile(port: int, duration_in_ms: int, host: str, in root_trace_folder.iterdir()] latest_folder = max(trace_folders, key=os.path.getmtime) xplane = next(latest_folder.glob("*.xplane.pb")) - result, _ = convert.xspace_to_tool_data([xplane], "trace_viewer^", {}) + result, _ = convert.xspace_to_tool_data([xplane], "trace_viewer", {}) with gzip.open(str(latest_folder / "remote.trace.json.gz"), "wb") as fp: fp.write(result.encode("utf-8")) From e1762b0af6c5199d53141557bdf81eaef55bd4c5 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Thu, 27 Mar 2025 00:46:21 -0700 Subject: [PATCH 214/483] Assert unused variable in lax.all_to_all batching rule P.S. minor improvement to code readability PiperOrigin-RevId: 741051082 --- jax/_src/lax/parallel.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 221fe2a9e87a..28e6dbef4a2c 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -1109,15 +1109,15 @@ def _all_to_all_batcher(vals_in, dims_in, *, axis_name, split_axis, concat_axis, def _all_to_all_batched_collective(axis_data, vals_in, dims_in, axis_name, split_axis, concat_axis, axis_index_groups, tiled): - axis_size, frame_name = axis_data.size, axis_data.name if axis_index_groups is not None: raise NotImplementedError("Please open a feature request!") + axis_size, frame_name = axis_data.size, axis_data.name if isinstance(axis_name, (list, tuple)): axes_names = axis_name else: axes_names = [axis_name] - if axis_data.name not in axes_names: + if frame_name not in axes_names: return _all_to_all_batcher( vals_in, dims_in, axis_name=axis_name, split_axis=split_axis, concat_axis=concat_axis, axis_index_groups=axis_index_groups, tiled=tiled) @@ -1157,6 +1157,7 @@ def _all_to_all_batched_collective(axis_data, vals_in, dims_in, axis_index_groups=axis_index_groups, tiled=tiled) # Split out the local part into axis new_d (NOTE: d is already in axis 1) + assert d == 1 x = _splitaxis(split_axis, axis_size, x) new_d = split_axis concat_axis += (split_axis <= concat_axis) # Offset the existing axes by the new batch axis From 8bd956d96a6979bcead917d7d0f8593203888cfe Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 27 Mar 2025 03:02:36 -0700 Subject: [PATCH 215/483] [Pallas] Skip reads/writes from/to slices of kernel input/output buffers when the slices do not change between iterations of the grid loop that interprets kernels on CPU. PiperOrigin-RevId: 741082349 --- jax/_src/pallas/mosaic/interpret.py | 192 ++++++++++++++++------ tests/pallas/tpu_pallas_interpret_test.py | 127 ++++++++++---- 2 files changed, 237 insertions(+), 82 deletions(-) diff --git a/jax/_src/pallas/mosaic/interpret.py b/jax/_src/pallas/mosaic/interpret.py index 5acbabc673aa..9d7b03ad5589 100644 --- a/jax/_src/pallas/mosaic/interpret.py +++ b/jax/_src/pallas/mosaic/interpret.py @@ -1307,8 +1307,13 @@ def _compute_start_indices( jaxpr.jaxpr, *jaxpr.consts, *loop_idx, *args, compiler_params=compiler_params, interpret_params=interpret_params) if isinstance(block_mapping.indexing_mode, pallas_core.Blocked): - ret = tuple(i if b is pallas_core.mapped else b * i - for b, i in zip(block_mapping.block_shape, block_indices)) + ret = jnp.array( + tuple( + i if b is pallas_core.mapped else b * i + for b, i in zip(block_mapping.block_shape, block_indices) + ), + dtype=jnp.int32, + ) elif isinstance(block_mapping.indexing_mode, pallas_core.Unblocked): ret = block_indices else: @@ -1534,64 +1539,114 @@ def interpret_pallas_call( # Base case is always one iteration when grid is () num_iterations = 1 - def body(carry): - # The loop carry: (i, loop_idx) -- - # - i:int32 is the interation index - # - loop_idx: tuple[int32] are the program ids for each grid axis - i, loop_idx = carry - + def _get_local_grid_env(loop_idx): if grid_mapping.local_grid_env is not None: - local_grid_env = grid_mapping.local_grid_env(loop_idx, grid) + return grid_mapping.local_grid_env(loop_idx, grid) else: - local_grid_env = tuple( + return tuple( pallas_core.GridAxis(idx, b) for dim, (idx, b) in enumerate(zip(loop_idx, grid)) if dim not in grid_mapping.vmapped_dims ) - with pallas_core.grid_env(local_grid_env): - start_indices = [ + def body( + carry: tuple[ + jnp.int32, tuple[jnp.int32, ...], list[jnp.ndarray], list[jnp.ndarray] + ], + ): + """Performs a single iteration of `jaxpr` in the device grid. + + Execution of `jaxpr` is preceded by reading kernel input buffers and + followed by writing kernel output buffers. + + Args: + carry: (iteration_idx, loop_idx, prev_start_indices, cur_start_indices). + - iteration_idx is the interation index. + - loop_idx are the program ids for each grid axis. + - prev_start_indices is a rank-1 array that contains the start indices + for the slices of inputs and outputs processed in the previous loop + iteration. + - cur_start_indices is a rank-1 array that contains the start indices + for the slices of inputs and outputs processed in the current loop + iteration. + + Note that by carrying the previous *and* current start indices between + loop iterations, it suffices to compute only one list of start indices, + i.e. `next_start_indices` (see below), per iteration. + + Returns: + The carry for the next iteration. + """ + iteration_idx, loop_idx, prev_start_indices, cur_start_indices = carry + + with pallas_core.grid_env(_get_local_grid_env(loop_idx)): + next_loop_idx = _get_next_indices(grid, loop_idx) + next_start_indices = [ _compute_start_indices( - bm, loop_idx, *scalar_buffer_ids, compiler_params=compiler_params, - interpret_params=interpret_params) - for bm in grid_mapping.block_mappings] + bm, + next_loop_idx, + *scalar_buffer_ids, + compiler_params=compiler_params, + interpret_params=interpret_params, + ) + for bm in grid_mapping.block_mappings + ] + # Copy slices of the input to the kernel buffers. - # - # TODO(jburnim): Only copy slices when the index mapping has changed? - for j, var in enumerate(input_vars): - if _is_any(var.aval.memory_space): - continue + + def _store_slice_to_kernel_input(index, input_var): # Copy from the HBM buffer for the pallas_call input to the kernel # input buffer. # TODO(jburnim): Just use input_args[j] when the input is not aliased? transform = indexing.NDIndexer( - indices=tuple(indexing.ds(st, sz) if not iid else st - for st, sz, iid in zip(start_indices[j], - block_shapes[j], - is_indexing_dim[j])), - shape=input_args[j].shape, - int_indexer_shape=()) + indices=tuple( + indexing.ds(st, sz) if not iid else st + for st, sz, iid in zip( + cur_start_indices[index], + block_shapes[index], + is_indexing_dim[index], + ) + ), + shape=input_args[index].shape, + int_indexer_shape=(), + ) sliced_val = callback.io_callback( # TODO(jburnim): Pass source_info from the pallas_call, in case this # read is involved in a data race. get, - jax.ShapeDtypeStruct(var.aval.shape, var.aval.dtype), + jax.ShapeDtypeStruct(input_var.aval.shape, input_var.aval.dtype), device_id, TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], - input_buffer_ids[j], + input_buffer_ids[index], (transform,), - ordered=True) + ordered=True, + ) callback.io_callback( # TODO(jburnim): Pass source_info from the pallas_call, in case this # store is involved in a data race. store, (), device_id, - TPU_MEMORY_SPACE_IDXS[var.aval.memory_space], - input_ids[j], + TPU_MEMORY_SPACE_IDXS[input_var.aval.memory_space], + input_ids[index], (), sliced_val, - ordered=True) + ordered=True, + ) + + for j, var in enumerate(input_vars): + if _is_any(var.aval.memory_space): + continue + assert len(cur_start_indices[j].shape) == 1 + assert len(prev_start_indices[j].shape) == 1 + jax.lax.cond( + (iteration_idx == 0) + | jax.lax.reduce_or( + cur_start_indices[j] != prev_start_indices[j], axes=(0,) + ), + functools.partial(_store_slice_to_kernel_input, j, var), + lambda: None, + ) # Invoke the kernel. _interpret_jaxpr(jaxpr, *kernel_buffer_ids, @@ -1599,29 +1654,30 @@ def body(carry): interpret_params=interpret_params) # Copy from the kernel buffers to slices of the output in HBM. - # - # TODO(jburnim): Only copy if the index mapping will change in the - # next iteration (or if this is the last iteration)? - for j, var in enumerate(output_vars): - if _is_any(var.aval.memory_space): - continue + def _store_to_output_buffer(index, output_var): kernel_output_val = callback.io_callback( # TODO(jburnim): Pass source_info from the pallas_call, in case this # get is involved in a data race. get, - var.aval, + output_var.aval, device_id, - TPU_MEMORY_SPACE_IDXS[var.aval.memory_space], + TPU_MEMORY_SPACE_IDXS[output_var.aval.memory_space], kernel_output_ids[j], (), - ordered=True) + ordered=True, + ) transform = indexing.NDIndexer( - indices=tuple(indexing.ds(st, sz) if not iid else st - for st, sz, iid in zip(start_indices[num_inputs + j], - block_shapes[num_inputs + j], - is_indexing_dim[num_inputs + j])), - shape=output_vals[j].shape, - int_indexer_shape=()) + indices=tuple( + indexing.ds(st, sz) if not iid else st + for st, sz, iid in zip( + cur_start_indices[num_inputs + index], + block_shapes[num_inputs + index], + is_indexing_dim[num_inputs + index], + ) + ), + shape=output_vals[index].shape, + int_indexer_shape=(index), + ) callback.io_callback( # TODO(jburnim): Pass source_info from the pallas_call, in case this # store is involved in a data race. @@ -1629,18 +1685,52 @@ def body(carry): (), device_id, TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], - output_buffer_ids[j], + output_buffer_ids[index], (transform,), kernel_output_val, - ordered=True) + ordered=True, + ) - return i + 1, _get_next_indices(grid, loop_idx) + for j, var in enumerate(output_vars): + if _is_any(var.aval.memory_space): + continue + assert len(cur_start_indices[num_inputs + j].shape) == 1 + assert len(next_start_indices[num_inputs + j].shape) == 1 + jax.lax.cond( + (iteration_idx + 1 == num_iterations) + | jax.lax.reduce_or( + cur_start_indices[num_inputs + j] + != next_start_indices[num_inputs + j], + axes=(0,), + ), + functools.partial(_store_to_output_buffer, j, var), + lambda: None, + ) + return iteration_idx + 1, next_loop_idx, cur_start_indices, next_start_indices + + initial_loop_idx = (jnp.int32(0),) * len(grid) + with pallas_core.grid_env(_get_local_grid_env(initial_loop_idx)): + initial_start_indices = [ + _compute_start_indices( + bm, + initial_loop_idx, + *scalar_buffer_ids, + compiler_params=compiler_params, + interpret_params=interpret_params, + ) + for bm in grid_mapping.block_mappings + ] # TODO(jburnim): Handle parallel grid dimensions + megacore. _ = lax.while_loop( lambda carry: carry[0] < num_iterations, body, - (jnp.int32(0), (jnp.int32(0),) * len(grid)) + ( + jnp.int32(0), + initial_loop_idx, + initial_start_indices, # Previous start indices are ignored on the first iteration. + initial_start_indices, + ), ) # Read the output from the allocated output buffers. diff --git a/tests/pallas/tpu_pallas_interpret_test.py b/tests/pallas/tpu_pallas_interpret_test.py index 5b729f0fe07e..afb573f8cf44 100644 --- a/tests/pallas/tpu_pallas_interpret_test.py +++ b/tests/pallas/tpu_pallas_interpret_test.py @@ -18,24 +18,48 @@ contains only tests that do not use shard_map. """ -from absl.testing import absltest -from absl.testing import parameterized import functools +from absl.testing import absltest +from absl.testing import parameterized import jax from jax._src import test_util as jtu import jax._src.pallas.mosaic.interpret as mosaic_interpret from jax.experimental import pallas as pl from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp - import numpy as np jax.config.parse_flags_with_absl() +class CountStoreCallbacksContext(object): + """Wraps the I/O callback `store` into a callback that counts the number of calls to `store`.""" + + def __init__(self): + self._num_stores = 0 + self._saved = mosaic_interpret.store + + def __enter__(self): + def _store_callback(self, *args, **kwargs): + self._num_stores += 1 + return self._saved(*args, **kwargs) + + mosaic_interpret.store = functools.partial(_store_callback, self) + return self + + def __exit__(self, ty, value, traceback): + del ty, value, traceback + mosaic_interpret.store = self._saved + + @property + def num_stores(self): + return self._num_stores + + class InterpretTest(jtu.JaxTestCase): + def setUp(self): super().setUp() self.num_devices = jax.device_count() @@ -50,17 +74,18 @@ def matmul_kernel(x_ref, y_ref, z_ref): @jax.jit def matmul(x: jax.Array, y: jax.Array): return pl.pallas_call( - matmul_kernel, - out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype), - grid=(2, 2), - in_specs=[ - pl.BlockSpec((x.shape[0] // 2, x.shape[1]), lambda i, j: (i, 0)), - pl.BlockSpec((y.shape[0], y.shape[1] // 2), lambda i, j: (0, j)) - ], - out_specs=pl.BlockSpec( - (x.shape[0] // 2, y.shape[1] // 2), lambda i, j: (i, j), - ), - interpret=mosaic_interpret.TPUInterpretParams(), + matmul_kernel, + out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype), + grid=(2, 2), + in_specs=[ + pl.BlockSpec((x.shape[0] // 2, x.shape[1]), lambda i, j: (i, 0)), + pl.BlockSpec((y.shape[0], y.shape[1] // 2), lambda i, j: (0, j)), + ], + out_specs=pl.BlockSpec( + (x.shape[0] // 2, y.shape[1] // 2), + lambda i, j: (i, j), + ), + interpret=mosaic_interpret.TPUInterpretParams(), )(x, y) k1, k2 = jax.random.split(jax.random.key(0)) @@ -79,9 +104,11 @@ def block_dynamic_slice(x, starts, sizes): grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=1, grid=(1, 1), - in_specs=[pl.BlockSpec( - sizes, - lambda i, j, block_idx: (block_idx[0], block_idx[1]))], + in_specs=[ + pl.BlockSpec( + sizes, lambda i, j, block_idx: (block_idx[0], block_idx[1]) + ) + ], out_specs=pl.BlockSpec(sizes, lambda *_: (0, 0)), ) @@ -96,17 +123,21 @@ def block_dynamic_slice(x, starts, sizes): shape = (512, 512) x = jnp.reshape(jnp.arange(np.prod(shape), dtype=jnp.int32), shape) - result = block_dynamic_slice(x, starts=jnp.array([128, 256]), sizes=(128, 128)) - ref = jax.lax.dynamic_slice(x, start_indices=(128, 256), slice_sizes=(128, 128)) + result = block_dynamic_slice( + x, starts=jnp.array([128, 256]), sizes=(128, 128) + ) + ref = jax.lax.dynamic_slice( + x, start_indices=(128, 256), slice_sizes=(128, 128) + ) diff = jnp.max(jnp.abs(result - ref)) np.testing.assert_allclose(result, ref) def test_dynamic_grid_and_aliasing(self): - self.skipTest('Broken pending fix to extra reads/writes of inputs/outputs') def kernel(s_ref, x_ref, o_ref): o_ref[...] = x_ref[...] + s_ref[0].astype(x_ref.dtype) iters = jax.random.randint(jax.random.key(0), (), 10, 20, dtype=jnp.int32) + @jax.jit def f(s, x): return pl.pallas_call( @@ -119,11 +150,11 @@ def f(s, x): ], out_specs=pl.BlockSpec(x.shape, lambda i: (0, 0)), input_output_aliases={1: 0}, - interpret=mosaic_interpret.TPUInterpretParams() + interpret=mosaic_interpret.TPUInterpretParams(), )(s, x) s = jnp.array([1], dtype=jnp.int32) - x = jnp.arange(32 * 128.).reshape((32, 128)) + x = jnp.arange(32 * 128.0).reshape((32, 128)) y = f(s, x) # NOTE: No matter how many times the kernel body is run, the kernel input # buffer will only be written once by the pallas_call machinery, just @@ -136,6 +167,7 @@ def kernel(x_ref, o_ref, s_ref): @pl.when((pl.program_id(0) == 0) & (pl.program_id(1) == 0)) def _(): s_ref[0] = jnp.int32(0) + s = s_ref[0] s_ref[0] = s + 1 o_ref[:] = x_ref[:] + s.astype(x_ref.dtype) @@ -149,7 +181,8 @@ def _(): pl.BlockSpec(block_shape=(8, 128), index_map=lambda i, j: (i, j)), ], out_specs=pl.BlockSpec( - block_shape=(8, 128), index_map=lambda i, j: (j, i)), + block_shape=(8, 128), index_map=lambda i, j: (j, i) + ), scratch_shapes=(pltpu.SMEM((1,), jnp.int32),), input_output_aliases={0: 0}, interpret=mosaic_interpret.TPUInterpretParams(), @@ -184,7 +217,8 @@ def kernel_with_race(x_ref, o_ref, t_ref, sem): copy.wait() x = jnp.zeros((8, 128), jnp.float32) - y = pl.pallas_call(kernel_without_race, + y = pl.pallas_call( + kernel_without_race, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)], scratch_shapes=[ @@ -192,12 +226,14 @@ def kernel_with_race(x_ref, o_ref, t_ref, sem): pltpu.SemaphoreType.DMA, ], interpret=mosaic_interpret.TPUInterpretParams( - detect_races=True, dma_execution_mode=dma_execution_mode), + detect_races=True, dma_execution_mode=dma_execution_mode + ), )(x).block_until_ready() self.assertFalse(mosaic_interpret.races.races_found) np.testing.assert_allclose(y, x + 1.0) - pl.pallas_call(kernel_with_race, + pl.pallas_call( + kernel_with_race, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)], scratch_shapes=[ @@ -205,7 +241,8 @@ def kernel_with_race(x_ref, o_ref, t_ref, sem): pltpu.SemaphoreType.DMA, ], interpret=mosaic_interpret.TPUInterpretParams( - detect_races=True, dma_execution_mode=dma_execution_mode), + detect_races=True, dma_execution_mode=dma_execution_mode + ), )(x).block_until_ready() self.assertTrue(mosaic_interpret.races.races_found) @@ -228,8 +265,8 @@ def matmul(x: jax.Array, y: jax.Array): z = jax.jit(matmul)(x, y) np.testing.assert_array_equal(z, jnp.full_like(z, jnp.inf)) - lowered = jax.jit(matmul).lower(x, y).as_text(dialect="stablehlo") - self.assertNotIn("dot_general", lowered) + lowered = jax.jit(matmul).lower(x, y).as_text(dialect='stablehlo') + self.assertNotIn('dot_general', lowered) @parameterized.parameters('nan', 'zero') def test_uninitialized_memory(self, uninitialized_memory): @@ -250,7 +287,8 @@ def kernel(o1_ref, o2_ref, o3_ref, t1_ref, t2_ref): pltpu.VMEM((8, 128), jnp.int16), ], interpret=mosaic_interpret.TPUInterpretParams( - uninitialized_memory=uninitialized_memory), + uninitialized_memory=uninitialized_memory + ), )() if uninitialized_memory == 'nan': self.assertTrue(jnp.isnan(x).all()) @@ -261,6 +299,33 @@ def kernel(o1_ref, o2_ref, o3_ref, t1_ref, t2_ref): np.testing.assert_equal(np.array(y), 0) np.testing.assert_equal(np.array(z), 0) + def test_correct_number_of_stores(self): + def kernel(x_ref, s_ref, o_ref): + s = s_ref[0] + x_ref[:] += jax.lax.full_like(x_ref, s) + s_ref[0] = s + 1 + o_ref[:] = x_ref[:] + + def kernel_call(x, s): + return pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((16, 256), jnp.float32), + grid=(2, 2), + in_specs=[ + pl.BlockSpec((8, 256), lambda i, j: (i, 0)), + pl.BlockSpec(memory_space=pltpu.SMEM), + ], + out_specs=pl.BlockSpec((8, 256), lambda i, j: (i, 0)), + interpret=mosaic_interpret.TPUInterpretParams(), + )(x, s) + + with CountStoreCallbacksContext() as store_callbacks_counter: + result = jax.jit(kernel_call)( + jnp.zeros((16, 256), jnp.float32), jnp.zeros((1,), jnp.int32) + ) + np.testing.assert_allclose(result[::8, ::256], [[1.0], [5.0]]) + self.assertEqual(store_callbacks_counter.num_stores, 5) + -if __name__ == "__main__": +if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From 8689550376089e78f22f87305422ffc6aaf5ddb8 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 27 Mar 2025 05:17:17 -0700 Subject: [PATCH 216/483] Update XLA dependency to use revision http://github.com/openxla/xla/commit/95abd7942747bd5d1884b309baecdf5a93ff928a. PiperOrigin-RevId: 741114363 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 359048ffacbb..625f33a072f5 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "d0b25f9cd8222a348c9728f88e909c4e2c30991b" -XLA_SHA256 = "8cd70a67a56a8b18087fc4849908f52c95c6413eb7edc9f800fdff6304804fa4" +XLA_COMMIT = "95abd7942747bd5d1884b309baecdf5a93ff928a" +XLA_SHA256 = "f8472323ffe621ade5317091fdf9acd66aaf67660fedd3143a96d9a347e88bac" def repo(): tf_http_archive( From 875e4795c444071604afe441c0d0fe965ccb0d50 Mon Sep 17 00:00:00 2001 From: Ayaka Date: Thu, 27 Mar 2025 07:02:22 -0700 Subject: [PATCH 217/483] Update `test_util.get_tpu_version()` PiperOrigin-RevId: 741139032 --- jax/_src/test_util.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index c3c4a934dd0e..1cd9546a1655 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -435,10 +435,10 @@ def get_tpu_version() -> int: if device_under_test() != "tpu": raise ValueError("Device is not TPU") kind = jax.devices()[0].device_kind - if kind.endswith(' lite'): - kind = kind[:-len(' lite')] - assert kind[:-1] == "TPU v", kind - return int(kind[-1]) + match = re.match(r"TPU[^\d]*(\d+)", kind) + if match is None: + raise ValueError(f"Device kind {kind} is not supported") + return int(match.group(1)) def is_device_tpu_at_least(version: int) -> bool: if device_under_test() != "tpu": From 9932ff1f79e3488a6660b44c9390bf81dc6389f5 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 27 Mar 2025 07:27:43 -0700 Subject: [PATCH 218/483] Deprecate the contents of jax.lib.xla_extension. PiperOrigin-RevId: 741145943 --- CHANGELOG.md | 2 + jax/lib/xla_extension.py | 132 +++++++++++++++++++++++++++++++-------- 2 files changed, 108 insertions(+), 26 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 93bbe81b5e63..c8805599364d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,9 +22,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. instead. * Implemented host callback handlers for CPU and GPU devices using XLA's FFI and removed existing CPU/GPU handlers using XLA's custom call. + * All APIs in `jax.lib.xla_extension` are now deprecated. * Several previously-deprecated APIs have been removed, including: * From `jax.lib.xla_client`: `FftType`, `PaddingType`, `dtype_to_etype`, and `shape_from_pyval`. + * From `jax.lib.xla_extension`: `ArrayImpl`, `XlaRuntimeError`. ## jax 0.5.3 (Mar 19, 2025) diff --git a/jax/lib/xla_extension.py b/jax/lib/xla_extension.py index 52fe94e231d1..8f1b27070e98 100644 --- a/jax/lib/xla_extension.py +++ b/jax/lib/xla_extension.py @@ -14,42 +14,122 @@ from jax._src.lib import xla_extension as _xe -get_distributed_runtime_client = _xe.get_distributed_runtime_client -get_distributed_runtime_service = _xe.get_distributed_runtime_service -hlo_module_cost_analysis = _xe.hlo_module_cost_analysis -hlo_module_to_dot_graph = _xe.hlo_module_to_dot_graph -ifrt_proxy = _xe.ifrt_proxy -jax_jit = _xe.jax_jit -mlir = _xe.mlir -pmap_lib = _xe.pmap_lib -profiler = _xe.profiler -pytree = _xe.pytree -Device = _xe.Device -DistributedRuntimeClient = _xe.DistributedRuntimeClient -HloModule = _xe.HloModule -HloPrintOptions = _xe.HloPrintOptions -OpSharding = _xe.OpSharding -PjitFunctionCache = _xe.PjitFunctionCache -PjitFunction = _xe.PjitFunction -PmapFunction = _xe.PmapFunction - _deprecations = { - # Added Nov 20 2024 "ArrayImpl": ( - "jax.lib.xla_extension.ArrayImpl is deprecated; use jax.Array instead.", - _xe.ArrayImpl, + ( + "jax.lib.xla_extension.ArrayImpl has been removed; use jax.Array" + " instead." + ), + None, ), "XlaRuntimeError": ( - "jax.lib.xla_extension.XlaRuntimeError is deprecated; use jax.errors.JaxRuntimeError instead.", - _xe.XlaRuntimeError, + ( + "jax.lib.xla_extension.XlaRuntimeError has been removed; use" + " jax.errors.JaxRuntimeError instead." + ), + None, + ), + # Deprecated March 26 2025. + "DistributedRuntimeClient": ( + ( + "jax.lib.xla_extension.DistributedRuntimeClient is" + " deprecated; use jax.distributed instead." + ), + _xe.DistributedRuntimeClient, + ), + "get_distributed_runtime_client": ( + ( + "jax.lib.xla_extension.get_distributed_runtime_client is" + " deprecated; use jax.distributed instead." + ), + _xe.get_distributed_runtime_client, + ), + "get_distributed_runtime_service": ( + ( + "jax.lib.xla_extension.get_distributed_runtime_service is" + " deprecated; use jax.distributed instead." + ), + _xe.get_distributed_runtime_service, + ), + "Device": ( + "jax.lib.xla_extension.Device is deprecated; use jax.Device instead.", + _xe.Device, + ), + "PjitFunctionCache": ( + "jax.lib.xla_extension.PjitFunctionCache is deprecated.", + _xe.PjitFunctionCache, + ), + "ifrt_proxy": ( + "jax.lib.xla_extension.ifrt_proxy is deprecated.", + _xe.ifrt_proxy, + ), + "jax_jit": ( + "jax.lib.xla_extension.jax_jit is deprecated.", + _xe.jax_jit, + ), + "mlir": ("jax.lib.xla_extension.mlir is deprecated.", _xe.mlir), + "pmap_lib": ("jax.lib.xla_extension.pmap_lib is deprecated.", _xe.pmap_lib), + "profiler": ( + "jax.lib.xla_extension.profiler is deprecated.", + _xe.profiler, + ), + "pytree": ( + "jax.lib.xla_extension.pytree is deprecated.", + _xe.pytree, + ), + "hlo_module_cost_analysis": ( + "jax.lib.xla_extension.hlo_module_cost_analysis is deprecated.", + _xe.hlo_module_cost_analysis, + ), + "hlo_module_to_dot_graph": ( + "jax.lib.xla_extension.hlo_module_to_dot_graph is deprecated.", + _xe.hlo_module_to_dot_graph, + ), + "HloModule": ( + "jax.lib.xla_extension.HloModule is deprecated.", + _xe.HloModule, + ), + "HloPrintOptions": ( + "jax.lib.xla_extension.HloPrintOptions is deprecated.", + _xe.HloPrintOptions, + ), + "OpSharding": ( + "jax.lib.xla_extension.OpSharding is deprecated.", + _xe.OpSharding, + ), + "PjitFunction": ( + "jax.lib.xla_extension.PjitFunction is deprecated.", + _xe.PjitFunction, + ), + "PmapFunction": ( + "jax.lib.xla_extension.PmapFunction is deprecated.", + _xe.PmapFunction, ), } import typing as _typing if _typing.TYPE_CHECKING: - ArrayImpl = _xe.ArrayImpl - XlaRuntimeError = _xe.XlaRuntimeError + Device = _xe.Device + DistributedRuntimeClient = _xe.DistributedRuntimeClient + HloModule = _xe.HloModule + HloPrintOptions = _xe.HloPrintOptions + OpSharding = _xe.OpSharding + PjitFunction = _xe.PjitFunction + PjitFunctionCache = _xe.PjitFunctionCache + PmapFunction = _xe.PmapFunction + + get_distributed_runtime_client = _xe.get_distributed_runtime_client + get_distributed_runtime_service = _xe.get_distributed_runtime_service + hlo_module_cost_analysis = _xe.hlo_module_cost_analysis + hlo_module_to_dot_graph = _xe.hlo_module_to_dot_graph + ifrt_proxy = _xe.ifrt_proxy + jax_jit = _xe.jax_jit + mlir = _xe.mlir + pmap_lib = _xe.pmap_lib + profiler = _xe.profiler + pytree = _xe.pytree + else: from jax._src.deprecations import deprecation_getattr as _deprecation_getattr From 108c590b2f11ffd4ed7a75d884f907bb945ef05b Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 27 Mar 2025 08:10:00 -0700 Subject: [PATCH 219/483] Replace uses of deprecated `Shape::rank()` with: - `dimensions().size()` if it's OK for the result to be changed to an unsigned number, - `dimensions_size()` if it's important that the result is a signed number. This should be a pure refactoring that doesn't affect the code's behavior. Note that `rank()` returns `int64_t` and `dimensions().size()` returns `size_t`. Sometimes the change of the signedness is not desirable, and we use `dimensions_size()`, which returns `int`, in such cases. PiperOrigin-RevId: 741157851 --- jaxlib/xla/xla_compiler.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jaxlib/xla/xla_compiler.cc b/jaxlib/xla/xla_compiler.cc index 00f8b4c295a7..0098cc28160d 100644 --- a/jaxlib/xla/xla_compiler.cc +++ b/jaxlib/xla/xla_compiler.cc @@ -648,7 +648,7 @@ void BuildXlaCompilerSubmodule(nb::module_& m) { nb::arg("dimension")) .def("set_dynamic_dimension", &Shape::set_dynamic_dimension, nb::arg("dimension"), nb::arg("is_dynamic")) - .def("rank", &Shape::rank) + .def("rank", &Shape::dimensions_size) .def("to_serialized_proto", [](const Shape& shape) { ShapeProto proto = shape.ToProto(); From 99d92f26a6f2c01659a6afe3ad2744f86fa521fa Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 27 Mar 2025 08:10:06 -0700 Subject: [PATCH 220/483] Explicitly export mgpu runtime symbols. PiperOrigin-RevId: 741157879 --- jaxlib/mosaic/gpu/BUILD | 3 +++ 1 file changed, 3 insertions(+) diff --git a/jaxlib/mosaic/gpu/BUILD b/jaxlib/mosaic/gpu/BUILD index abe326474808..80a8f0e51080 100644 --- a/jaxlib/mosaic/gpu/BUILD +++ b/jaxlib/mosaic/gpu/BUILD @@ -111,9 +111,12 @@ cc_library( cc_library( name = "runtime", srcs = ["runtime.cc"], + # Linker may prune these symbols if they are not explicitly exported. + linkopts = ["-Wl,--export-dynamic-symbol='mosaic_gpu_*'"], deps = [ "@local_config_cuda//cuda:cuda_headers", ], + alwayslink = True, ) cc_library( From 083bdfc9cc2613086ee2273395f127b20598dc6d Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 27 Mar 2025 08:44:31 -0700 Subject: [PATCH 221/483] Add license headers to files that were missing them. PiperOrigin-RevId: 741167870 --- jax/_src/mesh_utils.py | 2 +- jax/experimental/mesh_utils.py | 2 +- jaxlib/ffi_helpers.h | 15 +++++++++++++++ jaxlib/gpu/triton.cc | 15 +++++++++++++++ jaxlib/gpu/triton_kernels.cc | 15 +++++++++++++++ jaxlib/gpu/triton_kernels.h | 15 +++++++++++++++ jaxlib/gpu/triton_utils.cc | 15 +++++++++++++++ jaxlib/gpu/triton_utils.h | 15 +++++++++++++++ jaxlib/mlir/_mlir_libs/register_jax_dialects.cc | 15 +++++++++++++++ .../dialect/gpu/integrations/c/attributes.cc | 15 +++++++++++++++ .../dialect/tpu/transforms/apply_vector_layout.cc | 15 +++++++++++++++ .../dialect/tpu/transforms/apply_vector_layout.h | 15 +++++++++++++++ .../transforms/apply_vector_layout_extensions.h | 15 +++++++++++++++ .../dialect/tpu/transforms/canonicalize_mosaic.cc | 15 +++++++++++++++ .../extensions/apply_vector_layout_extensions.cc | 15 +++++++++++++++ .../extensions/infer_vector_layout_extensions.cc | 15 +++++++++++++++ .../dialect/tpu/transforms/infer_memref_layout.cc | 15 +++++++++++++++ .../dialect/tpu/transforms/infer_memref_layout.h | 15 +++++++++++++++ .../transforms/infer_vector_layout_extensions.h | 15 +++++++++++++++ .../dialect/tpu/transforms/relayout_insertion.cc | 15 +++++++++++++++ jaxlib/mosaic/dialect/tpu/transforms/serde.h | 15 +++++++++++++++ jaxlib/mosaic/dialect/tpu/util.h | 15 +++++++++++++++ tests/mesh_utils_test.py | 2 +- 23 files changed, 303 insertions(+), 3 deletions(-) diff --git a/jax/_src/mesh_utils.py b/jax/_src/mesh_utils.py index ccc75af8c84f..c135919b14c5 100644 --- a/jax/_src/mesh_utils.py +++ b/jax/_src/mesh_utils.py @@ -1,4 +1,4 @@ -# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# Copyright 2021 The JAX Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/jax/experimental/mesh_utils.py b/jax/experimental/mesh_utils.py index 075e4e6eed48..58d20c331d5f 100644 --- a/jax/experimental/mesh_utils.py +++ b/jax/experimental/mesh_utils.py @@ -1,4 +1,4 @@ -# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# Copyright 2021 The JAX Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/jaxlib/ffi_helpers.h b/jaxlib/ffi_helpers.h index 5c6d80093df5..634a48fcffc7 100644 --- a/jaxlib/ffi_helpers.h +++ b/jaxlib/ffi_helpers.h @@ -1,3 +1,18 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #ifndef JAXLIB_FFI_HELPERS_H_ #define JAXLIB_FFI_HELPERS_H_ diff --git a/jaxlib/gpu/triton.cc b/jaxlib/gpu/triton.cc index 135410568f6b..d0c48eef492f 100644 --- a/jaxlib/gpu/triton.cc +++ b/jaxlib/gpu/triton.cc @@ -1,3 +1,18 @@ +/* Copyright 2022 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include #include #include diff --git a/jaxlib/gpu/triton_kernels.cc b/jaxlib/gpu/triton_kernels.cc index 22397ff908bc..6565b5b87be2 100644 --- a/jaxlib/gpu/triton_kernels.cc +++ b/jaxlib/gpu/triton_kernels.cc @@ -1,3 +1,18 @@ +/* Copyright 2023 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include "jaxlib/gpu/triton_kernels.h" #include diff --git a/jaxlib/gpu/triton_kernels.h b/jaxlib/gpu/triton_kernels.h index c3457093c4f8..d23a9a7395e0 100644 --- a/jaxlib/gpu/triton_kernels.h +++ b/jaxlib/gpu/triton_kernels.h @@ -1,3 +1,18 @@ +/* Copyright 2023 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #ifndef JAXLIB_GPU_TRITON_H_ #define JAXLIB_GPU_TRITON_H_ diff --git a/jaxlib/gpu/triton_utils.cc b/jaxlib/gpu/triton_utils.cc index b3a0779118de..f6bbe46c846d 100644 --- a/jaxlib/gpu/triton_utils.cc +++ b/jaxlib/gpu/triton_utils.cc @@ -1,3 +1,18 @@ +/* Copyright 2023 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include "jaxlib/gpu/triton_utils.h" #include diff --git a/jaxlib/gpu/triton_utils.h b/jaxlib/gpu/triton_utils.h index 0c286391e296..19c64a88c216 100644 --- a/jaxlib/gpu/triton_utils.h +++ b/jaxlib/gpu/triton_utils.h @@ -1,3 +1,18 @@ +/* Copyright 2023 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #ifndef JAXLIB_GPU_TRITON_UTILS_H_ #define JAXLIB_GPU_TRITON_UTILS_H_ diff --git a/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc b/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc index 1ba6fd9375df..0eb4a57a2f4b 100644 --- a/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc +++ b/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc @@ -1,3 +1,18 @@ +/* Copyright 2022 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + // Registers MLIR dialects used by JAX. // This module is called by mlir/__init__.py during initialization. #include diff --git a/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.cc b/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.cc index eac1d104f07f..259c37fe5d07 100644 --- a/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.cc +++ b/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.cc @@ -1,3 +1,18 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include "jaxlib/mosaic/dialect/gpu/integrations/c/attributes.h" #include diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 71924739595c..c9c8d22a1363 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -1,3 +1,18 @@ +/* Copyright 2021 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include "jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h" #include diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h index ed72a21028eb..bbf23a9f3844 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h @@ -1,3 +1,18 @@ +/* Copyright 2023 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_APPLY_VECTOR_LAYOUT_H_ #define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_APPLY_VECTOR_LAYOUT_H_ diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h index fded0d1dbfd7..72bd8ca370c8 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h @@ -1,3 +1,18 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANFORMS_APPLY_VECTOR_LAYOUT_EXTENSIONS_H_ #define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANFORMS_APPLY_VECTOR_LAYOUT_EXTENSIONS_H_ diff --git a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc index 6f56489ab4b1..a15947f48f78 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc @@ -1,3 +1,18 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include #include #include diff --git a/jaxlib/mosaic/dialect/tpu/transforms/extensions/apply_vector_layout_extensions.cc b/jaxlib/mosaic/dialect/tpu/transforms/extensions/apply_vector_layout_extensions.cc index 067f8e592e30..d2c149a47150 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/extensions/apply_vector_layout_extensions.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/extensions/apply_vector_layout_extensions.cc @@ -1,3 +1,18 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include "jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h" #include "llvm/ADT/StringMap.h" diff --git a/jaxlib/mosaic/dialect/tpu/transforms/extensions/infer_vector_layout_extensions.cc b/jaxlib/mosaic/dialect/tpu/transforms/extensions/infer_vector_layout_extensions.cc index 9dbf89724fef..e34ef7fcb261 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/extensions/infer_vector_layout_extensions.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/extensions/infer_vector_layout_extensions.cc @@ -1,3 +1,18 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include "jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h" #include diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc index fdfd04949bce..e2196088728f 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc @@ -1,3 +1,18 @@ +/* Copyright 2023 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include "jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h" #include diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h index f2ab7c624eb1..a6dd8ad1dbd3 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h @@ -1,3 +1,18 @@ +/* Copyright 2023 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_INFER_MEMREF_LAYOUT_H_ #define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_INFER_MEMREF_LAYOUT_H_ diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h index 36fa2ce8113f..a81e982f8e1a 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h @@ -1,3 +1,18 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_INFER_VECTOR_LAYOUT_EXTENSIONS_H_ #define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_INFER_VECTOR_LAYOUT_EXTENSIONS_H_ diff --git a/jaxlib/mosaic/dialect/tpu/transforms/relayout_insertion.cc b/jaxlib/mosaic/dialect/tpu/transforms/relayout_insertion.cc index 8aae7a10279a..6ddf8bd5ce66 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/relayout_insertion.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/relayout_insertion.cc @@ -1,3 +1,18 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include #include #include diff --git a/jaxlib/mosaic/dialect/tpu/transforms/serde.h b/jaxlib/mosaic/dialect/tpu/transforms/serde.h index ccb32131e519..5da8a9e316e0 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/serde.h +++ b/jaxlib/mosaic/dialect/tpu/transforms/serde.h @@ -1,3 +1,18 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_SERDE_H_ #define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_SERDE_H_ diff --git a/jaxlib/mosaic/dialect/tpu/util.h b/jaxlib/mosaic/dialect/tpu/util.h index dadd71800f3e..e2cf27811f09 100644 --- a/jaxlib/mosaic/dialect/tpu/util.h +++ b/jaxlib/mosaic/dialect/tpu/util.h @@ -1,3 +1,18 @@ +/* Copyright 2023 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_UTIL_H_ #define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_UTIL_H_ diff --git a/tests/mesh_utils_test.py b/tests/mesh_utils_test.py index 136b507942e7..28efb266b281 100644 --- a/tests/mesh_utils_test.py +++ b/tests/mesh_utils_test.py @@ -1,4 +1,4 @@ -# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# Copyright 2021 The JAX Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From e342f2dd602ea33cc395dbcd71e38191ebf593d3 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 27 Mar 2025 09:53:09 -0700 Subject: [PATCH 222/483] Update the minimum supported CuDNN version to 9.8 (previously 9.1). Announce maximum supported CUDA version 12.8 (previously 12.3). PiperOrigin-RevId: 741188737 --- CHANGELOG.md | 5 +++++ build/gpu-test-requirements.txt | 2 +- build/requirements_lock_3_10.txt | 8 ++++---- build/requirements_lock_3_11.txt | 8 ++++---- build/requirements_lock_3_12.txt | 8 ++++---- build/requirements_lock_3_13.txt | 8 ++++---- build/requirements_lock_3_13_ft.txt | 8 ++++---- jax_plugins/cuda/plugin_setup.py | 2 +- 8 files changed, 27 insertions(+), 22 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c8805599364d..cfd8c2eb340b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. ## Unreleased +* Changes + * The minimum CuDNN version is v9.8. + * JAX is now built using CUDA 12.8. All versions of CUDA 12.1 or newer remain + supported. + * Deprecations * {func}`jax.tree_util.build_tree` is deprecated. Use {func}`jax.tree.unflatten` diff --git a/build/gpu-test-requirements.txt b/build/gpu-test-requirements.txt index ff43f91ba90f..d0dda5cf526c 100644 --- a/build/gpu-test-requirements.txt +++ b/build/gpu-test-requirements.txt @@ -5,7 +5,7 @@ nvidia-cublas-cu12>=12.1.3.1 ; sys_platform == "linux" nvidia-cuda-cupti-cu12>=12.1.105 ; sys_platform == "linux" nvidia-cuda-nvcc-cu12>=12.6.85 ; sys_platform == "linux" nvidia-cuda-runtime-cu12>=12.1.105 ; sys_platform == "linux" -nvidia-cudnn-cu12>=9.1,<10.0 ; sys_platform == "linux" +nvidia-cudnn-cu12>=9.8,<10.0 ; sys_platform == "linux" nvidia-cufft-cu12>=11.0.2.54 ; sys_platform == "linux" nvidia-cusolver-cu12>=11.4.5.107 ; sys_platform == "linux" nvidia-cusparse-cu12>=12.1.0.106 ; sys_platform == "linux" diff --git a/build/requirements_lock_3_10.txt b/build/requirements_lock_3_10.txt index 6ed6b59aa584..8bf5293bd948 100644 --- a/build/requirements_lock_3_10.txt +++ b/build/requirements_lock_3_10.txt @@ -410,10 +410,10 @@ nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 # via -r build/test-requirements.txt -nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \ - --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ - --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ - --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef +nvidia-cudnn-cu12==9.8.0.87 ; sys_platform == "linux" \ + --hash=sha256:b4b5cfddc32aa4180f9d390ee99e9a9f55a89e7087329b41aba4319327e22466 \ + --hash=sha256:b883faeb2f6f15dba7bbb6756eab6a0d9cecb59db5b0fa07577b9cfa24cd99f4 \ + --hash=sha256:d6b02cd0e3e24aa31d0193a8c39fec239354360d7d81055edddb69f35d53a4c8 # via -r build/test-requirements.txt nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ diff --git a/build/requirements_lock_3_11.txt b/build/requirements_lock_3_11.txt index 8446e8361505..487346ab6d12 100644 --- a/build/requirements_lock_3_11.txt +++ b/build/requirements_lock_3_11.txt @@ -405,10 +405,10 @@ nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 # via -r build/test-requirements.txt -nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \ - --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ - --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ - --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef +nvidia-cudnn-cu12==9.8.0.87 ; sys_platform == "linux" \ + --hash=sha256:b4b5cfddc32aa4180f9d390ee99e9a9f55a89e7087329b41aba4319327e22466 \ + --hash=sha256:b883faeb2f6f15dba7bbb6756eab6a0d9cecb59db5b0fa07577b9cfa24cd99f4 \ + --hash=sha256:d6b02cd0e3e24aa31d0193a8c39fec239354360d7d81055edddb69f35d53a4c8 # via -r build/test-requirements.txt nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ diff --git a/build/requirements_lock_3_12.txt b/build/requirements_lock_3_12.txt index 0436ab6dd486..e2f76cab8abc 100644 --- a/build/requirements_lock_3_12.txt +++ b/build/requirements_lock_3_12.txt @@ -405,10 +405,10 @@ nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 # via -r build/test-requirements.txt -nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \ - --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ - --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ - --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef +nvidia-cudnn-cu12==9.8.0.87 ; sys_platform == "linux" \ + --hash=sha256:b4b5cfddc32aa4180f9d390ee99e9a9f55a89e7087329b41aba4319327e22466 \ + --hash=sha256:b883faeb2f6f15dba7bbb6756eab6a0d9cecb59db5b0fa07577b9cfa24cd99f4 \ + --hash=sha256:d6b02cd0e3e24aa31d0193a8c39fec239354360d7d81055edddb69f35d53a4c8 # via -r build/test-requirements.txt nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ diff --git a/build/requirements_lock_3_13.txt b/build/requirements_lock_3_13.txt index e74d40b798f4..403d0ad8a061 100644 --- a/build/requirements_lock_3_13.txt +++ b/build/requirements_lock_3_13.txt @@ -460,10 +460,10 @@ nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 # via -r build/test-requirements.txt -nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \ - --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ - --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ - --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef +nvidia-cudnn-cu12==9.8.0.87 ; sys_platform == "linux" \ + --hash=sha256:b4b5cfddc32aa4180f9d390ee99e9a9f55a89e7087329b41aba4319327e22466 \ + --hash=sha256:b883faeb2f6f15dba7bbb6756eab6a0d9cecb59db5b0fa07577b9cfa24cd99f4 \ + --hash=sha256:d6b02cd0e3e24aa31d0193a8c39fec239354360d7d81055edddb69f35d53a4c8 # via -r build/test-requirements.txt nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ diff --git a/build/requirements_lock_3_13_ft.txt b/build/requirements_lock_3_13_ft.txt index e7a2968e981e..5157706c00e8 100644 --- a/build/requirements_lock_3_13_ft.txt +++ b/build/requirements_lock_3_13_ft.txt @@ -413,10 +413,10 @@ nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 # via -r build/test-requirements.txt -nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \ - --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ - --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ - --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef +nvidia-cudnn-cu12==9.8.0.87 ; sys_platform == "linux" \ + --hash=sha256:b4b5cfddc32aa4180f9d390ee99e9a9f55a89e7087329b41aba4319327e22466 \ + --hash=sha256:b883faeb2f6f15dba7bbb6756eab6a0d9cecb59db5b0fa07577b9cfa24cd99f4 \ + --hash=sha256:d6b02cd0e3e24aa31d0193a8c39fec239354360d7d81055edddb69f35d53a4c8 # via -r build/test-requirements.txt nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ diff --git a/jax_plugins/cuda/plugin_setup.py b/jax_plugins/cuda/plugin_setup.py index ce31684de46f..db9928f6cf61 100644 --- a/jax_plugins/cuda/plugin_setup.py +++ b/jax_plugins/cuda/plugin_setup.py @@ -57,7 +57,7 @@ def has_ext_modules(self): "nvidia-cuda-cupti-cu12>=12.1.105", "nvidia-cuda-nvcc-cu12>=12.6.85", "nvidia-cuda-runtime-cu12>=12.1.105", - "nvidia-cudnn-cu12>=9.1,<10.0", + "nvidia-cudnn-cu12>=9.8,<10.0", "nvidia-cufft-cu12>=11.0.2.54", "nvidia-cusolver-cu12>=11.4.5.107", "nvidia-cusparse-cu12>=12.1.0.106", From 3c81b184a7b169827069451373f671fc42543c51 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 27 Mar 2025 13:09:22 -0400 Subject: [PATCH 223/483] Add sm_100 and sm_120 to the list of CUDA GPU achitectures for which we compile. --- .bazelrc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.bazelrc b/.bazelrc index 642fb15ed541..76f72b0848a9 100644 --- a/.bazelrc +++ b/.bazelrc @@ -136,7 +136,7 @@ build:cuda --repo_env TF_NEED_CUDA=1 build:cuda --repo_env TF_NCCL_USE_STUB=1 # "sm" means we emit only cubin, which is forward compatible within a GPU generation. # "compute" means we emit both cubin and PTX, which is larger but also forward compatible to future GPU generations. -build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,compute_90" +build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,sm_90,sm_100,compute_120" build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain build:cuda --@local_config_cuda//:enable_cuda From 0dbc1222657e318c31edad50e1f567b835574c0f Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 27 Mar 2025 10:11:35 -0700 Subject: [PATCH 224/483] Add the `jax` wheel as a required dependency for running the Bazel CUDA non RBE tests Since https://github.com/jax-ml/jax/pull/27113, the wheel is tested when `--//jax:build_jaxlib=false`. Previously, we could depend on the source repository. Fixes https://github.com/jax-ml/jax/actions/runs/14108610313/job/39521951667 PiperOrigin-RevId: 741195252 --- .github/workflows/bazel_cuda_non_rbe.yml | 1 + .github/workflows/wheel_tests_continuous.yml | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/bazel_cuda_non_rbe.yml b/.github/workflows/bazel_cuda_non_rbe.yml index 0b0e1cb62497..3d15f4211a3f 100644 --- a/.github/workflows/bazel_cuda_non_rbe.yml +++ b/.github/workflows/bazel_cuda_non_rbe.yml @@ -79,6 +79,7 @@ jobs: continue-on-error: true run: >- mkdir -p $(pwd)/dist && + gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ && gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/ diff --git a/.github/workflows/wheel_tests_continuous.yml b/.github/workflows/wheel_tests_continuous.yml index ecdf43b133cc..4b6e1e0a8712 100644 --- a/.github/workflows/wheel_tests_continuous.yml +++ b/.github/workflows/wheel_tests_continuous.yml @@ -148,7 +148,7 @@ jobs: # build job fails. E.g Windows build job fails but everything else succeeds. In this case, we # still want to run the tests for other platforms. if: ${{ !cancelled() }} - needs: [build-jaxlib-artifact, build-cuda-artifacts] + needs: [build-jax-artifact, build-jaxlib-artifact, build-cuda-artifacts] uses: ./.github/workflows/bazel_cuda_non_rbe.yml strategy: fail-fast: false # don't cancel all jobs on failure From 18521fef08f0d42f6001141abb793998323f72b3 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 25 Mar 2025 14:45:41 -0700 Subject: [PATCH 225/483] Deprecate jax.tree_* aliases --- CHANGELOG.md | 4 ++++ jax/__init__.py | 50 ++++++++++++++++--------------------------------- 2 files changed, 20 insertions(+), 34 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cfd8c2eb340b..5785f6193065 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * From `jax.lib.xla_client`: `FftType`, `PaddingType`, `dtype_to_etype`, and `shape_from_pyval`. * From `jax.lib.xla_extension`: `ArrayImpl`, `XlaRuntimeError`. + * From `jax`: `jax.treedef_is_leaf`, `jax.tree_flatten`, `jax.tree_map`, + `jax.tree_leaves`, `jax.tree_structure`, `jax.tree_transpose`, and + `jax.tree_unflatten`. Replacements can be found in {mod}`jax.tree` or + {mod}`jax.tree_util`. ## jax 0.5.3 (Mar 19, 2025) diff --git a/jax/__init__.py b/jax/__init__.py index 988c224e4772..32ae955ae5b8 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -141,16 +141,6 @@ make_array_from_process_local_data as make_array_from_process_local_data, ) -from jax._src.tree_util import ( - tree_map as _deprecated_tree_map, - treedef_is_leaf as _deprecated_treedef_is_leaf, - tree_flatten as _deprecated_tree_flatten, - tree_leaves as _deprecated_tree_leaves, - tree_structure as _deprecated_tree_structure, - tree_transpose as _deprecated_tree_transpose, - tree_unflatten as _deprecated_tree_unflatten, -) - # These submodules are separate because they are in an import cycle with # jax and rely on the names imported above. from jax import custom_derivatives as custom_derivatives @@ -184,54 +174,46 @@ del _ccache _deprecations = { - # Added July 2022 + # Finalized 2025-03-25; remove after 2025-06-25 "treedef_is_leaf": ( - "jax.treedef_is_leaf is deprecated: use jax.tree_util.treedef_is_leaf.", - _deprecated_treedef_is_leaf + "jax.treedef_is_leaf was removed in JAX v0.6.0: use jax.tree_util.treedef_is_leaf.", + None ), "tree_flatten": ( - "jax.tree_flatten is deprecated: use jax.tree.flatten (jax v0.4.25 or newer) " + "jax.tree_flatten was removed in JAX v0.6.0: use jax.tree.flatten (jax v0.4.25 or newer) " "or jax.tree_util.tree_flatten (any JAX version).", - _deprecated_tree_flatten + None ), "tree_leaves": ( - "jax.tree_leaves is deprecated: use jax.tree.leaves (jax v0.4.25 or newer) " + "jax.tree_leaves was removed in JAX v0.6.0: use jax.tree.leaves (jax v0.4.25 or newer) " "or jax.tree_util.tree_leaves (any JAX version).", - _deprecated_tree_leaves + None ), "tree_structure": ( - "jax.tree_structure is deprecated: use jax.tree.structure (jax v0.4.25 or newer) " + "jax.tree_structure was removed in JAX v0.6.0: use jax.tree.structure (jax v0.4.25 or newer) " "or jax.tree_util.tree_structure (any JAX version).", - _deprecated_tree_structure + None ), "tree_transpose": ( - "jax.tree_transpose is deprecated: use jax.tree.transpose (jax v0.4.25 or newer) " + "jax.tree_transpose was removed in JAX v0.6.0: use jax.tree.transpose (jax v0.4.25 or newer) " "or jax.tree_util.tree_transpose (any JAX version).", - _deprecated_tree_transpose + None ), "tree_unflatten": ( - "jax.tree_unflatten is deprecated: use jax.tree.unflatten (jax v0.4.25 or newer) " + "jax.tree_unflatten was removed in JAX v0.6.0: use jax.tree.unflatten (jax v0.4.25 or newer) " "or jax.tree_util.tree_unflatten (any JAX version).", - _deprecated_tree_unflatten + None ), - # Added Feb 28, 2024 "tree_map": ( - "jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) " + "jax.tree_map was removed in JAX v0.6.0: use jax.tree.map (jax v0.4.25 or newer) " "or jax.tree_util.tree_map (any JAX version).", - _deprecated_tree_map + None ), } import typing as _typing if _typing.TYPE_CHECKING: - from jax._src.tree_util import treedef_is_leaf as treedef_is_leaf - from jax._src.tree_util import tree_flatten as tree_flatten - from jax._src.tree_util import tree_leaves as tree_leaves - from jax._src.tree_util import tree_map as tree_map - from jax._src.tree_util import tree_structure as tree_structure - from jax._src.tree_util import tree_transpose as tree_transpose - from jax._src.tree_util import tree_unflatten as tree_unflatten - + pass else: from jax._src.deprecations import deprecation_getattr as _deprecation_getattr __getattr__ = _deprecation_getattr(__name__, _deprecations) From 289221af8be2979d9a1e25c7e61e1eee18274948 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 27 Mar 2025 10:23:08 -0700 Subject: [PATCH 226/483] Use h100x2 for tests rather than p100x2. PiperOrigin-RevId: 741199510 --- tests/BUILD | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/BUILD b/tests/BUILD index 1baeb4f83af7..1d021b0c7110 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -97,7 +97,7 @@ jax_multiplatform_test( "gpu", ], enable_configs = [ - "gpu_p100x2", + "gpu_h100x2", ], env = { "PYTHONWARNINGS": "default", # TODO(b/394123878): protobuf, via TensorFlow, issues a Python warning under Python 3.12+ sometimes. @@ -190,7 +190,7 @@ jax_multiplatform_test( name = "ffi_test", srcs = ["ffi_test.py"], enable_configs = [ - "gpu_p100x2", + "gpu_h100x2", ], # TODO(dfm): Remove after removal of jex.ffi imports. deps = ["//jax:extend"], @@ -274,7 +274,7 @@ jax_multiplatform_test( srcs = ["memories_test.py"], enable_configs = [ "cpu", - "gpu_p100x2", + "gpu_h100x2", "tpu_v3_2x2", "tpu_v4_2x2", "tpu_v5p_2x2", @@ -301,7 +301,7 @@ jax_multiplatform_test( "gpu_p100x2_shardy", "tpu_v3_2x2_shardy", "tpu_v3_2x2", - "gpu_p100x2", + "gpu_h100x2", ], shard_count = { "cpu": 5, @@ -725,7 +725,7 @@ jax_multiplatform_test( "cpu", ], enable_configs = [ - "gpu_p100x2", + "gpu_h100x2", "gpu_p100x2_shardy", "gpu_p100x2_pjrt_c_api", ], @@ -766,7 +766,7 @@ jax_multiplatform_test( srcs = ["multibackend_test.py"], enable_configs = [ "tpu_v3_2x2", - "gpu_p100x2", + "gpu_h100x2", ], ) From 1719fa0d5bd0d95be54a9327ae6dfdff142dafbe Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Thu, 27 Mar 2025 10:26:35 -0700 Subject: [PATCH 227/483] Make sure array is copied under this situation: ``` x = np.arange(1000) y = jax.device_put(x, device=jax.devices()[0], may_alias=False, donate=False) z = jax.device_put(y, device=jax.devices()[0], may_alias=False, donate=False) ``` This condition will be true after this change `z.unsafe_buffer_pointer() != y.unsafe_buffer_pointer()` Also lift the restrictions that CopyToMemorySpace doesn't work sometimes for matching src+dest memory spaces. We can always bounce through the host if there is no more efficient copy. PiperOrigin-RevId: 741200853 --- jax/_src/dispatch.py | 3 +++ jax/_src/interpreters/pxla.py | 2 +- jaxlib/xla/py_values.cc | 13 ++++++++----- jaxlib/xla/xla_client.py | 2 +- tests/pjit_test.py | 13 +++++++++++++ 5 files changed, 26 insertions(+), 7 deletions(-) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 2330f7628966..d205f860b214 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -495,6 +495,9 @@ def _device_put_sharding_impl(x, aval, device, copy): return _DeferredShardArg(x, x.sharding, aval, x.committed, copy) elif is_single_device_sharding(x.sharding): device = x.sharding._device_assignment[0] if device is None else device + if copy == CopySemantics.COPY: + return xc.batched_device_put(aval, SingleDeviceSharding(device), [x], + [device], True, True) return pxla.batched_device_put(aval, SingleDeviceSharding(device), [x], [device]) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 6f95b1b72281..51854b457b37 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -257,7 +257,7 @@ def _shard_abstract_array(size, axis: int, x): raise ValueError(f"Axis size {size} does not match dimension {axis} of " f"shape {x.shape}") except IndexError: - raise ValueError("Cannot split a {x.dim}D value along axis {axis}") from None + raise ValueError(f"Cannot split a {x.dim}D value along axis {axis}") from None if config.pmap_no_rank_reduction.value: return x.update(shape=tuple_update(x.shape, axis, 1)) else: diff --git a/jaxlib/xla/py_values.cc b/jaxlib/xla/py_values.cc index 1c7db0bec13a..e13a38197c0a 100644 --- a/jaxlib/xla/py_values.cc +++ b/jaxlib/xla/py_values.cc @@ -418,6 +418,7 @@ absl::StatusOr HandlePyArray( } if (ifrt_array->sharding().devices()->devices().front() == to_device && + options.allow_zero_copy && (!to_memory_kind.memory_kind().has_value() || !ifrt_array->sharding().memory_kind().memory_kind().has_value() || ifrt_array->sharding().memory_kind() == to_memory_kind)) { @@ -426,15 +427,17 @@ absl::StatusOr HandlePyArray( return [result = std::move(result)]() mutable { return std::move(result); }; } else { return [ifrt_array = tsl::FormRef(ifrt_array), to_device, to_memory_kind, - owning_pybuffer = py_array.weak_type()]() mutable + owning_pybuffer = py_array.weak_type(), + allow_zero_copy = options.allow_zero_copy]() mutable -> absl::StatusOr { auto* ifrt_client = ifrt_array->client(); TF_ASSIGN_OR_RETURN( auto copied_ifrt_arrays, - ifrt_client->CopyArrays(absl::MakeSpan(&ifrt_array, 1), - ifrt_client->MakeDeviceList({to_device}), - to_memory_kind, - ifrt::ArrayCopySemantics::kReuseInput)); + ifrt_client->CopyArrays( + absl::MakeSpan(&ifrt_array, 1), + ifrt_client->MakeDeviceList({to_device}), to_memory_kind, + allow_zero_copy ? ifrt::ArrayCopySemantics::kReuseInput + : ifrt::ArrayCopySemantics::kAlwaysCopy)); return DevicePutResult(std::move(copied_ifrt_arrays[0]), std::move(owning_pybuffer)); }; diff --git a/jaxlib/xla/xla_client.py b/jaxlib/xla/xla_client.py index 30e8443276c8..776a22444208 100644 --- a/jaxlib/xla/xla_client.py +++ b/jaxlib/xla/xla_client.py @@ -50,7 +50,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version. -_version = 322 +_version = 323 # An internal increasing version number for protecting jaxlib code against # ifrt changes. diff --git a/tests/pjit_test.py b/tests/pjit_test.py index d72ecc98e771..aa5afccb38d4 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -63,6 +63,7 @@ from jax._src import xla_bridge from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension +from jax._src.lib import jaxlib_extension_version from jax._src.util import curry, unzip2 config.parse_flags_with_absl() @@ -1400,6 +1401,18 @@ def test_zero_literal_equality(self): self.assertIn("stablehlo.constant dense<0.000000e+00>", ir) self.assertIn("stablehlo.constant dense<-0.000000e+00>", ir) + def test_device_put_copy_donate(self): + if jaxlib_extension_version < 323: + raise unittest.SkipTest("Copy not supported in device put.") + x = np.arange(1000) + y = jax.device_put(x, device=jax.devices()[0], may_alias=False, donate=False) + z = jax.device_put(y, device=jax.devices()[0], may_alias=False, donate=False) + a = jax.jit(lambda y: y * 2, donate_argnums=0)(y) + self.assertDeleted(y) + self.assertNotDeleted(z) + self.assertArraysEqual(a, x * 2) + + @jtu.pytest_mark_if_available('multiaccelerator') class CustomPartitionerTest(jtu.JaxTestCase): From 3f8e1925f7f47a9aac176feb6c57028f594a5e17 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 27 Mar 2025 10:30:49 -0700 Subject: [PATCH 228/483] Remove CUDA 12.3 from the CUDA test matrix Also, update the Docker image to one with cudnn 12.8 PiperOrigin-RevId: 741202254 --- .github/workflows/pytest_cuda.yml | 7 +++---- .github/workflows/wheel_tests_continuous.yml | 4 ++-- .github/workflows/wheel_tests_nightly_release.yml | 2 +- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/.github/workflows/pytest_cuda.yml b/.github/workflows/pytest_cuda.yml index b3d1b15a0052..671af873b48d 100644 --- a/.github/workflows/pytest_cuda.yml +++ b/.github/workflows/pytest_cuda.yml @@ -54,12 +54,11 @@ jobs: run-tests: defaults: run: - # Explicitly set the shell to bash + # Set the shell to bash as GitHub actions run with /bin/sh by default shell: bash runs-on: ${{ inputs.runner }} - # TODO: Update to the generic ML ecosystem test containers when they are ready. - container: ${{ (contains(inputs.cuda, '12.3') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.3-cudnn9.1:latest') || - (contains(inputs.cuda, '12.1') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.1-cudnn9.1:latest') || + # Test the oldest and newest supported CUDA versions. + container: ${{ (contains(inputs.cuda, '12.1') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.1-cudnn9.8:latest') || (contains(inputs.cuda, '12.8') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.8-cudnn9.8:latest') }} name: "Pytest CUDA (${{ inputs.runner }}, CUDA ${{ inputs.cuda }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})" diff --git a/.github/workflows/wheel_tests_continuous.yml b/.github/workflows/wheel_tests_continuous.yml index 4b6e1e0a8712..f48c39bf4721 100644 --- a/.github/workflows/wheel_tests_continuous.yml +++ b/.github/workflows/wheel_tests_continuous.yml @@ -111,9 +111,9 @@ jobs: matrix: # Python values need to match the matrix stategy in the artifact build jobs above # See exlusions for what is fully tested - runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu","linux-x86-a4-224-b200-1gpu"] + runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu", "linux-x86-a4-224-b200-1gpu"] python: ["3.10",] - cuda: ["12.1","12.3","12.8"] + cuda: ["12.1", "12.8"] enable-x64: [1, 0] exclude: # L4 does not run on cuda 12.8 but tests other configs diff --git a/.github/workflows/wheel_tests_nightly_release.yml b/.github/workflows/wheel_tests_nightly_release.yml index 9cd48c925cf3..fd4a52d296e0 100644 --- a/.github/workflows/wheel_tests_nightly_release.yml +++ b/.github/workflows/wheel_tests_nightly_release.yml @@ -58,7 +58,7 @@ jobs: # that build the wheels. runner: ["linux-x86-g2-48-l4-4gpu"] python: ["3.10","3.11", "3.12", "3.13", "3.13-nogil"] - cuda: ["12.3", "12.1"] + cuda: ["12.1", "12.8"] enable-x64: [0] name: "Pytest CUDA (JAX artifacts version = ${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }})" with: From a61785d2b6fbee58736ff5f570234894bc6d17d9 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 27 Mar 2025 10:49:41 -0700 Subject: [PATCH 229/483] Run include_cleaner over JAX C++ code. PiperOrigin-RevId: 741208842 --- jaxlib/BUILD | 1 + jaxlib/cuda/BUILD | 9 +++++++- jaxlib/cuda/versions_helpers.cc | 1 + jaxlib/gpu/BUILD | 1 + jaxlib/gpu/blas.cc | 2 +- jaxlib/gpu/gpu_kernel_helpers.cc | 5 +++- jaxlib/gpu/gpu_kernel_helpers.h | 3 +-- jaxlib/gpu/gpu_plugin_extension.cc | 1 + jaxlib/gpu/make_batch_pointers.cu.cc | 1 + jaxlib/gpu/prng.cc | 1 + jaxlib/gpu/prng_kernels.cc | 4 ---- jaxlib/gpu/prng_kernels.cu.cc | 3 +-- jaxlib/gpu/prng_kernels.h | 2 -- jaxlib/gpu/py_client_gpu.h | 1 - jaxlib/gpu/rnn.cc | 2 +- jaxlib/gpu/rnn_kernels.cc | 5 ++++ jaxlib/gpu/rnn_kernels.h | 1 + jaxlib/gpu/solver.cc | 2 +- jaxlib/gpu/sparse.cc | 13 +++++------ jaxlib/gpu/sparse_kernels.cc | 10 ++++---- jaxlib/gpu/sparse_kernels.h | 7 ++---- jaxlib/gpu/triton.cc | 15 +++++++----- jaxlib/gpu/triton_kernels.cc | 4 +++- jaxlib/gpu/triton_kernels.h | 3 +-- jaxlib/gpu/triton_utils.cc | 1 + jaxlib/gpu/triton_utils.h | 1 - jaxlib/gpu/vendor.h | 1 + jaxlib/kernel_helpers.h | 2 +- .../mlir/_mlir_libs/register_jax_dialects.cc | 23 ++++++++++--------- jaxlib/mlir/_mlir_libs/tpu_ext.cc | 5 +--- jaxlib/mlir/_mlir_libs/triton_ext.cc | 1 + jaxlib/mosaic/BUILD | 1 + jaxlib/mosaic/dialect/gpu/BUILD | 4 ++-- .../dialect/gpu/integrations/c/gpu_dialect.h | 2 +- jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc | 2 +- jaxlib/mosaic/dialect/gpu/mosaic_gpu.h | 4 +--- jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc | 4 ++-- .../dialect/tpu/integrations/c/tpu_dialect.cc | 5 +++- jaxlib/mosaic/dialect/tpu/layout.cc | 1 - jaxlib/mosaic/dialect/tpu/layout.h | 4 ++-- jaxlib/mosaic/dialect/tpu/tpu_dialect.cc | 8 ++----- jaxlib/mosaic/dialect/tpu/tpu_dialect.h | 11 ++++----- jaxlib/mosaic/dialect/tpu/tpu_ops.cc | 1 + .../tpu/transforms/apply_vector_layout.cc | 1 - .../tpu/transforms/canonicalize_mosaic.cc | 13 ++++------- .../dialect/tpu/transforms/communication.cc | 7 ++++-- .../tpu/transforms/infer_memref_layout.cc | 1 - .../tpu/transforms/infer_vector_layout.cc | 3 --- .../transforms/memory_space_specialization.cc | 2 ++ jaxlib/mosaic/dialect/tpu/transforms/serde.h | 1 + jaxlib/mosaic/dialect/tpu/util.cc | 1 + jaxlib/mosaic/dialect/tpu/util.h | 3 ++- jaxlib/mosaic/gpu/custom_call.cc | 1 + jaxlib/mosaic/gpu/mosaic_gpu_ext.cc | 4 ++-- jaxlib/mosaic/gpu/passes.cc | 4 ++++ jaxlib/rocm/BUILD | 10 +++++++- jaxlib/xla/BUILD | 10 ++++---- jaxlib/xla/callback.cc | 1 + jaxlib/xla/callback.h | 1 - jaxlib/xla/config.cc | 1 + jaxlib/xla/custom_call_sharding.cc | 1 + jaxlib/xla/dlpack.cc | 2 +- jaxlib/xla/ifrt_proxy.cc | 4 ++-- jaxlib/xla/jax_jit.h | 2 +- jaxlib/xla/mlir.cc | 4 ---- jaxlib/xla/pmap_lib.h | 3 --- jaxlib/xla/py_array.cc | 1 - jaxlib/xla/py_array.h | 1 + jaxlib/xla/py_client.cc | 2 -- jaxlib/xla/py_client.h | 1 - jaxlib/xla/py_device.h | 1 + jaxlib/xla/py_device_list.cc | 2 -- jaxlib/xla/py_device_list.h | 1 - jaxlib/xla/py_executable.h | 4 +--- jaxlib/xla/py_memory_space.h | 1 + jaxlib/xla/py_socket_transfer.cc | 2 ++ jaxlib/xla/python_ref_manager.cc | 2 ++ jaxlib/xla/sharded_device_array.h | 1 - jaxlib/xla/sharding.cc | 1 - jaxlib/xla/sharding.h | 3 ++- jaxlib/xla/to_ifrt_sharding.cc | 1 - jaxlib/xla/to_ifrt_sharding.h | 8 ++++++- jaxlib/xla/traceback.h | 3 ++- jaxlib/xla/xla.cc | 6 ++--- jaxlib/xla/xla_compiler.cc | 1 + 85 files changed, 161 insertions(+), 139 deletions(-) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 52c945482222..c8114b48835f 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -154,6 +154,7 @@ cc_library( features = ["-use_header_modules"], deps = [ "@com_google_absl//absl/base", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", ], ) diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index c47bc3c8126f..fac62c81dee7 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -160,6 +160,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", "@local_config_cuda//cuda:cuda_headers", "@xla//xla/ffi/api:ffi", "@xla//xla/service:custom_call_status", @@ -336,6 +337,7 @@ nanobind_extension( ":cuda_vendor", ":cusparse_kernels", "//jaxlib:absl_status_casters", + "//jaxlib:kernel_helpers", "//jaxlib:kernel_nanobind_helpers", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", @@ -343,11 +345,13 @@ nanobind_extension( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/hash", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@local_config_cuda//cuda:cuda_headers", "@nanobind", + "@xla//xla/service:custom_call_status", "@xla//xla/tsl/cuda:cudart", "@xla//xla/tsl/cuda:cusparse", "@xla//xla/tsl/python/lib/core:numpy", @@ -455,6 +459,7 @@ nanobind_extension( deps = [ ":cuda_gpu_kernel_helpers", ":cuda_prng_kernels", + ":cuda_vendor", "//jaxlib:kernel_nanobind_helpers", "@local_config_cuda//cuda:cuda_headers", "@nanobind", @@ -545,8 +550,8 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/synchronization", - "@tsl//tsl/platform:env", "@xla//xla/service:custom_call_status", "@xla//xla/stream_executor/cuda:cuda_asm_compiler", "@xla//xla/tsl/cuda:cudart", @@ -586,7 +591,9 @@ nanobind_extension( "//jaxlib:absl_status_casters", "//jaxlib:kernel_nanobind_helpers", "//jaxlib/gpu:triton_cc_proto", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@nanobind", ], ) diff --git a/jaxlib/cuda/versions_helpers.cc b/jaxlib/cuda/versions_helpers.cc index d42199d37467..508a92c326cb 100644 --- a/jaxlib/cuda/versions_helpers.cc +++ b/jaxlib/cuda/versions_helpers.cc @@ -16,6 +16,7 @@ limitations under the License. #include "jaxlib/cuda/versions_helpers.h" #include +#include #include #include "absl/base/dynamic_annotations.h" diff --git a/jaxlib/gpu/BUILD b/jaxlib/gpu/BUILD index 59c0ab8dc164..1fd2775ecf9a 100644 --- a/jaxlib/gpu/BUILD +++ b/jaxlib/gpu/BUILD @@ -133,6 +133,7 @@ cc_library( "@xla//xla/pjrt/c:pjrt_c_api_hdrs", "@xla//xla/pjrt/c:pjrt_c_api_helpers", "@xla//xla/pjrt/c:pjrt_c_api_triton_extension_hdrs", + "@xla//xla/tsl/platform:statusor", "@xla//xla/tsl/python/lib/core:numpy", ], ) diff --git a/jaxlib/gpu/blas.cc b/jaxlib/gpu/blas.cc index 4a58859016f1..cf391e07e31e 100644 --- a/jaxlib/gpu/blas.cc +++ b/jaxlib/gpu/blas.cc @@ -20,7 +20,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/str_format.h" #include "nanobind/nanobind.h" -#include "nanobind/stl/pair.h" +#include "nanobind/stl/pair.h" // IWYU pragma: keep #include "jaxlib/gpu/blas_kernels.h" #include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_nanobind_helpers.h" diff --git a/jaxlib/gpu/gpu_kernel_helpers.cc b/jaxlib/gpu/gpu_kernel_helpers.cc index 5a434f4b6ad5..5b509ad9912d 100644 --- a/jaxlib/gpu/gpu_kernel_helpers.cc +++ b/jaxlib/gpu/gpu_kernel_helpers.cc @@ -15,12 +15,15 @@ limitations under the License. #include "jaxlib/gpu/gpu_kernel_helpers.h" +#include +#include + #include "absl/base/optimization.h" #include "absl/log/check.h" -#include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "jaxlib/gpu/vendor.h" namespace jax { namespace JAX_GPU_NAMESPACE { diff --git a/jaxlib/gpu/gpu_kernel_helpers.h b/jaxlib/gpu/gpu_kernel_helpers.h index aecb8a4fdcf1..0326d7f44620 100644 --- a/jaxlib/gpu/gpu_kernel_helpers.h +++ b/jaxlib/gpu/gpu_kernel_helpers.h @@ -16,11 +16,10 @@ limitations under the License. #ifndef JAXLIB_GPU_GPU_KERNEL_HELPERS_H_ #define JAXLIB_GPU_GPU_KERNEL_HELPERS_H_ -#include +#include #include "absl/base/optimization.h" #include "absl/status/status.h" -#include "absl/status/statusor.h" #include "jaxlib/gpu/vendor.h" #define JAX_AS_STATUS(expr) \ diff --git a/jaxlib/gpu/gpu_plugin_extension.cc b/jaxlib/gpu/gpu_plugin_extension.cc index d026806e9479..cca615cfb260 100644 --- a/jaxlib/gpu/gpu_plugin_extension.cc +++ b/jaxlib/gpu/gpu_plugin_extension.cc @@ -35,6 +35,7 @@ limitations under the License. #include "xla/pjrt/c/pjrt_c_api_helpers.h" #include "xla/pjrt/c/pjrt_c_api_triton_extension.h" #include "xla/pjrt/status_casters.h" +#include "xla/tsl/platform/statusor.h" #include "xla/tsl/python/lib/core/numpy.h" #include "xla/util.h" diff --git a/jaxlib/gpu/make_batch_pointers.cu.cc b/jaxlib/gpu/make_batch_pointers.cu.cc index 3a24e355ead0..1d05fa8adcac 100644 --- a/jaxlib/gpu/make_batch_pointers.cu.cc +++ b/jaxlib/gpu/make_batch_pointers.cu.cc @@ -16,6 +16,7 @@ limitations under the License. #include "jaxlib/gpu/make_batch_pointers.h" #include +#include #include #include "jaxlib/gpu/vendor.h" diff --git a/jaxlib/gpu/prng.cc b/jaxlib/gpu/prng.cc index 1ce428d7f9dc..007e51b76de7 100644 --- a/jaxlib/gpu/prng.cc +++ b/jaxlib/gpu/prng.cc @@ -15,6 +15,7 @@ limitations under the License. #include "nanobind/nanobind.h" #include "jaxlib/gpu/prng_kernels.h" +#include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_nanobind_helpers.h" namespace jax { diff --git a/jaxlib/gpu/prng_kernels.cc b/jaxlib/gpu/prng_kernels.cc index f5d6abef83f8..1dac1e47bd44 100644 --- a/jaxlib/gpu/prng_kernels.cc +++ b/jaxlib/gpu/prng_kernels.cc @@ -17,16 +17,12 @@ limitations under the License. #include #include -#include #include "absl/algorithm/container.h" -#include "absl/status/status.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/vendor.h" #include "jaxlib/ffi_helpers.h" -#include "jaxlib/kernel_helpers.h" #include "xla/ffi/api/ffi.h" -#include "xla/service/custom_call_status.h" namespace jax { namespace JAX_GPU_NAMESPACE { diff --git a/jaxlib/gpu/prng_kernels.cu.cc b/jaxlib/gpu/prng_kernels.cu.cc index d4aaec62320d..e42165f95d15 100644 --- a/jaxlib/gpu/prng_kernels.cu.cc +++ b/jaxlib/gpu/prng_kernels.cu.cc @@ -15,8 +15,7 @@ limitations under the License. #include "jaxlib/gpu/prng_kernels.h" -#include -#include +#include #include #include "jaxlib/gpu/vendor.h" diff --git a/jaxlib/gpu/prng_kernels.h b/jaxlib/gpu/prng_kernels.h index c98fd485700d..4d64d2b4a4e4 100644 --- a/jaxlib/gpu/prng_kernels.h +++ b/jaxlib/gpu/prng_kernels.h @@ -16,12 +16,10 @@ limitations under the License. #ifndef JAXLIB_GPU_PRNG_KERNELS_H_ #define JAXLIB_GPU_PRNG_KERNELS_H_ -#include #include #include "jaxlib/gpu/vendor.h" #include "xla/ffi/api/ffi.h" -#include "xla/service/custom_call_status.h" namespace jax { namespace JAX_GPU_NAMESPACE { diff --git a/jaxlib/gpu/py_client_gpu.h b/jaxlib/gpu/py_client_gpu.h index 8c5404570919..4d48858ad278 100644 --- a/jaxlib/gpu/py_client_gpu.h +++ b/jaxlib/gpu/py_client_gpu.h @@ -16,7 +16,6 @@ limitations under the License. #ifndef JAX_JAXLIB_GPU_PY_CLIENT_GPU_H_ #define JAX_JAXLIB_GPU_PY_CLIENT_GPU_H_ -#include #include "jaxlib/gpu/vendor.h" #include "xla/ffi/api/ffi.h" diff --git a/jaxlib/gpu/rnn.cc b/jaxlib/gpu/rnn.cc index eaa815d33e68..32e0842e3038 100644 --- a/jaxlib/gpu/rnn.cc +++ b/jaxlib/gpu/rnn.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include "nanobind/nanobind.h" -#include "nanobind/stl/pair.h" +#include "nanobind/stl/pair.h" // IWYU pragma: keep #include "jaxlib/absl_status_casters.h" #include "jaxlib/gpu/rnn_kernels.h" #include "jaxlib/gpu/vendor.h" diff --git a/jaxlib/gpu/rnn_kernels.cc b/jaxlib/gpu/rnn_kernels.cc index 45f8ba8187ba..d06535a668ac 100644 --- a/jaxlib/gpu/rnn_kernels.cc +++ b/jaxlib/gpu/rnn_kernels.cc @@ -16,14 +16,19 @@ limitations under the License. #include "jaxlib/gpu/rnn_kernels.h" #include +#include +#include #include #include #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "absl/synchronization/mutex.h" #include "jaxlib/gpu/ffi_wrapper.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/handle_pool.h" +#include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_helpers.h" #include "xla/service/custom_call_status.h" diff --git a/jaxlib/gpu/rnn_kernels.h b/jaxlib/gpu/rnn_kernels.h index e95b7788382a..36d8c25c6a9f 100644 --- a/jaxlib/gpu/rnn_kernels.h +++ b/jaxlib/gpu/rnn_kernels.h @@ -17,6 +17,7 @@ limitations under the License. #define JAXLIB_GPU_RNN_KERNELS_H_ #include +#include #include "absl/status/statusor.h" #include "jaxlib/gpu/vendor.h" diff --git a/jaxlib/gpu/solver.cc b/jaxlib/gpu/solver.cc index 1cf799bbb491..20fc308100c4 100644 --- a/jaxlib/gpu/solver.cc +++ b/jaxlib/gpu/solver.cc @@ -21,7 +21,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "nanobind/nanobind.h" -#include "nanobind/stl/pair.h" +#include "nanobind/stl/pair.h" // IWYU pragma: keep #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/solver_handle_pool.h" #include "jaxlib/gpu/solver_kernels.h" diff --git a/jaxlib/gpu/sparse.cc b/jaxlib/gpu/sparse.cc index a7f8dbebc2b3..592c0f454a55 100644 --- a/jaxlib/gpu/sparse.cc +++ b/jaxlib/gpu/sparse.cc @@ -13,24 +13,23 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include +#include #include +#include #include -#include -#include "absl/base/casts.h" -#include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" #include "absl/strings/str_format.h" -#include "absl/synchronization/mutex.h" #include "nanobind/nanobind.h" -#include "nanobind/stl/pair.h" +#include "nanobind/stl/pair.h" // IWYU pragma: keep #include "jaxlib/absl_status_casters.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/sparse_kernels.h" #include "jaxlib/gpu/vendor.h" +#include "jaxlib/kernel_helpers.h" #include "jaxlib/kernel_nanobind_helpers.h" +#include "xla/service/custom_call_status.h" #include "xla/tsl/python/lib/core/numpy.h" namespace nb = nanobind; diff --git a/jaxlib/gpu/sparse_kernels.cc b/jaxlib/gpu/sparse_kernels.cc index c66e96b6b89e..a9c08317e066 100644 --- a/jaxlib/gpu/sparse_kernels.cc +++ b/jaxlib/gpu/sparse_kernels.cc @@ -15,11 +15,9 @@ limitations under the License. #include "jaxlib/gpu/sparse_kernels.h" -#include -#include -#include -#include -#include +#include +#include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -27,8 +25,8 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "jaxlib/gpu/ffi_wrapper.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" -#include "jaxlib/gpu/vendor.h" #include "jaxlib/gpu/handle_pool.h" +#include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_helpers.h" #include "xla/service/custom_call_status.h" diff --git a/jaxlib/gpu/sparse_kernels.h b/jaxlib/gpu/sparse_kernels.h index 0d74ebc7d8e4..d735c320307c 100644 --- a/jaxlib/gpu/sparse_kernels.h +++ b/jaxlib/gpu/sparse_kernels.h @@ -16,15 +16,12 @@ limitations under the License. #ifndef JAXLIB_GPU_SPARSE_KERNELS_H_ #define JAXLIB_GPU_SPARSE_KERNELS_H_ -#include +#include #include -#include -#include -#include #include "absl/status/statusor.h" -#include "jaxlib/gpu/vendor.h" #include "jaxlib/gpu/handle_pool.h" +#include "jaxlib/gpu/vendor.h" #include "xla/ffi/api/ffi.h" #include "xla/service/custom_call_status.h" diff --git a/jaxlib/gpu/triton.cc b/jaxlib/gpu/triton.cc index d0c48eef492f..b3f313e4f7ea 100644 --- a/jaxlib/gpu/triton.cc +++ b/jaxlib/gpu/triton.cc @@ -13,20 +13,23 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include -#include #include #include #include +#include #include +#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "nanobind/nanobind.h" -#include "nanobind/stl/pair.h" -#include "nanobind/stl/string.h" -#include "nanobind/stl/string_view.h" -#include "nanobind/stl/tuple.h" -#include "nanobind/stl/vector.h" +#include "nanobind/stl/pair.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/tuple.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep #include "jaxlib/absl_status_casters.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/triton.pb.h" diff --git a/jaxlib/gpu/triton_kernels.cc b/jaxlib/gpu/triton_kernels.cc index 6565b5b87be2..9e0dc6c855ac 100644 --- a/jaxlib/gpu/triton_kernels.cc +++ b/jaxlib/gpu/triton_kernels.cc @@ -40,6 +40,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/triton.pb.h" @@ -52,7 +53,8 @@ limitations under the License. #endif // JAX_GPU_CUDA #ifdef JAX_GPU_HIP -#include "tsl/platform/env.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" #endif // JAX_GPU_HIP #define GPU_RETURN_IF_ERROR(expr) JAX_RETURN_IF_ERROR(JAX_AS_STATUS(expr)) diff --git a/jaxlib/gpu/triton_kernels.h b/jaxlib/gpu/triton_kernels.h index d23a9a7395e0..3ab3e9143fb8 100644 --- a/jaxlib/gpu/triton_kernels.h +++ b/jaxlib/gpu/triton_kernels.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef JAXLIB_GPU_TRITON_H_ #define JAXLIB_GPU_TRITON_H_ +#include #include -#include #include #include #include @@ -25,7 +25,6 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/triton.pb.h" #include "jaxlib/gpu/vendor.h" #include "xla/service/custom_call_status.h" diff --git a/jaxlib/gpu/triton_utils.cc b/jaxlib/gpu/triton_utils.cc index f6bbe46c846d..fd63435da177 100644 --- a/jaxlib/gpu/triton_utils.cc +++ b/jaxlib/gpu/triton_utils.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/triton.pb.h" +#include "jaxlib/gpu/vendor.h" namespace jax::JAX_GPU_NAMESPACE { diff --git a/jaxlib/gpu/triton_utils.h b/jaxlib/gpu/triton_utils.h index 19c64a88c216..a79c098373d1 100644 --- a/jaxlib/gpu/triton_utils.h +++ b/jaxlib/gpu/triton_utils.h @@ -18,7 +18,6 @@ limitations under the License. #include -#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "jaxlib/gpu/vendor.h" diff --git a/jaxlib/gpu/vendor.h b/jaxlib/gpu/vendor.h index 58a02e7c568c..5deb8d4c650a 100644 --- a/jaxlib/gpu/vendor.h +++ b/jaxlib/gpu/vendor.h @@ -20,6 +20,7 @@ limitations under the License. #ifndef JAXLIB_GPU_VENDOR_H_ #define JAXLIB_GPU_VENDOR_H_ +#include #if defined(JAX_GPU_CUDA) // IWYU pragma: begin_exports diff --git a/jaxlib/kernel_helpers.h b/jaxlib/kernel_helpers.h index dac0355fbde6..5a053f833ce4 100644 --- a/jaxlib/kernel_helpers.h +++ b/jaxlib/kernel_helpers.h @@ -17,10 +17,10 @@ limitations under the License. #define JAXLIB_KERNEL_HELPERS_H_ #include -#include #include #include "absl/base/casts.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" namespace jax { diff --git a/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc b/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc index 0eb4a57a2f4b..b8432bf615c9 100644 --- a/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc +++ b/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc @@ -17,18 +17,19 @@ limitations under the License. // This module is called by mlir/__init__.py during initialization. #include -#include "mlir-c/Dialect/Arith.h" -#include "mlir-c/Dialect/Func.h" -#include "mlir-c/Dialect/GPU.h" -#include "mlir-c/Dialect/LLVM.h" -#include "mlir-c/Dialect/Math.h" -#include "mlir-c/Dialect/MemRef.h" -#include "mlir-c/Dialect/NVGPU.h" -#include "mlir-c/Dialect/NVVM.h" -#include "mlir-c/Dialect/SCF.h" -#include "mlir-c/Dialect/Vector.h" +#include "mlir-c/Dialect/Arith.h" // IWYU pragma: keep +#include "mlir-c/Dialect/Func.h" // IWYU pragma: keep +#include "mlir-c/Dialect/GPU.h" // IWYU pragma: keep +#include "mlir-c/Dialect/LLVM.h" // IWYU pragma: keep +#include "mlir-c/Dialect/Math.h" // IWYU pragma: keep +#include "mlir-c/Dialect/MemRef.h" // IWYU pragma: keep +#include "mlir-c/Dialect/NVGPU.h" // IWYU pragma: keep +#include "mlir-c/Dialect/NVVM.h" // IWYU pragma: keep +#include "mlir-c/Dialect/SCF.h" // IWYU pragma: keep +#include "mlir-c/Dialect/Vector.h" // IWYU pragma: keep +#include "mlir-c/IR.h" #include "mlir-c/Transforms.h" -#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" // IWYU pragma: keep #include "shardy/integrations/c/passes.h" #include "jaxlib/mosaic/gpu/integrations/c/passes.h" diff --git a/jaxlib/mlir/_mlir_libs/tpu_ext.cc b/jaxlib/mlir/_mlir_libs/tpu_ext.cc index 7d616968b9aa..8f751693e451 100644 --- a/jaxlib/mlir/_mlir_libs/tpu_ext.cc +++ b/jaxlib/mlir/_mlir_libs/tpu_ext.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -27,9 +26,8 @@ limitations under the License. #include #include -#include "llvm/ADT/ArrayRef.h" +#include "absl/types/span.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/SmallVectorExtras.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "mlir-c/AffineMap.h" @@ -41,7 +39,6 @@ limitations under the License. #include "mlir-c/Support.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" // IWYU pragma: keep // clang-format off -#include "mlir-c/Bindings/Python/Interop.h" // clang-format on #include "absl/log/check.h" #include "nanobind/nanobind.h" diff --git a/jaxlib/mlir/_mlir_libs/triton_ext.cc b/jaxlib/mlir/_mlir_libs/triton_ext.cc index 2a13c40d963f..7fba7e1dfe80 100644 --- a/jaxlib/mlir/_mlir_libs/triton_ext.cc +++ b/jaxlib/mlir/_mlir_libs/triton_ext.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include "mlir-c/IR.h" diff --git a/jaxlib/mosaic/BUILD b/jaxlib/mosaic/BUILD index 775c34c8e7c7..41584d7692aa 100644 --- a/jaxlib/mosaic/BUILD +++ b/jaxlib/mosaic/BUILD @@ -95,6 +95,7 @@ cc_library( "@xla//xla:shape_util", "@xla//xla:util", "@xla//xla/tsl/platform:errors", + "@xla//xla/tsl/platform:statusor", ] + mosaic_extension_deps, ) diff --git a/jaxlib/mosaic/dialect/gpu/BUILD b/jaxlib/mosaic/dialect/gpu/BUILD index e21c8756a4e2..f0e399da0575 100644 --- a/jaxlib/mosaic/dialect/gpu/BUILD +++ b/jaxlib/mosaic/dialect/gpu/BUILD @@ -127,7 +127,7 @@ cc_library( "@llvm-project//mlir:MemRefUtils", "@llvm-project//mlir:SCFUtils", "@llvm-project//mlir:Support", - "@tsl//tsl/platform:statusor", + "@xla//xla/tsl/platform:statusor", ], ) @@ -151,7 +151,7 @@ cc_test( "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:SCFUtils", "@llvm-project//mlir:Support", - "@tsl//tsl/platform:errors", + "@xla//xla/tsl/platform:errors", ], ) diff --git a/jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.h b/jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.h index bb6cf6e3af4a..5fd0ce7a4f7a 100644 --- a/jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.h +++ b/jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "mlir/CAPI/Registration.h" +#include "mlir-c/IR.h" #ifdef __cplusplus extern "C" { diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc index 2358a97ba20d..1b3d08f91fb0 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc @@ -50,7 +50,7 @@ limitations under the License. #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" #include "mlir/Support/LLVM.h" -#include "tsl/platform/statusor.h" +#include "xla/tsl/platform/statusor.h" // Generated definitions. #include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_dialect.cc.inc" diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.h b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.h index 47b286aec302..474ed93806a1 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.h +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.h @@ -21,14 +21,12 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/string_view.h" -#include "llvm/ADT/StringRef.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Value.h" -#include "mlir/Interfaces/InferTypeOpInterface.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" // IWYU pragma: keep #include "mlir/Support/LLVM.h" // Generated definitions. diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc b/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc index c259da3e737c..5458ba7fac88 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -25,7 +26,6 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" -#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" #include "mlir/Conversion/LLVMCommon/StructBuilder.h" @@ -44,7 +44,7 @@ limitations under the License. #include "mlir/IR/Verifier.h" #include "mlir/Interfaces/DataLayoutInterfaces.h" #include "mlir/Support/LLVM.h" -#include "tsl/platform/errors.h" +#include "xla/tsl/platform/errors.h" namespace mosaic_gpu { namespace { diff --git a/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc b/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc index ce7e90d45fb9..dee4f5de43d8 100644 --- a/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc +++ b/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc @@ -21,10 +21,13 @@ limitations under the License. #include #include #include +#include #include "absl/log/check.h" #include "absl/log/log.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/LogicalResult.h" #include "llvm/Support/MemAlloc.h" #include "llvm/Support/raw_ostream.h" #include "mlir-c/IR.h" @@ -33,11 +36,11 @@ limitations under the License. #include "mlir/CAPI/Registration.h" #include "mlir/CAPI/Utils.h" #include "mlir/CAPI/Wrap.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" diff --git a/jaxlib/mosaic/dialect/tpu/layout.cc b/jaxlib/mosaic/dialect/tpu/layout.cc index c54c99fc9825..7ae8681e6980 100644 --- a/jaxlib/mosaic/dialect/tpu/layout.cc +++ b/jaxlib/mosaic/dialect/tpu/layout.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include #include #include #include diff --git a/jaxlib/mosaic/dialect/tpu/layout.h b/jaxlib/mosaic/dialect/tpu/layout.h index bcfe205d58a9..12bf66cfcec0 100644 --- a/jaxlib/mosaic/dialect/tpu/layout.h +++ b/jaxlib/mosaic/dialect/tpu/layout.h @@ -18,16 +18,16 @@ limitations under the License. #include #include +#include #include -#include #include #include #include #include "absl/log/check.h" +#include "llvm/ADT/Hashing.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/bit.h" -#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/Builders.h" diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc index 73c119b70e1a..e0e061fbd6dd 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc @@ -15,27 +15,23 @@ limitations under the License. #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" -#include -#include #include -#include #include -#include -#include #include "absl/hash/hash.h" #include "absl/log/log.h" +#include "llvm/ADT/Hashing.h" #include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep. #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/DialectImplementation.h" // IWYU pragma: keep. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" +#include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.cc.inc" #include "jaxlib/mosaic/dialect/tpu/tpu_enums.cc.inc" #include "xla/layout.h" diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h index cf74689dd3e6..2afaf08f29ed 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h @@ -23,15 +23,14 @@ limitations under the License. #include #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LogicalResult.h" -#include "jaxlib/mosaic/dialect/tpu/layout.h" +#include "jaxlib/mosaic/dialect/tpu/layout.h" // IWYU pragma: keep #include "jaxlib/mosaic/dialect/tpu/tpu_enums.h.inc" -#include "jaxlib/mosaic/dialect/tpu/transforms/serde.h" -#include "xla/layout.h" +#include "xla/layout.h" // IWYU pragma: keep namespace mlir::tpu { class TPUDialect; @@ -63,11 +62,11 @@ struct ApplyVectorLayoutContext { // mxu_shape = {contracting_size, non_contracting_size} std::array mxu_shape = {128, 128}; int64_t max_sublanes_in_scratch = 0; - int64_t vmem_banks = -1; // -1 means "unspecified". + int64_t vmem_banks = -1; // -1 means "unspecified". int32_t max_shuffle_sublane_offset = -1; // -1 means "unspecified". }; -std::pair mightCommunicateBetweenChips(Operation* op); +std::pair mightCommunicateBetweenChips(Operation *op); std::unique_ptr> createInferMemRefLayoutPass( int hardware_generation = -1, diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index b69a6ae06a7f..41342efeb1b4 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -39,6 +39,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/dialect/tpu/util.h" +#include "xla/layout.h" namespace mlir { namespace tpu { diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index c9c8d22a1363..e68d5da466eb 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -39,7 +39,6 @@ limitations under the License. #include "llvm/ADT/SmallVectorExtras.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/iterator_range.h" -#include "llvm/Support/Compiler.h" #include "llvm/Support/LogicalResult.h" #include "llvm/Support/MathExtras.h" #include "mlir/Dialect/Arith/IR/Arith.h" diff --git a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc index a15947f48f78..373a5db6b4f6 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc @@ -22,26 +22,23 @@ limitations under the License. #include #include -#include "llvm/ADT/STLExtras.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -// It requires these headers, but does not include them. -// NOLINTNEXTLINE(misc-include-cleaner) -#include "mlir/Dialect/MemRef/IR/MemRef.h" -// NOLINTNEXTLINE(misc-include-cleaner) #include "absl/log/check.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringMap.h" -#include "llvm/ADT/StringSet.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Math/IR/Math.h" -#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" // IWYU pragma: keep +#include "mlir/Dialect/SCF/IR/SCF.h" // IWYU pragma: keep #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/OpDefinition.h" diff --git a/jaxlib/mosaic/dialect/tpu/transforms/communication.cc b/jaxlib/mosaic/dialect/tpu/transforms/communication.cc index 89e3a8bb9f70..7e99dd15611b 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/communication.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/communication.cc @@ -17,13 +17,16 @@ limitations under the License. #include #include +#include "absl/log/check.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Diagnostics.h" +#include "mlir/IR/ValueRange.h" #include "mlir/IR/Visitors.h" -#include "mlir/Support/LogicalResult.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "xla/layout.h" diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc index e2196088728f..bfb9be87dfd0 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc @@ -30,7 +30,6 @@ limitations under the License. #include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Location.h" diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index c1a642b48f04..54ac777fc205 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -19,8 +19,6 @@ limitations under the License. #include #include #include -#include -#include #include #include "absl/log/check.h" @@ -28,7 +26,6 @@ limitations under the License. #include "absl/types/span.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVectorExtras.h" -#include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" diff --git a/jaxlib/mosaic/dialect/tpu/transforms/memory_space_specialization.cc b/jaxlib/mosaic/dialect/tpu/transforms/memory_space_specialization.cc index f78df135a45a..1cfb797c5478 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/memory_space_specialization.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/memory_space_specialization.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "absl/log/check.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinTypes.h" diff --git a/jaxlib/mosaic/dialect/tpu/transforms/serde.h b/jaxlib/mosaic/dialect/tpu/transforms/serde.h index 5da8a9e316e0..e5617ef151f7 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/serde.h +++ b/jaxlib/mosaic/dialect/tpu/transforms/serde.h @@ -21,6 +21,7 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/Support/CommandLine.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/Interfaces/DataLayoutInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassRegistry.h" diff --git a/jaxlib/mosaic/dialect/tpu/util.cc b/jaxlib/mosaic/dialect/tpu/util.cc index e61d9fa8d417..141f52ec125b 100644 --- a/jaxlib/mosaic/dialect/tpu/util.cc +++ b/jaxlib/mosaic/dialect/tpu/util.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/types/span.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include "mlir/IR/Attributes.h" diff --git a/jaxlib/mosaic/dialect/tpu/util.h b/jaxlib/mosaic/dialect/tpu/util.h index e2cf27811f09..eed0df14f707 100644 --- a/jaxlib/mosaic/dialect/tpu/util.h +++ b/jaxlib/mosaic/dialect/tpu/util.h @@ -27,6 +27,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/types/span.h" +#include "llvm/Support/Compiler.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" @@ -38,7 +39,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" #include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" -#include "tsl/platform/statusor.h" +#include "xla/tsl/platform/statusor.h" // TODO: Instead of CHECK_EQs, can we do something like TF_RET_CHECK but with // MLIR diagnostics? diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index d4f4d1732b2e..d9a69c57e142 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include #include #include diff --git a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc index ee11b22020dc..decdbaef28e1 100644 --- a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc +++ b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc @@ -25,8 +25,8 @@ limitations under the License. #include "absl/cleanup/cleanup.h" #include "absl/strings/str_cat.h" #include "nanobind/nanobind.h" -#include "nanobind/stl/tuple.h" -#include "nanobind/stl/vector.h" +#include "nanobind/stl/tuple.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep #include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_nanobind_helpers.h" #include "xla/ffi/api/c_api.h" diff --git a/jaxlib/mosaic/gpu/passes.cc b/jaxlib/mosaic/gpu/passes.cc index 1705405d2f32..9fa6f8df78a8 100644 --- a/jaxlib/mosaic/gpu/passes.cc +++ b/jaxlib/mosaic/gpu/passes.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "jaxlib/mosaic/gpu/passes.h" + #include #include #include @@ -24,10 +25,13 @@ limitations under the License. #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/SymbolTable.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/Visitors.h" #include "mlir/Pass/PassRegistry.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/DialectConversion.h" diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index 2c13228d3c51..d0c0c798abb8 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -148,6 +148,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", "@local_config_rocm//rocm:miopen", "@local_config_rocm//rocm:rocm_headers", "@xla//xla/ffi/api:ffi", @@ -318,6 +319,7 @@ nanobind_extension( ":hip_vendor", ":hipsparse_kernels", "//jaxlib:absl_status_casters", + "//jaxlib:kernel_helpers", "//jaxlib:kernel_nanobind_helpers", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", @@ -325,12 +327,14 @@ nanobind_extension( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/hash", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@local_config_rocm//rocm:hipsparse", "@local_config_rocm//rocm:rocm_headers", "@nanobind", + "@xla//xla/service:custom_call_status", "@xla//xla/tsl/python/lib/core:numpy", ], ) @@ -496,9 +500,11 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/synchronization", - "@tsl//tsl/platform:env", "@xla//xla/service:custom_call_status", + "@xla//xla/tsl/platform:env", + "@xla//xla/tsl/platform:errors", "@xla//xla/tsl/util:env_var", ], ) @@ -536,7 +542,9 @@ nanobind_extension( "//jaxlib:absl_status_casters", "//jaxlib:kernel_nanobind_helpers", "//jaxlib/gpu:triton_cc_proto", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@nanobind", ], ) diff --git a/jaxlib/xla/BUILD b/jaxlib/xla/BUILD index 347da6998b57..2c2a76f29f9b 100644 --- a/jaxlib/xla/BUILD +++ b/jaxlib/xla/BUILD @@ -66,6 +66,7 @@ nanobind_extension( ":xla_compiler", "@com_google_absl//absl/base", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/hash", "@com_google_absl//absl/log:initialize", "@com_google_absl//absl/status", @@ -164,7 +165,6 @@ cc_library( features = ["-use_header_modules"], deps = [ ":python_ref_manager", - "@com_google_absl//absl/base", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -178,8 +178,8 @@ cc_library( "@xla//xla/pjrt:host_callback", "@xla//xla/pjrt:transpose", "@xla//xla/python:nb_numpy", - "@xla//xla/service:custom_call_status", "@xla//xla/tsl/platform:statusor", + "@xla//xla/tsl/python/lib/core:numpy", ], ) @@ -319,13 +319,13 @@ cc_library( "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/time", "@nanobind", - "@tsl//tsl/platform:env", - "@tsl//tsl/platform:statusor", "@xla//xla/pjrt:status_casters", "@xla//xla/python/ifrt", "@xla//xla/python/ifrt:attribute_map", "@xla//xla/python/ifrt_proxy/client:grpc_client", "@xla//xla/python/ifrt_proxy/client:registry", + "@xla//xla/tsl/platform:env", + "@xla//xla/tsl/platform:statusor", ], ) @@ -752,7 +752,9 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", + "@llvm-project//llvm:Support", "@nanobind", + "@tsl//tsl/platform:casts", "@xla//xla:util", "@xla//xla/pjrt:pjrt_client", "@xla//xla/pjrt:status_casters", diff --git a/jaxlib/xla/callback.cc b/jaxlib/xla/callback.cc index bb238e6991ec..6f5644c3b0c7 100644 --- a/jaxlib/xla/callback.cc +++ b/jaxlib/xla/callback.cc @@ -40,6 +40,7 @@ limitations under the License. #include "xla/primitive_util.h" #include "xla/python/nb_numpy.h" #include "xla/tsl/platform/statusor.h" +#include "xla/tsl/python/lib/core/numpy.h" namespace nb = nanobind; diff --git a/jaxlib/xla/callback.h b/jaxlib/xla/callback.h index ebd0aaca4e6d..ee1f35ce34a3 100644 --- a/jaxlib/xla/callback.h +++ b/jaxlib/xla/callback.h @@ -18,7 +18,6 @@ limitations under the License. #include #include -#include #include #include diff --git a/jaxlib/xla/config.cc b/jaxlib/xla/config.cc index 82f0bd0b0f5a..c68ff7f4ac54 100644 --- a/jaxlib/xla/config.cc +++ b/jaxlib/xla/config.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_set.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" diff --git a/jaxlib/xla/custom_call_sharding.cc b/jaxlib/xla/custom_call_sharding.cc index f88bc93e3af3..3cb53b438e09 100644 --- a/jaxlib/xla/custom_call_sharding.cc +++ b/jaxlib/xla/custom_call_sharding.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ + #include "jaxlib/xla/custom_call_sharding.h" #include diff --git a/jaxlib/xla/dlpack.cc b/jaxlib/xla/dlpack.cc index 6c4c24bfe10e..d1cb91114b05 100644 --- a/jaxlib/xla/dlpack.cc +++ b/jaxlib/xla/dlpack.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "include/dlpack/dlpack.h" #include "llvm/Support/Casting.h" @@ -45,7 +46,6 @@ limitations under the License. #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_common.h" #include "xla/pjrt/pjrt_compiler.h" -#include "xla/pjrt/pjrt_layout.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/device.h" #include "xla/python/pjrt_ifrt/pjrt_array.h" diff --git a/jaxlib/xla/ifrt_proxy.cc b/jaxlib/xla/ifrt_proxy.cc index eda57be86ba5..a89941f8581c 100644 --- a/jaxlib/xla/ifrt_proxy.cc +++ b/jaxlib/xla/ifrt_proxy.cc @@ -42,8 +42,8 @@ #include "xla/python/ifrt/attribute_map.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt_proxy/client/registry.h" -#include "tsl/platform/env.h" -#include "tsl/platform/statusor.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/statusor.h" namespace nb = ::nanobind; diff --git a/jaxlib/xla/jax_jit.h b/jaxlib/xla/jax_jit.h index a2e6d725f3b0..9eba2e9d3228 100644 --- a/jaxlib/xla/jax_jit.h +++ b/jaxlib/xla/jax_jit.h @@ -40,7 +40,7 @@ limitations under the License. #include "jaxlib/xla/pytree.h" #include "jaxlib/xla/sharding.h" #include "xla/pjrt/pjrt_client.h" -#include "xla/python/nb_helpers.h" +#include "xla/pjrt/pjrt_layout.h" #include "xla/tsl/platform/logging.h" namespace jax { diff --git a/jaxlib/xla/mlir.cc b/jaxlib/xla/mlir.cc index 5905c6c6ec8d..987856daa983 100644 --- a/jaxlib/xla/mlir.cc +++ b/jaxlib/xla/mlir.cc @@ -24,10 +24,7 @@ limitations under the License. #include "llvm/Support/raw_ostream.h" #include "mlir/Bytecode/BytecodeWriter.h" #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" -#include "mlir/Dialect/Func/Extensions/AllExtensions.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/OwningOpRef.h" @@ -44,7 +41,6 @@ limitations under the License. #include "xla/pjrt/mlir_to_hlo.h" #include "xla/pjrt/status_casters.h" #include "xla/python/refine_polymorphic_shapes.h" -#include "xla/service/llvm_ir/llvm_util.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/logging.h" #include "xla/tsl/platform/statusor.h" diff --git a/jaxlib/xla/pmap_lib.h b/jaxlib/xla/pmap_lib.h index e02311e03c73..2bad85e59671 100644 --- a/jaxlib/xla/pmap_lib.h +++ b/jaxlib/xla/pmap_lib.h @@ -16,9 +16,6 @@ limitations under the License. #ifndef JAXLIB_XLA_PMAP_LIB_H_ #define JAXLIB_XLA_PMAP_LIB_H_ -#include -#include -#include // placeholder for index annotation headers #include "nanobind/nanobind.h" diff --git a/jaxlib/xla/py_array.cc b/jaxlib/xla/py_array.cc index 1325f0cbd2bc..a1937bc80327 100644 --- a/jaxlib/xla/py_array.cc +++ b/jaxlib/xla/py_array.cc @@ -75,7 +75,6 @@ limitations under the License. #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_compiler.h" #include "xla/pjrt/pjrt_future.h" -#include "xla/pjrt/pjrt_layout.h" #include "xla/pjrt/status_casters.h" #include "xla/primitive_util.h" #include "xla/python/ifrt/array.h" diff --git a/jaxlib/xla/py_array.h b/jaxlib/xla/py_array.h index 645f51096c1d..7fa2434c7c9f 100644 --- a/jaxlib/xla/py_array.h +++ b/jaxlib/xla/py_array.h @@ -30,6 +30,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/cord.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/Support/Casting.h" #include "nanobind/nanobind.h" diff --git a/jaxlib/xla/py_client.cc b/jaxlib/xla/py_client.cc index 795a4fee29fa..1e41d9cf8a0d 100644 --- a/jaxlib/xla/py_client.cc +++ b/jaxlib/xla/py_client.cc @@ -48,7 +48,6 @@ limitations under the License. #include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep #include "nanobind/stl/variant.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep -#include "jaxlib/xla/callback.h" #include "jaxlib/xla/guard_lib.h" #include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_array.h" @@ -83,7 +82,6 @@ limitations under the License. #include "xla/python/pjrt_ifrt/xla_compiler.h" #include "xla/python/pprof_profile_builder.h" #include "xla/python/types.h" -#include "xla/service/custom_call_target_registry.h" #include "xla/service/platform_util.h" // IWYU pragma: keep #include "xla/shape.h" #include "xla/status_macros.h" diff --git a/jaxlib/xla/py_client.h b/jaxlib/xla/py_client.h index 898a40141307..29a506d48864 100644 --- a/jaxlib/xla/py_client.h +++ b/jaxlib/xla/py_client.h @@ -24,7 +24,6 @@ limitations under the License. #include #include #include -#include #include #include "absl/container/flat_hash_map.h" diff --git a/jaxlib/xla/py_device.h b/jaxlib/xla/py_device.h index 6071ede52305..4e74992fb2ee 100644 --- a/jaxlib/xla/py_device.h +++ b/jaxlib/xla/py_device.h @@ -23,6 +23,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "nanobind/nanobind.h" #include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_client.h" diff --git a/jaxlib/xla/py_device_list.cc b/jaxlib/xla/py_device_list.cc index 300e477dbbd0..205c971b9317 100644 --- a/jaxlib/xla/py_device_list.cc +++ b/jaxlib/xla/py_device_list.cc @@ -38,9 +38,7 @@ limitations under the License. #include "jaxlib/xla/python_ref_manager.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/device_list.h" -#include "xla/python/nb_helpers.h" #include "xla/python/types.h" -#include "xla/tsl/concurrency/ref_count.h" #include "xla/util.h" namespace jax { diff --git a/jaxlib/xla/py_device_list.h b/jaxlib/xla/py_device_list.h index 1d0f64003f8c..0fa9b3965dfe 100644 --- a/jaxlib/xla/py_device_list.h +++ b/jaxlib/xla/py_device_list.h @@ -26,7 +26,6 @@ limitations under the License. #include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_client.h" #include "xla/python/ifrt/device_list.h" -#include "xla/tsl/concurrency/ref_count.h" namespace jax { diff --git a/jaxlib/xla/py_executable.h b/jaxlib/xla/py_executable.h index 688eb779df8d..804682db717e 100644 --- a/jaxlib/xla/py_executable.h +++ b/jaxlib/xla/py_executable.h @@ -26,9 +26,9 @@ limitations under the License. #include #include -#include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/Support/Casting.h" #include "nanobind/nanobind.h" @@ -37,10 +37,8 @@ limitations under the License. #include "jaxlib/xla/py_client.h" #include "jaxlib/xla/traceback.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/layout.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/pjrt_client.h" -#include "xla/pjrt/pjrt_common.h" #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/pjrt_future.h" #include "xla/pjrt/pjrt_layout.h" diff --git a/jaxlib/xla/py_memory_space.h b/jaxlib/xla/py_memory_space.h index f111263497fb..f38038af4870 100644 --- a/jaxlib/xla/py_memory_space.h +++ b/jaxlib/xla/py_memory_space.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/strings/string_view.h" #include "nanobind/nanobind.h" #include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_client.h" diff --git a/jaxlib/xla/py_socket_transfer.cc b/jaxlib/xla/py_socket_transfer.cc index 05397cdf116f..55d84fd71bb7 100644 --- a/jaxlib/xla/py_socket_transfer.cc +++ b/jaxlib/xla/py_socket_transfer.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/synchronization/mutex.h" +#include "llvm/Support/Casting.h" #include "nanobind/nanobind.h" #include "nanobind/stl/array.h" // IWYU pragma: keep #include "nanobind/stl/string.h" // IWYU pragma: keep @@ -61,6 +62,7 @@ limitations under the License. #include "xla/tsl/concurrency/ref_count.h" #include "xla/tsl/platform/statusor.h" #include "xla/util.h" +#include "tsl/platform/casts.h" namespace aux { diff --git a/jaxlib/xla/python_ref_manager.cc b/jaxlib/xla/python_ref_manager.cc index a19622d94244..5b85d2ab84cb 100644 --- a/jaxlib/xla/python_ref_manager.cc +++ b/jaxlib/xla/python_ref_manager.cc @@ -15,6 +15,8 @@ limitations under the License. #include "jaxlib/xla/python_ref_manager.h" +#include + #include #include #include diff --git a/jaxlib/xla/sharded_device_array.h b/jaxlib/xla/sharded_device_array.h index 1b0ca20aa1fc..6e014789a289 100644 --- a/jaxlib/xla/sharded_device_array.h +++ b/jaxlib/xla/sharded_device_array.h @@ -20,7 +20,6 @@ limitations under the License. #include #include -#include "absl/types/variant.h" #include "nanobind/nanobind.h" #include "nanobind/stl/variant.h" // IWYU pragma: keep #include "xla/python/types.h" diff --git a/jaxlib/xla/sharding.cc b/jaxlib/xla/sharding.cc index 409dddb62268..5a80c03e01da 100644 --- a/jaxlib/xla/sharding.cc +++ b/jaxlib/xla/sharding.cc @@ -38,7 +38,6 @@ limitations under the License. #include "xla/pjrt/status_casters.h" #include "xla/python/ifrt/device_list.h" #include "xla/python/nb_numpy.h" -#include "xla/tsl/concurrency/ref_count.h" #include "xla/tsl/platform/logging.h" #include "xla/tsl/platform/statusor.h" #include "xla/xla_data.pb.h" diff --git a/jaxlib/xla/sharding.h b/jaxlib/xla/sharding.h index dac18a4160b5..698ff2ca9ca8 100644 --- a/jaxlib/xla/sharding.h +++ b/jaxlib/xla/sharding.h @@ -30,8 +30,9 @@ limitations under the License. #include "jaxlib/xla/sharded_device_array.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/pjrt/status_casters.h" -#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" #include "xla/python/nb_numpy.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" namespace jax { diff --git a/jaxlib/xla/to_ifrt_sharding.cc b/jaxlib/xla/to_ifrt_sharding.cc index 116ead49ad23..2a7c6707e766 100644 --- a/jaxlib/xla/to_ifrt_sharding.cc +++ b/jaxlib/xla/to_ifrt_sharding.cc @@ -37,7 +37,6 @@ limitations under the License. #include "xla/python/pjrt_ifrt/xla_sharding.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/tsl/concurrency/ref_count.h" #include "xla/tsl/platform/statusor.h" namespace xla { diff --git a/jaxlib/xla/to_ifrt_sharding.h b/jaxlib/xla/to_ifrt_sharding.h index 0fa7f17c4563..ebc999888297 100644 --- a/jaxlib/xla/to_ifrt_sharding.h +++ b/jaxlib/xla/to_ifrt_sharding.h @@ -16,12 +16,18 @@ limitations under the License. #ifndef JAXLIB_XLA_TO_IFRT_SHARDING_H_ #define JAXLIB_XLA_TO_IFRT_SHARDING_H_ +#include +#include +#include + +#include "absl/status/statusor.h" #include "nanobind/nanobind.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/python/ifrt/device_list.h" #include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/sharding.h" -#include "xla/tsl/platform/statusor.h" namespace xla { diff --git a/jaxlib/xla/traceback.h b/jaxlib/xla/traceback.h index 953d626439c4..685ecc5f8793 100644 --- a/jaxlib/xla/traceback.h +++ b/jaxlib/xla/traceback.h @@ -16,7 +16,8 @@ limitations under the License. #ifndef JAXLIB_XLA_TRACEBACK_H_ #define JAXLIB_XLA_TRACEBACK_H_ -#include +#include + #include #include #include diff --git a/jaxlib/xla/xla.cc b/jaxlib/xla/xla.cc index 668c96869479..e460a1773e94 100644 --- a/jaxlib/xla/xla.cc +++ b/jaxlib/xla/xla.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/base/casts.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" #include "absl/hash/hash.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -61,12 +62,12 @@ limitations under the License. #include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h" #include "xla/pjrt/status_casters.h" #include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device.h" #include "xla/python/ifrt/device_list.h" #include "xla/python/ifrt/executable.h" #include "xla/python/ifrt/topology.h" -#include "xla/python/version.h" #include "xla/python/pjrt_ifrt/pjrt_attribute_map_util.h" -#include "xla/tsl/concurrency/ref_count.h" +#include "xla/python/version.h" #include "xla/tsl/python/lib/core/numpy.h" // NOLINT #if defined(__linux__) @@ -119,7 +120,6 @@ limitations under the License. #include "xla/python/nb_absl_span.h" // IWYU pragma: keep #include "xla/python/ops.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" -#include "xla/python/pjrt_ifrt/pjrt_executable.h" #include "xla/python/pjrt_ifrt/pjrt_topology.h" #include "xla/python/pprof_profile_builder.h" #include "xla/python/profiler.h" diff --git a/jaxlib/xla/xla_compiler.cc b/jaxlib/xla/xla_compiler.cc index 0098cc28160d..bea3062c64e4 100644 --- a/jaxlib/xla/xla_compiler.cc +++ b/jaxlib/xla/xla_compiler.cc @@ -56,6 +56,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_module_group.h" +#include "xla/hlo/ir/hlo_print_options.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/pass/hlo_pass_interface.h" From d8fc40f121d59019a31e867b8b1a97a837c15414 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Thu, 27 Mar 2025 18:51:48 +0000 Subject: [PATCH 230/483] allow saved_input_vjp functions to be jit inputs/outputs --- jax/_src/api.py | 4 ++-- tests/api_test.py | 19 +++++++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 4626714b5399..692f049b5f0c 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2053,8 +2053,8 @@ def saved_input_vjp(f: Callable, which: Sequence[bool], *primals, res_spec = [RSpec(id_map[id(r)], True) if id(r) in id_map else RSpec(opaque_residuals.append(r) or (len(opaque_residuals) - 1), False) # type: ignore for r in residuals] - f_vjp = Partial(_saved_input_vjpfun, res_spec, filt_tree, in_tree, out_tree(), - jaxpr, opaque_residuals) + f_vjp = Partial(partial(_saved_input_vjpfun, res_spec, filt_tree, in_tree, + out_tree(), jaxpr), opaque_residuals) if not allow_unused and not set(id_map).issubset(res_ids := {id(r) for r in residuals}): unused = [(i, core.get_aval(x)) for i, (x, w) in enumerate(zip(primals, which)) diff --git a/tests/api_test.py b/tests/api_test.py index 9d80b5fbed74..6a970051d56e 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -11520,6 +11520,25 @@ def f(x, y): self.assertAllClose(y, 6.) self.assertAllClose(arg_cts, (3., 2.)) + def test_basic_pass_through_jit(self): + def f(x, y): + return x * y + + @jax.jit + def g(): + primals = 2., 3. + y, f_vjp = api.si_vjp(f, [True, True], *primals) + return y, f_vjp + + @jax.jit + def h(f_vjp): + return f_vjp(1., 2., 3.) + + y, f_vjp = g() + arg_cts = h(f_vjp) + self.assertAllClose(y, 6.) + self.assertAllClose(arg_cts, (3., 2.)) + def test_basic_unused(self): f = jnp.sin primals = 3., From b02b1fe09267a5d4e7819763d6f9303d9d3a35e8 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 27 Mar 2025 12:49:31 -0700 Subject: [PATCH 231/483] Update Windows bazelrc configs to ltsc2022 PiperOrigin-RevId: 741249289 --- .bazelrc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.bazelrc b/.bazelrc index 76f72b0848a9..2d38dcc87044 100644 --- a/.bazelrc +++ b/.bazelrc @@ -260,8 +260,8 @@ build:ci_darwin_arm64 --color=yes # Windows x86 CI configs build:ci_windows_amd64 --config=avx_windows build:ci_windows_amd64 --compiler=clang-cl --config=clang --verbose_failures=true -build:ci_windows_amd64 --crosstool_top="@xla//tools/toolchains/win/20240424:toolchain" -build:ci_windows_amd64 --extra_toolchains="@xla//tools/toolchains/win/20240424:cc-toolchain-x64_windows-clang-cl" +build:ci_windows_amd64 --crosstool_top="@xla//tools/toolchains/win2022/20241118:toolchain" +build:ci_windows_amd64 --extra_toolchains="@xla//tools/toolchains/win2022/20241118:cc-toolchain-x64_windows-clang-cl" build:ci_windows_amd64 --host_linkopt=/FORCE:MULTIPLE --linkopt=/FORCE:MULTIPLE build:ci_windows_amd64 --color=yes @@ -329,9 +329,9 @@ common:rbe_windows_amd64 --remote_instance_name=projects/tensorflow-testing/inst build:rbe_windows_amd64 --config=rbe # Set the host, execution, and target platform -build:rbe_windows_amd64 --host_platform="@xla//tools/toolchains/win:x64_windows-clang-cl" -build:rbe_windows_amd64 --extra_execution_platforms="@xla//tools/toolchains/win:x64_windows-clang-cl" -build:rbe_windows_amd64 --platforms="@xla//tools/toolchains/win:x64_windows-clang-cl" +build:rbe_windows_amd64 --host_platform="@xla//tools/toolchains/win2022:windows_ltsc2022_clang" +build:rbe_windows_amd64 --extra_execution_platforms="@xla//tools/toolchains/win2022:windows_ltsc2022_clang" +build:rbe_windows_amd64 --platforms="@xla//tools/toolchains/win2022:windows_ltsc2022_clang" build:rbe_windows_amd64 --shell_executable=C:\\tools\\msys64\\usr\\bin\\bash.exe build:rbe_windows_amd64 --enable_runfiles From 358c55d06650fc9cea39943f6e91f78219d52eb8 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 27 Mar 2025 12:50:44 -0700 Subject: [PATCH 232/483] Update instructions for usage of `:build_jaxlib=false` flag. By adding [jax wheel testing](https://github.com/jax-ml/jax/pull/27113) functionality, we need to have pre-built jax and jaxlib wheels. PiperOrigin-RevId: 741249718 --- build/build.py | 5 ++++- docs/developer.md | 16 ++++++++++------ 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/build/build.py b/build/build.py index 4d16851f837c..1900073fc132 100755 --- a/build/build.py +++ b/build/build.py @@ -414,7 +414,10 @@ async def main(): for option in args.bazel_startup_options: bazel_command_base.append(option) - if not args.use_new_wheel_build_rule or args.command == "requirements_update": + if ( + not hasattr(args,"use_new_wheel_build_rule") + or args.command == "requirements_update" + ): bazel_command_base.append("run") else: bazel_command_base.append("build") diff --git a/docs/developer.md b/docs/developer.md index 0affbba9ed36..b1a978ffd0d6 100644 --- a/docs/developer.md +++ b/docs/developer.md @@ -1,7 +1,7 @@ (building-from-source)= # Building from source - + First, obtain the JAX source code: @@ -526,23 +526,27 @@ bazel test //tests:cpu_tests //tests:backend_independent_tests `//tests:gpu_tests` and `//tests:tpu_tests` are also available, if you have the necessary hardware. -To use a preinstalled `jaxlib` instead of building it you first need to -make it available in the hermetic Python. To install a specific version of -`jaxlib` within hermetic Python run (using `jaxlib >= 0.4.26` as an example): +To use the preinstalled `jax` and `jaxlib` instead of building them you first +need to make them available in the hermetic Python. To install the specific +versions of `jax` and `jaxlib` within hermetic Python run (using `jax >= 0.4.26` +and `jaxlib >= 0.4.26` as an example): ``` +echo -e "\njax >= 0.4.26" >> build/requirements.in echo -e "\njaxlib >= 0.4.26" >> build/requirements.in python build/build.py requirements_update ``` -Alternatively, to install `jaxlib` from a local wheel (assuming Python 3.12): +Alternatively, to install `jax` and `jaxlib` from the local wheels +(assuming Python 3.12): ``` +echo -e "\n$(realpath jax-0.4.26-py3-none-any.whl)" >> build/requirements.in echo -e "\n$(realpath jaxlib-0.4.26-cp312-cp312-manylinux2014_x86_64.whl)" >> build/requirements.in python build/build.py requirements_update --python_version=3.12 ``` -Once you have `jaxlib` installed hermetically, run: +Once you have `jax` and `jaxlib` installed hermetically, run: ``` bazel test --//jax:build_jaxlib=false //tests:cpu_tests //tests:backend_independent_tests From b290c132dd5b28e11e4d17f495b91a9bc8e88eac Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Thu, 27 Mar 2025 13:03:42 -0700 Subject: [PATCH 233/483] [jax:custom_partitioning] Raise an error when Shardy is used but the old sharding propagation callbacks instead of sharding rule are provided. PiperOrigin-RevId: 741253832 --- jax/_src/custom_partitioning.py | 6 +++++ tests/cache_key_test.py | 3 ++- tests/pjit_test.py | 39 +++++++++++++++++++++++++++++++++ tests/shard_map_test.py | 1 + 4 files changed, 48 insertions(+), 1 deletion(-) diff --git a/jax/_src/custom_partitioning.py b/jax/_src/custom_partitioning.py index 5374071517f1..feb1e0c39cc6 100644 --- a/jax/_src/custom_partitioning.py +++ b/jax/_src/custom_partitioning.py @@ -500,6 +500,12 @@ def __call__(self, *args, **kwargs): infer_sharding_from_operands = None sharding_rule = None if config.use_shardy_partitioner.value: + if (self.sharding_rule is None and + (self.propagate_user_sharding is not None or + self.infer_sharding_from_operands is not None)): + raise ValueError("Shardy is used, but sharding propagation callbacks " + "instead of sharding_rule are provided. Need to " + "provide sharding_rule to migrate to Shardy.") sharding_rule = self.sharding_rule else: propagate_user_sharding = self.propagate_user_sharding diff --git a/tests/cache_key_test.py b/tests/cache_key_test.py index ed80c7060e4c..a908d260d560 100644 --- a/tests/cache_key_test.py +++ b/tests/cache_key_test.py @@ -180,7 +180,8 @@ def _cp_add(x, y): _cp_add.def_partition( infer_sharding_from_operands=_infer_sharding_from_operands, - partition=_partition) + partition=_partition, + sharding_rule='i i -> i') devices = np.asarray(jax.devices()) with Mesh(devices, ('x',)) as m: diff --git a/tests/pjit_test.py b/tests/pjit_test.py index aa5afccb38d4..2cfe61cdf1fe 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -8207,5 +8207,44 @@ def f(x, y, static_arg0=1, static_arg1=2): self.assertArraysEqual(result, expected_result) self.assertEqual(result.sharding, NamedSharding(mesh, P(None, None, 'x'))) + def test_custom_partition_shardy_migration(self): + if jtu.is_cloud_tpu(): + raise unittest.SkipTest("Custom partitioning is not supported on libtpu.") + + def partition(mesh, arg_shapes, result_shape): + def lower_fn(x): + return x + + return ( + mesh, + lower_fn, + arg_shapes[0].sharding, + (arg_shapes[0].sharding,), + ) + + def infer_sharding_from_operands(mesh, arg_shapes, result_shape): + return arg_shapes[0].sharding + + def propagate_user_sharding(mesh, user_shape): + return user_shape.sharding + + @custom_partitioning + def f(x): + return x + + f.def_partition( + infer_sharding_from_operands=infer_sharding_from_operands, + partition=partition, + propagate_user_sharding=propagate_user_sharding, + ) + + mesh = jtu.create_mesh((4, 2), ('x', 'y')) + x = jax.device_put(np.arange(32 * 16).reshape(32, 16), + NamedSharding(mesh, P(None, 'x'))) + with self.assertRaisesRegex(ValueError, "provide sharding_rule to migrate " + "to Shardy"): + jax.jit(f)(x) + + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index ce01b6e6e944..1ffb3e1d137a 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -3105,6 +3105,7 @@ def f(x): infer_sharding_from_operands=infer_sharding_from_operands, partition=partition, propagate_user_sharding=propagate_user_sharding, + sharding_rule='i -> i', ) @jax.jit From 591c327e613507d1d4bb9706d4f353c2d4835eba Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 27 Mar 2025 14:09:15 -0700 Subject: [PATCH 234/483] Remove unused build dependencies in jaxlib/xla/... PiperOrigin-RevId: 741276224 --- jaxlib/xla/BUILD | 40 ---------------------------------------- 1 file changed, 40 deletions(-) diff --git a/jaxlib/xla/BUILD b/jaxlib/xla/BUILD index 2c2a76f29f9b..2ca18afda13d 100644 --- a/jaxlib/xla/BUILD +++ b/jaxlib/xla/BUILD @@ -268,7 +268,6 @@ cc_library( "@xla//xla/pjrt:pjrt_client", "@xla//xla/pjrt:pjrt_common", "@xla//xla/pjrt:pjrt_compiler", - "@xla//xla/pjrt:pjrt_layout", "@xla//xla/python:types", "@xla//xla/python/ifrt", "@xla//xla/python/pjrt_ifrt", @@ -360,7 +359,6 @@ cc_library( "@xla//xla/pjrt:status_casters", "@xla//xla/python:nb_absl_inlined_vector", "@xla//xla/python:nb_absl_span", - "@xla//xla/python:nb_helpers", "@xla//xla/python:types", "@xla//xla/tsl/platform:logging", ], @@ -383,8 +381,6 @@ cc_library( "@com_google_absl//absl/strings:string_view", "@llvm-project//llvm:Support", "@llvm-project//mlir:BytecodeWriter", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:FuncExtensions", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", @@ -398,7 +394,6 @@ cc_library( "@xla//xla/pjrt:mlir_to_hlo", "@xla//xla/pjrt:status_casters", "@xla//xla/python:refine_polymorphic_shapes", - "@xla//xla/service/llvm_ir:llvm_util", "@xla//xla/tsl/platform:errors", "@xla//xla/tsl/platform:logging", "@xla//xla/tsl/platform:statusor", @@ -547,18 +542,15 @@ cc_library( features = ["-use_header_modules"], visibility = jax_visibility("jaxlib/xla/py_client"), deps = [ - ":callback", ":guard_lib", ":nb_class_ptr", ":py_client_cpu", ":py_host_callback", - ":py_host_callback_cc_proto", ":python_ref_manager", ":traceback", ":util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", - "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", @@ -571,56 +563,36 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", - "@com_google_absl//absl/types:variant", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@nanobind", - "@shardy//shardy/dialect/sdy/ir:dialect", - "@tsl//tsl/platform:casts", "@tsl//tsl/platform:fingerprint", "@tsl//tsl/platform:ml_dtypes", - "@tsl//tsl/profiler/lib:profiler_session", "@tsl//tsl/profiler/lib:traceme", - "@tsl//tsl/profiler/protobuf:profiled_instructions_proto_cc", "@xla//third_party/python_runtime:headers", # buildcleaner: keep - "@xla//xla:comparison_util", "@xla//xla:literal", "@xla//xla:shape_util", "@xla//xla:status_macros", "@xla//xla:types", "@xla//xla:util", "@xla//xla:xla_data_proto_cc", - "@xla//xla/hlo/builder:xla_builder", - "@xla//xla/hlo/builder:xla_computation", - "@xla//xla/hlo/builder/lib:arithmetic", "@xla//xla/hlo/ir:hlo", "@xla//xla/pjrt:exceptions", - "@xla//xla/pjrt:host_callback", - "@xla//xla/pjrt:host_memory_spaces", "@xla//xla/pjrt:lru_cache", "@xla//xla/pjrt:mlir_to_hlo", "@xla//xla/pjrt:pjrt_client", - "@xla//xla/pjrt:pjrt_common", "@xla//xla/pjrt:pjrt_compiler", - "@xla//xla/pjrt:pjrt_device_description", "@xla//xla/pjrt:pjrt_executable", "@xla//xla/pjrt:pjrt_future", "@xla//xla/pjrt:pjrt_layout", "@xla//xla/pjrt:status_casters", - "@xla//xla/pjrt:transpose", - "@xla//xla/pjrt/distributed", - "@xla//xla/pjrt/distributed:client", - "@xla//xla/python:aggregate_profile", "@xla//xla/python:nb_absl_span", "@xla//xla/python:nb_helpers", "@xla//xla/python:nb_numpy", "@xla//xla/python:pprof_profile_builder", "@xla//xla/python:types", - "@xla//xla/python:xplane_to_profile_instructions", "@xla//xla/python/compile_only_ifrt:client", "@xla//xla/python/ifrt", "@xla//xla/python/ifrt:attribute_map", @@ -630,20 +602,11 @@ cc_library( "@xla//xla/python/ifrt:user_context", "@xla//xla/python/ifrt/hlo:hlo_program", "@xla//xla/python/pjrt_ifrt", - "@xla//xla/python/pjrt_ifrt:pjrt_attribute_map_util", "@xla//xla/python/pjrt_ifrt:pjrt_dtype", - "@xla//xla/python/pjrt_ifrt:xla_host_callback_proto_cc", "@xla//xla/python/pjrt_ifrt:xla_ifrt", - "@xla//xla/service:computation_placer_hdr", - "@xla//xla/service:custom_call_status", - "@xla//xla/service:custom_call_target_registry", "@xla//xla/service:platform_util", - "@xla//xla/service/spmd/shardy:constants", - "@xla//xla/service/spmd/shardy:utils", - "@xla//xla/service/spmd/shardy/sdy_round_trip:pipelines", "@xla//xla/tsl/concurrency:ref_count", "@xla//xla/tsl/framework:allocator", - "@xla//xla/tsl/framework/mlir:status_scoped_diagnostic_handler", "@xla//xla/tsl/platform:env", "@xla//xla/tsl/platform:errors", "@xla//xla/tsl/platform:logging", @@ -700,7 +663,6 @@ cc_library( ":py_host_callback_cc_proto", ":python_ref_manager", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/base", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -713,7 +675,6 @@ cc_library( "@xla//xla:util", "@xla//xla:xla_data_proto_cc", "@xla//xla/pjrt:host_callback", - "@xla//xla/pjrt:pjrt_compiler", "@xla//xla/python:types", "@xla//xla/python/ifrt", "@xla//xla/python/pjrt_ifrt", @@ -852,7 +813,6 @@ cc_library( "@llvm-project//mlir:Support", "@nanobind", "@shardy//shardy/dialect/sdy/ir:dialect", - "@shardy//shardy/dialect/sdy/transforms/import:passes", "@xla//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo", "@xla//xla/mlir_hlo:all_passes", "@xla//xla/pjrt:mlir_to_hlo", From 71b36dca8406538898df5e61b21ba29f6ef79ad5 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 27 Mar 2025 14:42:40 -0700 Subject: [PATCH 235/483] Sort the replicated_axes wrt mesh names in Shardy PiperOrigin-RevId: 741287495 --- jax/_src/sharding_impls.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index efa1b4cfd5b6..d95b12f244ba 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -105,6 +105,9 @@ def modify_sdy_sharding_wrt_axis_types(sdy_sharding: SdyArraySharding, mesh): if not d.axes and d.is_closed else d) used_axes.extend(d.axes) remaining_axes = set(mesh.axis_names) - set(used_axes) + # Sort wrt mesh axis names so order is deterministic and doesn't hang in + # McJAX. + remaining_axes = [n for n in mesh.axis_names if n in remaining_axes] replicated_axes = tuple(r for r in remaining_axes if mesh._name_to_type[r] == mesh_lib.AxisType.Explicit) return SdyArraySharding(sdy_sharding.mesh_shape, dim_shardings, From d08676e927e4917a629d609ac80770d208187f9a Mon Sep 17 00:00:00 2001 From: Michael Whittaker Date: Thu, 27 Mar 2025 16:48:28 -0700 Subject: [PATCH 236/483] Disable `lax_numpy_test` tsan tests. PiperOrigin-RevId: 741325580 --- tests/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/BUILD b/tests/BUILD index 1d021b0c7110..2526be066635 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -496,6 +496,7 @@ jax_multiplatform_test( name = "lax_numpy_test", srcs = ["lax_numpy_test.py"], backend_tags = { + "tpu": ["notsan"], # Test times out. "cpu": ["notsan"], # Test times out. }, shard_count = { From 25c106d132d01856ac3e1ad40b7ff52c65fafc4c Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 27 Mar 2025 16:55:45 -0700 Subject: [PATCH 237/483] Add standard_insert_pbroadcasts and standard_vma_rule to all primitives in following files: (Don't add `standard_insert_broadcast` for unary ops though) * slicing.py * windowed_reductions.py * special.py * convolution.py * fft.py * linalg.py * ann.py PiperOrigin-RevId: 741327361 --- jax/_src/ad_util.py | 1 + jax/_src/core.py | 15 +++- jax/_src/lax/ann.py | 6 +- jax/_src/lax/control_flow/loops.py | 3 +- jax/_src/lax/convolution.py | 4 +- jax/_src/lax/fft.py | 4 +- jax/_src/lax/lax.py | 116 ++++++---------------------- jax/_src/lax/linalg.py | 10 ++- jax/_src/lax/slicing.py | 42 +++++++--- jax/_src/lax/special.py | 8 ++ jax/_src/lax/windowed_reductions.py | 29 +++++-- 11 files changed, 123 insertions(+), 115 deletions(-) diff --git a/jax/_src/ad_util.py b/jax/_src/ad_util.py index c729a57cfb11..c8e64ce5c2ef 100644 --- a/jax/_src/ad_util.py +++ b/jax/_src/ad_util.py @@ -31,6 +31,7 @@ map = safe_map def add_jaxvals(x: ArrayLike, y: ArrayLike) -> Array: + x, y = core.standard_insert_pbroadcast(x, y) return add_jaxvals_p.bind(x, y) add_jaxvals_p = Primitive('add_any') diff --git a/jax/_src/core.py b/jax/_src/core.py index 3a1558802682..ca353486afd5 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -2018,6 +2018,16 @@ def standard_insert_pbroadcast(*args): return [pbroadcast(arg, tuple(n for n in out_vma if n not in src)) if out_vma - src else arg for arg, src in zip(args, in_vma)] +def standard_vma_rule(prim_name, *avals, **kwargs): + vma, *vmas = [a.vma for a in avals] + if not all(vma == vma_ for vma_ in vmas): + raise ValueError( + f'Primitive {prim_name} requires varying manual axes ' + f'to match, but got {[vma, *vmas]}. Please open an issue at ' + 'https://github.com/jax-ml/jax/issues and as a temporary ' + 'workaround pass the check_rep=False argument to shard_map') + return vma + # Dynamic shape stuff below here! We keep the abstract values distinct just so # as not to interfere with any static shape machinery. @@ -2697,7 +2707,10 @@ def typematch(t1: AbstractValue, t2: AbstractValue) -> bool: # could try normalizing first and then doing simple equality. # TODO(yashkatariya): Also check `sharding` here. # See https://github.com/jax-ml/jax/issues/26474 - return t1.dtype == t2.dtype and definitely_equal_shape(t1.shape, t2.shape) + sh_dt = t1.dtype == t2.dtype and definitely_equal_shape(t1.shape, t2.shape) + if config.varying_axes_in_types.value: + return sh_dt and t1.vma == t2.vma # type: ignore + return sh_dt else: return False diff --git a/jax/_src/lax/ann.py b/jax/_src/lax/ann.py index 0e037ec774b5..0d2eb338da22 100644 --- a/jax/_src/lax/ann.py +++ b/jax/_src/lax/ann.py @@ -77,6 +77,7 @@ def pmap_mips(qy, db, db_offset, db_size, k, recall_target): from jax._src import ad_util from jax._src import core from jax._src import dispatch +from jax._src import config from jax._src import dtypes from jax._src.interpreters import ad from jax._src.interpreters import batching @@ -239,9 +240,10 @@ def _approx_top_k_abstract_eval(operand, *, k, reduction_dimension, "approx_top_k with aggregate_to_topk=False not yet implemented when " f"either the `k` ({k}) or the " f" reduction dimension size ({reduction_input_size}) are symbolic") + out_vma = operand.vma if config.varying_axes_in_types.value else frozenset() return (operand.update(shape=dims, dtype=operand.dtype, - weak_type=operand.weak_type), - operand.update(shape=dims, dtype=np.dtype(np.int32))) + weak_type=operand.weak_type, vma=out_vma), + operand.update(shape=dims, dtype=np.dtype(np.int32), vma=out_vma)) def _get_init_val_literal(op_type, is_max_k): return np.array(-np.inf if is_max_k else np.inf, dtype=op_type) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 88af7c24e5b8..c7bcb1cf6b09 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -2520,7 +2520,8 @@ def _cumred_dtype_rule(name, operand, *args, **kw): def _cumulative_reduction_primitive(name, reduce_fn, reduce_window_fn): reducer_p = lax.standard_primitive( _cumred_shape_rule, partial(_cumred_dtype_rule, name), - name, sharding_rule=_cumred_sharding_rule) + name, sharding_rule=_cumred_sharding_rule, + vma_rule=partial(core.standard_vma_rule, name)) batching.primitive_batchers[reducer_p] = partial(_cumred_batch_rule, reducer_p) diff --git a/jax/_src/lax/convolution.py b/jax/_src/lax/convolution.py index 290d027cc6bc..32294bbd72cf 100644 --- a/jax/_src/lax/convolution.py +++ b/jax/_src/lax/convolution.py @@ -158,6 +158,7 @@ def conv_general_dilated( preferred_element_type = ( None if preferred_element_type is None else dtypes.canonicalize_dtype(np.dtype(preferred_element_type))) + lhs, rhs = core.standard_insert_pbroadcast(lhs, rhs) return conv_general_dilated_p.bind( lhs, rhs, window_strides=tuple(window_strides), padding=tuple(padding), lhs_dilation=tuple(lhs_dilation), rhs_dilation=tuple(rhs_dilation), @@ -633,7 +634,8 @@ def _conv_general_dilated_batch_rule( conv_general_dilated_p = lax.standard_primitive( _conv_general_dilated_shape_rule, _conv_general_dilated_dtype_rule, - 'conv_general_dilated') + 'conv_general_dilated', + vma_rule=partial(core.standard_vma_rule, 'conv_general_dilated')) ad.defbilinear(conv_general_dilated_p, _conv_general_dilated_transpose_lhs, diff --git a/jax/_src/lax/fft.py b/jax/_src/lax/fft.py index 6ca1a4abd193..9044f48f278c 100644 --- a/jax/_src/lax/fft.py +++ b/jax/_src/lax/fft.py @@ -23,6 +23,7 @@ from jax import lax +from jax._src import config from jax._src import dispatch from jax._src import dtypes from jax._src.api import jit, linear_transpose, ShapeDtypeStruct @@ -124,7 +125,8 @@ def fft_abstract_eval(x, fft_type, fft_lengths): f"be equal to fft_lengths {fft_lengths}") shape = x.shape dtype = x.dtype - return x.update(shape=shape, dtype=dtype) + out_vma = x.vma if config.varying_axes_in_types.value else frozenset() + return x.update(shape=shape, dtype=dtype, vma=out_vma) def _fft_lowering(ctx, x, *, fft_type, fft_lengths): if not is_constant_shape(fft_lengths): diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 655ef763f1ef..fcd7aba380bb 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -296,7 +296,6 @@ def neg(x: ArrayLike) -> Array: .. _stablehlo.negate: https://openxla.org/stablehlo/spec#negate """ - x, = core.standard_insert_pbroadcast(x) return neg_p.bind(x) @export @@ -340,7 +339,6 @@ def sign(x: ArrayLike) -> Array: .. _stablehlo.sign: https://openxla.org/stablehlo/spec#sign """ - x, = core.standard_insert_pbroadcast(x) return sign_p.bind(x) @export @@ -393,7 +391,6 @@ def floor(x: ArrayLike) -> Array: .. _stablehlo.floor: https://openxla.org/stablehlo/spec#floor """ - x, = core.standard_insert_pbroadcast(x) return floor_p.bind(x) @export @@ -415,7 +412,6 @@ def ceil(x: ArrayLike) -> Array: .. _stablehlo.ceil: https://openxla.org/stablehlo/spec#ceil """ - x, = core.standard_insert_pbroadcast(x) return ceil_p.bind(x) class RoundingMethod(enum.IntEnum): @@ -465,7 +461,6 @@ def round(x: ArrayLike, .. _stablehlo.round: https://openxla.org/stablehlo/spec#round """ rounding_method = RoundingMethod(rounding_method) - x, = core.standard_insert_pbroadcast(x) return round_p.bind(x, rounding_method=rounding_method) @export @@ -487,7 +482,6 @@ def is_finite(x: ArrayLike) -> Array: .. _stablehlo.is_finite: https://openxla.org/stablehlo/spec#is_finite """ - x, = core.standard_insert_pbroadcast(x) return is_finite_p.bind(x) @export @@ -509,7 +503,6 @@ def exp(x: ArrayLike) -> Array: .. _stablehlo.exponential: https://openxla.org/stablehlo/spec#exponential """ - x, = core.standard_insert_pbroadcast(x) return exp_p.bind(x) @export @@ -533,7 +526,6 @@ def exp2(x: ArrayLike) -> Array: .. _stablehlo.exponential: https://openxla.org/stablehlo/spec#exponential .. _stablehlo.multiply: https://openxla.org/stablehlo/spec#multiply """ - x, = core.standard_insert_pbroadcast(x) return exp2_p.bind(x) @export @@ -557,7 +549,6 @@ def expm1(x: ArrayLike) -> Array: .. _stablehlo.exponential_minus_one: https://openxla.org/stablehlo/spec#exponential_minus_one """ - x, = core.standard_insert_pbroadcast(x) return expm1_p.bind(x) @export @@ -578,7 +569,6 @@ def log(x: ArrayLike) -> Array: .. _stablehlo.log: https://openxla.org/stablehlo/spec#log """ - x, = core.standard_insert_pbroadcast(x) return log_p.bind(x) @export @@ -602,7 +592,6 @@ def log1p(x: ArrayLike) -> Array: .. _stablehlo.log_plus_one: https://openxla.org/stablehlo/spec#log_plus_one """ - x, = core.standard_insert_pbroadcast(x) return log1p_p.bind(x) @export @@ -625,7 +614,6 @@ def tanh(x: ArrayLike) -> Array: .. _stablehlo.tanh: https://openxla.org/stablehlo/spec#tanh """ - x, = core.standard_insert_pbroadcast(x) return tanh_p.bind(x) @export @@ -645,7 +633,6 @@ def logistic(x: ArrayLike) -> Array: See also: - :func:`jax.nn.sigmoid`: an alternative API for this functionality. """ - x, = core.standard_insert_pbroadcast(x) return logistic_p.bind(x) @export @@ -670,7 +657,6 @@ def sin(x: ArrayLike) -> Array: .. _stablehlo.sine: https://openxla.org/stablehlo/spec#sine """ - x, = core.standard_insert_pbroadcast(x) return sin_p.bind(x) @export @@ -695,7 +681,6 @@ def cos(x: ArrayLike) -> Array: .. _stablehlo.cosine: https://openxla.org/stablehlo/spec#cosine """ - x, = core.standard_insert_pbroadcast(x) return cos_p.bind(x) @export @@ -743,7 +728,6 @@ def real(x: ArrayLike) -> Array: .. _stablehlo.real: https://openxla.org/stablehlo/spec#real """ - x, = core.standard_insert_pbroadcast(x) return real_p.bind(x) @export @@ -766,7 +750,6 @@ def imag(x: ArrayLike) -> Array: .. _stablehlo.imag: https://openxla.org/stablehlo/spec#imag """ - x, = core.standard_insert_pbroadcast(x) return imag_p.bind(x) @export @@ -819,7 +802,6 @@ def conj(x: ArrayLike) -> Array: .. _stablehlo.complex: https://openxla.org/stablehlo/spec#complex """ # TODO(mattjj): remove input_dtype, not needed anymore - x, = core.standard_insert_pbroadcast(x) return conj_p.bind(x, input_dtype=_dtype(x)) @export @@ -840,7 +822,6 @@ def abs(x: ArrayLike) -> Array: .. _stablehlo.abs: https://openxla.org/stablehlo/spec#abs """ - x, = core.standard_insert_pbroadcast(x) return abs_p.bind(x) @export @@ -888,7 +869,6 @@ def integer_pow(x: ArrayLike, y: int) -> Array: .. _stablehlo.multiply: https://openxla.org/stablehlo/spec#multiply """ - x, = core.standard_insert_pbroadcast(x) return integer_pow_p.bind(x, y=y) @export @@ -910,7 +890,6 @@ def sqrt(x: ArrayLike) -> Array: .. _stablehlo.sqrt: https://openxla.org/stablehlo/spec#sqrt """ - x, = core.standard_insert_pbroadcast(x) return sqrt_p.bind(x) @export @@ -933,7 +912,6 @@ def rsqrt(x: ArrayLike) -> Array: .. _stablehlo.rsqrt: https://openxla.org/stablehlo/spec#rsqrt """ - x, = core.standard_insert_pbroadcast(x) return rsqrt_p.bind(x) @export @@ -955,7 +933,6 @@ def cbrt(x: ArrayLike) -> Array: .. _stablehlo.cbrt: https://openxla.org/stablehlo/spec#cbrt """ - x, = core.standard_insert_pbroadcast(x) return cbrt_p.bind(x) @export @@ -980,7 +957,6 @@ def bitwise_not(x: ArrayLike) -> Array: .. _stablehlo.not: https://openxla.org/stablehlo/spec#not """ - x, = core.standard_insert_pbroadcast(x) return not_p.bind(x) @export @@ -1083,7 +1059,6 @@ def population_count(x: ArrayLike) -> Array: .. _stablehlo.popcnt: https://openxla.org/stablehlo/spec#popcnt """ - x, = core.standard_insert_pbroadcast(x) return population_count_p.bind(x) @export @@ -1104,7 +1079,6 @@ def clz(x: ArrayLike) -> Array: .. _stablehlo.count_leading_zeros: https://openxla.org/stablehlo/spec#count_leading_zeros """ - x, = core.standard_insert_pbroadcast(x) return clz_p.bind(x) @export @@ -1623,7 +1597,6 @@ def _convert_element_type( f"{dtype_to_string(old_dtype)} -> {dtype_to_string(old_rep_dtype)} " f"-> {dtype_to_string(new_rep_dtype)} -> {dtype_to_string(new_dtype)}") - operand, = core.standard_insert_pbroadcast(operand) if isinstance(new_dtype, dtypes.ExtendedDType): return to_edtype_p.bind(operand, edtype=new_dtype) return from_edtype_p.bind(operand, dtype=np.dtype(new_dtype)) @@ -1662,7 +1635,6 @@ def _convert_element_type( (sharding._is_concrete and getattr(operand, 'sharding', None) == sharding))): return operand else: - operand, = core.standard_insert_pbroadcast(operand) return convert_element_type_p.bind( operand, new_dtype=new_dtype, weak_type=bool(weak_type), sharding=sharding) @@ -1699,7 +1671,6 @@ def bitcast_convert_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array: .. _stablehlo.bitcast_convert: https://openxla.org/stablehlo/spec#bitcast_convert """ new_dtype = dtypes.canonicalize_dtype(new_dtype) - operand, = core.standard_insert_pbroadcast(operand) return bitcast_convert_type_p.bind(operand, new_dtype=new_dtype) def clamp(min: ArrayLike, x: ArrayLike, max: ArrayLike) -> Array: @@ -1956,7 +1927,6 @@ def split(operand: ArrayLike, sizes: Sequence[int], taken along ``axis``. """ operand = asarray(operand) - operand, = core.standard_insert_pbroadcast(operand) return split_p.bind(operand, sizes=tuple(sizes), axis=canonicalize_axis(axis, operand.ndim)) @@ -2662,7 +2632,6 @@ def broadcast_in_dim(operand: ArrayLike, shape: Shape, dyn_shape, static_shape = _extract_tracers_dyn_shape(shape) else: dyn_shape, static_shape = [], shape # type: ignore - operand, = core.standard_insert_pbroadcast(operand) return broadcast_in_dim_p.bind( operand, *dyn_shape, shape=tuple(static_shape), broadcast_dimensions=tuple(broadcast_dimensions), @@ -2729,7 +2698,6 @@ def reshape(operand: ArrayLike, new_sizes: Shape, else: dyn_shape, static_new_sizes = _extract_tracers_dyn_shape(new_sizes) out_sharding = canonicalize_sharding(out_sharding, 'reshape') - operand, = core.standard_insert_pbroadcast(operand) return reshape_p.bind( operand, *dyn_shape, new_sizes=tuple(static_new_sizes), dimensions=None if dims is None or same_dims else dims, @@ -2793,7 +2761,6 @@ def rev(operand: ArrayLike, dimensions: Sequence[int]) -> Array: `_ operator. """ - operand, = core.standard_insert_pbroadcast(operand) return rev_p.bind(operand, dimensions=tuple(dimensions)) def select(pred: ArrayLike, on_true: ArrayLike, on_false: ArrayLike) -> Array: @@ -2860,20 +2827,18 @@ def transpose(operand: ArrayLike, if permutation == tuple(range(np.ndim(operand))) and isinstance(operand, Array): return operand else: - operand, = core.standard_insert_pbroadcast(operand) + return transpose_p.bind(operand, permutation=permutation) def argmin(operand: ArrayLike, axis: int, index_dtype: DTypeLike) -> Array: """Computes the index of the minimum element along ``axis``.""" - operand, = core.standard_insert_pbroadcast(operand) return argmin_p.bind(operand, axes=(axis,), index_dtype=dtypes.canonicalize_dtype(index_dtype)) def argmax(operand: ArrayLike, axis: int, index_dtype: DTypeLike) -> Array: """Computes the index of the maximum element along ``axis``.""" - operand, = core.standard_insert_pbroadcast(operand) return argmax_p.bind(operand, axes=(axis,), index_dtype=dtypes.canonicalize_dtype(index_dtype)) @@ -3039,7 +3004,6 @@ def reduce_sum(operand: ArrayLike, axes: Sequence[int]) -> Array: :func:`jax.lax.reduce_prod`, :func:`jax.lax.reduce_max`, :func:`jax.lax.reduce_min`, :func:`jax.lax.reduce_and`, :func:`jax.lax.reduce_or`, :func:`jax.lax.reduce_xor`. """ - operand, = core.standard_insert_pbroadcast(operand) return reduce_sum_p.bind(operand, axes=tuple(axes)) def reduce_prod(operand: ArrayLike, axes: Sequence[int]) -> Array: @@ -3066,7 +3030,6 @@ def reduce_prod(operand: ArrayLike, axes: Sequence[int]) -> Array: :func:`jax.lax.reduce_sum`, :func:`jax.lax.reduce_max`, :func:`jax.lax.reduce_min`, :func:`jax.lax.reduce_and`, :func:`jax.lax.reduce_or`, :func:`jax.lax.reduce_xor`. """ - operand, = core.standard_insert_pbroadcast(operand) return reduce_prod_p.bind(operand, axes=tuple(axes)) def reduce_max(operand: ArrayLike, axes: Sequence[int]) -> Array: @@ -3088,7 +3051,6 @@ def reduce_max(operand: ArrayLike, axes: Sequence[int]) -> Array: :func:`jax.lax.reduce_sum`, :func:`jax.lax.reduce_prod`, :func:`jax.lax.reduce_min`, :func:`jax.lax.reduce_and`, :func:`jax.lax.reduce_or`, :func:`jax.lax.reduce_xor`. """ - operand, = core.standard_insert_pbroadcast(operand) return reduce_max_p.bind(operand, axes=tuple(axes)) def reduce_min(operand: ArrayLike, axes: Sequence[int]) -> Array: @@ -3110,7 +3072,6 @@ def reduce_min(operand: ArrayLike, axes: Sequence[int]) -> Array: :func:`jax.lax.reduce_sum`, :func:`jax.lax.reduce_prod`, :func:`jax.lax.reduce_max`, :func:`jax.lax.reduce_and`, :func:`jax.lax.reduce_or`, :func:`jax.lax.reduce_xor`. """ - operand, = core.standard_insert_pbroadcast(operand) return reduce_min_p.bind(operand, axes=tuple(axes)) def reduce_or(operand: ArrayLike, axes: Sequence[int]) -> Array: @@ -3133,7 +3094,6 @@ def reduce_or(operand: ArrayLike, axes: Sequence[int]) -> Array: :func:`jax.lax.reduce_sum`, :func:`jax.lax.reduce_prod`, :func:`jax.lax.reduce_max`, :func:`jax.lax.reduce_min`, :func:`jax.lax.reduce_and`, :func:`jax.lax.reduce_xor`. """ - operand, = core.standard_insert_pbroadcast(operand) return reduce_or_p.bind(operand, axes=tuple(axes)) def reduce_and(operand: ArrayLike, axes: Sequence[int]) -> Array: @@ -3156,7 +3116,6 @@ def reduce_and(operand: ArrayLike, axes: Sequence[int]) -> Array: :func:`jax.lax.reduce_sum`, :func:`jax.lax.reduce_prod`, :func:`jax.lax.reduce_max`, :func:`jax.lax.reduce_min`, :func:`jax.lax.reduce_or`, :func:`jax.lax.reduce_xor`. """ - operand, = core.standard_insert_pbroadcast(operand) return reduce_and_p.bind(operand, axes=tuple(axes)) def reduce_xor(operand: ArrayLike, axes: Sequence[int]) -> Array: @@ -3179,7 +3138,6 @@ def reduce_xor(operand: ArrayLike, axes: Sequence[int]) -> Array: :func:`jax.lax.reduce_sum`, :func:`jax.lax.reduce_prod`, :func:`jax.lax.reduce_max`, :func:`jax.lax.reduce_min`, :func:`jax.lax.reduce_and`, :func:`jax.lax.reduce_or`. """ - operand, = core.standard_insert_pbroadcast(operand) return reduce_xor_p.bind(operand, axes=tuple(axes)) @overload @@ -3265,7 +3223,6 @@ def top_k(operand: ArrayLike, k: int) -> tuple[Array, Array]: k = int(k) if k < 0: raise ValueError(f"k argument to top_k must be nonnegative, got {k}") - operand, = core.standard_insert_pbroadcast(operand) return top_k_p.bind(operand, k=k) def tie_in(x: Any, y: T) -> T: @@ -3451,7 +3408,6 @@ def reduce_precision(operand: float | ArrayLike, operator.index, exponent_bits, "exponent_bits argument of lax.reduce_precision") mantissa_bits = core.concrete_or_error( operator.index, mantissa_bits, "mantissa_bits argument of lax.reduce_precision") - operand, = core.standard_insert_pbroadcast(operand) return reduce_precision_p.bind(operand, exponent_bits=exponent_bits, mantissa_bits=mantissa_bits) @@ -3461,7 +3417,6 @@ def squeeze(array: ArrayLike, dimensions: Sequence[int]) -> Array: dimensions = tuple(sorted(canonicalize_axis(i, ndim) for i in dimensions)) if not dimensions and isinstance(array, Array): return array - array, = core.standard_insert_pbroadcast(array) return squeeze_p.bind(array, dimensions=dimensions) def expand_dims(array: ArrayLike, dimensions: Sequence[int]) -> Array: @@ -3582,7 +3537,6 @@ def batch_matmul(lhs: Array, rhs: Array, def square(x: ArrayLike) -> Array: r"""Elementwise square: :math:`x^2`.""" - x, = core.standard_insert_pbroadcast(x) return square_p.bind(x) def reciprocal(x: ArrayLike) -> Array: @@ -3610,7 +3564,6 @@ def tan(x: ArrayLike) -> Array: .. _stablehlo.tangent: https://openxla.org/stablehlo/spec#tangent """ - x, = core.standard_insert_pbroadcast(x) return tan_p.bind(x) @export @@ -3631,7 +3584,6 @@ def asin(x: ArrayLike) -> Array: - :func:`jax.lax.acos`: elementwise arc cosine. - :func:`jax.lax.atan`: elementwise arc tangent. """ - x, = core.standard_insert_pbroadcast(x) return asin_p.bind(x) @export @@ -3652,7 +3604,6 @@ def acos(x: ArrayLike) -> Array: - :func:`jax.lax.asin`: elementwise arc sine. - :func:`jax.lax.atan`: elementwise arc tangent. """ - x, = core.standard_insert_pbroadcast(x) return acos_p.bind(x) @export @@ -3674,7 +3625,6 @@ def atan(x: ArrayLike) -> Array: - :func:`jax.lax.asin`: elementwise arc sine. - :func:`jax.lax.atan2`: elementwise 2-term arc tangent. """ - x, = core.standard_insert_pbroadcast(x) return atan_p.bind(x) @export @@ -3695,7 +3645,6 @@ def sinh(x: ArrayLike) -> Array: - :func:`jax.lax.cosh`: elementwise hyperbolic cosine. - :func:`jax.lax.tanh`: elementwise hyperbolic tangent. """ - x, = core.standard_insert_pbroadcast(x) return sinh_p.bind(x) @export @@ -3716,7 +3665,6 @@ def cosh(x: ArrayLike) -> Array: - :func:`jax.lax.sinh`: elementwise hyperbolic sine. - :func:`jax.lax.tanh`: elementwise hyperbolic tangent. """ - x, = core.standard_insert_pbroadcast(x) return cosh_p.bind(x) @export @@ -3737,7 +3685,6 @@ def asinh(x: ArrayLike) -> Array: - :func:`jax.lax.atanh`: elementwise inverse hyperbolic tangent. - :func:`jax.lax.sinh`: elementwise hyperbolic sine. """ - x, = core.standard_insert_pbroadcast(x) return asinh_p.bind(x) @export @@ -3758,7 +3705,6 @@ def acosh(x: ArrayLike) -> Array: - :func:`jax.lax.atanh`: elementwise inverse hyperbolic tangent. - :func:`jax.lax.cosh`: elementwise hyperbolic cosine. """ - x, = core.standard_insert_pbroadcast(x) return acosh_p.bind(x) @export @@ -3779,7 +3725,6 @@ def atanh(x: ArrayLike) -> Array: - :func:`jax.lax.asinh`: elementwise inverse hyperbolic sine. - :func:`jax.lax.tanh`: elementwise hyperbolic tangent. """ - x, = core.standard_insert_pbroadcast(x) return atanh_p.bind(x) @@ -3937,16 +3882,6 @@ def broadcasting_sharding_rule(name, *avals): f'{", ".join(map(str, map(tuple, specs)))}.') return NamedSharding(mesh, P(*result_specs)) -def standard_vma_rule(prim_name, *avals, **kwargs): - vma, *vmas = [a.vma for a in avals] - if not all(vma == vma_ for vma_ in vmas): - raise ValueError( - f'Primitive {prim_name} requires varying manual axes ' - f'to match, but got {[vma, *vmas]}. Please open an issue at ' - 'https://github.com/jax-ml/jax/issues and as a temporary ' - 'workaround pass the check_rep=False argument to shard_map') - return vma - def naryop(result_dtype, accepted_dtypes, name, allow_extended_dtype=False, require_same_dtypes=True): dtype_rule = partial(naryop_dtype_rule, result_dtype, accepted_dtypes, name, @@ -3956,7 +3891,7 @@ def naryop(result_dtype, accepted_dtypes, name, allow_extended_dtype=False, sharding_rule = partial(broadcasting_sharding_rule, name) prim = standard_primitive( shape_rule, dtype_rule, name, sharding_rule=sharding_rule, - vma_rule=partial(standard_vma_rule, name)) + vma_rule=partial(core.standard_vma_rule, name)) batching.defbroadcasting(prim) pe.def_trivial_padding(prim) return prim @@ -4808,7 +4743,7 @@ def _convert_element_type_bind_with_trace(trace, args, params): _convert_element_type_shape_rule, _convert_element_type_dtype_rule, _convert_element_type_weak_type_rule, _convert_element_type_sharding_rule, - partial(standard_vma_rule, convert_element_type_p.name))) + partial(core.standard_vma_rule, convert_element_type_p.name))) ad.defjvp2(convert_element_type_p, _convert_element_type_jvp_rule) ad.primitive_transposes[convert_element_type_p] = _convert_element_type_transpose_rule batching.defvectorized(convert_element_type_p) @@ -4974,7 +4909,7 @@ def _bitcast_convert_type_dtype_rule(operand, *, new_dtype): _bitcast_convert_type_shape_rule, _bitcast_convert_type_dtype_rule, 'bitcast_convert_type', weak_type_rule=_strip_weak_type, sharding_rule=_bitcast_convert_type_sharding_rule, - vma_rule=partial(standard_vma_rule, 'bitcast_convert_type')) + vma_rule=partial(core.standard_vma_rule, 'bitcast_convert_type')) ad.defjvp_zero(bitcast_convert_type_p) batching.defvectorized(bitcast_convert_type_p) @@ -5443,7 +5378,7 @@ def _dot_general_ragged_prop_rule(eqn_params, invar_raggedness, outvars): _dot_general_dtype_rule, 'dot_general', sharding_rule=_dot_general_sharding_rule, - vma_rule=partial(standard_vma_rule, 'dot_general') + vma_rule=partial(core.standard_vma_rule, 'dot_general') ) @@ -6444,7 +6379,7 @@ def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions, new_sharding = _broadcast_in_dim_sharding_rule( x, shape=shape, broadcast_dimensions=broadcast_dimensions, sharding=sharding) - new_vma = (standard_vma_rule('broadcast_in_dim', x) + new_vma = (core.standard_vma_rule('broadcast_in_dim', x) if config.varying_axes_in_types.value else frozenset()) return core.ShapedArray(shape, x.dtype, x.weak_type, sharding=new_sharding, vma=new_vma) @@ -6532,7 +6467,7 @@ def _clamp_batch_rule(batched_args, batch_dims, **params): clamp_p = standard_primitive(_clamp_shape_rule, _clamp_dtype_rule, 'clamp', sharding_rule=_clamp_sharding_rule, - vma_rule=partial(standard_vma_rule, 'clamp')) + vma_rule=partial(core.standard_vma_rule, 'clamp')) ad.defjvp(clamp_p, lambda g, min, operand, max: select(bitwise_and(gt(min, operand), lt(min, max)), @@ -6620,7 +6555,7 @@ def _concatenate_pad_rule(in_avals, out_avals, *operands, dimension): concatenate_p = standard_primitive( _concatenate_shape_rule, _concatenate_dtype_rule, 'concatenate', sharding_rule=_concatenate_sharding_rule, - vma_rule=partial(standard_vma_rule, 'concatenate')) + vma_rule=partial(core.standard_vma_rule, 'concatenate')) ad.deflinear2(concatenate_p, _concatenate_transpose_rule) ad.primitive_transposes[concatenate_p] = _concatenate_transpose_rule batching.primitive_batchers[concatenate_p] = _concatenate_batch_rule @@ -6693,7 +6628,7 @@ def _split_sharding_rule(operand, *, sizes, axis): for out_sh in out_shapes] def _split_vma_rule(operand, *, sizes, axis): - out_vma = standard_vma_rule('split', operand) + out_vma = core.standard_vma_rule('split', operand) out_shapes = _split_shape_rule(operand, sizes=sizes, axis=axis) return [out_vma] * len(out_shapes) @@ -6785,7 +6720,7 @@ def _pad_batch_rule(batched_args, batch_dims, *, padding_config): pad_p = standard_primitive(_pad_shape_rule, _pad_dtype_rule, 'pad', sharding_rule=_pad_sharding_rule, - vma_rule=partial(standard_vma_rule, 'pad')) + vma_rule=partial(core.standard_vma_rule, 'pad')) ad.deflinear2(pad_p, _pad_transpose) batching.primitive_batchers[pad_p] = _pad_batch_rule @@ -6850,7 +6785,7 @@ def _squeeze_batch_rule(batched_args, batch_dims, *, dimensions): squeeze_p = standard_primitive(_squeeze_shape_rule, _squeeze_dtype_rule, 'squeeze', sharding_rule=_squeeze_sharding_rule, - vma_rule=partial(standard_vma_rule, 'squeeze')) + vma_rule=partial(core.standard_vma_rule, 'squeeze')) ad.deflinear2(squeeze_p, _squeeze_transpose_rule) batching.primitive_batchers[squeeze_p] = _squeeze_batch_rule pe.def_trivial_padding(squeeze_p) @@ -7085,7 +7020,7 @@ def _reshape_staging_rule( reshape_p = standard_primitive(_reshape_shape_rule, _reshape_dtype_rule, 'reshape', sharding_rule=_reshape_sharding_rule, - vma_rule=partial(standard_vma_rule, 'reshape')) + vma_rule=partial(core.standard_vma_rule, 'reshape')) ad.deflinear2(reshape_p, _reshape_transpose_rule) batching.fancy_primitive_batchers[reshape_p] = _reshape_batch_rule batching.skippable_batchers[reshape_p] = lambda _: () @@ -7118,7 +7053,7 @@ def _rev_batch_rule(batched_args, batch_dims, *, dimensions): rev_p = standard_primitive(_rev_shape_rule, _input_dtype, 'rev', sharding_rule=_rev_sharding_rule, - vma_rule=partial(standard_vma_rule, 'rev')) + vma_rule=partial(core.standard_vma_rule, 'rev')) ad.deflinear2(rev_p, lambda t, _, dimensions: [rev(t, dimensions)]) batching.primitive_batchers[rev_p] = _rev_batch_rule @@ -7167,7 +7102,7 @@ def _transpose_lower(ctx, x, *, permutation): transpose_p = standard_primitive( _transpose_shape_rule, _input_dtype, 'transpose', sharding_rule=_transpose_sharding_rule, - vma_rule=partial(standard_vma_rule, 'transpose')) + vma_rule=partial(core.standard_vma_rule, 'transpose')) ad.deflinear2(transpose_p, lambda t, _, permutation: [transpose(t, np.argsort(permutation))]) batching.primitive_batchers[transpose_p] = _transpose_batch_rule @@ -7344,7 +7279,7 @@ def _select(offset, cases): select_n_p = standard_primitive( _select_shape_rule, _select_dtype_rule, 'select_n', weak_type_rule=_select_weak_type_rule, sharding_rule=_select_sharding_rule, - vma_rule=partial(standard_vma_rule, 'select_n')) + vma_rule=partial(core.standard_vma_rule, 'select_n')) ad.primitive_jvps[select_n_p] = _select_jvp ad.primitive_transposes[select_n_p] = _select_transpose_rule batching.primitive_batchers[select_n_p] = _select_batch_rule @@ -7526,7 +7461,7 @@ def _reduce_op_sharding_rule(operand, *, axes): reduce_sum_p = standard_primitive( _reduce_op_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_sum'), 'reduce_sum', sharding_rule=_reduce_op_sharding_rule, - vma_rule=partial(standard_vma_rule, 'reduce_sum')) + vma_rule=partial(core.standard_vma_rule, 'reduce_sum')) ad.deflinear2(reduce_sum_p, _reduce_sum_transpose_rule) batching.defreducer(reduce_sum_p, _get_sum_identity) pe.padding_rules[reduce_sum_p] = partial(_reducer_padding, reduce_sum, @@ -7542,7 +7477,7 @@ def _reduce_prod_jvp_rule(primals, tangents, *, axes): reduce_prod_p = standard_primitive( _reduce_op_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_prod'), 'reduce_prod', sharding_rule=_reduce_op_sharding_rule, - vma_rule=partial(standard_vma_rule, 'reduce_prod')) + vma_rule=partial(core.standard_vma_rule, 'reduce_prod')) ad.primitive_jvps[reduce_prod_p] = _reduce_prod_jvp_rule batching.defreducer(reduce_prod_p, _get_prod_identity) pe.padding_rules[reduce_prod_p] = partial(_reducer_padding, reduce_prod, @@ -7563,7 +7498,7 @@ def _reduce_chooser_jvp_rule(g, ans, operand, *, axes): reduce_max_p = standard_primitive( _reduce_op_shape_rule, _input_dtype, 'reduce_max', sharding_rule=_reduce_op_sharding_rule, - vma_rule=partial(standard_vma_rule, 'reduce_max')) + vma_rule=partial(core.standard_vma_rule, 'reduce_max')) ad.defjvp2(reduce_max_p, _reduce_chooser_jvp_rule) batching.defreducer(reduce_max_p, _get_max_identity) pe.padding_rules[reduce_max_p] = partial(_reducer_padding, reduce_max, @@ -7574,7 +7509,7 @@ def _reduce_chooser_jvp_rule(g, ans, operand, *, axes): reduce_min_p = standard_primitive( _reduce_op_shape_rule, _input_dtype, 'reduce_min', sharding_rule=_reduce_op_sharding_rule, - vma_rule=partial(standard_vma_rule, 'reduce_min')) + vma_rule=partial(core.standard_vma_rule, 'reduce_min')) ad.defjvp2(reduce_min_p, _reduce_chooser_jvp_rule) batching.defreducer(reduce_min_p, _get_min_identity) pe.padding_rules[reduce_min_p] = partial(_reducer_padding, reduce_min, @@ -7642,14 +7577,14 @@ def _compute_argminmax(value_comparator, get_identity, argmin_p = standard_primitive(_argminmax_shape_rule, _argminmax_dtype_rule, 'argmin', weak_type_rule=_strip_weak_type, sharding_rule=_argminmax_sharding_rule, - vma_rule=partial(standard_vma_rule, 'argmin')) + vma_rule=partial(core.standard_vma_rule, 'argmin')) batching.defreducer(argmin_p, _get_min_identity) ad.defjvp_zero(argmin_p) argmax_p = standard_primitive(_argminmax_shape_rule, _argminmax_dtype_rule, 'argmax', weak_type_rule=_strip_weak_type, sharding_rule=_argminmax_sharding_rule, - vma_rule=partial(standard_vma_rule, 'argmax')) + vma_rule=partial(core.standard_vma_rule, 'argmax')) batching.defreducer(argmax_p, _get_max_identity) ad.defjvp_zero(argmax_p) @@ -7673,14 +7608,14 @@ def _reduce_logical_sharding_rule(operand, *, axes): reduce_or_p = standard_primitive( _reduce_logical_shape_rule, _input_dtype, 'reduce_or', weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule, - vma_rule=partial(standard_vma_rule, 'reduce_or')) + vma_rule=partial(core.standard_vma_rule, 'reduce_or')) batching.defreducer(reduce_or_p, _get_bitwise_or_identity) reduce_and_p = standard_primitive( _reduce_logical_shape_rule, _input_dtype, 'reduce_and', weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule, - vma_rule=partial(standard_vma_rule, 'reduce_and')) + vma_rule=partial(core.standard_vma_rule, 'reduce_and')) batching.defreducer(reduce_and_p, _get_bitwise_and_identity) batching.ragged_prop_rules[reduce_and_p] = batching.ragged_mask_elementwise_rule @@ -7688,7 +7623,7 @@ def _reduce_logical_sharding_rule(operand, *, axes): reduce_xor_p = standard_primitive( _reduce_logical_shape_rule, _input_dtype, 'reduce_xor', weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule, - vma_rule=partial(standard_vma_rule, 'reduce_xor')) + vma_rule=partial(core.standard_vma_rule, 'reduce_xor')) batching.defreducer(reduce_xor_p, _get_bitwise_or_identity) @@ -7736,7 +7671,7 @@ def _reduce_precision_sharding_rule(operand, *, exponent_bits, mantissa_bits): _reduce_precision_shape_rule, partial(unop_dtype_rule, _identity, _float, 'reduce_precision'), name='reduce_precision', sharding_rule=_reduce_precision_sharding_rule, - vma_rule=partial(standard_vma_rule, 'reduce_precision')) + vma_rule=partial(core.standard_vma_rule, 'reduce_precision')) ad.deflinear(reduce_precision_p, lambda t, **kwargs: [reduce_precision_p.bind(t, **kwargs)]) batching.defvectorized(reduce_precision_p) @@ -8368,7 +8303,6 @@ def rng_bit_generator(key, shape, dtype=np.uint32, if np.dtype(dtype) not in {np.dtype('uint8'), np.dtype('uint16'), np.dtype('uint32'), np.dtype('uint64')}: raise TypeError(f'rng_bit_generator: unsupported dtype {dtype}') - key, = core.standard_insert_pbroadcast(key) return tuple( rng_bit_generator_p.bind( key, shape=shape, dtype=dtype, algorithm=algorithm, diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index b455257e107c..d53d54599527 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -121,6 +121,7 @@ def cholesky_update(r_matrix: ArrayLike, w_vector: ArrayLike) -> Array: A new upper-triangular matrix :math:`R` defining the Cholesky decomposition of :math:`A + w \, w^T`. """ + r_matrix, w_vector = core.standard_insert_pbroadcast(r_matrix, w_vector) return cholesky_update_p.bind(r_matrix, w_vector) @@ -268,6 +269,7 @@ def householder_product(a: ArrayLike, taus: ArrayLike) -> Array: A batch of orthogonal (unitary) matrices with the same shape as ``a``, containing the products of the elementary Householder reflectors. """ + a, taus = core.standard_insert_pbroadcast(a, taus) return householder_product_p.bind(a, taus) @@ -545,6 +547,7 @@ def symmetric_product( ``symmetrize_output`` is ``True``, the upper triangle is filled with the transpose of the lower triangle, and the whole matrix is valid. """ + a_matrix, c_matrix = core.standard_insert_pbroadcast(a_matrix, c_matrix) result = symmetric_product_p.bind(a_matrix, c_matrix, alpha=alpha, beta=beta) if symmetrize_output: upper_half = lax.transpose( @@ -602,6 +605,7 @@ def triangular_solve( singleton = np.ndim(b) == np.ndim(a) - 1 if singleton: b = lax.expand_dims(b, (-1 if left_side else -2,)) + a, b = core.standard_insert_pbroadcast(a, b) out = triangular_solve_p.bind( a, b, left_side=left_side, lower=lower, transpose_a=transpose_a, conjugate_a=conjugate_a, unit_diagonal=unit_diagonal) @@ -661,6 +665,7 @@ def tridiagonal_solve(dl: Array, d: Array, du: Array, b: Array) -> Array: Returns: Solution ``X`` of tridiagonal system. """ + dl, d, du, b = core.standard_insert_pbroadcast(dl, d, du, b) return tridiagonal_solve_p.bind(dl, d, du, b) @@ -742,7 +747,7 @@ def linalg_sharding_rule( def linalg_vma_rule(multiple_results, shape_rule, name, *avals, **kwargs): output_shapes = shape_rule(*avals, **kwargs) - out_vma = lax_internal.standard_vma_rule(name, *avals) + out_vma = core.standard_vma_rule(name, *avals) if multiple_results: return [out_vma] * len(output_shapes) else: @@ -775,7 +780,7 @@ def linalg_primitive(result_dtype, accepted_dtypes, ranks, result_shape, name, prim.def_abstract_eval( partial(lax_utils.standard_abstract_eval, prim, shape_rule, dtype_rule, lax_utils._standard_weak_type_rule, sharding_rule, - partial(lax_internal.standard_vma_rule, name))) + partial(core.standard_vma_rule, name))) if supports_batching: batching.primitive_batchers[prim] = partial( batching.expand_dims_batcher, prim) @@ -1768,6 +1773,7 @@ def geqp3(a: ArrayLike, jpvt: ArrayLike, *, elementary Householder reflectors, and ``jpvt`` is the column-pivot indices such that ``a[:, jpvt] = q @ r``. """ + a, jpvt = core.standard_insert_pbroadcast(a, jpvt) a_out, jpvt_out, taus = geqp3_p.bind(a, jpvt, use_magma=use_magma) return a_out, jpvt_out, taus diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index d3bcb6da2807..3f4d1b6d576f 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -234,6 +234,7 @@ def dynamic_update_slice( """ start_indices = _dynamic_slice_indices( operand, start_indices, allow_negative_indices) + operand, update = core.standard_insert_pbroadcast(operand, update) return dynamic_update_slice_p.bind(operand, update, *start_indices) @@ -416,6 +417,7 @@ def gather(operand: ArrayLike, start_indices: ArrayLike, raise ValueError(f"Unsupported dtype for gather fill_value {dtype}") else: fill_value = None + operand, start_indices = core.standard_insert_pbroadcast(operand, start_indices) return gather_p.bind( operand, start_indices, dimension_numbers=dimension_numbers, slice_sizes=core.canonicalize_shape(slice_sizes), @@ -505,6 +507,8 @@ def scatter_add( """ jaxpr, consts = lax._reduction_jaxpr(lax.add, core.get_aval(lax._const(operand, 0))) + operand, scatter_indices, updates = core.standard_insert_pbroadcast( + operand, scatter_indices, updates) return scatter_add_p.bind( operand, scatter_indices, updates, update_jaxpr=jaxpr, update_consts=consts, dimension_numbers=dimension_numbers, @@ -559,6 +563,8 @@ def scatter_sub( jaxpr, consts = lax._reduction_jaxpr( lax.sub, core.get_aval(lax._const(operand, 0)) ) + operand, scatter_indices, updates = core.standard_insert_pbroadcast( + operand, scatter_indices, updates) return scatter_sub_p.bind( operand, scatter_indices, @@ -613,6 +619,8 @@ def scatter_mul( """ jaxpr, consts = lax._reduction_jaxpr(lax.mul, core.get_aval(lax._const(operand, 1))) + operand, scatter_indices, updates = core.standard_insert_pbroadcast( + operand, scatter_indices, updates) return scatter_mul_p.bind( operand, scatter_indices, updates, update_jaxpr=jaxpr, update_consts=consts, dimension_numbers=dimension_numbers, @@ -660,6 +668,8 @@ def scatter_min( """ jaxpr, consts = lax._reduction_jaxpr(lax.min, core.get_aval(lax._const(operand, 0))) + operand, scatter_indices, updates = core.standard_insert_pbroadcast( + operand, scatter_indices, updates) return scatter_min_p.bind( operand, scatter_indices, updates, update_jaxpr=jaxpr, update_consts=consts, dimension_numbers=dimension_numbers, @@ -707,6 +717,8 @@ def scatter_max( """ jaxpr, consts = lax._reduction_jaxpr(lax.max, core.get_aval(lax._const(operand, 0))) + operand, scatter_indices, updates = core.standard_insert_pbroadcast( + operand, scatter_indices, updates) return scatter_max_p.bind( operand, scatter_indices, updates, update_jaxpr=jaxpr, update_consts=consts, dimension_numbers=dimension_numbers, @@ -854,6 +866,8 @@ def scatter( ... mode=lax.GatherScatterMode.PROMISE_IN_BOUNDS) Array([0., 2., 3., 0., 4.], dtype=float32) """ + operand, scatter_indices, updates = core.standard_insert_pbroadcast( + operand, scatter_indices, updates) return scatter_p.bind( operand, scatter_indices, updates, update_jaxpr=None, update_consts=(), dimension_numbers=dimension_numbers, @@ -1393,7 +1407,8 @@ def _slice_batching_rule(batched_args, batch_dims, *, start_indices, return out, bdim slice_p = standard_primitive(_slice_shape_rule, _input_dtype, 'slice', - sharding_rule=_slice_sharding_rule) + sharding_rule=_slice_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'slice')) ad.deflinear2(slice_p, _slice_transpose_rule) batching.primitive_batchers[slice_p] = _slice_batching_rule # TODO(mvoz): A better slice rule for ragged prop, enforcing boundaries @@ -1559,7 +1574,8 @@ def _dynamic_slice_padding_rule(in_avals, out_avals, x, *starts_and_dyn, dynamic_slice_p = standard_primitive( _dynamic_slice_shape_rule, _dynamic_slice_dtype_rule, 'dynamic_slice', weak_type_rule=_argnum_weak_type(0), - sharding_rule=_dynamic_slice_sharding_rule) + sharding_rule=_dynamic_slice_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'dynamic_slice')) ad.primitive_jvps[dynamic_slice_p] = _dynamic_slice_jvp ad.primitive_transposes[dynamic_slice_p] = _dynamic_slice_transpose_rule batching.primitive_batchers[dynamic_slice_p] = _dynamic_slice_batching_rule @@ -1679,7 +1695,8 @@ def _dynamic_update_slice_batching_rule(batched_args, batch_dims): dynamic_update_slice_p = standard_primitive( _dynamic_update_slice_shape_rule, _dynamic_update_slice_dtype_rule, - 'dynamic_update_slice', sharding_rule=_dynamic_update_slice_sharding_rule) + 'dynamic_update_slice', sharding_rule=_dynamic_update_slice_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'dynamic_update_slice')) ad.primitive_jvps[dynamic_update_slice_p] = _dynamic_update_slice_jvp ad.primitive_transposes[dynamic_update_slice_p] = \ _dynamic_update_slice_transpose_rule @@ -2117,7 +2134,8 @@ def _gather_pad_rule(in_avals, out_avals, operand, indices, *, gather_p = standard_primitive( _gather_shape_rule, _gather_dtype_rule, 'gather', - weak_type_rule=_argnum_weak_type(0), sharding_rule=_gather_sharding_rule) + weak_type_rule=_argnum_weak_type(0), sharding_rule=_gather_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'gather')) ad.defjvp(gather_p, _gather_jvp_rule, None) ad.primitive_transposes[gather_p] = _gather_transpose_rule batching.primitive_batchers[gather_p] = _gather_batching_rule @@ -2599,7 +2617,8 @@ def _scatter_batching_rule(scatter_op, batched_args, batch_dims, *, scatter_add_p = standard_primitive( _scatter_shape_rule, _scatter_dtype_rule, 'scatter-add', - weak_type_rule=_argnum_weak_type(0)) + weak_type_rule=_argnum_weak_type(0), + vma_rule=partial(core.standard_vma_rule, 'scatter_add')) ad.primitive_jvps[scatter_add_p] = partial(_scatter_addsub_jvp, scatter_add_p) ad.primitive_transposes[scatter_add_p] = partial(_scatter_addsub_transpose_rule, scatter_add_p) batching.primitive_batchers[scatter_add_p] = ( @@ -2610,6 +2629,7 @@ def _scatter_batching_rule(scatter_op, batched_args, batch_dims, *, _scatter_dtype_rule, "scatter-sub", weak_type_rule=_argnum_weak_type(0), + vma_rule=partial(core.standard_vma_rule, 'scatter_sub') ) ad.primitive_jvps[scatter_sub_p] = partial(_scatter_addsub_jvp, scatter_sub_p) ad.primitive_transposes[scatter_sub_p] = partial(_scatter_addsub_transpose_rule, scatter_sub_p) @@ -2619,7 +2639,8 @@ def _scatter_batching_rule(scatter_op, batched_args, batch_dims, *, scatter_mul_p = standard_primitive( _scatter_shape_rule, _scatter_dtype_rule, 'scatter-mul', - weak_type_rule=_argnum_weak_type(0)) + weak_type_rule=_argnum_weak_type(0), + vma_rule=partial(core.standard_vma_rule, 'scatter_mul')) def _scatter_mul_jvp_rhs(g, x, i, y, *, dimension_numbers, indices_are_sorted, unique_indices, mode, **kw): @@ -2748,14 +2769,16 @@ def _scatter_extremal_jvp(scatter_op, primals, tangents, update_jaxpr, scatter_min_p = standard_primitive( _scatter_shape_rule, _scatter_dtype_rule, 'scatter-min', - weak_type_rule=_argnum_weak_type(0)) + weak_type_rule=_argnum_weak_type(0), + vma_rule=partial(core.standard_vma_rule, 'scatter_min')) batching.primitive_batchers[scatter_min_p] = ( partial(_scatter_batching_rule, scatter_min_p)) ad.primitive_jvps[scatter_min_p] = partial(_scatter_extremal_jvp, scatter_min_p) scatter_max_p = standard_primitive( _scatter_shape_rule, _scatter_dtype_rule, 'scatter-max', - weak_type_rule=_argnum_weak_type(0)) + weak_type_rule=_argnum_weak_type(0), + vma_rule=partial(core.standard_vma_rule, 'scatter_max')) batching.primitive_batchers[scatter_max_p] = ( partial(_scatter_batching_rule, scatter_max_p)) ad.primitive_jvps[scatter_max_p] = partial(_scatter_extremal_jvp, scatter_max_p) @@ -2913,7 +2936,8 @@ def _scatter_transpose_rule(t, operand, indices, updates, *, scatter_p = standard_primitive( _scatter_shape_rule, _scatter_dtype_rule, 'scatter', - weak_type_rule=_argnum_weak_type(0)) + weak_type_rule=_argnum_weak_type(0), + vma_rule=partial(core.standard_vma_rule, 'scatter')) ad.primitive_jvps[scatter_p] = _scatter_jvp ad.primitive_transposes[scatter_p] = _scatter_transpose_rule batching.primitive_batchers[scatter_p] = ( diff --git a/jax/_src/lax/special.py b/jax/_src/lax/special.py index ba2687d4acd7..041205156d58 100644 --- a/jax/_src/lax/special.py +++ b/jax/_src/lax/special.py @@ -21,6 +21,7 @@ import numpy as np from functools import partial +from jax._src import core from jax._src.lax.lax import (add, bitwise_and, bitwise_not, bitwise_or, broadcast_in_dim, broadcast_shapes, convert_element_type, div, eq, exp, full_like, ge, @@ -39,6 +40,7 @@ def betainc(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array: r"""Elementwise regularized incomplete beta integral.""" + a, b, x = core.standard_insert_pbroadcast(a, b, x) return regularized_incomplete_beta_p.bind(a, b, x) def lgamma(x: ArrayLike) -> Array: @@ -51,26 +53,32 @@ def digamma(x: ArrayLike) -> Array: def polygamma(m: ArrayLike, x: ArrayLike) -> Array: r"""Elementwise polygamma: :math:`\psi^{(m)}(x)`.""" + m, x = core.standard_insert_pbroadcast(m, x) return polygamma_p.bind(m, x) def igamma(a: ArrayLike, x: ArrayLike) -> Array: r"""Elementwise regularized incomplete gamma function.""" + a, x = core.standard_insert_pbroadcast(a, x) return igamma_p.bind(a, x) def igammac(a: ArrayLike, x: ArrayLike) -> Array: r"""Elementwise complementary regularized incomplete gamma function.""" + a, x = core.standard_insert_pbroadcast(a, x) return igammac_p.bind(a, x) def igamma_grad_a(a: ArrayLike, x: ArrayLike) -> Array: r"""Elementwise derivative of the regularized incomplete gamma function.""" + a, x = core.standard_insert_pbroadcast(a, x) return igamma_grad_a_p.bind(a, x) def random_gamma_grad(a: ArrayLike, x: ArrayLike) -> Array: r"""Elementwise derivative of samples from `Gamma(a, 1)`.""" + a, x = core.standard_insert_pbroadcast(a, x) return random_gamma_grad_p.bind(a, x) def zeta(x: ArrayLike, q: ArrayLike) -> Array: r"""Elementwise Hurwitz zeta function: :math:`\zeta(x, q)`""" + x, q = core.standard_insert_pbroadcast(x, q) return zeta_p.bind(x, q) def bessel_i0e(x: ArrayLike) -> Array: diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index 42b2e9278889..00bdfe75f3e7 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -21,6 +21,7 @@ from jax import tree_util from jax._src import api_util from jax._src import core +from jax._src import config from jax._src import dispatch from jax._src import dtypes from jax._src import util @@ -97,6 +98,7 @@ def _reduce_window( raise ValueError( 'reduce_window output must have the same tree structure as the operands' f' {operand_tree} vs. {out_tree}') + flat_operands = core.standard_insert_pbroadcast(*flat_operands) out_flat = reduce_window_p.bind( *flat_operands, *flat_init_values, @@ -250,6 +252,8 @@ def _select_and_scatter(operand: Array, select: Callable, select, core.get_aval(init_value)) scatter_jaxpr, scatter_consts = lax._reduction_jaxpr( scatter, core.get_aval(init_value)) + operand, source, init_value = core.standard_insert_pbroadcast( + operand, source, init_value) return select_and_scatter_p.bind( operand, source, init_value, select_jaxpr=select_jaxpr, select_consts=select_consts, scatter_jaxpr=scatter_jaxpr, @@ -261,6 +265,7 @@ def _select_and_scatter_add(source: Array, operand: Array, window_dimensions: core.Shape, window_strides: Sequence[int], padding: Sequence[tuple[int, int]]) -> Array: + source, operand = core.standard_insert_pbroadcast(source, operand) return select_and_scatter_add_p.bind( source, operand, select_prim=select_prim, window_dimensions=tuple(window_dimensions), @@ -296,6 +301,7 @@ def _select_and_gather_add(tangents: Array, operand: Array, An array containing the elements in `tangents` corresponding to the output of the reduction of `operand` fin each window. """ + tangents, operand = core.standard_insert_pbroadcast(tangents, operand) return select_and_gather_add_p.bind( tangents, operand, select_prim=select_prim, window_dimensions=tuple(window_dimensions), @@ -332,7 +338,10 @@ def _reduce_window_abstract_eval_rule( out_sharding = reduce_window_sharding_rule( operand_avals[0], window_dimensions, window_strides, padding, base_dilation, window_dilation) - return tuple(ShapedArray(out_shape, op.dtype, sharding=out_sharding) + out_vma = (core.standard_vma_rule('reduce_window', operand_avals) + if config.varying_axes_in_types.value else frozenset()) + return tuple(ShapedArray(out_shape, op.dtype, sharding=out_sharding, + vma=out_vma) for op in operand_avals) @@ -532,7 +541,8 @@ def reduce_window_sharding_rule(operand, window_dimensions, window_strides, reduce_window_sum_p = lax.standard_primitive( _reduce_window_sum_shape_rule, lax._input_dtype, 'reduce_window_sum', - sharding_rule=reduce_window_sharding_rule) + sharding_rule=reduce_window_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'reduce_window_sum')) ad.deflinear2(reduce_window_sum_p, _reduce_window_sum_transpose_rule) batching.primitive_batchers[reduce_window_sum_p] = partial( _reduce_window_batch_rule, _reduce_window_sum) @@ -598,7 +608,8 @@ def reduce_window_shape_tuple(operand_shape, window_dimensions, window_strides, reduce_window_max_p = lax.standard_primitive( _common_reduce_window_shape_rule, lax._input_dtype, 'reduce_window_max', - sharding_rule=reduce_window_sharding_rule) + sharding_rule=reduce_window_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'reduce_window_max')) ad.defjvp(reduce_window_max_p, partial(_reduce_window_chooser_jvp_rule, lax.max_p)) batching.primitive_batchers[reduce_window_max_p] = partial( @@ -606,7 +617,8 @@ def reduce_window_shape_tuple(operand_shape, window_dimensions, window_strides, reduce_window_min_p = lax.standard_primitive( _common_reduce_window_shape_rule, lax._input_dtype, 'reduce_window_min', - sharding_rule=reduce_window_sharding_rule) + sharding_rule=reduce_window_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'reduce_window_min')) ad.defjvp(reduce_window_min_p, partial(_reduce_window_chooser_jvp_rule, lax.min_p)) @@ -671,7 +683,8 @@ def _select_and_scatter_shape_rule( return operand.shape select_and_scatter_p = lax.standard_primitive( - _select_and_scatter_shape_rule, lax._input_dtype, 'select_and_scatter') + _select_and_scatter_shape_rule, lax._input_dtype, 'select_and_scatter', + vma_rule=partial(core.standard_vma_rule, 'select_and_scatter')) def _select_and_scatter_lower( ctx, operand, source, init_value, *, select_jaxpr, @@ -766,7 +779,8 @@ def _select_and_scatter_add_batch_rule( select_and_scatter_add_p = lax.standard_primitive( _select_and_scatter_add_shape_rule, lax._input_dtype, - 'select_and_scatter_add') + 'select_and_scatter_add', + vma_rule=partial(core.standard_vma_rule, 'select_and_scatter_add')) ad.primitive_transposes[select_and_scatter_add_p] = \ _select_and_scatter_add_transpose @@ -1039,7 +1053,8 @@ def _select_and_gather_add_batching_rule( select_and_gather_add_p = lax.standard_primitive( _select_and_gather_add_shape_rule, lax._input_dtype, - 'select_and_gather_add', sharding_rule=_select_and_gather_add_sharding_rule) + 'select_and_gather_add', sharding_rule=_select_and_gather_add_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'select_and_gather_add')) ad.primitive_jvps[select_and_gather_add_p] = _select_and_gather_add_jvp ad.primitive_transposes[select_and_gather_add_p] = \ _select_and_gather_add_transpose From a52f7b26e7d5b2696a73a150518441204a2d9565 Mon Sep 17 00:00:00 2001 From: Rachel Han Date: Thu, 27 Mar 2025 17:12:08 -0700 Subject: [PATCH 238/483] Add accuracy field to unary ops * Cbrt * Cos * Exp, Exp2 * Expm1 * Log * Logistic * Log1p * Rsqrt * Sin * Sqrt * Tan * Tanh which allows users to select implementation that will satisfy the requested accuracy. PiperOrigin-RevId: 741331787 --- jax/_src/api.py | 13 +- jax/_src/internal_test_util/test_harnesses.py | 26 +- jax/_src/lax/lax.py | 248 +++++++++--- jax/_src/pallas/mosaic/lowering.py | 44 ++- jax/_src/pallas/mosaic_gpu/lowering.py | 24 +- jax/_src/pallas/triton/lowering.py | 6 +- jax/experimental/jax2tf/jax2tf.py | 27 +- jax/experimental/jet.py | 18 +- tests/BUILD | 14 + tests/api_test.py | 154 ++++---- tests/core_test.py | 8 +- tests/pallas/ops_test.py | 29 +- tests/pmap_test.py | 16 +- tests/unary_ops_accuracy_test.py | 373 ++++++++++++++++++ 14 files changed, 782 insertions(+), 218 deletions(-) create mode 100644 tests/unary_ops_accuracy_test.py diff --git a/jax/_src/api.py b/jax/_src/api.py index 692f049b5f0c..e01bdd4a9d81 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2269,13 +2269,16 @@ def make_jaxpr( >>> print(f(3.0)) -0.83602 >>> jax.make_jaxpr(f)(3.0) - { lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) } + { lambda ; a:f32[]. let + b:f32[] = cos[accuracy=None] a + c:f32[] = sin[accuracy=None] b + in (c,) } >>> jax.make_jaxpr(jax.grad(f))(3.0) { lambda ; a:f32[]. let - b:f32[] = cos a - c:f32[] = sin a - _:f32[] = sin b - d:f32[] = cos b + b:f32[] = cos[accuracy=None] a + c:f32[] = sin[accuracy=None] a + _:f32[] = sin[accuracy=None] b + d:f32[] = cos[accuracy=None] b e:f32[] = mul 1.0 d f:f32[] = neg e g:f32[] = mul f c diff --git a/jax/_src/internal_test_util/test_harnesses.py b/jax/_src/internal_test_util/test_harnesses.py index 02779c85977e..b557434ac7f3 100644 --- a/jax/_src/internal_test_util/test_harnesses.py +++ b/jax/_src/internal_test_util/test_harnesses.py @@ -408,11 +408,11 @@ def parameterized(harnesses: Iterable[Harness], ############################################################################### -def _make_unary_elementwise_harness(*, prim, shape=(20, 20), dtype): +def _make_unary_elementwise_harness(*, prim, shape=(20, 20), dtype, **kwargs): define( str(prim), f"shape={jtu.format_shape_dtype_string(shape, dtype)}", - prim.bind, [RandArg(shape, dtype)], + lambda x: prim.bind(x, **kwargs), [RandArg(shape, dtype)], prim=prim, dtype=dtype, shape=shape) @@ -429,19 +429,19 @@ def _make_unary_elementwise_harness(*, prim, shape=(20, 20), dtype): _make_unary_elementwise_harness(prim=lax.acos_p, dtype=dtype) _make_unary_elementwise_harness(prim=lax.atan_p, dtype=dtype) _make_unary_elementwise_harness(prim=lax.asin_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.cos_p, dtype=dtype) + _make_unary_elementwise_harness(prim=lax.cos_p, dtype=dtype, accuracy=None) _make_unary_elementwise_harness(prim=lax.cosh_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.exp_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.expm1_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.log_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.log1p_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.rsqrt_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.sin_p, dtype=dtype) + _make_unary_elementwise_harness(prim=lax.exp_p, dtype=dtype, accuracy=None) + _make_unary_elementwise_harness(prim=lax.expm1_p, dtype=dtype, accuracy=None) + _make_unary_elementwise_harness(prim=lax.log_p, dtype=dtype, accuracy=None) + _make_unary_elementwise_harness(prim=lax.log1p_p, dtype=dtype, accuracy=None) + _make_unary_elementwise_harness(prim=lax.rsqrt_p, dtype=dtype, accuracy=None) + _make_unary_elementwise_harness(prim=lax.sin_p, dtype=dtype, accuracy=None) _make_unary_elementwise_harness(prim=lax.sinh_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.sqrt_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.tan_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.tanh_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.logistic_p, dtype=dtype) + _make_unary_elementwise_harness(prim=lax.sqrt_p, dtype=dtype, accuracy=None) + _make_unary_elementwise_harness(prim=lax.tan_p, dtype=dtype, accuracy=None) + _make_unary_elementwise_harness(prim=lax.tanh_p, dtype=dtype, accuracy=None) + _make_unary_elementwise_harness(prim=lax.logistic_p, dtype=dtype, accuracy=None) for dtype in jtu.dtypes.all_floating: _make_unary_elementwise_harness(prim=lax.bessel_i0e_p, dtype=dtype) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index fcd7aba380bb..b79c81e19195 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -484,14 +484,41 @@ def is_finite(x: ArrayLike) -> Array: """ return is_finite_p.bind(x) +class Tolerance: + """Specify the tolerances used for computing unary functions. + + Maximum two tolerances can be specified: (atol and rtol) or (atol and ulps). + """ + + def __init__(self, atol: float = 0.0, rtol: float = 0.0, ulps: int = 0): + if atol < 0.0 or rtol < 0.0 or ulps < 0.0: + raise ValueError('Tolerances must be non-negative.') + if atol == 0.0 and rtol == 0.0 and ulps == 0: + raise ValueError('At least one of atol, rtol, or ulps must be set.') + + self.atol = atol + self.rtol = rtol + self.ulps = ulps + + +class AccuracyMode(enum.Enum): + HIGHEST = 1 + DEFAULT = 2 + @export -def exp(x: ArrayLike) -> Array: +def exp(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise exponential: :math:`e^x`. This function lowers directly to the `stablehlo.exponential`_ operation. Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -503,10 +530,10 @@ def exp(x: ArrayLike) -> Array: .. _stablehlo.exponential: https://openxla.org/stablehlo/spec#exponential """ - return exp_p.bind(x) + return exp_p.bind(x, accuracy=accuracy) -@export -def exp2(x: ArrayLike) -> Array: + +def exp2(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise base-2 exponential: :math:`2^x`. This function is implemented in terms of the `stablehlo.exponential`_ @@ -514,6 +541,12 @@ def exp2(x: ArrayLike) -> Array: Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -526,10 +559,10 @@ def exp2(x: ArrayLike) -> Array: .. _stablehlo.exponential: https://openxla.org/stablehlo/spec#exponential .. _stablehlo.multiply: https://openxla.org/stablehlo/spec#multiply """ - return exp2_p.bind(x) + return exp2_p.bind(x, accuracy=accuracy) @export -def expm1(x: ArrayLike) -> Array: +def expm1(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise :math:`e^{x} - 1`. This function lowers directly to the `stablehlo.exponential_minus_one`_ @@ -538,6 +571,12 @@ def expm1(x: ArrayLike) -> Array: Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -549,16 +588,22 @@ def expm1(x: ArrayLike) -> Array: .. _stablehlo.exponential_minus_one: https://openxla.org/stablehlo/spec#exponential_minus_one """ - return expm1_p.bind(x) + return expm1_p.bind(x, accuracy=accuracy) @export -def log(x: ArrayLike) -> Array: +def log(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise natural logarithm: :math:`\mathrm{log}(x)`. This function lowers directly to the `stablehlo.log`_ operation. Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -569,10 +614,10 @@ def log(x: ArrayLike) -> Array: .. _stablehlo.log: https://openxla.org/stablehlo/spec#log """ - return log_p.bind(x) + return log_p.bind(x, accuracy=accuracy) @export -def log1p(x: ArrayLike) -> Array: +def log1p(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise :math:`\mathrm{log}(1 + x)`. This function lowers directly to the `stablehlo.log_plus_one`_ operation. @@ -581,6 +626,12 @@ def log1p(x: ArrayLike) -> Array: Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -592,16 +643,22 @@ def log1p(x: ArrayLike) -> Array: .. _stablehlo.log_plus_one: https://openxla.org/stablehlo/spec#log_plus_one """ - return log1p_p.bind(x) + return log1p_p.bind(x, accuracy=accuracy) @export -def tanh(x: ArrayLike) -> Array: +def tanh(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise hyperbolic tangent: :math:`\mathrm{tanh}(x)`. This function lowers directly to the `stablehlo.tanh`_ operation. Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -614,10 +671,11 @@ def tanh(x: ArrayLike) -> Array: .. _stablehlo.tanh: https://openxla.org/stablehlo/spec#tanh """ - return tanh_p.bind(x) + return tanh_p.bind(x, accuracy=accuracy) @export -def logistic(x: ArrayLike) -> Array: + +def logistic(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise logistic (sigmoid) function: :math:`\frac{1}{1 + e^{-x}}`. There is no HLO logistic/sigmoid primitive, so this lowers to a sequence @@ -633,10 +691,10 @@ def logistic(x: ArrayLike) -> Array: See also: - :func:`jax.nn.sigmoid`: an alternative API for this functionality. """ - return logistic_p.bind(x) + return logistic_p.bind(x, accuracy=accuracy) @export -def sin(x: ArrayLike) -> Array: +def sin(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise sine: :math:`\mathrm{sin}(x)`. For floating-point inputs, this function lowers directly to the @@ -645,6 +703,12 @@ def sin(x: ArrayLike) -> Array: Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -657,10 +721,10 @@ def sin(x: ArrayLike) -> Array: .. _stablehlo.sine: https://openxla.org/stablehlo/spec#sine """ - return sin_p.bind(x) + return sin_p.bind(x, accuracy=accuracy) @export -def cos(x: ArrayLike) -> Array: +def cos(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise cosine: :math:`\mathrm{cos}(x)`. For floating-point inputs, this function lowers directly to the @@ -669,6 +733,12 @@ def cos(x: ArrayLike) -> Array: Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -681,7 +751,7 @@ def cos(x: ArrayLike) -> Array: .. _stablehlo.cosine: https://openxla.org/stablehlo/spec#cosine """ - return cos_p.bind(x) + return cos_p.bind(x, accuracy=accuracy) @export def atan2(x: ArrayLike, y: ArrayLike) -> Array: @@ -871,14 +941,21 @@ def integer_pow(x: ArrayLike, y: int) -> Array: """ return integer_pow_p.bind(x, y=y) + @export -def sqrt(x: ArrayLike) -> Array: +def sqrt(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise square root: :math:`\sqrt{x}`. This function lowers directly to the `stablehlo.sqrt`_ operation. Args: x: Input array. Must have floating or complex dtype. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: An array of the same shape and dtype as ``x`` containing the square root. @@ -890,16 +967,22 @@ def sqrt(x: ArrayLike) -> Array: .. _stablehlo.sqrt: https://openxla.org/stablehlo/spec#sqrt """ - return sqrt_p.bind(x) + return sqrt_p.bind(x, accuracy=accuracy) @export -def rsqrt(x: ArrayLike) -> Array: +def rsqrt(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise reciprocal square root: :math:`1 \over \sqrt{x}`. This function lowers directly to the `stablehlo.rsqrt`_ operation. Args: x: Input array. Must have floating or complex dtype. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: An array of the same shape and dtype as ``x`` containing the @@ -912,16 +995,22 @@ def rsqrt(x: ArrayLike) -> Array: .. _stablehlo.rsqrt: https://openxla.org/stablehlo/spec#rsqrt """ - return rsqrt_p.bind(x) + return rsqrt_p.bind(x, accuracy=accuracy) @export -def cbrt(x: ArrayLike) -> Array: +def cbrt(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise cube root: :math:`\sqrt[3]{x}`. This function lowers directly to the `stablehlo.cbrt`_ operation. Args: x: Input array. Must have floating or complex dtype. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: An array of the same shape and dtype as ``x`` containing the cube root. @@ -933,7 +1022,7 @@ def cbrt(x: ArrayLike) -> Array: .. _stablehlo.cbrt: https://openxla.org/stablehlo/spec#cbrt """ - return cbrt_p.bind(x) + return cbrt_p.bind(x, accuracy=accuracy) @export def bitwise_not(x: ArrayLike) -> Array: @@ -3544,13 +3633,19 @@ def reciprocal(x: ArrayLike) -> Array: return integer_pow(x, -1) @export -def tan(x: ArrayLike) -> Array: +def tan(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise tangent: :math:`\mathrm{tan}(x)`. This function lowers directly to the `stablehlo.tangent`_ operation. Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -3564,7 +3659,7 @@ def tan(x: ArrayLike) -> Array: .. _stablehlo.tangent: https://openxla.org/stablehlo/spec#tangent """ - return tan_p.bind(x) + return tan_p.bind(x, accuracy=accuracy) @export def asin(x: ArrayLike) -> Array: @@ -3958,8 +4053,9 @@ def multi_sharding_in_dim(ctx, ops, in_avals, out_aval): return out -def _nary_lower_hlo(op: Callable, ctx, - *args: ir.Value, **params) -> Sequence[ir.Value]: +def _nary_lower_hlo( + op: Callable, ctx, *args: ir.Value, accuracy=None, **params +) -> Sequence[ir.Value]: """Lowers an elementwise operator to its MLIR equivalent. """ del params @@ -3968,6 +4064,8 @@ def _nary_lower_hlo(op: Callable, ctx, args = multi_sharding_in_dim(ctx, args, avals_in, aval_out) out = op(*args) + if accuracy: + out = op(*args, result_accuracy=accuracy_attr(accuracy)) return [mlir.lower_with_sharding_in_types(ctx, out, aval_out)] @@ -4029,43 +4127,57 @@ def _round_lower(ctx, x, *, rounding_method): mlir.register_lowering(is_finite_p, partial(_nary_lower_hlo, hlo.is_finite)) exp_p = standard_unop(_float | _complex, 'exp') -ad.defjvp2(exp_p, lambda g, ans, x: mul(g, ans)) +ad.defjvp2(exp_p, lambda g, ans, x, **kwargs: mul(g, ans)) mlir.register_lowering(exp_p, partial(_nary_lower_hlo, hlo.exponential)) batching.ragged_prop_rules[exp_p] = batching.ragged_mask_elementwise_rule exp2_p = standard_unop(_float | _complex, 'exp2') -ad.defjvp2(exp2_p, lambda g, ans, x: mul(log(_const(x, 2)), mul(g, ans))) -def _exp2_lower(ctx, x): +ad.defjvp2( + exp2_p, lambda g, ans, x, **kwargs: mul(log(_const(x, 2)), mul(g, ans)) +) + +def _exp2_lower(ctx, x, accuracy): x_aval, = ctx.avals_in log2 = mlir.ir_constant(np.array(np.log(2), x_aval.dtype)) log2 = mlir.broadcast_in_dim(ctx, log2, x_aval, broadcast_dimensions=()) - return [hlo.exponential(hlo.multiply(log2, x))] + return [ + hlo.exponential( + hlo.multiply(log2, x), result_accuracy=accuracy_attr(accuracy) + ) + ] + mlir.register_lowering(exp2_p, _exp2_lower) log_p = standard_unop(_float | _complex, 'log') -ad.defjvp(log_p, lambda g, x: div(g, x)) +ad.defjvp(log_p, lambda g, x, **kwargs: div(g, x)) mlir.register_lowering(log_p, partial(_nary_lower_hlo, hlo.log)) expm1_p = standard_unop(_float | _complex, 'expm1') -ad.defjvp2(expm1_p, lambda g, ans, x: mul(g, add(ans, _one(ans)))) +ad.defjvp2(expm1_p, lambda g, ans, x, **kwargs: mul(g, add(ans, _one(ans)))) mlir.register_lowering(expm1_p, partial(_nary_lower_hlo, hlo.exponential_minus_one)) log1p_p = standard_unop(_float | _complex, 'log1p') -ad.defjvp(log1p_p, lambda g, x: div(g, add(x, _one(x)))) +ad.defjvp(log1p_p, lambda g, x, **kwargs: div(g, add(x, _one(x)))) mlir.register_lowering(log1p_p, partial(_nary_lower_hlo, hlo.log_plus_one)) tanh_p = standard_unop(_float | _complex, 'tanh') -ad.defjvp2(tanh_p, lambda g, ans, x: mul(add(g, mul(g, ans)), - sub(_one(x), ans))) +ad.defjvp2( + tanh_p, + lambda g, ans, x, **kwargs: mul(add(g, mul(g, ans)), sub(_one(x), ans)), +) mlir.register_lowering(tanh_p, partial(_nary_lower_hlo, hlo.tanh)) logistic_p = standard_unop(_float | _complex, 'logistic') -ad.defjvp2(logistic_p, lambda g, ans, x: mul(g, mul(ans, sub(_one(ans), ans)))) +ad.defjvp2( + logistic_p, + lambda g, ans, x, **kwargs: mul(g, mul(ans, sub(_one(ans), ans))), +) # TODO(phawkins): switch to LogisticOp lowering; debug numerical problems. # mlir.register_lowering(logistic_p, partial(_nary_lower_hlo, hlo.logistic)) -def logistic_impl(x): + +def logistic_impl(x, accuracy): one = _const(x, 1) return div(one, add(one, exp(neg(x)))) @@ -4088,20 +4200,26 @@ def _sin_complex(x): # avoid nan value when real(x) is zero and abs(x) is so large that abs(expm1(x)) is inf return select(a_is_zero, complex(_const(a, 0), im), complex(re, im)) -def _sin_lowering(ctx, x): +def _sin_lowering(ctx, x, accuracy): if dtypes.issubdtype(ctx.avals_in[0].dtype, np.complexfloating): sine = mlir.lower_fun(_sin_complex, multiple_results=False) return sine(ctx, x) - return _nary_lower_hlo(hlo.sine, ctx, x) + return _nary_lower_hlo(hlo.sine, ctx, x, accuracy=accuracy) + -def _sin_lin(nzs, x): +def _sin_p_lin(nzs, x, accuracy): nz, = nzs cos_x = cos(x) # TODO: allow this to happen in the linearized computation (need to fix backward_pass) - return (sin_p.bind(x), nz, cos_x, lambda cos_x_, t: mul(t, cos_x_)) + return ( + sin_p.bind(x, accuracy=accuracy), + nz, + cos_x, + lambda cos_x_, t: mul(t, cos_x_), + ) sin_p = standard_unop(_float | _complex, 'sin') -ad.defjvp(sin_p, lambda g, x: mul(g, cos(x))) -ad.primitive_linearizations[sin_p] = _sin_lin +ad.defjvp(sin_p, lambda g, x, accuracy: mul(g, cos(x, accuracy=accuracy))) +ad.primitive_linearizations[sin_p] = _sin_p_lin mlir.register_lowering(sin_p, _sin_lowering) batching.ragged_prop_rules[sin_p] = batching.ragged_mask_elementwise_rule @@ -4117,18 +4235,20 @@ def _cos_complex(x): re, im = mul(cs, csh), mul(neg(sn), snh) return select(a_is_zero, complex(re, _const(a, 0)), complex(re, im)) -def _cos_lowering(ctx, x): +def _cos_lowering(ctx, x, accuracy): if dtypes.issubdtype(ctx.avals_in[0].dtype, np.complexfloating): cosine = mlir.lower_fun(_cos_complex, multiple_results=False) return cosine(ctx, x) - return _nary_lower_hlo(hlo.cosine, ctx, x) + return _nary_lower_hlo(hlo.cosine, ctx, x, accuracy=accuracy) cos_p = standard_unop(_float | _complex, 'cos') -ad.defjvp(cos_p, lambda g, x: neg(mul(g, sin(x)))) +ad.defjvp( + cos_p, lambda g, x, accuracy: neg(mul(g, sin(x, accuracy=accuracy))) +) mlir.register_lowering(cos_p, _cos_lowering) tan_p = standard_unop(_float | _complex, 'tan') -ad.defjvp2(tan_p, lambda g, ans, x: mul(g, add(_const(x, 1), square(ans)))) +ad.defjvp2(tan_p, lambda g, ans, x, **kwargs: mul(g, add(_const(x, 1), square(ans)))) mlir.register_lowering(tan_p, partial(_nary_lower_hlo, hlo.tan)) asin_p = standard_unop(_float | _complex, 'asin') @@ -4245,18 +4365,23 @@ def _abs_jvp_rule(g, ans, x): _maybe_real = lambda x: real(x) if _iscomplex(x) else x sqrt_p = standard_unop(_float | _complex, 'sqrt') -ad.defjvp2(sqrt_p, lambda g, ans, x: mul(g, div(_const(x, 0.5), ans))) +ad.defjvp2(sqrt_p, lambda g, ans, x, **kwargs: mul(g, div(_const(x, 0.5), ans))) mlir.register_lowering(sqrt_p, partial(_nary_lower_hlo, hlo.sqrt)) rsqrt_p = standard_unop(_float | _complex, 'rsqrt') -ad.defjvp2(rsqrt_p, - lambda g, ans, x: - mul(g, mul(_const(x, -0.5), div(ans, x)))) +ad.defjvp2( + rsqrt_p, + lambda g, ans, x, **kwargs: mul(g, mul(_const(x, -0.5), div(ans, x))), +) mlir.register_lowering(rsqrt_p, partial(_nary_lower_hlo, hlo.rsqrt)) cbrt_p = standard_unop(_float, 'cbrt') -ad.defjvp2(cbrt_p, - lambda g, ans, x: mul(g, mul(_const(x, 1/3), integer_pow(ans, -2)))) +ad.defjvp2( + cbrt_p, + lambda g, ans, x, **kwargs: mul( + g, mul(_const(x, 1 / 3), integer_pow(ans, -2)) + ), +) mlir.register_lowering(cbrt_p, partial(_nary_lower_hlo, hlo.cbrt)) square_p = standard_unop(_int | _float | _complex, 'square') @@ -5463,6 +5588,17 @@ def maybe_convert_dtype(input_dtype, target_dtypes): return lhs_dtype, rhs_dtype, out_type +def accuracy_attr(accuracy) -> hlo.ResultAccuracyAttr: + if isinstance(accuracy, AccuracyMode): + return hlo.ResultAccuracyAttr.get(0.0, 0.0, int(0), str(accuracy.name)) + elif isinstance(accuracy, Tolerance): + return hlo.ResultAccuracyAttr.get( + atol=accuracy.atol, + rtol=accuracy.rtol, + ulps=accuracy.ulps, + mode='TOLERANCE', + ) + def _handle_dot_precision(ctx, lhs, rhs, precision, platform): def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes): fp8_dtypes = (dtypes.float8_e4m3fn, dtypes.float8_e5m2, diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 537a2cc07575..617324d43bf9 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -2549,14 +2549,18 @@ def _nextafter_lowering_rule(ctx: LoweringRuleContext, x, y): lowering_rules[lax.nextafter_p] = _nextafter_lowering_rule -def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x): +def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.rsqrt(x) lowering_rules[lax.rsqrt_p] = _rsqrt_lowering_rule -def _sqrt_lowering_rule(ctx: LoweringRuleContext, x): +def _sqrt_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.sqrt(x) @@ -2572,7 +2576,9 @@ def _square_lowering_rule(ctx: LoweringRuleContext, x): lowering_rules[lax.square_p] = _square_lowering_rule -def _exp_lowering_rule(ctx: LoweringRuleContext, x): +def _exp_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.exp(x) @@ -2605,9 +2611,11 @@ def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, *, y): lowering_rules[lax.integer_pow_p] = _integer_pow_lowering_rule -def _exp2_lowering_rule(ctx: LoweringRuleContext, x): +def _exp2_lowering_rule(ctx: LoweringRuleContext, x, accuracy): # exp2 in JAX lowers to exp(ln2 * x), not to pow2. We match that behavior # here. + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return lower_fun( lambda x: jnp.exp(jnp.astype(np.log(2), x.dtype) * x), multiple_results=False, @@ -2618,7 +2626,9 @@ def _exp2_lowering_rule(ctx: LoweringRuleContext, x): skip_mlir_conversions.add(lax.exp2_p) -def _logistic_lowering_rule(ctx: LoweringRuleContext, x): +def _logistic_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") neg_x = arith.negf(x) exp_neg_x = math.exp(neg_x) aval_out = ctx.avals_out[0] @@ -2636,42 +2646,54 @@ def _logistic_lowering_rule(ctx: LoweringRuleContext, x): lowering_rules[lax.logistic_p] = _logistic_lowering_rule -def _sin_lowering_rule(ctx: LoweringRuleContext, x): +def _sin_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.sin(x) lowering_rules[lax.sin_p] = _sin_lowering_rule -def _cos_lowering_rule(ctx: LoweringRuleContext, x): +def _cos_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.cos(x) lowering_rules[lax.cos_p] = _cos_lowering_rule -def _tan_lowering_rule(ctx: LoweringRuleContext, x): +def _tan_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.tan(x) lowering_rules[lax.tan_p] = _tan_lowering_rule -def _tanh_lowering_rule(ctx: LoweringRuleContext, x): +def _tanh_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.tanh(x) lowering_rules[lax.tanh_p] = _tanh_lowering_rule -def _log_lowering_rule(ctx: LoweringRuleContext, x): +def _log_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.log(x) lowering_rules[lax.log_p] = _log_lowering_rule -def _log1p_lowering_rule(ctx: LoweringRuleContext, x): +def _log1p_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.log1p(x) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 286fedfa44d5..0c9f70937873 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1584,7 +1584,9 @@ def _square_lowering_rule(ctx: LoweringRuleContext, x): @register_lowering_rule(lax.rsqrt_p, mgpu.ThreadSemantics.Lane) @register_lowering_rule(lax.rsqrt_p, mgpu.ThreadSemantics.Warpgroup) -def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x): +def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") [x_aval] = ctx.avals_in if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: return _ensure_fa(x, x_aval.dtype).rsqrt(approx=ctx.module_ctx.approx_math) @@ -1598,7 +1600,9 @@ def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x): @register_lowering_rule(lax.tanh_p, mgpu.ThreadSemantics.Lane) @register_lowering_rule(lax.tanh_p, mgpu.ThreadSemantics.Warpgroup) -def _tanh_lowering_rule(ctx: LoweringRuleContext, x): +def _tanh_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") [x_aval] = ctx.avals_in if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: return _ensure_fa(x, x_aval.dtype).tanh(approx=ctx.module_ctx.approx_math) @@ -1608,7 +1612,9 @@ def _tanh_lowering_rule(ctx: LoweringRuleContext, x): return math_dialect.tanh(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath) -def _logistic(x): +def _logistic(x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return 1.0 / (1 + lax.exp(-x)) @@ -1622,7 +1628,9 @@ def _logistic(x): @register_lowering_rule(lax.exp_p, mgpu.ThreadSemantics.Lane) @register_lowering_rule(lax.exp_p, mgpu.ThreadSemantics.Warpgroup) -def _exp_lowering_rule(ctx: LoweringRuleContext, x): +def _exp_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") [x_aval] = ctx.avals_in if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: return _ensure_fa(x, x_aval.dtype).exp(approx=ctx.module_ctx.approx_math) @@ -1633,7 +1641,9 @@ def _exp_lowering_rule(ctx: LoweringRuleContext, x): @register_lowering_rule(lax.exp2_p, mgpu.ThreadSemantics.Lane) -def _exp2_lowering_rule(ctx: LoweringRuleContext, x): +def _exp2_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") [x_aval] = ctx.avals_in if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: return _ensure_fa(x, x_aval.dtype).exp2(approx=ctx.module_ctx.approx_math) @@ -1645,7 +1655,9 @@ def _exp2_lowering_rule(ctx: LoweringRuleContext, x): @register_lowering_rule(lax.log_p, mgpu.ThreadSemantics.Lane) @register_lowering_rule(lax.log_p, mgpu.ThreadSemantics.Warpgroup) -def _log_lowering_rule(ctx: LoweringRuleContext, x): +def _log_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") [x_aval] = ctx.avals_in if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: return _ensure_fa(x, x_aval.dtype).log(approx=ctx.module_ctx.approx_math) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index c85c5f0a39c0..150ae9b8b2d7 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -654,7 +654,9 @@ def _make_dispatch_table( name: str, **tables: Sequence[_Extern | _Fallback] ) -> Callable[..., ir.Value]: - def inner(ctx: LoweringRuleContext, *args: ir.Value) -> ir.Value: + def inner( + ctx: LoweringRuleContext, *args: ir.Value, **_ + ) -> ir.Value: table = tables[ctx.context.platform] h = next((e for e in table if e.matches(ctx.avals_in)), None) if h is None: @@ -1404,7 +1406,7 @@ def _integer_pow_rule(ctx: LoweringRuleContext, x, *, y: int): _JAX_FN_MAPPING = { lax.clamp_p: lambda min, a, max: jnp.minimum(jnp.maximum(min, a), max), - lax.logistic_p: lambda a: 1 / (1 + jnp.exp(-a)), + lax.logistic_p: lambda a, accuracy: 1 / (1 + jnp.exp(-a)), } for prim, fn in _JAX_FN_MAPPING.items(): diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 3d71af38388b..492e070de1af 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -1666,17 +1666,18 @@ def _integer_pow(x, *, y: int, _in_avals: Sequence[core.ShapedArray], tf_impl_with_avals[lax.integer_pow_p] = _integer_pow -tf_impl[lax.exp_p] = tf.math.exp -tf_impl[lax_internal.exp2_p] = lambda x: \ - tf.math.exp(tf.math.multiply(tf.math.log(tf.constant(2, x.dtype)), x)) -tf_impl[lax.expm1_p] = tf.math.expm1 -tf_impl[lax.log_p] = tf.math.log -tf_impl[lax.log1p_p] = tf.math.log1p -tf_impl[lax.tan_p] = tf.math.tan -tf_impl[lax.tanh_p] = tf.math.tanh -tf_impl[lax.sin_p] = tf.math.sin +tf_impl[lax.exp_p] = lambda x, accuracy: tf.math.exp(x) +tf_impl[lax_internal.exp2_p] = lambda x, accuracy: tf.math.exp( + tf.math.multiply(tf.math.log(tf.constant(2, x.dtype)), x) +) +tf_impl[lax.expm1_p] = lambda x, accuracy: tf.math.expm1(x) +tf_impl[lax.log_p] = lambda x, accuracy: tf.math.log(x) +tf_impl[lax.log1p_p] = lambda x, accuracy: tf.math.log1p(x) +tf_impl[lax.tan_p] = lambda x, accuracy: tf.math.tan(x) +tf_impl[lax.tanh_p] = lambda x, accuracy: tf.math.tanh(x) +tf_impl[lax.sin_p] = lambda x, accuracy: tf.math.sin(x) tf_impl[lax.sinh_p] = tf.math.sinh -tf_impl[lax.cos_p] = tf.math.cos +tf_impl[lax.cos_p] = lambda x, accuracy: tf.math.cos(x) tf_impl[lax.cosh_p] = tf.math.cosh tf_impl_with_avals[lax.atan_p] = _convert_jax_impl( lax_internal.atan_impl, multiple_results=False) @@ -1706,11 +1707,11 @@ def _atan2(y, x, **kwargs): tf_impl[lax.asin_p] = tf.math.asin tf_impl[lax.acos_p] = tf.math.acos -tf_impl[lax.sqrt_p] = tf.math.sqrt +tf_impl[lax.sqrt_p] = lambda x, accuracy: tf.math.sqrt(x) tf_impl[lax.square_p] = tf.math.square -tf_impl[lax.rsqrt_p] = tf.math.rsqrt +tf_impl[lax.rsqrt_p] = lambda x, accuracy: tf.math.rsqrt(x) -def _cbrt(x): +def _cbrt(x, accuracy): return tf.math.sign(x) * tf.math.pow(tf.math.abs(x), 1/3) tf_impl[lax.cbrt_p] = _cbrt diff --git a/jax/experimental/jet.py b/jax/experimental/jet.py index 15273f0fd02a..acf8885b0f98 100644 --- a/jax/experimental/jet.py +++ b/jax/experimental/jet.py @@ -76,7 +76,7 @@ from jax._src.util import unzip2, weakref_lru_cache, safe_zip -def jet(fun, primals, series): +def jet(fun, primals, series, **_): r"""Taylor-mode higher-order automatic differentiation. Args: @@ -405,11 +405,11 @@ def deriv_prop(prim, deriv, primals_in, series_in): lax.exp(lax.neg(lax.square(x))))) -def def_comp(prim, comp): +def def_comp(prim, comp, **kwargs): """ Define the jet rule for a primitive in terms of a composition of simpler primitives. """ - jet_rules[prim] = partial(jet, comp) + jet_rules[prim] = partial(jet, comp, **kwargs) def_comp(lax.expm1_p, lambda x: lax.exp(x) - 1) @@ -478,7 +478,7 @@ def _scale(k, j): def _scale2(k, j): return 1. / (fact(k - j) * fact(j)) -def _exp_taylor(primals_in, series_in): +def _exp_taylor(primals_in, series_in, **_): x, = primals_in series, = series_in u = [x] + series @@ -522,7 +522,7 @@ def _integer_pow_taylor(primals_in, series_in, *, y): jet_rules[lax.integer_pow_p] = _integer_pow_taylor -def _logistic_taylor(primals_in, series_in): +def _logistic_taylor(primals_in, series_in, **_): x, = primals_in series, = series_in u = [x] + series @@ -538,7 +538,7 @@ def _logistic_taylor(primals_in, series_in): jet_rules[lax.logistic_p] = _logistic_taylor -def _tanh_taylor(primals_in, series_in): +def _tanh_taylor(primals_in, series_in, **_): x, = primals_in series, = series_in u = [2*x] + [2 * series_ for series_ in series] @@ -548,7 +548,7 @@ def _tanh_taylor(primals_in, series_in): return 2 * primal_out - 1, series_out jet_rules[lax.tanh_p] = _tanh_taylor -def _log_taylor(primals_in, series_in): +def _log_taylor(primals_in, series_in, **_): x, = primals_in series, = series_in u = [x] + series @@ -590,7 +590,7 @@ def scale(k, j): return 1. / (fact(k - j) * fact(j)) return primal_out, series_out jet_rules[lax.div_p] = _div_taylor_rule -def _sinusoidal_rule(sign, prims, primals_in, series_in): +def _sinusoidal_rule(sign, prims, primals_in, series_in, **_): x, = primals_in series, = series_in u = [x] + series @@ -603,7 +603,7 @@ def _sinusoidal_rule(sign, prims, primals_in, series_in): return (s[0], s[1:]), (c[0], c[1:]) def _get_ind(f, ind): - return lambda *args: f(*args)[ind] + return lambda *args, **kwargs: f(*args, **kwargs)[ind] jet_rules[lax.sin_p] = _get_ind(partial(_sinusoidal_rule, -1, (lax.sin, lax.cos)), 0) jet_rules[lax.cos_p] = _get_ind(partial(_sinusoidal_rule, -1, (lax.sin, lax.cos)), 1) diff --git a/tests/BUILD b/tests/BUILD index 2526be066635..b501a614da39 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1640,6 +1640,20 @@ jax_multiplatform_test( deps = ["//jax:experimental"], ) +jax_multiplatform_test( + name = "unary_ops_accuracy_test", + srcs = ["unary_ops_accuracy_test.py"], + disable_configs = [ + "tpu_pjrt_c_api", + ], + enable_backends = [ + "tpu", + ], + deps = [ + "//jax:experimental", + ], +) + jax_py_test( name = "pretty_printer_test", srcs = ["pretty_printer_test.py"], diff --git a/tests/api_test.py b/tests/api_test.py index 6a970051d56e..9710131a92fa 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -4780,7 +4780,7 @@ def sin_of_sin(x): def test_deferred_primal_with_direct_linearize(self): def my_sin_lin(nzs, x): nz, = nzs - return (my_sin_p.bind(x), nz, x, lambda x, t: lax.mul(t, lax.cos(x))) + return (my_sin_p.bind(x, accuracy=None), nz, x, lambda x, t: lax.mul(t, lax.cos(x))) my_sin_p = core.Primitive("my_sin_p") my_sin_p.def_impl(lax.sin) @@ -4827,8 +4827,8 @@ def f(x): sin_impl = lax.sin_p.impl cos_impl = lax.cos_p.impl try: - lax.sin_p.def_impl(lambda x: sin_calls.append(1) or sin_impl(x)) - lax.cos_p.def_impl(lambda x: cos_calls.append(1) or cos_impl(x)) + lax.sin_p.def_impl(lambda x, **kwargs: sin_calls.append(1) or sin_impl(x, **kwargs)) + lax.cos_p.def_impl(lambda x, **kwargs: cos_calls.append(1) or cos_impl(x, **kwargs)) f_lin(3.) finally: lax.sin_p.def_impl(sin_impl) @@ -5092,11 +5092,11 @@ def f_yesremat(x): jaxpr = api.make_jaxpr(api.linearize(f_yesremat, 4.)[1])(1.) scan_eqn, = jaxpr.jaxpr.eqns - self.assertIn(' cos ', str(scan_eqn.params['jaxpr'])) + self.assertIn(' cos[', str(scan_eqn.params['jaxpr'])) jaxpr = api.make_jaxpr(api.vjp(f_yesremat, 4.)[1])(1.) scan_eqn, = jaxpr.jaxpr.eqns - self.assertIn(' cos ', str(scan_eqn.params['jaxpr'])) + self.assertIn(' cos[', str(scan_eqn.params['jaxpr'])) @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} @@ -5121,7 +5121,7 @@ def f(x, y): called = [] sin_impl = lax.sin_p.impl try: - lax.sin_p.def_impl(lambda x: called.append(1) or sin_impl(x)) + lax.sin_p.def_impl(lambda x, **kwargs: called.append(1) or sin_impl(x, **kwargs)) api.grad(g)(3.) finally: lax.sin_p.def_impl(sin_impl) @@ -5449,9 +5449,9 @@ def f(x): ('new_remat', new_checkpoint), ] for policy_name, policy, in_jaxpr2, not_in_jaxpr2 in [ - ('save_anything', lambda *_, **__: True, [], [' sin ', ' cos ']), - ('save_nothing', lambda *_, **__: False, [' sin ', ' cos '], []), - ('save_sin', lambda p, *_, **__: str(p) == 'sin', [' cos '], [' sin ']), + ('save_anything', lambda *_, **__: True, [], [' sin[', ' cos[[ ']), + ('save_nothing', lambda *_, **__: False, [' sin[', ' cos['], []), + ('save_sin', lambda p, *_, **__: str(p) == 'sin', [' cos['], [' sin[']), ]) def test_remat_custom_policy(self, remat, policy, in_jaxpr2, not_in_jaxpr2): for square in [lambda x: x * x, api.jit(lambda x: x * x)]: @@ -5481,8 +5481,8 @@ def test_remat_custom_policy_save_cos(self, remat): policy=save_cos) _, f_lin = api.linearize(f, 1.) jaxpr_text = str(f_lin.func.args[0]) - self.assertNotIn(' sin ', jaxpr_text) - self.assertNotIn(' cos ', jaxpr_text) + self.assertNotIn(' sin[', jaxpr_text) + self.assertNotIn(' cos[', jaxpr_text) jtu.check_grads(f, (3.,), order=2, modes=['fwd', 'rev']) @parameterized.named_parameters( @@ -5504,7 +5504,7 @@ def f(x): _, f_lin = api.linearize(f, jnp.ones((2, 2))) jaxpr_text = str(f_lin.func.args[0]) - self.assertEqual(jaxpr_text.count(' sin '), 2) + self.assertEqual(jaxpr_text.count(' sin['), 2) self.assertEqual(jaxpr_text.count(' dot_'), 6) jtu.check_grads(f, (jnp.ones((2, 2)),), order=2, modes=['fwd', 'rev']) @@ -5527,7 +5527,7 @@ def f(x): _, f_lin = api.linearize(f, jnp.ones((2, 2))) jaxpr_text = str(f_lin.func.args[0]) - self.assertEqual(jaxpr_text.count(' sin '), 2) + self.assertEqual(jaxpr_text.count(' sin['), 2) self.assertEqual(jaxpr_text.count(' dot_general'), 6) jtu.check_grads(f, (jnp.ones((2, 2)),), order=2, modes=['fwd', 'rev']) @@ -5550,7 +5550,7 @@ def f(x): _, f_lin = api.linearize(f, jnp.ones((3, 2, 2))) jaxpr_text = str(f_lin.func.args[0]) - self.assertEqual(jaxpr_text.count(' sin '), 2) + self.assertEqual(jaxpr_text.count(' sin['), 2) self.assertEqual(jaxpr_text.count(' dot_general'), 9) jtu.check_grads(f, (jnp.ones((3, 2, 2)),), order=2, modes=['fwd', 'rev']) @@ -5574,7 +5574,7 @@ def f(x): _, f_lin = api.linearize(f, jnp.ones((2, 2))) jaxpr_text = str(f_lin.func.args[0]) - self.assertEqual(jaxpr_text.count(' sin '), 2) + self.assertEqual(jaxpr_text.count(' sin['), 2) self.assertEqual(jaxpr_text.count(' dot_'), 6) jtu.check_grads(f, (jnp.ones((2, 2)),), order=2, modes=['fwd', 'rev']) @@ -5598,8 +5598,8 @@ def body(x, _): return f(x), None # Two sine calls in the backward pass because while we don't save sines # within the (rematted) body function, we can save the scan carry, which # effectively saves one sine. Three cosines for the Jacobian coefficients. - self.assertEqual(jaxpr_text.count(' sin '), 2) - self.assertEqual(jaxpr_text.count(' cos '), 3) + self.assertEqual(jaxpr_text.count(' sin['), 2) + self.assertEqual(jaxpr_text.count(' cos['), 3) # Six calls to dot_general in the backward pass because we save the primal # matmuls and only compure the backward pass ones (two for each primal one). self.assertEqual(jaxpr_text.count(' dot_'), 6) @@ -5905,8 +5905,8 @@ def test_remat_of_scan(self, remat): jtu.check_grads(remat(f), (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(remat(f), 4.)[1])(1.) - self.assertIn(' sin ', str(jaxpr)) - self.assertIn(' cos ', str(jaxpr)) + self.assertIn(' sin[', str(jaxpr)) + self.assertIn(' cos[', str(jaxpr)) @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} @@ -5951,8 +5951,8 @@ def body(x, _): return f(x), None jaxpr = f_vjp.args[0].func.args[1] jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 3) - self.assertEqual(jaxpr_text.count(' cos '), 3) + self.assertEqual(jaxpr_text.count(' sin['), 3) + self.assertEqual(jaxpr_text.count(' cos['), 3) # Six calls to dot_general in the backward pass because we save the primal # matmuls and only compute the backward pass ones (two for each primal one). self.assertEqual(jaxpr_text.count(' dot_'), 6) @@ -5969,8 +5969,8 @@ def test_remat_of_scan_policy(self): jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 0) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 0) def test_remat_of_scan_funky_custom_jvp(self): def scan_apply(f, x): @@ -5993,40 +5993,40 @@ def sin_jvp(primals, tangents): jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 0) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 0) save_sin = lambda prim, *_, **__: str(prim) == 'sin' f = new_checkpoint(partial(scan_apply, sin), policy=save_sin) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 1) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 1) f = new_checkpoint(partial(scan_apply, sin), policy=jax.checkpoint_policies.everything_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 0) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 0) f = new_checkpoint(partial(scan_apply, sin), policy=jax.checkpoint_policies.nothing_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 1) # +1 b/c dce fixed point - self.assertEqual(jaxpr_text.count(' cos '), 1) + self.assertEqual(jaxpr_text.count(' sin['), 1) # +1 b/c dce fixed point + self.assertEqual(jaxpr_text.count(' cos['), 1) f = new_checkpoint(lambda x: scan_apply(sin, scan_apply(sin, x)), policy=jax.checkpoint_policies.nothing_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 2) # +1 b/c dce fixed point - self.assertEqual(jaxpr_text.count(' cos '), 2) + self.assertEqual(jaxpr_text.count(' sin['), 2) # +1 b/c dce fixed point + self.assertEqual(jaxpr_text.count(' cos['), 2) def test_remat_of_scan_funky_custom_jvp2(self): # Like the above test but instead of using jit inside custom_jvp, use scan. @@ -6051,40 +6051,40 @@ def sin_jvp(primals, tangents): jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 1) # +1 b/c dce fixed point - self.assertEqual(jaxpr_text.count(' cos '), 0) + self.assertEqual(jaxpr_text.count(' sin['), 1) # +1 b/c dce fixed point + self.assertEqual(jaxpr_text.count(' cos['), 0) save_sin = lambda prim, *_, **__: str(prim) == 'sin' f = new_checkpoint(partial(scan_apply, sin), policy=save_sin) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 1) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 1) f = new_checkpoint(partial(scan_apply, sin), policy=jax.checkpoint_policies.everything_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 0) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 0) f = new_checkpoint(partial(scan_apply, sin), policy=jax.checkpoint_policies.nothing_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 1) # +1 b/c dce fixed point - self.assertEqual(jaxpr_text.count(' cos '), 1) + self.assertEqual(jaxpr_text.count(' sin['), 1) # +1 b/c dce fixed point + self.assertEqual(jaxpr_text.count(' cos['), 1) f = new_checkpoint(lambda x: scan_apply(sin, scan_apply(sin, x)), policy=jax.checkpoint_policies.nothing_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 2) # +1 b/c dce fixed point - self.assertEqual(jaxpr_text.count(' cos '), 2) + self.assertEqual(jaxpr_text.count(' sin['), 2) # +1 b/c dce fixed point + self.assertEqual(jaxpr_text.count(' cos['), 2) @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} @@ -6099,8 +6099,8 @@ def test_remat_of_cond(self, remat): jtu.check_grads(remat(f), (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(remat(f), 4.)[1])(1.) - self.assertNotIn(' sin ', str(jaxpr)) - self.assertIn(' cos ', str(jaxpr)) + self.assertNotIn(' sin[', str(jaxpr)) + self.assertIn(' cos[', str(jaxpr)) true_fn = lambda c: jnp.sin(jnp.sin(c)) false_fn = lambda c: c @@ -6108,8 +6108,8 @@ def test_remat_of_cond(self, remat): jtu.check_grads(remat(f), (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(remat(f), 4.)[1])(1.) - self.assertIn(' sin ', str(jaxpr)) - self.assertIn(' cos ', str(jaxpr)) + self.assertIn(' sin[', str(jaxpr)) + self.assertIn(' cos[', str(jaxpr)) @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} @@ -6149,8 +6149,8 @@ def f(x): _, f_vjp = api.vjp(f, jnp.ones((5, 5))) jaxpr_text = str(f_vjp.args[0].func.args[1]) - self.assertEqual(jaxpr_text.count(' sin '), 2) - self.assertEqual(jaxpr_text.count(' cos '), 3) + self.assertEqual(jaxpr_text.count(' sin['), 2) + self.assertEqual(jaxpr_text.count(' cos['), 3) # Five calls to dot_general in the backward pass because we have two for # each forward-pass dot, except for the first which only has one (as we are # differentiating with respect to only W and not x). @@ -6180,8 +6180,8 @@ def f(x): jaxpr = f_vjp.args[0].func.args[1] jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 2) - self.assertEqual(jaxpr_text.count(' cos '), 3) + self.assertEqual(jaxpr_text.count(' sin['), 2) + self.assertEqual(jaxpr_text.count(' cos['), 3) self.assertEqual(jaxpr_text.count(' dot_'), 5) jtu.check_grads(api.jit(f), (jnp.ones((5, 5)),), order=2, @@ -6195,8 +6195,8 @@ def test_remat_of_cond_policy(self): jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 0) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 0) def test_remat_of_cond_funky_custom_jvp(self): def cond_apply(f, x): @@ -6218,40 +6218,40 @@ def sin_jvp(primals, tangents): jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 0) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 0) save_sin = lambda prim, *_, **__: str(prim) == 'sin' f = new_checkpoint(partial(cond_apply, sin), policy=save_sin) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 1) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 1) f = new_checkpoint(partial(cond_apply, sin), policy=jax.checkpoint_policies.everything_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 0) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 0) f = new_checkpoint(partial(cond_apply, sin), policy=jax.checkpoint_policies.nothing_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 1) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 1) f = new_checkpoint(lambda x: cond_apply(sin, cond_apply(sin, x)), policy=jax.checkpoint_policies.nothing_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 1) - self.assertEqual(jaxpr_text.count(' cos '), 2) + self.assertEqual(jaxpr_text.count(' sin['), 1) + self.assertEqual(jaxpr_text.count(' cos['), 2) def test_remat_of_cond_funky_custom_jvp2(self): # Like the above test but instead of using jit inside custom_jvp, use cond. @@ -6275,40 +6275,40 @@ def sin_jvp(primals, tangents): jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 0) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 0) save_sin = lambda prim, *_, **__: str(prim) == 'sin' f = new_checkpoint(partial(cond_apply, sin), policy=save_sin) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 1) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 1) f = new_checkpoint(partial(cond_apply, sin), policy=jax.checkpoint_policies.everything_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 0) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 0) f = new_checkpoint(partial(cond_apply, sin), policy=jax.checkpoint_policies.nothing_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 1) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 1) f = new_checkpoint(lambda x: cond_apply(sin, cond_apply(sin, x)), policy=jax.checkpoint_policies.nothing_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 1) - self.assertEqual(jaxpr_text.count(' cos '), 2) + self.assertEqual(jaxpr_text.count(' sin['), 1) + self.assertEqual(jaxpr_text.count(' cos['), 2) @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} @@ -6333,8 +6333,8 @@ def f(x): self.assertArraysAllClose(y_dot, expected, check_dtypes=False) jaxpr = api.make_jaxpr(jax.linearize(remat(f), 4.)[1])(1.) - self.assertIn(' sin ', str(jaxpr)) - self.assertIn(' cos ', str(jaxpr)) + self.assertIn(' sin[', str(jaxpr)) + self.assertIn(' cos[', str(jaxpr)) def test_remat_of_while_loop_policy(self): def cond_fn(carry): @@ -6351,8 +6351,8 @@ def f(x): save_cos = lambda prim, *_, **__: str(prim) == 'cos' g = new_checkpoint(f, policy=save_cos) jaxpr = api.make_jaxpr(jax.linearize(g, 4.)[1])(1.) - self.assertIn(' sin ', str(jaxpr)) - self.assertIn(' cos ', str(jaxpr)) + self.assertIn(' sin[', str(jaxpr)) + self.assertIn(' cos[', str(jaxpr)) @jtu.thread_unsafe_test() # logging isn't thread-safe def test_remat_residual_logging(self): diff --git a/tests/core_test.py b/tests/core_test.py index c46d493bda54..03d6355cb257 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -474,8 +474,8 @@ def new_jaxpr(): # jaxpr is: # # { lambda ; a. - # let b = sin a - # c = cos a + # let b = sin[accuracy=None] a + # c = cos[accuracy=None] a # d = add b c # in (d,) } # @@ -487,7 +487,7 @@ def new_jaxpr(): self.assertRaisesRegex( core.JaxprTypeError, r"Value for variable 'b' inconsistently typed as f32\[\] " - r"for let-binder of type i32\[\]\n\nin equation:\n\nb:i32\[\] = sin a", + r"for let-binder of type i32\[\]\n\nin equation:\n\nb:i32\[\] = sin\[accuracy=None] a", lambda: core.check_jaxpr(jaxpr)) jaxpr = new_jaxpr() @@ -496,7 +496,7 @@ def new_jaxpr(): self.assertRaisesRegex( core.JaxprTypeError, r"Value for variable 'b' inconsistently typed as f32\[\] " - r"for let-binder of type f32\[2,3\]\n\nin equation:\n\nb:f32\[2,3\] = sin a", + r"for let-binder of type f32\[2,3\]\n\nin equation:\n\nb:f32\[2,3\] = sin\[accuracy=None] a", lambda: core.check_jaxpr(jaxpr)) def test_jaxpr_dropvar_from_jit_call(self): diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 8d5dc471e847..f5b70878533d 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -204,7 +204,7 @@ def select_n_strategy( # TODO(sharadmv,apaszke): enable zero dim sizes # TODO(sharadmv,apaszke): enable one dim sizes ( - lax.neg_p, + lax.neg_p, {}, make_shape_dtype_strategy( min_rank=2, max_rank=3, @@ -214,7 +214,7 @@ def select_n_strategy( ), ), ( - lax.not_p, + lax.not_p, {}, make_shape_dtype_strategy( min_rank=2, max_rank=3, @@ -226,6 +226,7 @@ def select_n_strategy( *[ ( prim, + params, make_shape_dtype_strategy( min_rank=2, max_rank=3, @@ -234,23 +235,23 @@ def select_n_strategy( valid_dtypes=[jnp.dtype("float32")], ), ) - for prim in [ - lax.exp_p, - lax.tanh_p, - lax.logistic_p, - lax.rsqrt_p, - lax.log_p, - lax.exp2_p, - lax.abs_p, - lax.log1p_p, - lax.sin_p, - lax.sqrt_p, + for prim, params in [ + (lax.abs_p, {}), + (lax.exp_p, {"accuracy": None}), + (lax.tanh_p, {"accuracy": None}), + (lax.logistic_p, {"accuracy": None}), + (lax.rsqrt_p, {"accuracy": None}), + (lax.log_p, {"accuracy": None}), + (lax.exp2_p, {"accuracy": None}), + (lax.log1p_p, {"accuracy": None}), + (lax.sin_p, {"accuracy": None}), + (lax.sqrt_p, {"accuracy": None}), ] ], ] UNARY_FUNCTIONS = [ - (prim.name, prim.bind, strategy) for prim, strategy in UNARY_PRIMITIVES + (prim.name, functools.partial(prim.bind, **params), strategy) for prim, params, strategy in UNARY_PRIMITIVES ] + [ ( name, diff --git a/tests/pmap_test.py b/tests/pmap_test.py index af2d03e2945d..d40293501edf 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -2082,8 +2082,8 @@ def test_remat_of_pmap(self, remat): x = jnp.arange(1.) jaxpr = jax.make_jaxpr(jax.linearize(f, x)[1])(x) - self.assertIn(' sin ', str(jaxpr)) - self.assertIn(' cos ', str(jaxpr)) + self.assertIn(' sin[', str(jaxpr)) + self.assertIn(' cos[', str(jaxpr)) @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} @@ -2100,24 +2100,24 @@ def test_remat_of_pmap_policy(self, remat): _, f_vjp = jax.vjp(f, x) jaxpr = f_vjp.args[0].func.args[1] jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 0) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 0) save_sin = lambda prim, *_, **__: str(prim) == 'sin' f = remat(g, policy=save_sin) _, f_vjp = jax.vjp(f, x) jaxpr = f_vjp.args[0].func.args[1] jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 2) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 2) save_nothing = lambda prim, *_, **__: False f = remat(g, policy=save_nothing) _, f_vjp = jax.vjp(f, x) jaxpr = f_vjp.args[0].func.args[1] jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 1) - self.assertEqual(jaxpr_text.count(' cos '), 2) + self.assertEqual(jaxpr_text.count(' sin['), 1) + self.assertEqual(jaxpr_text.count(' cos['), 2) def test_axis_name_shadowing_with_vmap(self): # vmap-of-pmap with mismatched axis sizes diff --git a/tests/unary_ops_accuracy_test.py b/tests/unary_ops_accuracy_test.py new file mode 100644 index 000000000000..fb370ab96923 --- /dev/null +++ b/tests/unary_ops_accuracy_test.py @@ -0,0 +1,373 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit test for result accuracy for unary ops.""" + +from typing import Any, Callable, NamedTuple, Union +import unittest + +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax._src import config +from jax._src import test_util as jtu +from jax._src.lax import lax +from jax._src.lib import xla_extension +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import hlo +import jax.numpy as jnp +import numpy as np + + +config.parse_flags_with_absl() + + +class TolerancePair(NamedTuple): + high: Union[lax.Tolerance, lax.AccuracyMode] = lax.AccuracyMode.DEFAULT + low: Union[lax.Tolerance, lax.AccuracyMode] = lax.AccuracyMode.DEFAULT + + +def make_unary_test_cases( + testcase_name: str, + op: Callable[..., Any], + x: np.ndarray, + tp: TolerancePair = None, + min_error_val: float = 0.0, +): + """Creates a single test case.""" + return [{ + "testcase_name": testcase_name, + "op": op, + "x": x, + "tp": tp, + "min_error_val": min_error_val, + }] + + +UNARY_OPS = { + "exp": make_unary_test_cases( + "exp", + lax.exp, + np.arange(84.0, 88.0, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=2**-5, rtol=2**-5, ulps=2), + low=lax.Tolerance(atol=1.5 * 2**-8, rtol=2**-18, ulps=2), + ), + ), + "exp2": make_unary_test_cases( + "exp2", + lax.exp2, + np.arange(84.0, 88.0, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=2**-5, rtol=2**-5, ulps=2), + low=lax.Tolerance(atol=1.5 * 2**-8, rtol=2**-18, ulps=2), + ), + ), + "expm1": make_unary_test_cases( + "expm1", + lax.expm1, + np.arange(84.0, 88.0, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=2**-5, rtol=2**-5, ulps=2), + low=lax.Tolerance(atol=1.5 * 2**-8, rtol=2**-18, ulps=2), + ), + ), + "log": make_unary_test_cases( + "log", + lax.log, + np.linspace(1e28, 2e28, 10, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=0, rtol=2**-10, ulps=0), + low=lax.Tolerance(atol=2**-16, rtol=2**-20, ulps=0), + ), + 1.0, + ), + "log1p": make_unary_test_cases( + "log1p", + lax.log1p, + np.linspace(-9e-8, -8e-8, 10, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=0, rtol=2**-11, ulps=0), + low=lax.Tolerance(atol=0, rtol=2**-14, ulps=0), + ), + 1.0, + ), + "tanh": make_unary_test_cases( + "tanh", + lax.tanh, + np.linspace(5.83, 5.86, 10, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=2**-12, rtol=0, ulps=0), + low=lax.Tolerance(atol=2**-16, rtol=0, ulps=0), + ), + ), + "cos": make_unary_test_cases( + "cos", + lax.cos, + np.linspace(9.7e22, 9.8e22, 10, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=0, rtol=2**-10, ulps=0), + low=lax.Tolerance(atol=0, rtol=2**-30, ulps=0), + ), + ), + "sin": make_unary_test_cases( + "sin", + lax.sin, + np.linspace(9.7e22, 9.8e22, 10, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=0, rtol=2**-10, ulps=0), + low=lax.Tolerance(atol=0, rtol=2**-30, ulps=0), + ), + ), + "tan": make_unary_test_cases( + "tan", + lax.tan, + np.linspace(250.0, 252.0, 10, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=0, rtol=2**-10, ulps=0), + low=lax.Tolerance(atol=0, rtol=2**-30, ulps=0), + ), + ), + "sqrt": make_unary_test_cases( + "sqrt", + lax.sqrt, + np.linspace(250.0, 252.0, 10, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=0, rtol=2**-10, ulps=0), + low=lax.Tolerance(atol=0, rtol=2**-30, ulps=0), + ), + ), + "rsqrt": make_unary_test_cases( + "rsqrt", + lax.rsqrt, + np.linspace(250.0, 252.0, 10, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=0, rtol=2**-10, ulps=0), + low=lax.Tolerance(atol=0, rtol=2**-30, ulps=0), + ), + ), +} + + +def generate_test_cases(op_names): + test_cases = [] + for op in op_names: + op_group = UNARY_OPS[op] + if op_group is None: + raise ValueError(f"No test cases found for op: {op}") + test_cases.extend(op_group) + return test_cases + + +@unittest.skipIf(not jtu.is_device_tpu(), "Skipping test on non TPU devices.") +class UnaryOpsAccuracyTest(jtu.JaxTestCase): + + def test_result_accuracy_mode_attr(self): + with ir.Context() as context: + hlo.register_dialect(context) + attr = hlo.ResultAccuracyModeAttr.get("DEFAULT") + assert attr is not None + assert attr.value == "DEFAULT" + + def test_result_accuracy_attr(self): + with ir.Context() as context: + hlo.register_dialect(context) + attr = hlo.ResultAccuracyAttr.get( + atol=1e-5, rtol=0.0, ulps=1, mode="TOLERANCE" + ) + assert attr is not None + assert attr.mode == "TOLERANCE" + assert attr.atol == 1e-5 + assert attr.rtol == 0.0 + assert attr.ulps == 1 + + @parameterized.named_parameters( + *generate_test_cases(["exp", "expm1", "exp2", "log", "log1p", "tanh"]) + ) + def test_unary_ops_choose_impl(self, op, x, tp, **kwargs): + @jax.jit + def f_default(x): + y = op(x, accuracy=tp.high) + return y + + @jax.jit + def f_accurate(x): + y = op(x, accuracy=tp.low) + return y + + # Input values that would cause large differences between the two + # implementations. + diff = abs(f_default(x) - f_accurate(x)) + if jtu.get_tpu_version() >= 5 and op in [ + lax.tanh, + jnp.tanh, + lax.log, + jnp.log, + ]: + # From tpu version 5 and onwards, even with tighter tolerance, the high performant + # implementation for tanh is chosen because the chip implementation has improved accuracy. + self.assertTrue(jnp.all(diff == 0)) + else: + self.assertTrue(jnp.any(diff > 0)) + + @parameterized.named_parameters( + *generate_test_cases(["exp", "expm1", "exp2", "log", "log1p", "tanh"]) + ) + def test_unary_vmap(self, op, x, tp, min_error_val): + @jax.jit + def f(x, y): + diff = lambda val: abs( + op(val, accuracy=tp.high) - op(val, accuracy=tp.low) + ) + return diff(x), diff(y) + + diff_x, diff_y = jax.vmap(f, in_axes=(None, 0), out_axes=0)( + min_error_val, x + ) + # diff(min_error_val) should be 0 + self.assertTrue(jnp.all(diff_x == 0)) + # diff(x) should be > 0 + if jtu.get_tpu_version() >= 5 and op in [ + lax.tanh, + jnp.tanh, + lax.log, + jnp.log, + ]: + # From tpu version 5 and onwards, even with tighter tolerance, the high performant + # implementation for tanh and log is chosen because the chip implementation has improved accuracy. + self.assertTrue(jnp.all(diff_y == 0)) + else: + self.assertTrue(jnp.any(diff_y > 0)) + + @parameterized.named_parameters( + *generate_test_cases(["exp", "expm1", "exp2"]) + ) + def test_diff_grad(self, op, x, tp, **kwargs): + @jax.jit + def f_default(x): + default_op = op(x, accuracy=tp.low) + return jnp.sum(default_op) + + f_default_grad = jax.grad(f_default) + + @jax.jit + def f_accurate(x): + high_op = op(x, accuracy=tp.high) + return jnp.sum(high_op) + + f_accurate_grad = jax.grad(f_accurate) + # Accuracy should be carried through to the gradient causing + # a large diff. + diff = abs(f_default_grad(x) - f_accurate_grad(x)) + self.assertTrue(jnp.any(diff > 0)) + + @parameterized.named_parameters( + *generate_test_cases(["log", "log1p", "tanh"]) + ) + def test_grad_unchanged(self, op, x, tp, **kwargs): + @jax.jit + def f(x): + return jnp.sum(op(x)) + + f_grad = jax.grad(f) + + @jax.jit + def f_default(x): + default_op = op(x, accuracy=tp.low) + return jnp.sum(default_op) + + f_default_grad = jax.grad(f_default) + + @jax.jit + def f_accurate(x): + high_op = op(x, accuracy=tp.high) + return jnp.sum(high_op) + + f_accurate_grad = jax.grad(f_accurate) + # Accuracy should be carried through to the gradient causing a large diff. + # Diff between f_default and f_accurate should follow diff(f_grad,f_default_grad). + expected_diff = abs(f_grad(x) - f_default_grad(x)) + if jnp.all(expected_diff > 0): + # Don't expect f_accurate_grad and f_default_grad to be equal. + self.assertFalse( + jnp.all(abs(f_default_grad(x) - f_accurate_grad(x)) == 0) + ) + elif jnp.all(expected_diff == 0): + # f_accurate_grad and f_default_grad should be equal. + diff = abs(f_default_grad(x) - f_accurate_grad(x)) + self.assertTrue(jnp.all(diff == 0)) + else: + raise ValueError("Unexpected diff: ", expected_diff) + + @parameterized.named_parameters( + *generate_test_cases(["cos", "sin", "tan", "sqrt", "rsqrt"]) + ) + def test_single_impl(self, op, x, tp, **kwargs): + @jax.jit + def f_tol(x): + return op(x, accuracy=tp.high) + + @jax.jit + def f(x): + return op(x) + + diff = abs(f_tol(x) - f(x)) + self.assertTrue(jnp.all(diff == 0)) + + @parameterized.named_parameters( + *generate_test_cases(["cos", "sin", "tan", "sqrt", "rsqrt"]) + ) + def test_default_grad(self, op, x, tp, **kwargs): + @jax.jit + def f_tol(x): + return jnp.sum(op(x, accuracy=tp.high)) + + @jax.jit + def f(x): + return jnp.sum(op(x)) + + self.assertTrue(jnp.all(abs(jax.grad(f_tol)(x) - jax.grad(f)(x)) == 0)) + + def test_invalid_accuracy(self): + with self.assertRaisesRegex( + ValueError, "At least one of atol, rtol, or ulps must be set." + ): + lax.exp(1.0, accuracy=lax.Tolerance(atol=0.0, rtol=0.0, ulps=0)) + with self.assertRaisesRegex(ValueError, "Tolerances must be non-negative."): + lax.exp(1.0, accuracy=lax.Tolerance(atol=-4e-10, rtol=0.0, ulps=0)) + + @parameterized.named_parameters( + *generate_test_cases([ + "exp", + "expm1", + "exp2", + "log", + "log1p", + "tanh", + "cos", + "sin", + "tan", + "sqrt", + "rsqrt", + ]) + ) + def test_low_tol(self, op, x, **kwargs): + with self.assertRaisesRegex( + xla_extension.XlaRuntimeError, "impl_type.ok()" + ): + op(x, accuracy=lax.Tolerance(atol=1e-60, rtol=1e-60, ulps=0)) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) From c5aa86a41a8a6ec1d66b072080377c26c09512a8 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 27 Mar 2025 22:25:12 -0700 Subject: [PATCH 239/483] Remove redundant filtering in the paged flash attention kernel Reason: `l_next >= 1.0` so the `jnp.where(l_next == 0.0, 1.0, l_next)` clause is not needed. PiperOrigin-RevId: 741400472 --- .../pallas/ops/tpu/paged_attention/paged_attention_kernel.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py b/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py index eb1e11df17da..99cb2c9c94c1 100644 --- a/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py +++ b/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py @@ -274,14 +274,13 @@ def prefetch_next_block(): # pylint: disable=unused-variable alpha = jnp.exp(m_prev - m_next) beta = jnp.exp(m_curr - m_next) l_next = alpha * l_prev + beta * l_curr - l_next_safe = jnp.where(l_next == 0.0, 1.0, l_next) + m_ref[...], l_ref[...] = m_next, l_next v = async_copy_v.wait_and_get_loaded() o_curr_times_l_curr = jnp.dot(s_curr, v) - m_ref[...], l_ref[...] = m_next, l_next_safe o_ref[...] = ( - (l_prev * alpha * o_ref[...] + beta * o_curr_times_l_curr) / l_next_safe + (l_prev * alpha * o_ref[...] + beta * o_curr_times_l_curr) / l_next ).astype(o_ref.dtype) step_ref[0] = step + 1 From efa5ae8e9831dc0e510ff1b59d61615a2898925a Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 28 Mar 2025 04:26:18 -0700 Subject: [PATCH 240/483] Update XLA dependency to use revision http://github.com/openxla/xla/commit/edfd919df316d687b2d3b08bbc8d9c32f4bcc1c4. PiperOrigin-RevId: 741478215 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 625f33a072f5..43bba2fcc903 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "95abd7942747bd5d1884b309baecdf5a93ff928a" -XLA_SHA256 = "f8472323ffe621ade5317091fdf9acd66aaf67660fedd3143a96d9a347e88bac" +XLA_COMMIT = "edfd919df316d687b2d3b08bbc8d9c32f4bcc1c4" +XLA_SHA256 = "d82a7174a8a129180b180b08f5eedfa5fe6ff19fbd46dc11dae8cf64d87dfbf9" def repo(): tf_http_archive( From 063654000c148699383bad1656e23d808f76a97e Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 28 Mar 2025 11:19:06 +0000 Subject: [PATCH 241/483] Marked as thread_unsafe_test: - ShardingInTypesTest.test_set_mesh - APITest.test_cache_clear_pmap This helps to prevent errors like: 1) in pjit_test.py: ``` ValueError: For primitive mul, context mesh AbstractMesh('x': 2, axis_types=(Explicit,)) should match the aval mesh AbstractMesh('x': 2, 'y': 1, axis_types=(Auto, Auto)) for shape float32[8,2] ``` raised for example by ArrayPjitTest.test_pjit_array_multi_input_multi_output_mesh3 and also by ArrayPjitTest.test_convert_element_type_sharding, when pjit tests are run concurrently with `--local_test_jobs=32` and `--test_env=JAX_TEST_NUM_THREADS=8` 2) in api_test.py ``` AssertionError: Expected exactly 1 XLA compilations, but executed 2 ``` raised by APITest.test_pmap_global_cache. --- tests/api_test.py | 1 + tests/pjit_test.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tests/api_test.py b/tests/api_test.py index 82b673fe4b1e..e99189c671d0 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -4424,6 +4424,7 @@ def test_grad_conj_symbolic_zeros(self): out = jax.grad(f)(3.0) # doesn't crash self.assertAllClose(out, 1., check_dtypes=False) + @jtu.thread_unsafe_test() def test_cache_clear_pmap(self): @jax.pmap def f(i): diff --git a/tests/pjit_test.py b/tests/pjit_test.py index d6673c6b6d5a..0da97a6f0c14 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7162,6 +7162,7 @@ def f(x): out = f(np.arange(8)) self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + @jtu.thread_unsafe_test() def test_set_mesh(self): mesh = jtu.create_mesh((2,), ('x',), axis_types=(AxisType.Explicit,)) try: From 30451478c05e9bae9caaba09cf6b5a15805b3808 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 28 Mar 2025 05:09:30 -0700 Subject: [PATCH 242/483] [Pallas][NFC] Move the remainder of Semaphore-related extended dtypes to Pallas core This completes the move started in https://github.com/jax-ml/jax/pull/26673. PiperOrigin-RevId: 741487331 --- jax/_src/pallas/core.py | 52 ++++++++++++++++++++++++++-- jax/_src/pallas/mosaic/core.py | 41 +++------------------- jax/_src/pallas/mosaic/primitives.py | 2 +- 3 files changed, 54 insertions(+), 41 deletions(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 26101405fdeb..8602205eef22 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -67,9 +67,55 @@ def __repr__(self): SEMAPHORE_INTERPRET_DTYPE = jnp.int16 SEMAPHORE_MAX_VALUE = jnp.iinfo(SEMAPHORE_INTERPRET_DTYPE).max -class semaphore_dtype(dtypes.extended): pass -class semaphore(semaphore_dtype): pass -class barrier_semaphore(semaphore_dtype): pass +class AbstractSemaphoreTyRules: + @staticmethod + def pallas_interpret_element_aval(_) -> jax_core.ShapedArray: + return jax_core.ShapedArray((), SEMAPHORE_INTERPRET_DTYPE) + + @staticmethod + def physical_element_aval(_) -> jax_core.ShapedArray: + return jax_core.ShapedArray((), jnp.int32) + +# TODO(sharadmv): implement dtype rules for AbstractSemaphoreTy +class AbstractSemaphoreTy(dtypes.ExtendedDType): + name: str + _rules = AbstractSemaphoreTyRules + + def __repr__(self) -> str: + return self.name + + def __eq__(self, other): + return self.__class__ == other.__class__ + + def __hash__(self) -> int: + return hash(self.__class__) + +class semaphore_dtype(dtypes.extended): + """Common dtype for all kinds of semaphore dtypes. + + This is an abstract class that should never be instantiated, but rather + exists for the sake of `jnp.issubdtype`. + """ + +class semaphore(semaphore_dtype): + """Regular semaphore dtype. + + Like its superclass, this class should never be instantiated. + """ + +class Semaphore(AbstractSemaphoreTy): + name = "semaphore" + type = semaphore + +class barrier_semaphore(semaphore_dtype): + """Barrier semaphore dtype. + + Like its superclass, this class should never be instantiated. + """ + +class BarrierSemaphore(AbstractSemaphoreTy): + name = "barrier_semaphore" + type = barrier_semaphore @runtime_checkable class CompilerParams(Protocol): diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index 5df3c01a1934..37b6e51892c7 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -25,7 +25,6 @@ import jax from jax._src import config from jax._src import core as jax_core -from jax._src import dtypes from jax._src import util from jax._src.pallas import core as pallas_core import jax.numpy as jnp @@ -114,42 +113,10 @@ def __call__(self, shape: tuple[int, ...], dtype: jnp.dtype): class dma_semaphore(pallas_core.semaphore_dtype): pass -class AbstractSemaphoreTyRules: - @staticmethod - def pallas_interpret_element_aval(_) -> jax_core.ShapedArray: - return jax_core.ShapedArray((), pallas_core.SEMAPHORE_INTERPRET_DTYPE) - - @staticmethod - def physical_element_aval(_) -> jax_core.ShapedArray: - return jax_core.ShapedArray((), jnp.int32) - -class AbstractSemaphoreTy(dtypes.ExtendedDType): - name: str - _rules = AbstractSemaphoreTyRules - - def __repr__(self) -> str: - return self.name - - def __eq__(self, other): - return self.__class__ == other.__class__ - - def __hash__(self) -> int: - return hash(self.__class__) - -# TODO(sharadmv): implement dtype rules for AbstractSemaphoreTy - -class SemaphoreTy(AbstractSemaphoreTy): - type = pallas_core.semaphore - name = "sem" - -class DmaSemaphoreTy(AbstractSemaphoreTy): +class DMASemaphore(pallas_core.AbstractSemaphoreTy): type = dma_semaphore name = "dma_sem" -class BarrierSemaphoreTy(AbstractSemaphoreTy): - type = pallas_core.barrier_semaphore - name = "barrier_sem" - class SemaphoreType(enum.Enum): REGULAR = "regular" DMA = "dma" @@ -158,11 +125,11 @@ class SemaphoreType(enum.Enum): def __call__(self, shape: tuple[int, ...]): dtype: Any if self == SemaphoreType.DMA: - dtype = DmaSemaphoreTy() + dtype = DMASemaphore() elif self == SemaphoreType.BARRIER: - dtype = BarrierSemaphoreTy() + dtype = pallas_core.BarrierSemaphore() else: - dtype = SemaphoreTy() + dtype = pallas_core.Semaphore() return pallas_core.MemoryRef(shape, dtype, TPUMemorySpace.SEMAPHORE) def get_array_aval(self) -> pallas_core.ShapedArrayWithMemorySpace: diff --git a/jax/_src/pallas/mosaic/primitives.py b/jax/_src/pallas/mosaic/primitives.py index 106f342bace8..c50a21218117 100644 --- a/jax/_src/pallas/mosaic/primitives.py +++ b/jax/_src/pallas/mosaic/primitives.py @@ -623,7 +623,7 @@ def async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id, @get_barrier_semaphore_p.def_abstract_eval def _get_barrier_semaphore_abstract_eval(): return pl_core.AbstractMemoryRef( - jax_core.ShapedArray((), tpu_core.BarrierSemaphoreTy()), + jax_core.ShapedArray((), pl_core.BarrierSemaphore()), tpu_core.TPUMemorySpace.SEMAPHORE, ) From 1c1e2e6dc0d521fa01750cbff7d61c1c130897f1 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 28 Mar 2025 05:11:12 -0700 Subject: [PATCH 243/483] [Mosaic GPU] Add support for stores to TMEM We can support reading and writing of both 32- and 16-bit types now. PiperOrigin-RevId: 741487690 --- jax/experimental/mosaic/gpu/tcgen05.py | 221 ++++++++++++++++++++----- tests/mosaic/gpu_test.py | 37 +++++ 2 files changed, 212 insertions(+), 46 deletions(-) diff --git a/jax/experimental/mosaic/gpu/tcgen05.py b/jax/experimental/mosaic/gpu/tcgen05.py index ac3b80b93689..53056ce594b2 100644 --- a/jax/experimental/mosaic/gpu/tcgen05.py +++ b/jax/experimental/mosaic/gpu/tcgen05.py @@ -327,24 +327,30 @@ def tmem_relinquish_alloc_permit(): has_side_effects=True, ) -def tmem_load(tmem_addr, shape, num, packing: int = 1): +def _tmem_access_helper(shape, num, packing: int = 1): if num.bit_count() != 1 or num > 128: raise ValueError(f"num must be a power of 2 and <= 128, got: {num}") match shape: case "16x128b": - num_out_regs = 2 + num_regs = 2 case "16x256b": - num_out_regs = 4 + num_regs = 4 case _: raise NotImplementedError(f"{shape=} is unsupported") - if num * num_out_regs >= 256: + num_regs *= num + if num_regs > 255: raise ValueError( - f"Loading too much TMEM at once: {num=} and each load requires" - f" {num_out_regs} registers, which exceeds the limit of 256" + f"TMEM transation too big : {shape=} and {num=} involve" + f" {num_regs} registers per-thread, which exceeds the limit of 255" ) - num_out_regs *= num + regs_vector = ",".join(f"${i}" for i in range(num_regs)) + regs_vector = "{" + regs_vector + "}" + return num_regs, regs_vector + + +def tmem_load(tmem_addr, shape, num, packing: int = 1): i32 = ir.IntegerType.get_signless(32) - out_regs = ",".join("$" + str(i) for i in range(num_out_regs)) + num_out_regs, regs_vector = _tmem_access_helper(shape, num, packing) if packing == 1: pack_mod = "" elif packing == 2: @@ -356,13 +362,30 @@ def tmem_load(tmem_addr, shape, num, packing: int = 1): "!llvm.struct<(" + ",".join("i32" for _ in range(num_out_regs)) + ")>" ), [tmem_addr], - f"tcgen05.ld.sync.aligned.{shape}.x{num}{pack_mod}.b32 {{{out_regs}}}, [${num_out_regs}];", + f"tcgen05.ld.sync.aligned.{shape}.x{num}{pack_mod}.b32 {regs_vector}, [${num_out_regs}];", "=r," * num_out_regs + "r", has_side_effects=True, ) return [llvm.extractvalue(i32, regs, [i]) for i in range(num_out_regs)] +def tmem_store(tmem_addr, shape, num, regs, packing: int = 1): + num_out_regs, regs_vector = _tmem_access_helper(shape, num, packing) + if packing == 1: + pack_mod = "" + elif packing == 2: + pack_mod = ".unpack::16b" + else: + raise ValueError(f"Unsupported packing: {packing}") + llvm.inline_asm( + ir.Type.parse("!llvm.void"), + [*regs, tmem_addr], + f"tcgen05.st.sync.aligned.{shape}.x{num}{pack_mod}.b32 [${num_out_regs}], {regs_vector};", + "r," * num_out_regs + "r", + has_side_effects=True, + ) + + @dataclasses.dataclass(frozen=True) class TMEMLayout: """Represents the way a shape is laid out in TMEM. @@ -562,62 +585,168 @@ def __getitem__(self, *idxs): ) return fa.FragmentedArray(_registers=registers, _layout=layout, _is_signed=None) + def __setitem__(self, idxs, value): + if not isinstance(idxs, tuple): + idxs = (idxs,) + base_idxs, slice_shape, is_squeezed = utils.parse_indices(idxs, self.shape) + if any(is_squeezed): + raise ValueError( + "TMEM stores don't support integer indexing (only slices allowed)" + ) + if any(idx != 0 for idx in base_idxs) or tuple(slice_shape) != self.shape: + raise NotImplementedError("Slicing parts of TMEM not implemented yet") + if self.shape[1] % 8: + raise NotImplementedError + if utils.bitwidth(self.dtype) not in {16, 32}: + raise NotImplementedError(f"Unsupported dtype: {self.dtype}") + if not isinstance(value, fa.FragmentedArray): + raise ValueError(f"TMEM stores expect a FragmentedArray, got: {value}") + if value.shape != self.shape: + raise ValueError( + f"Stored array has shape {value.shape}, but TMEM has shape" + f" {self.shape}" + ) + if value.mlir_dtype != self.dtype: + raise ValueError( + f"Stored array has dtype {value.mlir_dtype}, but TMEM has dtype" + f" {self.dtype}" + ) + if value.layout != LAYOUT: + raise ValueError( + f"Stored array has layout {value.layout}, but only tcgen05.LAYOUT is" + " supported" + ) + if self.layout == TMEMLayout(elements_in_tile=(TMEM_ROWS, 8)): + # store_32xcols needs a 4xN array, but the FA tiling we use here tiles + # columns before rows, and so it is Nx4 (after ignoring all 1 dims). + _store_32xcols( + self.address, value.registers.T.reshape((4, -1)) + ) + else: # TODO(apaszke): Collective MMA layout + raise NotImplementedError( + f"Stores only implemented for refs with standard layout, got: {self.layout}" + ) + + +def _transfer_32xcols(base_addr, cols): + i32 = ir.IntegerType.get_signless(32) + cols_per_num = 8 # Here we generate a plan compatible with tcgen05.LAYOUT. + assert cols % cols_per_num == 0 + total_num = cols // cols_per_num + if total_num <= 32: + instr_num = total_num + elif total_num == 64: + instr_num = 32 + else: + raise NotImplementedError(total_num) + # We transfer 16 lanes at a time, but have 32 to deal with. + for lane_step in range(2): + addr_row = arith.addi(base_addr, utils.c((lane_step * 16) << 16, i32)) + cols_per_instr = instr_num * cols_per_num + for num_step in range(total_num // instr_num): + num_slice = slice(num_step * instr_num, (num_step + 1) * instr_num) + addr_row_col = arith.addi(addr_row, utils.c(num_step * cols_per_instr, i32)) + yield addr_row_col, instr_num, lane_step, num_slice + + +def _store_32xcols(base_addr, vector_regs): + i32 = ir.IntegerType.get_signless(32) + assert vector_regs.ndim == 2 and vector_regs.shape[0] == 4 + cols = vector_regs.shape[1] * 8 + + packing = 64 // utils.bitwidth(vector_regs.flat[0].type) + if packing == 1: + store_shape = "16x256b" # 4 threads * 64 bits per vreg = 256 bits + regs = np.empty((4, vector_regs.shape[1], 2), dtype=object) + c0 = arith.constant(i32, 0) + c1 = arith.constant(i32, 1) + for idx, vreg in np.ndenumerate(vector_regs): + regs[(*idx, 0)] = llvm.extractelement(vreg, c0) + regs[(*idx, 1)] = llvm.extractelement(vreg, c1) + regs = regs.reshape(2, 2, vector_regs.shape[1], 2).swapaxes(1, 2) + # From a single lane perspective a num tile consists of a 2x2, with the + # minor dim traversing columns and major being 8 rows apart. + # See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16256b + assert regs.shape[-2:] == (2, 2) + elif packing == 2: + store_shape = "16x128b" # 4 threads * 32 bits per vreg = 128 bits + # From a single lane perspective a num tile has 2 registers, 8 rows apart. + # See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16128b + regs = vector_regs.reshape(2, 2, vector_regs.shape[1]).swapaxes(1, 2) + else: + raise NotImplementedError(packing) + + it = _transfer_32xcols(base_addr, cols) + for addr_row_col, instr_num, lane_step, num_slice in it: + regs_slice = regs[lane_step, num_slice].flat + tmem_store(addr_row_col, store_shape, instr_num, regs_slice, packing) + def _load_32xcols(base_addr, cols, dtype): - # See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16256b i32 = ir.IntegerType.get_signless(32) + vec_ty = ir.VectorType.get((2,), dtype) packing = 32 // utils.bitwidth(dtype) if packing == 1: - load_shape = "16x256b" # 8 columns * 32 bits = 256 bits - cols_per_num_tile = 8 * packing + load_shape = "16x256b" # 4 threads * 64 bits per vreg = 256 bits elif packing == 2: - load_shape = "16x128b" # 8 columns * 16 bits = 128 bits - cols_per_num_tile = 4 * packing + load_shape = "16x128b" # 4 threads * 32 bits per vreg = 128 bits else: raise NotImplementedError(packing) - assert cols % cols_per_num_tile == 0 - num = cols // cols_per_num_tile - if num <= 32: - num_tiling = num - elif num == 64: - num_tiling = 32 - else: - raise NotImplementedError(num) + vector_regs = np.ndarray((4, cols // 8), dtype=object) - # We load 16 lanes at a time, but need 32 in total. - for row_group in range(2): - addr_row = arith.addi(base_addr, arith.constant(i32, (row_group * 16) << 16)) - regs = [] - for num_group in range(num // num_tiling): - addr_row_col = arith.addi( - addr_row, - arith.constant(i32, num_tiling * num_group * cols_per_num_tile), - ) - regs += tmem_load(addr_row_col, load_shape, num_tiling, packing) + + it = _transfer_32xcols(base_addr, cols) + c0 = arith.constant(i32, 0) + c1 = arith.constant(i32, 1) + for addr_row_col, instr_num, lane_step, num_slice in it: + regs = tmem_load(addr_row_col, load_shape, instr_num, packing) + row_slice = slice(lane_step * 2, (lane_step + 1) * 2) + # This aliases the original array, so updates will be reflected there. + vector_regs_update = vector_regs[row_slice, num_slice] + assert vector_regs_update.shape == (2, instr_num), (vector_regs_update.shape, instr_num) if packing == 1: regs = [llvm.bitcast(dtype, r) for r in regs] - undef = llvm.mlir_undef(ir.VectorType.get((2,), dtype)) - for r_low, r_high, idx in zip(regs[::2], regs[1::2], np.ndindex(cols // 8, 2), strict=True): - high_undef = llvm.insertelement(undef, r_low, utils.c(0, i32)) - vreg = llvm.insertelement(high_undef, r_high, utils.c(1, i32)) - vector_regs[idx[1] + 2 * row_group, idx[0]] = vreg + # From a single lane perspective a num tile consists of a 2x2, with the + # minor dim traversing columns and major being 8 rows apart. + # See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16256b + regs = np.asarray(regs, dtype=object).reshape(instr_num, 2, 2).swapaxes(0, 1) + undef = llvm.mlir_undef(vec_ty) + assert regs.shape == (*vector_regs_update.shape, 2) + for idx in np.ndindex(vector_regs_update.shape): + high_undef = llvm.insertelement(undef, regs[(*idx, 0)], c0) + vreg = llvm.insertelement(high_undef, regs[(*idx, 1)], c1) + vector_regs_update[idx] = vreg else: assert packing == 2 - regs = [llvm.bitcast(ir.VectorType.get((2,), dtype), r) for r in regs] - for vreg, idx in zip(regs, np.ndindex(cols // 8, 2), strict=True): - vector_regs[idx[1] + 2 * row_group, idx[0]] = vreg + regs = [llvm.bitcast(vec_ty, r) for r in regs] + # From a single lane perspective a num tile has 2 registers, 8 rows apart. + # See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16128b + regs = np.asarray(regs, dtype=object).reshape(instr_num, 2).swapaxes(0, 1) + vector_regs_update[...] = regs + return vector_regs -# Like WGMMA_LAYOUT, only each warp holds a 32xN strip instead of 16xN. def _m128_layout(shape: tuple[int, ...]): if len(shape) != 2: raise ValueError(f"Shape {shape} is not 2D") if shape[0] % 128 != 0 or shape[1] % 8 != 0: raise ValueError(f"Shape {shape} is not a multiple of 64x8") - return fa.TiledLayout( - fa.Tiling(((128, 8), (32, 8), (8, 8), (1, 2))), - warp_dim=-8, - lane_dims=(-4, -3), - vector_dim=-1, + return LAYOUT + +# Like WGMMA_LAYOUT, only each warp holds a 32xN strip instead of 16xN. +# The name is so short, because it's meant to be used qualified (tcgen05.LAYOUT) +LAYOUT = fa.TiledLayout( + fa.Tiling(((128, 8), (32, 8), (8, 8), (1, 2))), + warp_dim=-8, + lane_dims=(-4, -3), + vector_dim=-1, +) + + +def commit_tmem(): + void = ir.Type.parse("!llvm.void") + llvm.inline_asm( + void, [], "tcgen05.wait::st.sync.aligned;", "", has_side_effects=True, ) + utils.warpgroup_barrier() diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index d9f56ee1d454..6c63e3ce40e1 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -908,6 +908,43 @@ def setUp(self): if not any(jtu.is_cuda_compute_capability_equal(sm) for sm in capabilities): self.skipTest("Only works on GPU with capability sm_100a or sm_101a") + @parameterized.parameters([jnp.float32, jnp.float16]) + def test_load_store_tmem(self, jax_dtype): + swizzle = 128 + in_mlir_dtype = utils.dtype_to_ir_type(jax_dtype) + swizzle_elems = swizzle // bytewidth(in_mlir_dtype) + tiling = (8, swizzle_elems) + + def kernel(ctx, input, output, scratch): + smem, barrier, tmem = scratch + ctx.async_copy( + src_ref=input, + dst_ref=smem, + swizzle=swizzle, + gmem_transform=mgpu.TileTransform(tiling), + barrier=barrier, + ) + barrier.wait() + tmem[:] = fa.FragmentedArray.load_tiled(smem, swizzle, layout=tcgen05.LAYOUT) + tcgen05.commit_tmem() + tmem[:].store_tiled(smem, swizzle) + mgpu.commit_shared() + ctx.async_copy( + src_ref=smem, dst_ref=output, swizzle=swizzle, gmem_transform=mgpu.TileTransform(tiling), + ) + ctx.await_async_copy(0) + + x = self.prng.uniform(-1, 1, (128, 128)).astype(jax_dtype) + scratch_shape = [ + jax.ShapeDtypeStruct(tile_shape(x.shape, tiling), jax_dtype), + mgpu.TMABarrier(), + mgpu.TMEM(x.shape, jax_dtype), + ] + y = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, x, scratch_shape + )(x) + np.testing.assert_array_equal(x, y) + @parameterized.product( lhs_transpose=(False, True), rhs_transpose=(False, True), From 39fb2a00a6b4313e836266dfa4e6a6c73b65ca42 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 28 Mar 2025 05:43:05 -0700 Subject: [PATCH 244/483] [Mosaic GPU] Add support for allocation and lowering of scratch semaphores The semaphore arrays are allocated in GMEM and zeroed by XLA before the kernel begins. PiperOrigin-RevId: 741494241 --- jax/_src/pallas/mosaic_gpu/BUILD | 2 +- jax/_src/pallas/mosaic_gpu/core.py | 19 ++++++ jax/_src/pallas/mosaic_gpu/lowering.py | 60 ++++++++++++++----- .../mosaic_gpu/pallas_call_registration.py | 32 ++++++++-- jax/experimental/mosaic/gpu/core.py | 1 + jax/experimental/pallas/mosaic_gpu.py | 1 + tests/pallas/mosaic_gpu_test.py | 24 ++++++++ 7 files changed, 117 insertions(+), 22 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/BUILD b/jax/_src/pallas/mosaic_gpu/BUILD index ab35eebafc04..33883326e58c 100644 --- a/jax/_src/pallas/mosaic_gpu/BUILD +++ b/jax/_src/pallas/mosaic_gpu/BUILD @@ -48,7 +48,7 @@ pytype_strict_library( "//jax:mlir", "//jax:mosaic_gpu", "//jax/_src/pallas", - ], + ] + py_deps("numpy"), ) pytype_strict_library( diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 19007b6850fd..857daaefe38f 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -120,6 +120,25 @@ def __call__( return GPUMemoryRef(shape, dtype, memory_space=self, transforms=transforms) +class SemaphoreType(enum.Enum): + REGULAR = "regular" + BARRIER = "barrier" + + def __call__(self, shape: tuple[int, ...]): + dtype: Any + if self == SemaphoreType.BARRIER: + dtype = pallas_core.BarrierSemaphore() + else: + dtype = pallas_core.Semaphore() + return pallas_core.MemoryRef(shape, dtype, GPUMemorySpace.GMEM) + + def get_array_aval(self) -> jax_core.ShapedArray: + return self(()).get_array_aval() + + def get_ref_aval(self) -> pallas_core.TransformedRef | AbstractMemoryRef: + return self(()).get_ref_aval() + + def kernel( body: Callable[..., None], out_shape: object, diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 0c9f70937873..1b4aa33dc909 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -418,8 +418,9 @@ class LoweringResult: module: ir.Module grid: tuple[int, ...] block: tuple[int, ...] - out_structs: tuple[jax.ShapeDtypeStruct, ...] + new_out_shapes: tuple[jax.ShapeDtypeStruct, ...] # Does not include gmem scratch! profiler_context: ProfilerContext | None + gmem_scratch_shapes: tuple[jax.ShapeDtypeStruct, ...] @dataclasses.dataclass(frozen=True) @@ -588,16 +589,41 @@ def ref_for_aval(aval: jax_core.AbstractValue): else: return gpu_core.SMEM(aval.shape, aval.dtype) + sem_placeholder = None + semaphore_ref_avals = [] + scratch_avals = [] + # Need to unzip semaphores + for v in jaxpr.invars[grid_mapping.slice_scratch_ops]: + aval = v.aval + if (isinstance(aval, pallas_core.AbstractMemoryRef) and + jnp.issubdtype(aval.dtype, pallas_core.semaphore_dtype)): + if aval.memory_space != gpu_core.GPUMemorySpace.GMEM: + raise ValueError( + "Only GMEM memory space is supported for semaphores in Mosaic GPU." + ) + semaphore_ref_avals.append(aval) + scratch_avals.append(sem_placeholder) + else: + scratch_avals.append(aval) + def pipeline_fn(*refs): - return primitives.run_scoped( - functools.partial(scoped_pipeline_fn, *refs), + sem_refs = [] + if semaphore_ref_avals: + refs, sem_refs = util.split_list(refs, [-len(semaphore_ref_avals)]) + primitives.run_scoped( + functools.partial(scoped_pipeline_fn, *refs, sem_refs=sem_refs), scratch_refs=[ - ref_for_aval(v.aval) - for v in jaxpr.invars[grid_mapping.slice_scratch_ops] + ref_for_aval(aval) if aval is not sem_placeholder else aval + for aval in scratch_avals ], ) + return () # ``wrap_init`` does not support functions returning None. - def scoped_pipeline_fn(*refs, scratch_refs): + def scoped_pipeline_fn(*refs, sem_refs, scratch_refs): + sem_refs_it = iter(sem_refs) + scratch_refs = [ + next(sem_refs_it) if r is sem_placeholder else r for r in scratch_refs + ] def body_fn(*refs): grid_env = pallas_core.current_grid_env() assert grid_env is not None # Set by ``emit_pipeline``. @@ -628,17 +654,13 @@ def body_fn(*refs): with grid_mapping.trace_env(): new_jaxpr, _, new_consts, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init( - # ``wrap_init`` does not support functions returning None. - lambda *args: pipeline_fn(*args) or (), - debug_info=jaxpr.debug_info, - ), + lu.wrap_init(pipeline_fn, debug_info=jaxpr.debug_info), [ gpu_core.GMEM( bm.array_shape_dtype.shape, bm.array_shape_dtype.dtype ).get_ref_aval() for bm in block_mappings - ], + ] + semaphore_ref_avals, ) assert not new_consts @@ -655,6 +677,10 @@ def body_fn(*refs): mesh.cluster if mesh is not None else (), [bm.array_shape_dtype for bm in in_block_mappings], [bm.array_shape_dtype for bm in out_block_mappings], + [ + jax.ShapeDtypeStruct(r.shape, np.dtype(np.int32)) + for r in semaphore_ref_avals + ], new_jaxpr, compiler_params, new_consts, @@ -668,6 +694,7 @@ def lower_jaxpr_to_module( cluster: Sequence[int], in_shapes: Sequence[jax.ShapeDtypeStruct], out_shapes: Sequence[jax.ShapeDtypeStruct], + gmem_scratch_shapes: Sequence[jax.ShapeDtypeStruct], jaxpr: jax_core.Jaxpr, compiler_params: dict[str, Any], consts=(), @@ -754,14 +781,14 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): # Each range is 2 events, each event is 4 bytes. prof_spec = mgpu_profiler.ProfilerSpec(prof_space * 2 * 4) prof_ctx = ProfilerContext(params["profile_dir"], prof_spec) - module, out_structs_gmem, _, launch_ctx, scratch_arr = ( + module, new_out_shapes, _, launch_ctx, scratch_arr = ( mgpu_core._lower_as_gpu_kernel( body, grid=tuple(map(operator.mul, parallel_grid, cluster)), cluster=cluster, block=block, in_shapes=in_shapes, - out_shape=out_shapes, + out_shape=(*out_shapes, *gmem_scratch_shapes), smem_scratch_shape=scratch_buffers, module_name=mlir.sanitize_name(debug_info.func_name), prof_spec=prof_spec, @@ -777,8 +804,11 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): mgpu_core._initialize_scratch(launch_ctx, scratch_arr) + if gmem_scratch_shapes: + new_out_shapes = new_out_shapes[:-len(gmem_scratch_shapes)] + return LoweringResult( - module, parallel_grid, block, out_structs_gmem, prof_ctx + module, parallel_grid, block, new_out_shapes, prof_ctx, tuple(gmem_scratch_shapes) ) diff --git a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py index 40b12215c003..6dc958edbc53 100644 --- a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py @@ -23,11 +23,13 @@ import warnings import jax +from jax import lax from jax._src import core as jax_core from jax._src.interpreters import mlir from jax._src.pallas import core as pallas_core from jax._src.pallas.mosaic_gpu import lowering from jax.experimental.mosaic import gpu as mgpu +import numpy as np def pallas_call_lowering( @@ -74,16 +76,30 @@ def pallas_call_lowering( print(lowering_result.module.operation) module = lowering_result.module - new_avals_out = [ - jax_core.ShapedArray(t.shape, t.dtype) for t in lowering_result.out_structs - ] + new_avals_in = list(ctx.avals_in) + new_avals_out = list(map(_as_shaped_array, lowering_result.new_out_shapes)) + scratch_args = () + if lowering_result.gmem_scratch_shapes: + input_output_aliases += tuple( + (len(new_avals_in) + i, len(new_avals_out) + i) + for i in range(len(lowering_result.gmem_scratch_shapes)) + ) + new_avals_in.extend(map(_as_shaped_array, lowering_result.gmem_scratch_shapes)) + new_avals_out.extend(map(_as_shaped_array, lowering_result.gmem_scratch_shapes)) + def zero_init_gmem_scratch(): + return [lax.zeros_like_array(s) for s in lowering_result.gmem_scratch_shapes] + scratch_args = mlir.lower_fun( + zero_init_gmem_scratch, multiple_results=True + )(ctx.replace(avals_in=())) outs = mgpu.core._mosaic_gpu_lowering_rule( - ctx.replace(avals_out=new_avals_out), - *args, + ctx.replace(avals_in=new_avals_in, avals_out=new_avals_out), + *args, *scratch_args, module=module, - out_types=lowering_result.out_structs, + out_types=(*lowering_result.new_out_shapes, *lowering_result.gmem_scratch_shapes), input_output_aliases=input_output_aliases, ) + if lowering_result.gmem_scratch_shapes: # Drop the GMEM scratch. + outs = outs[:-len(lowering_result.gmem_scratch_shapes)] if (prof_ctx := lowering_result.profiler_context) is not None: *outs, prof_buffer = outs if (dump_path := prof_ctx.dump_path) == "sponge": @@ -112,3 +128,7 @@ def do_callback(prof_buffer): ctx.replace(avals_in=(new_avals_out[-1],)), prof_buffer ) return outs + + +def _as_shaped_array(t: jax.ShapeDtypeStruct) -> jax_core.ShapedArray: + return jax_core.ShapedArray(t.shape, np.dtype(t.dtype)) diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index f5331eb1b56a..fcc5d3db6d60 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -104,6 +104,7 @@ def _mosaic_gpu_lowering_rule( out_types, input_output_aliases: tuple[tuple[int, int], ...] = (), ): + assert len(args) == len(ctx.avals_in) assert len(out_types) == len(ctx.avals_out) module = _run_serde_pass( module, diff --git a/jax/experimental/pallas/mosaic_gpu.py b/jax/experimental/pallas/mosaic_gpu.py index d5acb9b131ad..b791fbb8b573 100644 --- a/jax/experimental/pallas/mosaic_gpu.py +++ b/jax/experimental/pallas/mosaic_gpu.py @@ -23,6 +23,7 @@ from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace as GPUMemorySpace from jax._src.pallas.mosaic_gpu.core import GPUMesh as GPUMesh from jax._src.pallas.mosaic_gpu.core import kernel as kernel +from jax._src.pallas.mosaic_gpu.core import SemaphoreType as SemaphoreType from jax._src.pallas.mosaic_gpu.core import SwizzleTransform as SwizzleTransform from jax._src.pallas.mosaic_gpu.core import TilingTransform as TilingTransform from jax._src.pallas.mosaic_gpu.core import transpose_ref as transpose_ref diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index aea49b645ec6..c5da44d7b6fa 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -2408,6 +2408,30 @@ def compute(l_smem, r_smem, o_smem): out = plgpu.kernel(body, out_shape=x, grid=(2,), axis_names=("rows",))(x, x) np.testing.assert_allclose(out, x + x) + def test_semaphore_lowering(self): + # This is a smoke test until we add support for lowering of semaphore ops. + def body(i_ref1, i_ref2, o_ref, sem_ref): + del i_ref2 # Only here to have a different number of inputs and outputs. + assert sem_ref.shape == (4,) + assert jnp.issubdtype(sem_ref.dtype, pl.semaphore) + o_ref[...] = i_ref1[...] + x = jnp.arange(128, dtype=jnp.float32).reshape((128,)) + kernel = pl.pallas_call( + body, out_shape=x, scratch_shapes=[plgpu.SemaphoreType.REGULAR((4,))], + ) + text = jax.jit(kernel).lower(x, x).as_text() + self.assertIn( + r"output_operand_aliases =" + r" [#stablehlo.output_operand_alias]", + text, + ) + self.assertIn( + r"(tensor<128xf32>, tensor<128xf32>, tensor<4xi32>) ->" + r" (tensor<128xf32>, tensor<4xi32>)", + text, + ) + class ExamplesSm90ATest(PallasSm90ATest): From 5c61a69fd6dc46657c952e6f235c35cb5f3bcbd6 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Fri, 28 Mar 2025 09:17:20 -0400 Subject: [PATCH 245/483] Fixes failing FFI example builds. Breaking CI: https://github.com/jax-ml/jax/actions/runs/14126719325/job/39577362075?pr=27557 See breaking nanobind PR: https://github.com/wjakob/nanobind/pull/978 See fixing nanobind PR (not landed) https://github.com/wjakob/nanobind/pull/980 --- examples/ffi/pyproject.toml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/ffi/pyproject.toml b/examples/ffi/pyproject.toml index 130dd91bbc70..6f188ee037da 100644 --- a/examples/ffi/pyproject.toml +++ b/examples/ffi/pyproject.toml @@ -1,5 +1,7 @@ [build-system] -requires = ["scikit-build-core", "nanobind", "jax>=0.4.31"] +# TODO(dsuo): Remove nanobind pin after +# https://github.com/wjacob/nanobind/pull/980 lands. +requires = ["scikit-build-core", "nanobind==2.5.0", "jax>=0.4.31"] build-backend = "scikit_build_core.build" [project] From 4024897372cac19d2c17004babf1063d4975a38b Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 28 Mar 2025 06:51:48 -0700 Subject: [PATCH 246/483] Update CUDA tests matrix in the continuous jobs We now test only CUDA 12.1 and CUDA 12.8 PiperOrigin-RevId: 741509853 --- .github/workflows/wheel_tests_continuous.yml | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/.github/workflows/wheel_tests_continuous.yml b/.github/workflows/wheel_tests_continuous.yml index f48c39bf4721..530e0a9b0768 100644 --- a/.github/workflows/wheel_tests_continuous.yml +++ b/.github/workflows/wheel_tests_continuous.yml @@ -116,23 +116,16 @@ jobs: cuda: ["12.1", "12.8"] enable-x64: [1, 0] exclude: - # L4 does not run on cuda 12.8 but tests other configs - - runner: "linux-x86-g2-48-l4-4gpu" - cuda: "12.8" - # H100 runs only a single config, CUDA 12.3 Enable x64 1 - - runner: "linux-x86-a3-8g-h100-8gpu" - cuda: "12.8" + # H100 runs only a single config, CUDA 12.8 Enable x64 1 - runner: "linux-x86-a3-8g-h100-8gpu" cuda: "12.1" - runner: "linux-x86-a3-8g-h100-8gpu" enable-x64: "0" # B200 runs only a single config, CUDA 12.8 Enable x64 1 - - runner: "linux-x86-a4-224-b200-1gpu" - enable-x64: "0" - runner: "linux-x86-a4-224-b200-1gpu" cuda: "12.1" - runner: "linux-x86-a4-224-b200-1gpu" - cuda: "12.3" + enable-x64: "0" name: "Pytest CUDA (JAX artifacts version = ${{ format('{0}', 'head') }})" with: From 28f63ee27e751f0d4033131c881f157615acbfd9 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 28 Mar 2025 06:53:19 -0700 Subject: [PATCH 247/483] Use the Docker image with CUDA 12.8 and cudnn 9.8 in the Bazel CUDA non RBE job PiperOrigin-RevId: 741510217 --- .github/workflows/bazel_cuda_non_rbe.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/bazel_cuda_non_rbe.yml b/.github/workflows/bazel_cuda_non_rbe.yml index 3d15f4211a3f..ff1cf9900ce3 100644 --- a/.github/workflows/bazel_cuda_non_rbe.yml +++ b/.github/workflows/bazel_cuda_non_rbe.yml @@ -47,7 +47,7 @@ jobs: # Explicitly set the shell to bash shell: bash runs-on: ${{ inputs.runner }} - container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.3-cudnn9.1:latest" + container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.8-cudnn9.8:latest" env: JAXCI_HERMETIC_PYTHON_VERSION: ${{ inputs.python }} From f1ebb1e1e13e0ee57feb3916f6cb4919cfc0e62c Mon Sep 17 00:00:00 2001 From: Ayaka Date: Fri, 28 Mar 2025 07:16:57 -0700 Subject: [PATCH 248/483] Skip failing tests on TPU v6+ PiperOrigin-RevId: 741515935 --- tests/lax_numpy_reducers_test.py | 4 ++-- tests/lax_scipy_test.py | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index 0c3f1d1471fb..aa5e08e96a3e 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -905,8 +905,8 @@ def testCumulativeSumBool(self): @jtu.ignore_warning(category=NumpyComplexWarning) @jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion def testCumulativeProd(self, shape, axis, dtype, out_dtype, include_initial): - if jtu.is_device_tpu(6): - raise unittest.SkipTest("TODO(b/364258243): Test fails on TPU v6e") + if jtu.is_device_tpu_at_least(6): + raise unittest.SkipTest("TODO(b/364258243): Test fails on TPU v6+") rng = jtu.rand_some_zero(self.rng()) # We currently "cheat" to ensure we have JAX arrays, not NumPy arrays as diff --git a/tests/lax_scipy_test.py b/tests/lax_scipy_test.py index 388d053d9608..bc80ed4e1cc2 100644 --- a/tests/lax_scipy_test.py +++ b/tests/lax_scipy_test.py @@ -339,8 +339,8 @@ def scipy_fun(z): ) @jtu.ignore_warning(category=DeprecationWarning, message=".*scipy.special.lpmn.*") def testLpmn(self, l_max, shape, dtype): - if jtu.is_device_tpu(6, "e"): - self.skipTest("TODO(b/364258243): fails on TPU v6e") + if jtu.is_device_tpu_at_least(6): + self.skipTest("TODO(b/364258243): fails on TPU v6+") rng = jtu.rand_uniform(self.rng(), low=-0.2, high=0.9) args_maker = lambda: [rng(shape, dtype)] @@ -461,8 +461,8 @@ def testSphHarmOrderOneDegreeOne(self): @jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion def testSphHarmForJitAndAgainstNumpy(self, l_max, num_z, dtype): """Tests against JIT compatibility and Numpy.""" - if jtu.is_device_tpu(6, "e"): - self.skipTest("TODO(b/364258243): fails on TPU v6e") + if jtu.is_device_tpu_at_least(6): + self.skipTest("TODO(b/364258243): fails on TPU v6+") n_max = l_max shape = (num_z,) rng = jtu.rand_int(self.rng(), -l_max, l_max + 1) @@ -508,8 +508,8 @@ def testSphHarmCornerCaseWithWrongNmax(self): ) @jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion def testSphHarmY(self, l_max, num_z, dtype): - if jtu.is_device_tpu(6, "e"): - self.skipTest("TODO(b/364258243): fails on TPU v6e") + if jtu.is_device_tpu_at_least(6): + self.skipTest("TODO(b/364258243): fails on TPU v6+") n_max = l_max shape = (num_z,) rng = jtu.rand_int(self.rng(), -l_max, l_max + 1) From 563c3e224425d0fe3bf8016105cae7769eb0474b Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 28 Mar 2025 07:19:07 -0700 Subject: [PATCH 249/483] Add standard pbroadcast rules to more primitives. This should cover all primitives from which shard_map registered standard_rewrite rules PiperOrigin-RevId: 741516445 --- jax/_src/core.py | 8 ++++++-- jax/_src/ffi.py | 9 ++++++--- jax/_src/lax/control_flow/solves.py | 14 ++++++++------ jax/_src/lax/windowed_reductions.py | 2 +- jax/_src/prng.py | 29 ++++++++++++++++++++--------- 5 files changed, 41 insertions(+), 21 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index ca353486afd5..1be60336f1a9 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1528,7 +1528,7 @@ def check_valid_jaxtype(x): def update_aval_with_sharding(aval, sharding): if isinstance(sharding, NamedSharding): - aval = aval.update(sharding=NamedSharding( + return aval.update(sharding=NamedSharding( sharding.mesh.abstract_mesh, sharding.spec._normalized_spec_for_aval(aval.ndim))) return aval @@ -1659,8 +1659,10 @@ def physical_aval(aval): elt_aval = physical_element_aval(aval.dtype) if isinstance(aval, ShapedArray): from jax._src.sharding_impls import physical_sharding # type: ignore + vma = aval.vma if config.varying_axes_in_types.value else frozenset() return ShapedArray((*aval.shape, *elt_aval.shape), elt_aval.dtype, - sharding=physical_sharding(aval, aval.sharding)) + sharding=physical_sharding(aval, aval.sharding), + vma=vma) return DShapedArray((*aval.shape, *elt_aval.shape), elt_aval.dtype) return aval @@ -2019,6 +2021,8 @@ def standard_insert_pbroadcast(*args): if out_vma - src else arg for arg, src in zip(args, in_vma)] def standard_vma_rule(prim_name, *avals, **kwargs): + if not avals: + return avals vma, *vmas = [a.vma for a in avals] if not all(vma == vma_ for vma_ in vmas): raise ValueError( diff --git a/jax/_src/ffi.py b/jax/_src/ffi.py index 05697f00e945..c867ec16b9b3 100644 --- a/jax/_src/ffi.py +++ b/jax/_src/ffi.py @@ -24,6 +24,7 @@ import jax from jax._src import core +from jax._src import config from jax._src import deprecations from jax._src import dispatch from jax._src import effects @@ -515,7 +516,7 @@ def wrapped(*args: ArrayLike, **kwargs: Any): "and an output with a different layout " f"{static_output_layouts[o_idx]}.") static_input_output_aliases += ((i_idx, o_idx),) - + args = core.standard_insert_pbroadcast(*args) results = ffi_call_p.bind( *args, result_avals=result_avals, @@ -638,9 +639,11 @@ def ffi_call_abstract_eval( has_side_effect: bool, **_, ): - del avals_in # unused + out_vma = (core.standard_vma_rule('ffi_call', *avals_in) + if config.varying_axes_in_types.value else frozenset()) effects = {_FfiEffect} if has_side_effect else core.no_effects - return result_avals, effects + return tuple(r if r is core.abstract_token else r.update(vma=out_vma) + for r in result_avals), effects def ffi_call_jvp(*args, target_name, **_): diff --git a/jax/_src/lax/control_flow/solves.py b/jax/_src/lax/control_flow/solves.py index acfcfd7ff3d3..4a0872bef4b2 100644 --- a/jax/_src/lax/control_flow/solves.py +++ b/jax/_src/lax/control_flow/solves.py @@ -23,6 +23,7 @@ from jax._src import api from jax._src import api_util from jax._src import core +from jax._src import config from jax._src import custom_derivatives from jax._src import linear_util as lu from jax._src.interpreters import ad @@ -309,24 +310,25 @@ def f_aux(x): jaxprs = _LinearSolveTuple( matvec_jaxpr, vecmat_jaxpr, solve_jaxpr, tr_solve_jaxpr) - out_flat = linear_solve_p.bind( - *(_flatten(all_consts) + b_flat), - const_lengths=const_lengths, jaxprs=jaxprs) + args = _flatten(all_consts) + b_flat + args = core.standard_insert_pbroadcast(*args) + out_flat = linear_solve_p.bind(*args, const_lengths=const_lengths, jaxprs=jaxprs) return tree_unflatten(out_tree, out_flat) def _linear_solve_abstract_eval(*args, const_lengths, jaxprs): args_to_raise = args[sum(const_lengths):] - # raise aux_args to shaped arrays as well if present # number of aux args is the difference in out_avals # of solve and matvec (since they map to the same vector space) - num_aux = len(jaxprs.solve.out_avals) - len(jaxprs.matvec.out_avals) if num_aux > 0: args_to_raise += tuple(jaxprs.solve.out_avals[-num_aux:]) - return args_to_raise, jaxprs.solve.effects + out_vma = (core.standard_vma_rule('linear_solve', *args_to_raise) + if config.varying_axes_in_types.value else frozenset()) + return (tuple(a.update(vma=out_vma) for a in args_to_raise), + jaxprs.solve.effects) def _custom_linear_solve_impl(*args, const_lengths, jaxprs): diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index 00bdfe75f3e7..73fae7df40e1 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -338,7 +338,7 @@ def _reduce_window_abstract_eval_rule( out_sharding = reduce_window_sharding_rule( operand_avals[0], window_dimensions, window_strides, padding, base_dilation, window_dilation) - out_vma = (core.standard_vma_rule('reduce_window', operand_avals) + out_vma = (core.standard_vma_rule('reduce_window', *operand_avals) if config.varying_axes_in_types.value else frozenset()) return tuple(ShapedArray(out_shape, op.dtype, sharding=out_sharding, vma=out_vma) diff --git a/jax/_src/prng.py b/jax/_src/prng.py index ead939d74351..17d16527bb71 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -178,7 +178,9 @@ def copy_to_host_async(self): def aval(self): logical_sharding = (self.sharding if hasattr(self._base_array, 'sharding') else None) - return keys_shaped_array(self._impl, self.shape, logical_sharding) + vma = (self._base_array.aval.vma if config.varying_axes_in_types.value else frozenset() + if hasattr(self._base_array, 'aval') else frozenset()) + return keys_shaped_array(self._impl, self.shape, logical_sharding, vma) @property def shape(self): @@ -329,8 +331,8 @@ def seed_with_impl(impl: PRNGImpl, seed: int | typing.ArrayLike) -> PRNGKeyArray return random_seed(seed, impl=impl) -def keys_shaped_array(impl, shape, sharding): - aval = core.ShapedArray(shape, KeyTy(impl)) +def keys_shaped_array(impl, shape, sharding, vma): + aval = core.ShapedArray(shape, KeyTy(impl), vma=vma) return core.update_aval_with_sharding(aval, sharding) def base_arr_shape_to_keys_shape(impl, base_arr_shape): @@ -550,7 +552,8 @@ def random_seed(seeds: int | typing.ArrayLike, impl: PRNGImpl) -> PRNGKeyArray: @random_seed_p.def_abstract_eval def random_seed_abstract_eval(seeds_aval, *, impl): - return keys_shaped_array(impl, seeds_aval.shape, seeds_aval.sharding) + out_vma = seeds_aval.vma if config.varying_axes_in_types.value else frozenset() + return keys_shaped_array(impl, seeds_aval.shape, seeds_aval.sharding, out_vma) @random_seed_p.def_impl def random_seed_impl(seeds, *, impl): @@ -584,8 +587,9 @@ def random_split_abstract_eval(keys_aval, *, shape): # TODO(yashkatariya): random_split should take sharding as an arg too so we # don't choose None here? new_spec = (*keys_aval.sharding.spec, *[None] * len(shape)) + out_vma = keys_aval.vma if config.varying_axes_in_types.value else frozenset() return keys_shaped_array(keys_aval.dtype._impl, (*keys_aval.shape, *shape), - keys_aval.sharding.with_spec(new_spec)) + keys_aval.sharding.with_spec(new_spec), out_vma) @random_split_p.def_impl def random_split_impl(keys, *, shape): @@ -611,7 +615,9 @@ def random_split_lowering(ctx, keys, *, shape): def random_fold_in(keys, msgs): - return random_fold_in_p.bind(keys, jnp.asarray(msgs)) + msgs = jnp.asarray(msgs) + keys, msgs = core.standard_insert_pbroadcast(keys, msgs) + return random_fold_in_p.bind(keys, msgs) random_fold_in_p = core.Primitive('random_fold_in') ad.defjvp_zero(random_fold_in_p) @@ -623,7 +629,9 @@ def random_fold_in_abstract_eval(keys_aval, msgs_aval): 'random_fold_in', keys_aval, msgs_aval) sharding = lax_internal.broadcasting_sharding_rule( 'random_fold_in', keys_aval, msgs_aval) - return core.ShapedArray(shape, keys_aval.dtype, sharding=sharding) + vma = (core.standard_vma_rule('random_fold_in', keys_aval, msgs_aval) + if config.varying_axes_in_types.value else frozenset()) + return core.ShapedArray(shape, keys_aval.dtype, sharding=sharding, vma=vma) @random_fold_in_p.def_impl def random_fold_in_impl(keys, msgs): @@ -661,7 +669,8 @@ def random_bits(keys, bit_width, shape): def random_bits_abstract_eval(keys_aval, *, bit_width, shape): out_shape = (*keys_aval.shape, *shape) out_dtype = dtypes.dtype(f'uint{bit_width}') - return core.ShapedArray(out_shape, out_dtype) + vma = keys_aval.vma if config.varying_axes_in_types.value else frozenset() + return core.ShapedArray(out_shape, out_dtype, vma=vma) @random_bits_p.def_impl def random_bits_impl(keys, *, bit_width, shape): @@ -718,7 +727,9 @@ def random_wrap(base_arr, *, impl): def random_wrap_abstract_eval(base_arr_aval, *, impl): shape = base_arr_shape_to_keys_shape(impl, base_arr_aval.shape) sharding = logical_sharding(shape, KeyTy(impl), base_arr_aval.sharding) - return keys_shaped_array(impl, shape, sharding) + out_vma = (base_arr_aval.vma if config.varying_axes_in_types.value else + frozenset()) + return keys_shaped_array(impl, shape, sharding, out_vma) @random_wrap_p.def_impl def random_wrap_impl(base_arr, *, impl): From e679811c4ae245bcc48e9f18a51361cabbdf5561 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Fri, 28 Mar 2025 07:20:57 -0700 Subject: [PATCH 250/483] [Mosaic GPU] Add warpgroup lowering for `Exp2` in Pallas. This change also enables tests for supported elementwise ops. PiperOrigin-RevId: 741516852 --- jax/_src/pallas/mosaic_gpu/lowering.py | 1 + jax/experimental/mosaic/gpu/dialect_lowering.py | 12 ++++++++++++ tests/pallas/mosaic_gpu_test.py | 1 - tests/pallas/ops_test.py | 9 ++++++++- 4 files changed, 21 insertions(+), 2 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 1b4aa33dc909..baac1e6eb316 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1671,6 +1671,7 @@ def _exp_lowering_rule(ctx: LoweringRuleContext, x, accuracy): @register_lowering_rule(lax.exp2_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.exp2_p, mgpu.ThreadSemantics.Warpgroup) def _exp2_lowering_rule(ctx: LoweringRuleContext, x, accuracy): if accuracy is not None: raise NotImplementedError("Not implemented: accuracy") diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index 936bba73915b..f00cff9a500c 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -560,6 +560,12 @@ def _mgpu_async_load_op_lowering_rule( v = idx if size < 0 else utils.DynamicSlice(idx, size) gmem_slice.append(v) + # TODO(dasenov): async_copy requires all GMEM strides except the last one + # to be a multiple of 16 bytes. This restriction could be loosned with + # strided layouts when they are contiguous in GMEM. In that case, we could do: + # flatten -> async_copy -> unflatted here, as long as flattened size is a + # multiple of 16. + # TODO(dasenov): Add support for the remaining op properties. ctx.launch_context.async_copy( src_ref=load_op.source, @@ -596,6 +602,12 @@ def _mgpu_async_store_op_lowering_rule( v = idx if size < 0 else utils.DynamicSlice(idx, size) gmem_slice.append(v) + # TODO(dasenov): async_copy requires all GMEM strides except the last one + # to be a multiple of 16 bytes. This restriction could be loosned with + # strided layouts when they are contiguous in GMEM. In that case, we could do: + # flatten -> async_copy -> unflatted here, as long as flattened size is a + # multiple of 16. + # TODO(dasenov): Add support for the remaining op properties. ctx.launch_context.async_copy( src_ref=transform_memref(store_op.source, transforms), diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index c5da44d7b6fa..874ecae93f3f 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1418,7 +1418,6 @@ def test_missing_primitive_lowerings_are_tracked(self): actual_missing_primitives = lane_lowered_primitives - wg_lowered_primitives expected_missing_primitives = { mgpu_primitives.broadcasted_iota_p, - lax.exp2_p, mgpu_primitives.layout_cast_p, mgpu_primitives.load_p, lax.slice_p, diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index f5b70878533d..aeb0ba1cca1a 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -562,7 +562,8 @@ def kernel(*refs): ) @hp.given(hps.data()) def test_unary_primitives(self, name, func, shape_dtype_strategy, data): - self.skip_if_mosaic_gpu() + if name in ["abs", "log1p", "pow2", "reciprocal", "relu", "sin", "sqrt"]: + self.skip_if_mosaic_gpu() if self.INTERPRET: self.skipTest("This hypothesis test is slow, even more so in interpret mode.") @@ -579,6 +580,12 @@ def test_unary_primitives(self, name, func, shape_dtype_strategy, data): def kernel(x_ref, y_ref): y_ref[...] = func(x_ref[...]) x_shape_dtype = data.draw(shape_dtype_strategy) + + sut_is_mosaic_gpu = jtu.test_device_matches(["gpu"]) and use_mosaic_gpu + if sut_is_mosaic_gpu: + hp.assume(math.prod(x_shape_dtype.shape) % 128 == 0) + hp.assume(x_shape_dtype.shape[-1] >= 16) + key = random.key(0) x = _random_value(key, x_shape_dtype) out = self.pallas_call(kernel, out_shape=x_shape_dtype)(x) From 431c2c080728a1c880f1facab0bc431631658fe9 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 28 Mar 2025 07:44:38 -0700 Subject: [PATCH 251/483] cleanup now that we depend on ml_dtypes>=0.5 --- jax/_src/dtypes.py | 119 +++++++++++-------------------- jax/_src/export/serialization.py | 12 ++-- jax/_src/interpreters/mlir.py | 22 ++---- jax/_src/lax/lax.py | 19 ++--- jax/_src/numpy/scalar_types.py | 18 ++--- jax/_src/public_test_util.py | 41 ++++------- jax/_src/test_util.py | 12 ++-- jax/numpy/__init__.py | 26 ++----- jax/tools/jax_to_ir.py | 8 +-- jaxlib/xla/xla_client.py | 11 ++- tests/dtypes_test.py | 25 ++----- tests/jax_to_ir_test.py | 6 +- 12 files changed, 104 insertions(+), 215 deletions(-) diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index 01500c008405..d1e5b7bf430b 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -90,19 +90,18 @@ def type(self) -> type: ... # fp8 support -# TODO: remove Optional when minimum ml_dtypes version >= 0.5.0 -float8_e3m4: type[np.generic] | None = None -float8_e4m3: type[np.generic] | None = None -float8_e8m0fnu: type[np.generic] | None = None +float8_e3m4: type[np.generic] = ml_dtypes.float8_e3m4 +float8_e4m3: type[np.generic] = ml_dtypes.float8_e4m3 +float8_e8m0fnu: type[np.generic] = ml_dtypes.float8_e8m0fnu float8_e4m3b11fnuz: type[np.generic] = ml_dtypes.float8_e4m3b11fnuz float8_e4m3fn: type[np.generic] = ml_dtypes.float8_e4m3fn float8_e4m3fnuz: type[np.generic] = ml_dtypes.float8_e4m3fnuz float8_e5m2: type[np.generic] = ml_dtypes.float8_e5m2 float8_e5m2fnuz: type[np.generic] = ml_dtypes.float8_e5m2fnuz -_float8_e3m4_dtype: np.dtype | None = None -_float8_e4m3_dtype: np.dtype | None = None -_float8_e8m0fnu_dtype: np.dtype | None = None +_float8_e3m4_dtype: np.dtype = np.dtype(float8_e3m4) +_float8_e4m3_dtype: np.dtype = np.dtype(float8_e4m3) +_float8_e8m0fnu_dtype: np.dtype = np.dtype(float8_e8m0fnu) _float8_e4m3b11fnuz_dtype: np.dtype = np.dtype(float8_e4m3b11fnuz) _float8_e4m3fn_dtype: np.dtype = np.dtype(float8_e4m3fn) _float8_e4m3fnuz_dtype: np.dtype = np.dtype(float8_e4m3fnuz) @@ -111,9 +110,9 @@ def type(self) -> type: ... # fp4 support # TODO: remove Optional when minimum ml_dtypes version >= 0.5.0 -float4_e2m1fn: type[np.generic] | None = None +float4_e2m1fn: type[np.generic] = ml_dtypes.float4_e2m1fn -_float4_e2m1fn_dtype: np.dtype | None = None +_float4_e2m1fn_dtype: np.dtype = np.dtype(float4_e2m1fn) def supports_inf(dtype: DTypeLike) -> bool: """Return true if the dtype supports infinity, else return False.""" @@ -127,6 +126,10 @@ def supports_inf(dtype: DTypeLike) -> bool: _bfloat16_dtype: np.dtype = np.dtype(bfloat16) _custom_float_scalar_types = [ + float4_e2m1fn, + float8_e3m4, + float8_e4m3, + float8_e8m0fnu, float8_e4m3b11fnuz, float8_e4m3fn, float8_e4m3fnuz, @@ -135,6 +138,10 @@ def supports_inf(dtype: DTypeLike) -> bool: bfloat16, ] _custom_float_dtypes = [ + _float4_e2m1fn_dtype, + _float8_e3m4_dtype, + _float8_e4m3_dtype, + _float8_e8m0fnu_dtype, _float8_e4m3b11fnuz_dtype, _float8_e4m3fn_dtype, _float8_e4m3fnuz_dtype, @@ -143,6 +150,9 @@ def supports_inf(dtype: DTypeLike) -> bool: _bfloat16_dtype, ] _float8_dtypes = [ + _float8_e3m4_dtype, + _float8_e4m3_dtype, + _float8_e8m0fnu_dtype, _float8_e4m3b11fnuz_dtype, _float8_e4m3fn_dtype, _float8_e4m3fnuz_dtype, @@ -150,58 +160,28 @@ def supports_inf(dtype: DTypeLike) -> bool: _float8_e5m2fnuz_dtype, ] -_float4_dtypes: list[np.dtype] = [] - -# TODO: remove the if statements below when minimum ml_dtypes version >= 0.5.0 -if hasattr(ml_dtypes, "float8_e4m3"): - float8_e4m3 = ml_dtypes.float8_e4m3 - _float8_e4m3_dtype = np.dtype(float8_e4m3) - _custom_float_scalar_types.insert(0, float8_e4m3) # type: ignore[arg-type] - _custom_float_dtypes.insert(0, _float8_e4m3_dtype) - _float8_dtypes.insert(0, _float8_e4m3_dtype) -if hasattr(ml_dtypes, "float8_e3m4"): - float8_e3m4 = ml_dtypes.float8_e3m4 - _float8_e3m4_dtype = np.dtype(float8_e3m4) - _custom_float_scalar_types.insert(0, float8_e3m4) # type: ignore[arg-type] - _custom_float_dtypes.insert(0, _float8_e3m4_dtype) - _float8_dtypes.insert(0, _float8_e3m4_dtype) -if hasattr(ml_dtypes, "float8_e8m0fnu"): - float8_e8m0fnu = ml_dtypes.float8_e8m0fnu - _float8_e8m0fnu_dtype = np.dtype(float8_e8m0fnu) - _custom_float_scalar_types.insert(0, float8_e8m0fnu) # type: ignore[arg-type] - _custom_float_dtypes.insert(0, _float8_e8m0fnu_dtype) - _float8_dtypes.insert(0, _float8_e8m0fnu_dtype) -if hasattr(ml_dtypes, "float4_e2m1fn"): - float4_e2m1fn = ml_dtypes.float4_e2m1fn - _float4_e2m1fn_dtype = np.dtype(float4_e2m1fn) - _custom_float_scalar_types.insert(0, float4_e2m1fn) # type: ignore[arg-type] - _custom_float_dtypes.insert(0, _float4_e2m1fn_dtype) - _float4_dtypes.insert(0, _float4_e2m1fn_dtype) - -# 2-bit integer support -int2: type[np.generic] | None = None -uint2: type[np.generic] | None = None - -_int2_dtype: np.dtype | None = None -_uint2_dtype: np.dtype | None = None - -_intn_dtypes = [] - -# Remove the condition once the minimum ml_dtypes version required by JAX -# contains https://github.com/jax-ml/ml_dtypes/pull/154. -if hasattr(ml_dtypes, 'int2'): - int2 = ml_dtypes.int2 - uint2 = ml_dtypes.uint2 - _int2_dtype = np.dtype(int2) - _uint2_dtype = np.dtype(uint2) - _intn_dtypes.extend([_int2_dtype, _uint2_dtype]) +_float4_dtypes: list[np.dtype] = [ + _float4_e2m1fn_dtype, +] + +int2: type[np.generic] = ml_dtypes.int2 +uint2: type[np.generic] = ml_dtypes.uint2 + +_int2_dtype: np.dtype = np.dtype(int2) +_uint2_dtype: np.dtype = np.dtype(uint2) # 4-bit integer support int4: type[np.generic] = ml_dtypes.int4 uint4: type[np.generic] = ml_dtypes.uint4 _int4_dtype = np.dtype(int4) _uint4_dtype = np.dtype(uint4) -_intn_dtypes.extend([_int4_dtype, _uint4_dtype]) + +_intn_dtypes = [ + _int2_dtype, + _uint2_dtype, + _int4_dtype, + _uint4_dtype, +] # Default types. bool_ = np.bool_ @@ -472,9 +452,9 @@ def _issubdtype_cached(a: type | np.dtype | ExtendedDType, # to the normal scalar type hierarchy. if a_sctype in _custom_float_scalar_types: return b_sctype in {a_sctype, np.floating, np.inexact, np.number, np.generic} - if (int2 is not None and a_sctype == int2) or a_sctype == int4: + if a_sctype in [int2, int4]: return b_sctype in {a_sctype, np.signedinteger, np.integer, np.number, np.generic} - if (uint2 is not None and a_sctype == uint2) or a_sctype == uint4: + if a_sctype in [uint2, uint4]: return b_sctype in {a_sctype, np.unsignedinteger, np.integer, np.number, np.generic} # Otherwise, fall back to numpy.issubdtype @@ -491,6 +471,7 @@ def _issubdtype_cached(a: type | np.dtype | ExtendedDType, _unsigned_types: list[JAXType] _int_types: list[JAXType] _unsigned_types = [ + np.dtype(uint2), np.dtype(uint4), np.dtype('uint8'), np.dtype('uint16'), @@ -498,6 +479,7 @@ def _issubdtype_cached(a: type | np.dtype | ExtendedDType, np.dtype('uint64'), ] _signed_types = [ + np.dtype(int2), np.dtype(int4), np.dtype('int8'), np.dtype('int16'), @@ -505,11 +487,6 @@ def _issubdtype_cached(a: type | np.dtype | ExtendedDType, np.dtype('int64'), ] -if _int2_dtype is not None: - _signed_types.insert(0, _int2_dtype) -if _uint2_dtype is not None: - _unsigned_types.insert(0, _uint2_dtype) - _int_types = _unsigned_types + _signed_types _float_types: list[JAXType] = [ @@ -622,11 +599,7 @@ def _type_promotion_lattice(jax_numpy_dtype_promotion: str) -> dict[JAXType, lis This DAG maps each type to its immediately higher type on the lattice. """ b1, = _bool_types - if _int2_dtype is not None: - assert _uint2_dtype is not None - _uint2, uint4, u1, u2, u4, u8, _int2, int4, i1, i2, i4, i8 = _int_types - else: - uint4, u1, u2, u4, u8, int4, i1, i2, i4, i8 = _int_types + uint2, uint4, u1, u2, u4, u8, int2, int4, i1, i2, i4, i8 = _int_types *f1_types, bf, f2, f4, f8 = _float_types c4, c8 = _complex_types i_, f_, c_ = _weak_types @@ -634,19 +607,13 @@ def _type_promotion_lattice(jax_numpy_dtype_promotion: str) -> dict[JAXType, lis out: dict[JAXType, list[JAXType]] out = { b1: [i_], - i_: [u1, uint4, i1, int4], - uint4: [], u1: [i2, u2], u2: [i4, u4], u4: [i8, u8], u8: [f_], - int4: [], i1: [i2], i2: [i4], i4: [i8], i8: [f_], + i_: [u1, uint2, uint4, i1, int2, int4], + uint2: [], uint4: [], u1: [i2, u2], u2: [i4, u4], u4: [i8, u8], u8: [f_], + int2: [], int4: [], i1: [i2], i2: [i4], i4: [i8], i8: [f_], f_: [*f1_types, bf, f2, c_], **{t: [] for t in f1_types}, bf: [f4], f2: [f4], f4: [f8, c4], f8: [c8], c_: [c4], c4: [c8], c8: [], } - if _int2_dtype is not None: - out[i_].append(_int2_dtype) - out[_int2_dtype] = [] - if _uint2_dtype is not None: - out[i_].append(_uint2_dtype) - out[_uint2_dtype] = [] return out elif jax_numpy_dtype_promotion == 'strict': return { diff --git a/jax/_src/export/serialization.py b/jax/_src/export/serialization.py index ac97c11d1177..94c0baf642b6 100644 --- a/jax/_src/export/serialization.py +++ b/jax/_src/export/serialization.py @@ -357,16 +357,12 @@ def _deserialize_pytreedef_to_pytree(p: ser_flatbuf.PyTreeDef): dtypes._float8_e4m3fnuz_dtype: ser_flatbuf.DType.f8_e4m3fnuz, dtypes._float8_e5m2_dtype: ser_flatbuf.DType.f8_e5m2, dtypes._float8_e5m2fnuz_dtype: ser_flatbuf.DType.f8_e5m2fnuz, + dtypes._float8_e3m4_dtype: ser_flatbuf.DType.f8_e3m4, + dtypes._float8_e4m3_dtype: ser_flatbuf.DType.f8_e4m3, + dtypes._float8_e8m0fnu_dtype: ser_flatbuf.DType.f8_e8m0fnu, + dtypes._float4_e2m1fn_dtype: ser_flatbuf.DType.f4_e2m1fn, } -if dtypes._float8_e3m4_dtype is not None: - _dtype_to_dtype_kind[dtypes._float8_e3m4_dtype] = ser_flatbuf.DType.f8_e3m4 -if dtypes._float8_e4m3_dtype is not None: - _dtype_to_dtype_kind[dtypes._float8_e4m3_dtype] = ser_flatbuf.DType.f8_e4m3 -if dtypes._float8_e8m0fnu_dtype is not None: - _dtype_to_dtype_kind[dtypes._float8_e8m0fnu_dtype] = ser_flatbuf.DType.f8_e8m0fnu -if dtypes._float4_e2m1fn_dtype is not None: - _dtype_to_dtype_kind[dtypes._float4_e2m1fn_dtype] = ser_flatbuf.DType.f4_e2m1fn _dtype_kind_to_dtype = { kind: dtype for dtype, kind in _dtype_to_dtype_kind.items() } diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index a707981f5403..23d1b5dd9d89 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -185,24 +185,14 @@ def _is_ir_values(x: IrValues) -> bool: np.dtype(np.float64): ir.F64Type.get, np.dtype(np.complex64): lambda: ir.ComplexType.get(ir.F32Type.get()), np.dtype(np.complex128): lambda: ir.ComplexType.get(ir.F64Type.get()), + np.dtype(dtypes.int2): partial(ir.IntegerType.get_signless, 2), + np.dtype(dtypes.uint2): partial(ir.IntegerType.get_unsigned, 2), + np.dtype(dtypes.float8_e3m4): ir.Float8E3M4Type.get, + np.dtype(dtypes.float8_e4m3): ir.Float8E4M3Type.get, + np.dtype(dtypes.float8_e8m0fnu): ir.Float8E8M0FNUType.get, + np.dtype(dtypes.float4_e2m1fn): ir.Float4E2M1FNType.get, } - -if dtypes.int2 is not None: - assert dtypes.uint2 is not None - _dtype_to_ir_type[np.dtype(dtypes.int2)] = partial(ir.IntegerType.get_signless, 2) - _dtype_to_ir_type[np.dtype(dtypes.uint2)] = partial(ir.IntegerType.get_unsigned, 2) - -if dtypes.float8_e3m4 is not None: - _dtype_to_ir_type[np.dtype(dtypes.float8_e3m4)] = ir.Float8E3M4Type.get -if dtypes.float8_e4m3 is not None: - _dtype_to_ir_type[np.dtype(dtypes.float8_e4m3)] = ir.Float8E4M3Type.get -if dtypes.float8_e8m0fnu is not None: - _dtype_to_ir_type[np.dtype(dtypes.float8_e8m0fnu)] = ir.Float8E8M0FNUType.get - -if dtypes.float4_e2m1fn is not None: - _dtype_to_ir_type[np.dtype(dtypes.float4_e2m1fn)] = ir.Float4E2M1FNType.get - def dtype_to_ir_type(dtype: core.bint | np.dtype | np.generic) -> ir.Type: if isinstance(dtype, core.bint): # TODO Support different-size underlying dtypes to take advantage of the diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index b79c81e19195..a4fb04698365 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -2346,13 +2346,10 @@ def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike, np.dtype(dtypes.float8_e4m3fnuz), np.dtype(dtypes.float8_e5m2), np.dtype(dtypes.float8_e5m2fnuz), + np.dtype(dtypes.float8_e3m4), + np.dtype(dtypes.float8_e4m3), + np.dtype(dtypes.float8_e8m0fnu), ] - if dtypes.float8_e3m4 is not None: - fp8_dtypes += [np.dtype(dtypes.float8_e3m4)] - if dtypes.float8_e4m3 is not None: - fp8_dtypes += [np.dtype(dtypes.float8_e4m3)] - if dtypes.float8_e8m0fnu is not None: - fp8_dtypes += [np.dtype(dtypes.float8_e8m0fnu)] if lhs_dtype not in fp8_dtypes or rhs_dtype not in fp8_dtypes: raise ValueError( f"The dot algorithm '{self}' requires both inputs to have float8 " @@ -5602,13 +5599,9 @@ def accuracy_attr(accuracy) -> hlo.ResultAccuracyAttr: def _handle_dot_precision(ctx, lhs, rhs, precision, platform): def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes): fp8_dtypes = (dtypes.float8_e4m3fn, dtypes.float8_e5m2, - dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz) - if dtypes.float8_e3m4 is not None: - fp8_dtypes += (dtypes.float8_e3m4,) - if dtypes.float8_e4m3 is not None: - fp8_dtypes += (dtypes.float8_e4m3,) - if dtypes.float8_e8m0fnu is not None: - fp8_dtypes += (dtypes.float8_e8m0fnu,) + dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz, + dtypes.float8_e3m4, dtypes.float8_e4m3, + dtypes.float8_e8m0fnu) return _lhs_dtypes in fp8_dtypes and _rhs_dtypes in fp8_dtypes # The *_ lets us reuse this for ragged_dot_general, which has group_sizes. diff --git a/jax/_src/numpy/scalar_types.py b/jax/_src/numpy/scalar_types.py index 2f9954488b41..2b0e04adc997 100644 --- a/jax/_src/numpy/scalar_types.py +++ b/jax/_src/numpy/scalar_types.py @@ -68,33 +68,27 @@ def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta: return meta bool_ = _make_scalar_type(np.bool_) -if dtypes.uint2 is not None: - uint2 = _make_scalar_type(dtypes.uint2) +uint2 = _make_scalar_type(dtypes.uint2) uint4 = _make_scalar_type(dtypes.uint4) uint8 = _make_scalar_type(np.uint8) uint16 = _make_scalar_type(np.uint16) uint32 = _make_scalar_type(np.uint32) uint64 = _make_scalar_type(np.uint64) -if dtypes.int2 is not None: - int2 = _make_scalar_type(dtypes.int2) +int2 = _make_scalar_type(dtypes.int2) int4 = _make_scalar_type(dtypes.int4) int8 = _make_scalar_type(np.int8) int16 = _make_scalar_type(np.int16) int32 = _make_scalar_type(np.int32) int64 = _make_scalar_type(np.int64) -if dtypes.float8_e3m4 is not None: - float8_e3m4 = _make_scalar_type(dtypes.float8_e3m4) -if dtypes.float8_e4m3 is not None: - float8_e4m3 = _make_scalar_type(dtypes.float8_e4m3) -if dtypes.float8_e8m0fnu is not None: - float8_e8m0fnu = _make_scalar_type(dtypes.float8_e8m0fnu) +float4_e2m1fn = _make_scalar_type(dtypes.float4_e2m1fn) +float8_e3m4 = _make_scalar_type(dtypes.float8_e3m4) +float8_e4m3 = _make_scalar_type(dtypes.float8_e4m3) +float8_e8m0fnu = _make_scalar_type(dtypes.float8_e8m0fnu) float8_e4m3fn = _make_scalar_type(dtypes.float8_e4m3fn) float8_e4m3fnuz = _make_scalar_type(dtypes.float8_e4m3fnuz) float8_e5m2 = _make_scalar_type(dtypes.float8_e5m2) float8_e5m2fnuz = _make_scalar_type(dtypes.float8_e5m2fnuz) float8_e4m3b11fnuz = _make_scalar_type(dtypes.float8_e4m3b11fnuz) -if dtypes.float4_e2m1fn is not None: - float4_e2m1fn = _make_scalar_type(dtypes.float4_e2m1fn) bfloat16 = _make_scalar_type(dtypes.bfloat16) float16 = _make_scalar_type(np.float16) float32 = single = _make_scalar_type(np.float32) diff --git a/jax/_src/public_test_util.py b/jax/_src/public_test_util.py index 59ddb73dc9e1..3b1e24bc9c50 100644 --- a/jax/_src/public_test_util.py +++ b/jax/_src/public_test_util.py @@ -46,16 +46,22 @@ def _dtype(x: Any) -> np.dtype: _default_tolerance: ToleranceDict = { _dtypes.float0: 0, np.dtype(np.bool_): 0, + np.dtype(_dtypes.int2): 0, np.dtype(_dtypes.int4): 0, np.dtype(np.int8): 0, np.dtype(np.int16): 0, np.dtype(np.int32): 0, np.dtype(np.int64): 0, + np.dtype(_dtypes.uint2): 0, np.dtype(_dtypes.uint4): 0, np.dtype(np.uint8): 0, np.dtype(np.uint16): 0, np.dtype(np.uint32): 0, np.dtype(np.uint64): 0, + np.dtype(_dtypes.float4_e2m1fn): 1e0, + np.dtype(_dtypes.float8_e3m4): 1e-1, + np.dtype(_dtypes.float8_e4m3): 1e-1, + np.dtype(_dtypes.float8_e8m0fnu): 1e0, np.dtype(_dtypes.float8_e4m3b11fnuz): 1e-1, np.dtype(_dtypes.float8_e4m3fn): 1e-1, np.dtype(_dtypes.float8_e4m3fnuz): 1e-1, @@ -69,16 +75,15 @@ def _dtype(x: Any) -> np.dtype: np.dtype(np.complex128): 1e-15, } -if _dtypes.int2 is not None: - assert _dtypes.uint2 is not None - _default_tolerance[np.dtype(_dtypes.int2)] = 0 - _default_tolerance[np.dtype(_dtypes.uint2)] = 0 - def default_tolerance(): return _default_tolerance default_gradient_tolerance: ToleranceDict = { + np.dtype(_dtypes.float4_e2m1fn): 1e0, + np.dtype(_dtypes.float8_e3m4): 1e-1, + np.dtype(_dtypes.float8_e4m3): 1e-1, + np.dtype(_dtypes.float8_e8m0fnu): 1e0, np.dtype(_dtypes.float8_e4m3b11fnuz): 1e-1, np.dtype(_dtypes.float8_e4m3fn): 1e-1, np.dtype(_dtypes.float8_e4m3fnuz): 1e-1, @@ -92,19 +97,6 @@ def default_tolerance(): np.dtype(np.complex128): 1e-5, } -# TODO: make this unconditional when ml_dtypes>=0.5.0 is required -if _dtypes.float8_e3m4 is not None: - _default_tolerance[np.dtype(_dtypes.float8_e3m4)] = 1e-1 - default_gradient_tolerance[np.dtype(_dtypes.float8_e3m4)] = 1e-1 -if _dtypes.float8_e4m3 is not None: - _default_tolerance[np.dtype(_dtypes.float8_e4m3)] = 1e-1 - default_gradient_tolerance[np.dtype(_dtypes.float8_e4m3)] = 1e-1 -if _dtypes.float8_e8m0fnu is not None: - _default_tolerance[np.dtype(_dtypes.float8_e8m0fnu)] = 1e0 - default_gradient_tolerance[np.dtype(_dtypes.float8_e8m0fnu)] = 1e0 -if _dtypes.float4_e2m1fn is not None: - _default_tolerance[np.dtype(_dtypes.float4_e2m1fn)] = 1e0 - default_gradient_tolerance[np.dtype(_dtypes.float4_e2m1fn)] = 1e0 def is_python_scalar(val: Any) -> bool: return not isinstance(val, np.generic) and isinstance(val, (bool, int, float, complex)) @@ -115,6 +107,10 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''): return custom_float_dtypes = [ + _dtypes.float4_e2m1fn, + _dtypes.float8_e8m0fnu, + _dtypes.float8_e3m4, + _dtypes.float8_e4m3, _dtypes.float8_e4m3b11fnuz, _dtypes.float8_e4m3fn, _dtypes.float8_e4m3fnuz, @@ -123,15 +119,6 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''): _dtypes.bfloat16, ] - if _dtypes.float8_e4m3 is not None: - custom_float_dtypes.insert(0, _dtypes.float8_e4m3) - if _dtypes.float8_e3m4 is not None: - custom_float_dtypes.insert(0, _dtypes.float8_e3m4) - if _dtypes.float8_e8m0fnu is not None: - custom_float_dtypes.insert(0, _dtypes.float8_e8m0fnu) - if _dtypes.float4_e2m1fn is not None: - custom_float_dtypes.insert(0, _dtypes.float4_e2m1fn) - def maybe_upcast(x): if x.dtype in custom_float_dtypes: return x.astype(np.float32) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 1cd9546a1655..c3f7fb4c4139 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -1632,15 +1632,11 @@ def custom_floats(self): _dtypes.float8_e4m3fnuz, _dtypes.float8_e5m2, _dtypes.float8_e5m2fnuz, + _dtypes.float8_e3m4, + _dtypes.float8_e4m3, + _dtypes.float8_e8m0fnu, + _dtypes.float4_e2m1fn, ] - if _dtypes.float8_e3m4 is not None: - float_dtypes += [_dtypes.float8_e3m4] - if _dtypes.float8_e4m3 is not None: - float_dtypes += [_dtypes.float8_e4m3] - if _dtypes.float8_e8m0fnu is not None: - float_dtypes += [_dtypes.float8_e8m0fnu] - if _dtypes.float4_e2m1fn is not None: - float_dtypes += [_dtypes.float4_e2m1fn] return self.supported(float_dtypes) @_cached_property diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index 31cca3578916..b6cfb1ff06ac 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -211,13 +211,18 @@ double as double, float16 as float16, float32 as float32, + float4_e2m1fn as float4_e2m1fn, float64 as float64, + float8_e3m4 as float8_e3m4, + float8_e4m3 as float8_e4m3, float8_e4m3b11fnuz as float8_e4m3b11fnuz, float8_e4m3fn as float8_e4m3fn, float8_e4m3fnuz as float8_e4m3fnuz, float8_e5m2 as float8_e5m2, float8_e5m2fnuz as float8_e5m2fnuz, + float8_e8m0fnu as float8_e8m0fnu, float_ as float_, + int2 as int2, int4 as int4, int8 as int8, int16 as int16, @@ -226,6 +231,7 @@ int_ as int_, single as single, uint as uint, + uint2 as uint2, uint4 as uint4, uint8 as uint8, uint16 as uint16, @@ -295,26 +301,6 @@ unsignedinteger as unsignedinteger, ) -# TODO(slebedev): Remove the try-except once we upgrade to ml_dtypes 0.4.1. -try: - from jax._src.numpy.scalar_types import ( - int2 as int2, - uint2 as uint2, - ) -except ImportError: - pass - -# TODO: Remove the try-except once we upgrade to ml_dtypes 0.5.0 -try: - from jax._src.numpy.scalar_types import ( - float8_e3m4 as float8_e3m4, - float8_e4m3 as float8_e4m3, - float8_e8m0fnu as float8_e8m0fnu, - float4_e2m1fn as float4_e2m1fn, - ) -except ImportError: - pass - from jax._src.numpy.array_api_metadata import ( __array_api_version__ as __array_api_version__, __array_namespace_info__ as __array_namespace_info__, diff --git a/jax/tools/jax_to_ir.py b/jax/tools/jax_to_ir.py index 904ce509a87e..47b85382f8bf 100644 --- a/jax/tools/jax_to_ir.py +++ b/jax/tools/jax_to_ir.py @@ -240,16 +240,12 @@ def parse_shape_str(s): _DT = { 'pred': jnp.bool_, - 'u4': jnp.uint4, 'u8': jnp.uint8, 'u16': jnp.uint16, 'u32': jnp.uint32, 'u64': jnp.uint64, - 's4': jnp.int4, 's8': jnp.int8, 's16': jnp.int16, 's32': jnp.int32, 's64': jnp.int64, + 'u2': jnp.uint2, 'u4': jnp.uint4, 'u8': jnp.uint8, 'u16': jnp.uint16, 'u32': jnp.uint32, 'u64': jnp.uint64, + 's2': jnp.int2, 's4': jnp.int4, 's8': jnp.int8, 's16': jnp.int16, 's32': jnp.int32, 's64': jnp.int64, 'bf16': jnp.bfloat16, 'f16': jnp.float16, 'f32': jnp.float32, 'f64': jnp.float64, 'c64': jnp.complex64, 'c128': jnp.complex128 } -if hasattr(jnp, 'int2'): - _DT['s2'] = jnp.int2 -if hasattr(jnp, 'uint2'): - _DT['u2'] = jnp.uint2 _SHAPE_RE = re.compile(f"^({'|'.join(_DT)})\\[\\s*(\\d*[\\s*,\\d+]*)\\s*\\]$") diff --git a/jaxlib/xla/xla_client.py b/jaxlib/xla/xla_client.py index 776a22444208..21ea81ac6efa 100644 --- a/jaxlib/xla/xla_client.py +++ b/jaxlib/xla/xla_client.py @@ -238,13 +238,10 @@ def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): PrimitiveType = _xla.PrimitiveType bfloat16 = ml_dtypes.bfloat16 -# TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. -# Also, it would be better to conditionally import these based on whether they -# are in the current version of ml_dtypes. -# float4_e2m1fn = ml_dtypes.float4_e2m1fn -# float8_e3m4 = ml_dtypes.float8_e3m4 -# float8_e4m3 = ml_dtypes.float8_e4m3 -# float8_e8m0fnu = ml_dtypes.float8_e8m0fnu +float4_e2m1fn = ml_dtypes.float4_e2m1fn +float8_e3m4 = ml_dtypes.float8_e3m4 +float8_e4m3 = ml_dtypes.float8_e4m3 +float8_e8m0fnu = ml_dtypes.float8_e8m0fnu float8_e4m3fn = ml_dtypes.float8_e4m3fn float8_e4m3b11fnuz = ml_dtypes.float8_e4m3b11fnuz float8_e4m3fnuz = ml_dtypes.float8_e4m3fnuz diff --git a/tests/dtypes_test.py b/tests/dtypes_test.py index 87380443f4cb..d8fb30397b27 100644 --- a/tests/dtypes_test.py +++ b/tests/dtypes_test.py @@ -46,30 +46,19 @@ np.dtype('uint64')] unsigned_dtypes = list(np_unsigned_dtypes) -intn_dtypes = [np.dtype('int4'), np.dtype('uint4')] -signed_dtypes += [np.dtype('int4')] -unsigned_dtypes += [np.dtype('uint4')] -if dtypes.int2 is not None: - assert dtypes.uint2 is not None - intn_dtypes[:0] = [np.dtype('int2'), np.dtype('uint2')] - signed_dtypes[:0] = [np.dtype('int2')] - unsigned_dtypes[:0] = [np.dtype('uint2')] - -np_float_dtypes = [np.dtype('float16'), np.dtype('float32'), - np.dtype('float64')] +intn_dtypes = [np.dtype('int2'), np.dtype('uint2'), np.dtype('int4'), np.dtype('uint4')] +signed_dtypes += [np.dtype('int2'), np.dtype('int4')] +unsigned_dtypes += [np.dtype('uint2'), np.dtype('uint4')] + +np_float_dtypes = [np.dtype('float16'), np.dtype('float32'), np.dtype('float64')] float_dtypes = [np.dtype(dtypes.bfloat16)] + np_float_dtypes custom_float_dtypes = [np.dtype(dtypes.bfloat16)] fp8_dtypes = [np.dtype(dtypes.float8_e4m3b11fnuz), np.dtype(dtypes.float8_e4m3fn), np.dtype(dtypes.float8_e4m3fnuz), np.dtype(dtypes.float8_e5m2), - np.dtype(dtypes.float8_e5m2fnuz)] -if dtypes.float8_e3m4 is not None: - fp8_dtypes += [np.dtype(dtypes.float8_e3m4)] -if dtypes.float8_e4m3 is not None: - fp8_dtypes += [np.dtype(dtypes.float8_e4m3)] -if dtypes.float8_e8m0fnu is not None: - fp8_dtypes += [np.dtype(dtypes.float8_e8m0fnu)] + np.dtype(dtypes.float8_e5m2fnuz), np.dtype(dtypes.float8_e3m4), + np.dtype(dtypes.float8_e4m3), np.dtype(dtypes.float8_e8m0fnu)] float_dtypes += fp8_dtypes custom_float_dtypes += fp8_dtypes diff --git a/tests/jax_to_ir_test.py b/tests/jax_to_ir_test.py index f600a08f5dc4..4eb8190b712f 100644 --- a/tests/jax_to_ir_test.py +++ b/tests/jax_to_ir_test.py @@ -114,15 +114,13 @@ def test_parse_shape_str(self): self.assertParsedShape('f32[]', [], jnp.float32) self.assertParsedShape('f32[1,2,3]', [1, 2, 3], jnp.float32) self.assertParsedShape('pred[1]', [1], jnp.bool_) - if hasattr(jnp, 'int2'): - self.assertParsedShape('s2[1]', [1], jnp.int2) + self.assertParsedShape('s2[1]', [1], jnp.int2) self.assertParsedShape('s4[1]', [1], jnp.int4) self.assertParsedShape('s8[1]', [1], jnp.int8) self.assertParsedShape('s16[1]', [1], jnp.int16) self.assertParsedShape('s32[1]', [1], jnp.int32) self.assertParsedShape('s64[1]', [1], jnp.int64) - if hasattr(jnp, 'uint2'): - self.assertParsedShape('u2[1]', [1], jnp.uint2) + self.assertParsedShape('u2[1]', [1], jnp.uint2) self.assertParsedShape('u4[1]', [1], jnp.uint4) self.assertParsedShape('u8[1]', [1], jnp.uint8) self.assertParsedShape('u16[1]', [1], jnp.uint16) From 4a8f520a8747207cb7d5b7f08cc4a7d6418aa539 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 28 Mar 2025 07:53:06 -0700 Subject: [PATCH 252/483] Replace uses of deprecated `Shape::rank()` with: - `dimensions().size()` if it's OK for the result to be changed to an unsigned number, - `dimensions_size()` if it's important that the result is a signed number. This should be a pure refactoring that doesn't affect the code's behavior. Note that `rank()` returns `int64_t` and `dimensions().size()` returns `size_t`. Sometimes the change of the signedness is not desirable, and we use `dimensions_size()`, which returns `int`, in such cases. PiperOrigin-RevId: 741524661 --- jaxlib/gpu/py_client_gpu.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jaxlib/gpu/py_client_gpu.cc b/jaxlib/gpu/py_client_gpu.cc index 59cc385825a0..861ffce3e749 100644 --- a/jaxlib/gpu/py_client_gpu.cc +++ b/jaxlib/gpu/py_client_gpu.cc @@ -182,7 +182,7 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, options.dims = absl::Span( reinterpret_cast(array.shape()), array.ndim()); absl::InlinedVector reversed_layout; - reversed_layout.resize(expected_shape.rank()); + reversed_layout.resize(expected_shape.dimensions().size()); absl::c_reverse_copy(expected_shape.layout().minor_to_major(), reversed_layout.begin()); options.permutation = reversed_layout; From cf12cc5fc5cd9b76e3a09da99084fc9a1e943b09 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Fri, 28 Mar 2025 08:05:04 -0700 Subject: [PATCH 253/483] [Mosaic GPU] Ignore layouts that are already set when computing default vector size in layout inference. PiperOrigin-RevId: 741528085 --- .../mosaic/gpu/inference_utils.py | 20 +++++++++-- .../mosaic/gpu/layout_inference.py | 34 ++++++++++++------- 2 files changed, 39 insertions(+), 15 deletions(-) diff --git a/jax/experimental/mosaic/gpu/inference_utils.py b/jax/experimental/mosaic/gpu/inference_utils.py index 6362626404c5..73ce23c427cd 100644 --- a/jax/experimental/mosaic/gpu/inference_utils.py +++ b/jax/experimental/mosaic/gpu/inference_utils.py @@ -95,6 +95,22 @@ def has_out_transforms_set(op: MlirOperation) -> bool: return "out_transforms" in op.attributes +def attr_element( + attr_name: str, op: MlirOperation, index: int +) -> ir.Attribute | None: + """Returns `op.attributes[attr_name][index]` if it exists, otherwise None. + + If `op.attributes[attr_name]` exists, then `index` must be a valid index into + the attribute array. + """ + if attr_name not in op.attributes: + return None + attr = op.attributes[attr_name] + if not attr: + return None + return op.attributes[attr_name][index] # type: ignore + + def _in_attr_for_operand( op: MlirOperation, operand: ir.Value, @@ -109,9 +125,7 @@ def _in_attr_for_operand( operand_number = [o for o in op.operands if predicate(o)].index(operand) - if attr_name not in op.attributes: - return None - return op.attributes[attr_name][operand_number] # type: ignore + return attr_element(attr_name, op, operand_number) in_layout_for_operand = partial( diff --git a/jax/experimental/mosaic/gpu/layout_inference.py b/jax/experimental/mosaic/gpu/layout_inference.py index 402a8c08a4ef..e49b3677b2ad 100644 --- a/jax/experimental/mosaic/gpu/layout_inference.py +++ b/jax/experimental/mosaic/gpu/layout_inference.py @@ -548,21 +548,31 @@ def inference_step(op: ir.Operation): # make sure to derive a single vector size in order to avoid relayouts at # lowering time. default_vector_size = math.inf - - def update_default_vector_size(op: ir.OpView): + def update_default_vector_size_from_vector(v: ir.Value): nonlocal default_vector_size - for v in list(op.operands) + list(op.results): - if ir.VectorType.isinstance(v.type): - max_vec_size_for_v = ( - np.prod(cast(ir.ShapedType, v.type).shape) // fa.WARPGROUP_SIZE - ) - desired_vec_size = 8 // utils.bytewidth(v.type.element_type) - default_vector_size = min( - default_vector_size, max_vec_size_for_v, desired_vec_size - ) + max_vec_size_for_v = ( + np.prod(cast(ir.ShapedType, v.type).shape) // fa.WARPGROUP_SIZE + ) + desired_vec_size = 8 // utils.bytewidth(v.type.element_type) + default_vector_size = min( + default_vector_size, max_vec_size_for_v, desired_vec_size + ) + + def update_default_vector_size_from_op(op: ir.OpView): + for i, v in enumerate( + filter(lambda v: ir.VectorType.isinstance(v.type), op.operands) + ): + if inference_utils.attr_element("in_layouts", op, i) is None: + update_default_vector_size_from_vector(v) + + for i, v in enumerate( + filter(lambda v: ir.VectorType.isinstance(v.type), op.results) + ): + if inference_utils.attr_element("out_layouts", op, i) is None: + update_default_vector_size_from_vector(v) for op in module.body: - traverse_op(op, update_default_vector_size) + traverse_op(op, update_default_vector_size_from_op) if default_vector_size == math.inf: # Nothing to annotate. return From 968bbd2bf25e3ace63a4e6938adc70d5e4540caa Mon Sep 17 00:00:00 2001 From: Ayaka Date: Fri, 28 Mar 2025 08:09:01 -0700 Subject: [PATCH 254/483] Add a small atol bump to `betainc` test in `LaxVmapOpTest` PiperOrigin-RevId: 741529177 --- jax/_src/internal_test_util/lax_test_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/internal_test_util/lax_test_util.py b/jax/_src/internal_test_util/lax_test_util.py index 4e28791e9cee..767b41dc8ba0 100644 --- a/jax/_src/internal_test_util/lax_test_util.py +++ b/jax/_src/internal_test_util/lax_test_util.py @@ -304,7 +304,7 @@ def lax_ops(): float_dtypes, test_util.rand_uniform, { - np.float32: 1e-5, + np.float32: 2e-5, np.float64: 1e-12, }, ), From d974b090565022ef7139c4c407a047a7f2e406ea Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 28 Mar 2025 08:29:13 -0700 Subject: [PATCH 255/483] Fix error in build.py when trying to build aarch64 jaxlib wheel. PiperOrigin-RevId: 741534342 --- build/build.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/build/build.py b/build/build.py index 1900073fc132..f8c0ccbfa6a4 100755 --- a/build/build.py +++ b/build/build.py @@ -414,10 +414,7 @@ async def main(): for option in args.bazel_startup_options: bazel_command_base.append(option) - if ( - not hasattr(args,"use_new_wheel_build_rule") - or args.command == "requirements_update" - ): + if args.command == "requirements_update" or not args.use_new_wheel_build_rule: bazel_command_base.append("run") else: bazel_command_base.append("build") From 98b763cfe48a14749252e29ceb862f9ca228ccbe Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 28 Mar 2025 08:45:45 -0700 Subject: [PATCH 256/483] Use a 16 core Windows runner when building artifacts Also, switch the Linux aarch64 runner type to t2a as we run the tests on t2a. PiperOrigin-RevId: 741538543 --- .github/workflows/build_artifacts.yml | 12 ++++++------ .github/workflows/wheel_tests_continuous.yml | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index c2e7acb91f7a..37a791784506 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -16,8 +16,8 @@ on: default: "linux-x86-n2-16" options: - "linux-x86-n2-16" - - "linux-arm64-c4a-64" - - "windows-x86-n2-64" + - "linux-arm64-t2a-48" + - "windows-x86-n2-16" artifact: description: "Which JAX artifact to build?" type: choice @@ -119,11 +119,11 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Enable RBE if building on Linux x86 - if: contains(inputs.runner, 'linux-x86') + - name: Enable RBE if building on Linux x86 or Windows x86 + if: contains(inputs.runner, 'linux-x86') || contains(inputs.runner, 'windows-x86') run: echo "JAXCI_BUILD_ARTIFACT_WITH_RBE=1" >> $GITHUB_ENV - - name: Enable Bazel remote cache (with writes enabled) if building on Linux Aarch64 or Windows x86 - if: contains(inputs.runner, 'linux-arm64') || contains(inputs.runner, 'windows-x86') + - name: Enable Bazel remote cache (with writes enabled) if building on Linux Aarch64 + if: contains(inputs.runner, 'linux-arm64') run: echo "JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE=1" >> $GITHUB_ENV # Halt for testing - name: Wait For Connection diff --git a/.github/workflows/wheel_tests_continuous.yml b/.github/workflows/wheel_tests_continuous.yml index 530e0a9b0768..3739c9267730 100644 --- a/.github/workflows/wheel_tests_continuous.yml +++ b/.github/workflows/wheel_tests_continuous.yml @@ -44,7 +44,7 @@ jobs: fail-fast: false # don't cancel all jobs on failure matrix: # Runner OS and Python values need to match the matrix stategy in the CPU tests job - runner: ["linux-x86-n2-16", "linux-arm64-t2a-48", "windows-x86-n2-64"] + runner: ["linux-x86-n2-16", "linux-arm64-t2a-48", "windows-x86-n2-16"] artifact: ["jaxlib"] python: ["3.10"] # Note: For reasons unknown, Github actions groups jobs with the same top-level name in the From 4bfe0d10e95cc6d14eec74c46dcaf897322044ee Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 28 Mar 2025 09:13:50 -0700 Subject: [PATCH 257/483] Remove get_emit_python_callback_descriptor from the type stubs. The function itself was already deleted. PiperOrigin-RevId: 741546212 --- jaxlib/xla/xla_extension/__init__.pyi | 6 ------ 1 file changed, 6 deletions(-) diff --git a/jaxlib/xla/xla_extension/__init__.pyi b/jaxlib/xla/xla_extension/__init__.pyi index 3a6435824b67..d002080b17bc 100644 --- a/jaxlib/xla/xla_extension/__init__.pyi +++ b/jaxlib/xla/xla_extension/__init__.pyi @@ -580,12 +580,6 @@ class Client: ) -> LoadedExecutable: ... def heap_profile(self) -> bytes: ... def defragment(self) -> _Status: ... - def get_emit_python_callback_descriptor( - self, - callable: Callable, - operand_shapes: Sequence[Shape], - results_shapes: Sequence[Shape], - ) -> Tuple[Any, Any]: ... def make_python_callback_from_host_send_and_recv( self, callable: Callable, From 5495c56990956c92f2c671a47c99cc85c018df05 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 28 Mar 2025 09:21:29 -0700 Subject: [PATCH 258/483] Remove a use of XlaComputation from call_tf. call_tf is the only remaining user of the XlaComputation type in JAX. Change it to use a new helper function that converts an HLO proto to stablehlo bytecode without using the XlaComputation Python bindings. Also port the code to parse types from the stablehlo rather than the HLO. Remove jax.interpreters.mlir.xla_computation_to_mlir_module. PiperOrigin-RevId: 741548298 --- jax/experimental/jax2tf/call_tf.py | 99 +++++++++++++++++++++--------- jax/interpreters/mlir.py | 1 - jaxlib/xla/mlir.cc | 14 +++++ jaxlib/xla/xla_client.py | 2 +- jaxlib/xla/xla_extension/mlir.pyi | 1 + 5 files changed, 85 insertions(+), 32 deletions(-) diff --git a/jax/experimental/jax2tf/call_tf.py b/jax/experimental/jax2tf/call_tf.py index 98c1c20cd6e5..3b175cd64c4c 100644 --- a/jax/experimental/jax2tf/call_tf.py +++ b/jax/experimental/jax2tf/call_tf.py @@ -41,11 +41,14 @@ from jax._src import effects from jax._src import util from jax._src.lib import xla_client +from jax._src.lib import xla_extension as _xla +from jax._src.lib import jaxlib_extension_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import func as func_dialect from jax._src.lib.mlir.dialects import hlo from jax.experimental.jax2tf import jax2tf as jax2tf_internal from jax._src.interpreters import mlir +import ml_dtypes import numpy as np import tensorflow as tf @@ -468,6 +471,47 @@ def is_fully_known_shape(s): call_tf_p.def_effectful_abstract_eval(_call_tf_abstract_eval) +def _mlir_type_to_numpy_dtype(type: ir.Type) -> np.dtype: + """Converts an MLIR scalar type to a NumPy dtype.""" + + if ir.IntegerType.isinstance(type): + type = ir.IntegerType(type) + width = type.width + if width == 1: + return np.dtype(np.bool_) + elif width == 8: + return np.dtype(np.uint8 if type.is_unsigned else np.int8) + elif width == 16: + return np.dtype(np.uint16 if type.is_unsigned else np.int16) + elif width == 32: + return np.dtype(np.uint32 if type.is_unsigned else np.int32) + elif width == 64: + return np.dtype(np.uint64 if type.is_unsigned else np.int64) + else: + raise ValueError(f"Unsupported integer width: {width}") + + elif ir.F16Type.isinstance(type): + return np.dtype(np.float16) + elif ir.F32Type.isinstance(type): + return np.dtype(np.float32) + elif ir.F64Type.isinstance(type): + return np.dtype(np.float64) + elif ir.BF16Type.isinstance(type): + return np.dtype(ml_dtypes.bfloat16) + + elif ir.ComplexType.isinstance(type): + element_type = ir.ComplexType(type).element_type + if ir.F32Type.isinstance(element_type): + return np.dtype(np.complex64) + elif ir.F64Type.isinstance(element_type): + return np.dtype(np.complex128) + else: + raise ValueError(f"Unsupported complex element type: {element_type}") + + else: + raise TypeError(f"Unsupported MLIR type for NumPy conversion: {type}") + + def _call_tf_lowering( ctx: mlir.LoweringRuleContext, *args_op, @@ -555,33 +599,12 @@ def convert_to_spec(x): "\n\nCaught TensorFlow exception: " + str(e)) raise ValueError(msg) from e - xla_comp = xla_client.XlaComputation(func_tf_hlo) - - # Canonicalize the results; e.g., makes them x32 if JAX is in 32-bit mode - def canonical_res_aval(res_shape: xla_client.Shape) -> core.ShapedArray: - if not res_shape.is_static(): - msg = ("Compiled TensorFlow function has dynamic output shape " + - f"{res_shape}. call_tf can used " + - "in a staged context (under jax.jit, lax.scan, etc.) only with " + - "compilable functions with static output shapes. " + - "See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion.") - raise ValueError(msg) - - res_dtype = res_shape.numpy_dtype() - jax_res_dtype = dtypes.canonicalize_dtype(res_dtype) - return core.ShapedArray(res_shape.dimensions(), jax_res_dtype) - - result_shape = xla_comp.program_shape().result_shape() - if not result_shape.is_tuple(): - # TF does not wrap singletons as tuples, but JAX expects tuples because - # call_tf is a multiple_results primitive. - result_shapes = (result_shape,) + if jaxlib_extension_version >= 324: + stablehlo = _xla.mlir.hlo_to_stablehlo(func_tf_hlo) else: - result_shapes = result_shape.tuple_shapes() # type: ignore - - result_avals = tuple(map(canonical_res_aval, result_shapes)) - - submodule = mlir.xla_computation_to_mlir_module(xla_comp) + xla_comp = xla_client.XlaComputation(func_tf_hlo) + stablehlo = _xla.mlir.xla_computation_to_mlir_module(xla_comp) + submodule = ir.Module.parse(stablehlo) symtab = ir.SymbolTable(submodule.operation) callee_result_types = symtab["main"].type.results fn = mlir.merge_mlir_modules(ctx.module_context.module, @@ -600,10 +623,26 @@ def canonical_res_aval(res_shape: xla_client.Shape) -> core.ShapedArray: ) outputs = [] - for op, res_aval, res_shape in zip(flat_results, result_avals, - result_shapes): - if res_aval.dtype != res_shape.numpy_dtype(): - op = hlo.ConvertOp(mlir.aval_to_ir_type(res_aval), op).result + for op, res_type in zip(flat_results, callee_result_types): + if not res_type.has_static_shape: + msg = ( + "Compiled TensorFlow function has dynamic output shape " + + f"{res_type}. call_tf can used in a staged context (under jax.jit," + " lax.scan, etc.) only with compilable functions with static" + " output shapes. See" + " https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf" + " for a discussion." + ) + raise ValueError(msg) + + res_dtype = _mlir_type_to_numpy_dtype(res_type.element_type) + # Canonicalize the results; e.g., makes them x32 if JAX is in 32-bit mode + jax_res_dtype = dtypes.canonicalize_dtype(res_dtype) + if res_dtype != jax_res_dtype: + op = hlo.ConvertOp( + mlir.aval_to_ir_type(core.ShapedArray(res_type.shape, jax_res_dtype)), + op, + ).result outputs.append(op) return outputs diff --git a/jax/interpreters/mlir.py b/jax/interpreters/mlir.py index 0f32799f7ea9..8a615be968a6 100644 --- a/jax/interpreters/mlir.py +++ b/jax/interpreters/mlir.py @@ -63,7 +63,6 @@ register_lowering as register_lowering, shape_tensor as shape_tensor, token_type as token_type, - xla_computation_to_mlir_module as xla_computation_to_mlir_module, ) from jax._src.mesh import Mesh as Mesh diff --git a/jaxlib/xla/mlir.cc b/jaxlib/xla/mlir.cc index 987856daa983..e045f7284ec6 100644 --- a/jaxlib/xla/mlir.cc +++ b/jaxlib/xla/mlir.cc @@ -75,6 +75,17 @@ void EnablePrintBeforeAndAfter(mlir::PassManager& pm) { pm.enableIRPrinting(print_before, print_after); } +absl::StatusOr HloToStableHlo(const nb::bytes& hlo_module_proto) { + mlir::MLIRContext context; + if (VLOG_IS_ON(3)) context.disableMultithreading(); + HloModuleProto proto; + proto.ParseFromArray(hlo_module_proto.c_str(), hlo_module_proto.size()); + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ConvertHloToStablehlo(context, &proto)); + TF_ASSIGN_OR_RETURN(std::string bytecode, SerializeUsingBytecode(*module)); + return nb::bytes(bytecode.data(), bytecode.size()); +} + // Converts an XlaComputation to a StableHLO mlir::Module string. // Exists for backwards compatibility. // TODO(phawkins): port remaining users of XlaComputations to use mlir::Modules @@ -180,6 +191,9 @@ absl::StatusOr PyDeserializePortableArtifact( void BuildMlirSubmodule(nb::module_& m) { nb::module_ mlir_module = m.def_submodule("mlir", "MLIR/XLA integration"); + mlir_module.def("hlo_to_stablehlo", xla::ValueOrThrowWrapper(HloToStableHlo), + nb::arg("computation")); + mlir_module.def("xla_computation_to_mlir_module", xla::ValueOrThrowWrapper(PyXlaComputationToMlirModule), nb::arg("computation")); diff --git a/jaxlib/xla/xla_client.py b/jaxlib/xla/xla_client.py index 776a22444208..7c4e2ccb427f 100644 --- a/jaxlib/xla/xla_client.py +++ b/jaxlib/xla/xla_client.py @@ -50,7 +50,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version. -_version = 323 +_version = 324 # An internal increasing version number for protecting jaxlib code against # ifrt changes. diff --git a/jaxlib/xla/xla_extension/mlir.pyi b/jaxlib/xla/xla_extension/mlir.pyi index 95eeae660c0c..961f01a0352c 100644 --- a/jaxlib/xla/xla_extension/mlir.pyi +++ b/jaxlib/xla/xla_extension/mlir.pyi @@ -16,6 +16,7 @@ from typing import Union from . import XlaComputation +def hlo_to_stablehlo(computation: bytes) -> bytes: ... def xla_computation_to_mlir_module(computation: XlaComputation) -> str: ... def mlir_module_to_xla_computation( mlir_module: Union[bytes, str], From 8c737993e94d8106e0641f565bc83a8632a03ec1 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 28 Mar 2025 09:32:42 -0700 Subject: [PATCH 259/483] Change the `step counter` to an `init flag` It is clearer to use a flag to indicate the first step than to use a step counter == 0, since in theory the step counter (a 32 bit integer in the code) can wrap around back to zero, even though this will unlikely happen since there are way less than 2**32 blocks. PiperOrigin-RevId: 741551623 --- .../paged_attention/paged_attention_kernel.py | 33 +++++++++---------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py b/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py index 99cb2c9c94c1..62f3101bef6e 100644 --- a/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py +++ b/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py @@ -114,7 +114,7 @@ def paged_flash_attention_kernel( lengths_ref, page_indices_ref, buffer_index_ref, - step_ref, + init_flag_ref, q_ref, k_pages_hbm_ref, k_scales_pages_hbm_ref, @@ -223,16 +223,12 @@ def create_kv_async_copy_descriptors(b, h, i, buffer_index): @pl.when(i * bk < length) def flash_attention(): # pylint: disable=unused-variable - step = step_ref[0] + init_flag = init_flag_ref[0] + init_flag_ref[0] = 0 buffer_index = buffer_index_ref[0] + next_b, next_h, next_i = compute_block_indices(b, h, i + 1) - @pl.when(i == 0) - def init(): # pylint: disable=unused-variable - m_ref[...] = jnp.full_like(m_ref, -jnp.inf) - l_ref[...] = jnp.zeros_like(l_ref) - o_ref[...] = jnp.zeros_like(o_ref) - - @pl.when(step == 0) + @pl.when(init_flag) def prefetch_first_block(): # pylint: disable=unused-variable async_copy_k, async_copy_v = create_kv_async_copy_descriptors( b, h, i, buffer_index @@ -240,7 +236,11 @@ def prefetch_first_block(): # pylint: disable=unused-variable async_copy_k.start() async_copy_v.start() - next_b, next_h, next_i = compute_block_indices(b, h, i + 1) + @pl.when(i == 0) + def init(): # pylint: disable=unused-variable + m_ref[...] = jnp.full_like(m_ref, -jnp.inf) + l_ref[...] = jnp.zeros_like(l_ref) + o_ref[...] = jnp.zeros_like(o_ref) @pl.when(next_b < batch_size) def prefetch_next_block(): # pylint: disable=unused-variable @@ -283,14 +283,12 @@ def prefetch_next_block(): # pylint: disable=unused-variable (l_prev * alpha * o_ref[...] + beta * o_curr_times_l_curr) / l_next ).astype(o_ref.dtype) - step_ref[0] = step + 1 - def paged_flash_attention_kernel_inline_seq_dim( lengths_ref, page_indices_ref, buffer_index_ref, - step_ref, + init_flag_ref, q_ref, k_pages_hbm_ref, k_scales_pages_hbm_ref, @@ -325,7 +323,7 @@ def body(i, _): lengths_ref, page_indices_ref, buffer_index_ref, - step_ref, + init_flag_ref, q_ref, k_pages_hbm_ref, k_scales_pages_hbm_ref, @@ -631,7 +629,7 @@ def paged_attention( ), grid_spec=pltpu.PrefetchScalarGridSpec( # There are 4 scalars prefetched per kernel call: `lengths_ref`, - # `page_indices_ref`, `buffer_index_ref`, `step_ref` + # `page_indices_ref`, `buffer_index_ref`, `init_flag_ref` num_scalar_prefetch=4, in_specs=in_specs, out_specs=[ @@ -643,7 +641,8 @@ def paged_attention( scratch_shapes=scratch_shapes, ), compiler_params=pltpu.TPUCompilerParams( - dimension_semantics=dimension_semantics), + dimension_semantics=dimension_semantics + ), out_shape=[ jax.ShapeDtypeStruct(q.shape, q_dtype_for_kernel_launch), jax.ShapeDtypeStruct((*q.shape[:-1], 1), jnp.float32), @@ -653,7 +652,7 @@ def paged_attention( lengths, page_indices.reshape(-1), jnp.zeros((1,), jnp.int32), # buffer index - jnp.zeros((1,), jnp.int32), # step + jnp.ones((1,), jnp.int32), # init flag q.astype(q_dtype_for_kernel_launch), k_pages, k_scales_pages, From 5950e722e292063f920f5be1d23296b10ce36074 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 28 Mar 2025 09:43:05 -0700 Subject: [PATCH 260/483] Make sure `vma` on ShapedArray exists by default to make development easier. The field is populated inside shard_map guarded on the varying_axes_in_types config though. PiperOrigin-RevId: 741554623 --- jax/_src/core.py | 39 ++++++++++++++++------------- jax/_src/ffi.py | 4 +-- jax/_src/lax/ann.py | 6 ++--- jax/_src/lax/control_flow/solves.py | 4 +-- jax/_src/lax/fft.py | 4 +-- jax/_src/lax/lax.py | 17 ++++++++++--- jax/_src/lax/utils.py | 7 ++---- jax/_src/lax/windowed_reductions.py | 7 ++---- jax/_src/prng.py | 21 ++++++---------- jax/experimental/shard_map.py | 3 ++- 10 files changed, 53 insertions(+), 59 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index 1be60336f1a9..ee6537650f20 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1659,10 +1659,9 @@ def physical_aval(aval): elt_aval = physical_element_aval(aval.dtype) if isinstance(aval, ShapedArray): from jax._src.sharding_impls import physical_sharding # type: ignore - vma = aval.vma if config.varying_axes_in_types.value else frozenset() return ShapedArray((*aval.shape, *elt_aval.shape), elt_aval.dtype, sharding=physical_sharding(aval, aval.sharding), - vma=vma) + vma=aval.vma) return DShapedArray((*aval.shape, *elt_aval.shape), elt_aval.dtype) return aval @@ -1917,6 +1916,7 @@ def get_vma(vma, mesh): raise ValueError( "Axes mentioned in `vma` field of ShapedArray should" f" be of type `Manual`. Got axis: {i} of type {mesh._name_to_type[i]}") + assert isinstance(vma, frozenset) return vma class ShapedArray(UnshapedArray): @@ -1929,8 +1929,7 @@ def __init__(self, shape, dtype, weak_type=False, *, sharding=None, self.dtype = _dtype_object(dtype) self.weak_type = weak_type self.sharding = get_sharding(sharding, self.shape) - if config.varying_axes_in_types.value: - self.vma = get_vma(vma, self.sharding.mesh) + self.vma = get_vma(vma, self.sharding.mesh) def update(self, shape=None, dtype=None, weak_type=None, **kwargs): if shape is None: @@ -1942,7 +1941,7 @@ def update(self, shape=None, dtype=None, weak_type=None, **kwargs): if 'sharding' not in kwargs: kwargs['sharding'] = self.sharding if 'vma' not in kwargs: - kwargs['vma'] = getattr(self, 'vma', frozenset()) + kwargs['vma'] = self.vma return ShapedArray(shape, dtype, weak_type, **kwargs) ndim = property(lambda self: len(self.shape)) @@ -1960,26 +1959,24 @@ def __eq__(self, other): and self.dtype == other.dtype and self.shape == other.shape and self.weak_type == other.weak_type and self.sharding == other.sharding - and (getattr(self, 'vma', frozenset()) == - getattr(other, 'vma', frozenset()))) + and self.vma == other.vma) def __hash__(self): # can use hash(self.dtype) and rely on the fact that numpy reuses base dtype # objects, e.g. `np.zeros(3).dtype is np.zeros(4).dtype`, or we can use # the unique character code via hash(self.dtype.char) return hash((self.shape, self.dtype, self.weak_type, self.sharding, - getattr(self, 'vma', frozenset()))) + self.vma)) def to_tangent_aval(self): return ShapedArray( self.shape, primal_dtype_to_tangent_dtype(self.dtype), - self.weak_type, sharding=self.sharding, - vma=getattr(self, 'vma', frozenset())) + self.weak_type, sharding=self.sharding, vma=self.vma) def str_short(self, short_dtypes=False, mesh_axis_types=False): return str_short_aval( self.shape, self.dtype, self.sharding.mesh, self.sharding.spec, - getattr(self, 'vma', frozenset()), short_dtypes, mesh_axis_types) + self.vma, short_dtypes, mesh_axis_types) def _len(self, ignored_tracer): try: @@ -2013,16 +2010,20 @@ def primal_dtype_to_tangent_dtype(primal_dtype): def standard_insert_pbroadcast(*args): if not config.varying_axes_in_types.value: return args + if not args: + return args # TODO(yashkatariya): Move pbroadcast out of shard_map from jax.experimental.shard_map import pbroadcast # type: ignore - in_vma = [get_aval(a).vma for a in args] + in_vma = [frozenset() if (aval := get_aval(a)) is abstract_token + else aval.vma for a in args] out_vma = frozenset.union(*in_vma) return [pbroadcast(arg, tuple(n for n in out_vma if n not in src)) if out_vma - src else arg for arg, src in zip(args, in_vma)] -def standard_vma_rule(prim_name, *avals, **kwargs): +def standard_vma_rule(prim_name, *avals, **kwargs) -> frozenset[AxisName]: + avals = tuple(a for a in avals if a is not abstract_token) if not avals: - return avals + return frozenset() vma, *vmas = [a.vma for a in avals] if not all(vma == vma_ for vma_ in vmas): raise ValueError( @@ -2078,6 +2079,10 @@ def update(self, shape=None, dtype=None, weak_type=None): def sharding(self): return NamedSharding(mesh_lib.empty_abstract_mesh, P()) + @property + def vma(self): + return frozenset() + def _len(self, tracer): return self.shape[0] @@ -2711,10 +2716,8 @@ def typematch(t1: AbstractValue, t2: AbstractValue) -> bool: # could try normalizing first and then doing simple equality. # TODO(yashkatariya): Also check `sharding` here. # See https://github.com/jax-ml/jax/issues/26474 - sh_dt = t1.dtype == t2.dtype and definitely_equal_shape(t1.shape, t2.shape) - if config.varying_axes_in_types.value: - return sh_dt and t1.vma == t2.vma # type: ignore - return sh_dt + return (t1.dtype == t2.dtype and definitely_equal_shape(t1.shape, t2.shape) + and t1.vma == t2.vma) # type: ignore else: return False diff --git a/jax/_src/ffi.py b/jax/_src/ffi.py index c867ec16b9b3..eb3e9aaa10fb 100644 --- a/jax/_src/ffi.py +++ b/jax/_src/ffi.py @@ -24,7 +24,6 @@ import jax from jax._src import core -from jax._src import config from jax._src import deprecations from jax._src import dispatch from jax._src import effects @@ -639,8 +638,7 @@ def ffi_call_abstract_eval( has_side_effect: bool, **_, ): - out_vma = (core.standard_vma_rule('ffi_call', *avals_in) - if config.varying_axes_in_types.value else frozenset()) + out_vma = core.standard_vma_rule('ffi_call', *avals_in) effects = {_FfiEffect} if has_side_effect else core.no_effects return tuple(r if r is core.abstract_token else r.update(vma=out_vma) for r in result_avals), effects diff --git a/jax/_src/lax/ann.py b/jax/_src/lax/ann.py index 0d2eb338da22..c9a68d84b024 100644 --- a/jax/_src/lax/ann.py +++ b/jax/_src/lax/ann.py @@ -77,7 +77,6 @@ def pmap_mips(qy, db, db_offset, db_size, k, recall_target): from jax._src import ad_util from jax._src import core from jax._src import dispatch -from jax._src import config from jax._src import dtypes from jax._src.interpreters import ad from jax._src.interpreters import batching @@ -240,10 +239,9 @@ def _approx_top_k_abstract_eval(operand, *, k, reduction_dimension, "approx_top_k with aggregate_to_topk=False not yet implemented when " f"either the `k` ({k}) or the " f" reduction dimension size ({reduction_input_size}) are symbolic") - out_vma = operand.vma if config.varying_axes_in_types.value else frozenset() return (operand.update(shape=dims, dtype=operand.dtype, - weak_type=operand.weak_type, vma=out_vma), - operand.update(shape=dims, dtype=np.dtype(np.int32), vma=out_vma)) + weak_type=operand.weak_type, vma=operand.vma), + operand.update(shape=dims, dtype=np.dtype(np.int32), vma=operand.vma)) def _get_init_val_literal(op_type, is_max_k): return np.array(-np.inf if is_max_k else np.inf, dtype=op_type) diff --git a/jax/_src/lax/control_flow/solves.py b/jax/_src/lax/control_flow/solves.py index 4a0872bef4b2..2c736f403044 100644 --- a/jax/_src/lax/control_flow/solves.py +++ b/jax/_src/lax/control_flow/solves.py @@ -23,7 +23,6 @@ from jax._src import api from jax._src import api_util from jax._src import core -from jax._src import config from jax._src import custom_derivatives from jax._src import linear_util as lu from jax._src.interpreters import ad @@ -325,8 +324,7 @@ def _linear_solve_abstract_eval(*args, const_lengths, jaxprs): num_aux = len(jaxprs.solve.out_avals) - len(jaxprs.matvec.out_avals) if num_aux > 0: args_to_raise += tuple(jaxprs.solve.out_avals[-num_aux:]) - out_vma = (core.standard_vma_rule('linear_solve', *args_to_raise) - if config.varying_axes_in_types.value else frozenset()) + out_vma = core.standard_vma_rule('linear_solve', *args_to_raise) return (tuple(a.update(vma=out_vma) for a in args_to_raise), jaxprs.solve.effects) diff --git a/jax/_src/lax/fft.py b/jax/_src/lax/fft.py index 9044f48f278c..2eebe6d91f22 100644 --- a/jax/_src/lax/fft.py +++ b/jax/_src/lax/fft.py @@ -23,7 +23,6 @@ from jax import lax -from jax._src import config from jax._src import dispatch from jax._src import dtypes from jax._src.api import jit, linear_transpose, ShapeDtypeStruct @@ -125,8 +124,7 @@ def fft_abstract_eval(x, fft_type, fft_lengths): f"be equal to fft_lengths {fft_lengths}") shape = x.shape dtype = x.dtype - out_vma = x.vma if config.varying_axes_in_types.value else frozenset() - return x.update(shape=shape, dtype=dtype, vma=out_vma) + return x.update(shape=shape, dtype=dtype, vma=x.vma) def _fft_lowering(ctx, x, *, fft_type, fft_lengths): if not is_constant_shape(fft_lengths): diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index b79c81e19195..53bf9a0c7ebf 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -6106,6 +6106,7 @@ def _ragged_dot_general_batch_rule( _ragged_dot_general_shape_rule, _ragged_dot_general_dtype_rule, 'ragged_dot_general', + vma_rule=partial(core.standard_vma_rule, 'ragged_dot') ) ad.primitive_jvps[ragged_dot_general_p] = _ragged_dot_general_jvp_rule ad.primitive_transposes[ragged_dot_general_p] = _ragged_dot_general_transpose_rule @@ -6515,8 +6516,7 @@ def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions, new_sharding = _broadcast_in_dim_sharding_rule( x, shape=shape, broadcast_dimensions=broadcast_dimensions, sharding=sharding) - new_vma = (core.standard_vma_rule('broadcast_in_dim', x) - if config.varying_axes_in_types.value else frozenset()) + new_vma = core.standard_vma_rule('broadcast_in_dim', x) return core.ShapedArray(shape, x.dtype, x.weak_type, sharding=new_sharding, vma=new_vma) # If any BInts in shape, or Tracers in dyn_shape, produce a DShapedArray @@ -7435,6 +7435,11 @@ def _reduce_sharding_rule(*avals, computation, jaxpr, dimensions): return [op.sharding.with_spec(tuple_delete(op.sharding.spec, dimensions)) for op in operand_avals] +def _reduce_vma_rule(*avals, computation, jaxpr, dimensions): + operand_avals, _ = split_list(avals, [len(avals) // 2]) + out_vma = core.standard_vma_rule('reduce', *operand_avals) + return [out_vma] * len(operand_avals) + def _reduce_dtype_rule(*avals, computation, jaxpr, dimensions): operand_avals, init_val_avals = split_list(avals, [len(avals) // 2]) operand_dtypes = [dtypes.canonicalize_dtype(op.dtype) for op in operand_avals] @@ -7522,7 +7527,7 @@ def _reduce_jvp_rule(primals, tangents, *, computation, jaxpr, dimensions): reduce_p.def_abstract_eval( partial(standard_multi_result_abstract_eval, reduce_p, _reduce_shape_rule, _reduce_dtype_rule, _reduce_weak_type_rule, _reduce_sharding_rule, - None)) + _reduce_vma_rule)) batching.primitive_batchers[reduce_p] = _reduce_batch_rule ad.primitive_jvps[reduce_p] = _reduce_jvp_rule @@ -8254,6 +8259,10 @@ def _rng_bit_generator_sharding_rule(key, *, shape, dtype, algorithm, out_sharding): return (key.sharding, out_sharding) +def _rng_bit_generator_vma_rule(key, *, shape, dtype, algorithm, out_sharding): + assert key.vma == frozenset() + return (key.vma, frozenset()) + def _rng_bit_generator_dtype_rule(key, *, shape, dtype, algorithm, out_sharding): del shape, algorithm return (key.dtype, dtype) @@ -8355,7 +8364,7 @@ def _rng_bit_generator_lowering( partial(standard_multi_result_abstract_eval, rng_bit_generator_p, _rng_bit_generator_shape_rule, _rng_bit_generator_dtype_rule, _rng_bit_generator_weak_type_rule, _rng_bit_generator_sharding_rule, - None)) + _rng_bit_generator_vma_rule)) mlir.register_lowering(rng_bit_generator_p, _rng_bit_generator_lowering) diff --git a/jax/_src/lax/utils.py b/jax/_src/lax/utils.py index 0a641c122064..8e97621912f1 100644 --- a/jax/_src/lax/utils.py +++ b/jax/_src/lax/utils.py @@ -19,7 +19,6 @@ from functools import partial from jax._src import core -from jax._src import config from jax._src import dispatch from jax._src import dtypes from jax._src import mesh as mesh_lib @@ -113,8 +112,7 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, out_shape, out_dtype, out_sharding = call_shape_dtype_sharding_rule( prim, shape_rule, dtype_rule, sharding_rule, False, *avals, **kwargs) - out_vma = (vma_rule(*avals, **kwargs) if config.varying_axes_in_types.value - else frozenset()) + out_vma = vma_rule(*avals, **kwargs) out_aval = core.ShapedArray( out_shape, out_dtype, weak_type=weak_type, sharding=out_sharding, vma=out_vma) @@ -141,8 +139,7 @@ def standard_multi_result_abstract_eval( core.check_avals_context_mesh(avals, prim.name) out_shapes, out_dtypes, out_shardings = call_shape_dtype_sharding_rule( prim, shape_rule, dtype_rule, sharding_rule, True, *avals, **kwargs) - out_vmas = (vma_rule(*avals, **kwargs) if config.varying_axes_in_types.value - else [frozenset()] * len(out_shapes)) + out_vmas = vma_rule(*avals, **kwargs) if isinstance(weak_types, bool): weak_types = (weak_types,) * len(out_shapes) out_avals = [core.ShapedArray(s, d, weak_type=weak_type, sharding=sh, vma=vma) diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index 73fae7df40e1..472b92d858f9 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -21,7 +21,6 @@ from jax import tree_util from jax._src import api_util from jax._src import core -from jax._src import config from jax._src import dispatch from jax._src import dtypes from jax._src import util @@ -338,10 +337,8 @@ def _reduce_window_abstract_eval_rule( out_sharding = reduce_window_sharding_rule( operand_avals[0], window_dimensions, window_strides, padding, base_dilation, window_dilation) - out_vma = (core.standard_vma_rule('reduce_window', *operand_avals) - if config.varying_axes_in_types.value else frozenset()) - return tuple(ShapedArray(out_shape, op.dtype, sharding=out_sharding, - vma=out_vma) + vma = core.standard_vma_rule('reduce_window', *operand_avals) + return tuple(ShapedArray(out_shape, op.dtype, sharding=out_sharding, vma=vma) for op in operand_avals) diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 17d16527bb71..926a57446f5b 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -178,8 +178,8 @@ def copy_to_host_async(self): def aval(self): logical_sharding = (self.sharding if hasattr(self._base_array, 'sharding') else None) - vma = (self._base_array.aval.vma if config.varying_axes_in_types.value else frozenset() - if hasattr(self._base_array, 'aval') else frozenset()) + vma = (self._base_array.aval.vma if hasattr(self._base_array, 'aval') + else frozenset()) return keys_shaped_array(self._impl, self.shape, logical_sharding, vma) @property @@ -552,8 +552,8 @@ def random_seed(seeds: int | typing.ArrayLike, impl: PRNGImpl) -> PRNGKeyArray: @random_seed_p.def_abstract_eval def random_seed_abstract_eval(seeds_aval, *, impl): - out_vma = seeds_aval.vma if config.varying_axes_in_types.value else frozenset() - return keys_shaped_array(impl, seeds_aval.shape, seeds_aval.sharding, out_vma) + return keys_shaped_array(impl, seeds_aval.shape, seeds_aval.sharding, + seeds_aval.vma) @random_seed_p.def_impl def random_seed_impl(seeds, *, impl): @@ -587,9 +587,8 @@ def random_split_abstract_eval(keys_aval, *, shape): # TODO(yashkatariya): random_split should take sharding as an arg too so we # don't choose None here? new_spec = (*keys_aval.sharding.spec, *[None] * len(shape)) - out_vma = keys_aval.vma if config.varying_axes_in_types.value else frozenset() return keys_shaped_array(keys_aval.dtype._impl, (*keys_aval.shape, *shape), - keys_aval.sharding.with_spec(new_spec), out_vma) + keys_aval.sharding.with_spec(new_spec), keys_aval.vma) @random_split_p.def_impl def random_split_impl(keys, *, shape): @@ -629,8 +628,7 @@ def random_fold_in_abstract_eval(keys_aval, msgs_aval): 'random_fold_in', keys_aval, msgs_aval) sharding = lax_internal.broadcasting_sharding_rule( 'random_fold_in', keys_aval, msgs_aval) - vma = (core.standard_vma_rule('random_fold_in', keys_aval, msgs_aval) - if config.varying_axes_in_types.value else frozenset()) + vma = core.standard_vma_rule('random_fold_in', keys_aval, msgs_aval) return core.ShapedArray(shape, keys_aval.dtype, sharding=sharding, vma=vma) @random_fold_in_p.def_impl @@ -669,8 +667,7 @@ def random_bits(keys, bit_width, shape): def random_bits_abstract_eval(keys_aval, *, bit_width, shape): out_shape = (*keys_aval.shape, *shape) out_dtype = dtypes.dtype(f'uint{bit_width}') - vma = keys_aval.vma if config.varying_axes_in_types.value else frozenset() - return core.ShapedArray(out_shape, out_dtype, vma=vma) + return core.ShapedArray(out_shape, out_dtype, vma=keys_aval.vma) @random_bits_p.def_impl def random_bits_impl(keys, *, bit_width, shape): @@ -727,9 +724,7 @@ def random_wrap(base_arr, *, impl): def random_wrap_abstract_eval(base_arr_aval, *, impl): shape = base_arr_shape_to_keys_shape(impl, base_arr_aval.shape) sharding = logical_sharding(shape, KeyTy(impl), base_arr_aval.sharding) - out_vma = (base_arr_aval.vma if config.varying_axes_in_types.value else - frozenset()) - return keys_shaped_array(impl, shape, sharding, out_vma) + return keys_shaped_array(impl, shape, sharding, base_arr_aval.vma) @random_wrap_p.def_impl def random_wrap_impl(base_arr, *, impl): diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index c0306f0c5e91..44c2b569f947 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -578,7 +578,8 @@ def _shard_shaped_array(mesh: Mesh, auto: frozenset, names: AxisNames, for i, sz in enumerate(aval.shape)) manual_mesh = _as_manual_mesh(mesh, auto) new_sharding = NamedSharding(manual_mesh, aval.sharding.spec) - vma = frozenset({n for ns in names.values() for n in ns}) + vma = (frozenset({n for ns in names.values() for n in ns}) + if config.varying_axes_in_types.value else frozenset()) return aval.update(shape=new_shape, sharding=new_sharding, vma=vma) core.shard_aval_handlers[core.ShapedArray] = _shard_shaped_array From e1c866cd0af657240620683cdc230e031f504998 Mon Sep 17 00:00:00 2001 From: Michael Whittaker Date: Fri, 28 Mar 2025 09:55:21 -0700 Subject: [PATCH 261/483] Fixed failing `ExcessPrecisionTest.test_matmul_f32_out_simple` test. PiperOrigin-RevId: 741558343 --- tests/pallas/tpu_fusable_matmul_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pallas/tpu_fusable_matmul_test.py b/tests/pallas/tpu_fusable_matmul_test.py index df7c1221bb0c..5ee372ce92ab 100644 --- a/tests/pallas/tpu_fusable_matmul_test.py +++ b/tests/pallas/tpu_fusable_matmul_test.py @@ -924,7 +924,7 @@ def matmul(impl, x, y): atol = 0 if jtu.is_device_tpu_at_least(6): # 256 MXU changes some tols. - atol = 1e-6 + atol = 1e-5 self.assertAllClose(out, out_ref, atol=atol) def test_matmul_f32_out_fused_downcast(self): From fde7d16c6086981e0f4bfd62e0a4a0618ded9b25 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 28 Mar 2025 10:18:32 -0700 Subject: [PATCH 262/483] Clean up: num_groups = num_q_heads // num_kv_heads No code functionality change in this commit. PiperOrigin-RevId: 741566312 --- .../paged_attention/paged_attention_kernel.py | 35 ++++++++++--------- .../pallas/tpu_paged_attention_kernel_test.py | 25 +++++++------ 2 files changed, 30 insertions(+), 30 deletions(-) diff --git a/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py b/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py index 62f3101bef6e..4c03fb01be2b 100644 --- a/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py +++ b/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py @@ -257,7 +257,7 @@ def prefetch_next_block(): # pylint: disable=unused-variable ) q = q_ref[...].astype(jnp.float32) k = async_copy_k.wait_and_get_loaded() - qk = jnp.einsum('hd,td->ht', q, k, preferred_element_type=jnp.float32) + qk = jnp.einsum("gd,td->gt", q, k, preferred_element_type=jnp.float32) if attn_logits_soft_cap is not None: capped_qk = jnp.tanh(qk / attn_logits_soft_cap) qk = capped_qk * attn_logits_soft_cap @@ -277,10 +277,10 @@ def prefetch_next_block(): # pylint: disable=unused-variable m_ref[...], l_ref[...] = m_next, l_next v = async_copy_v.wait_and_get_loaded() - o_curr_times_l_curr = jnp.dot(s_curr, v) + o_curr = jnp.einsum("gt,td->gd", s_curr, v) o_ref[...] = ( - (l_prev * alpha * o_ref[...] + beta * o_curr_times_l_curr) / l_next + (l_prev * alpha * o_ref[...] + beta * o_curr) / l_next ).astype(o_ref.dtype) @@ -384,7 +384,7 @@ def paged_attention( """Paged grouped query attention. Args: - q: A [batch_size, num_heads, head_dim] jax.Array. + q: A [batch_size, num_q_heads, head_dim] jax.Array. k_pages: A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array. v_pages: A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array. lengths: A i32[batch_size] jax.Array the length of each example. @@ -409,7 +409,7 @@ def paged_attention( one kernel. Returns: - The output of attention([batch_size, num_heads, head_dim]). + The output of attention([batch_size, num_q_heads, head_dim]). """ if isinstance(k_pages, quantization_utils.QuantizedTensor): k_pages, k_scales_pages = k_pages.weight, k_pages.scales @@ -428,7 +428,7 @@ def paged_attention( else: v_scales_pages = None - batch_size, num_heads, head_dim = q.shape + batch_size, num_q_heads, head_dim = q.shape num_kv_heads, _, page_size, head_dim_k = k_pages.shape batch_size_paged_indices, pages_per_sequence = page_indices.shape @@ -437,10 +437,10 @@ def paged_attention( f"k_pages and v_pages must have the same shape. Got {k_pages.shape} and" f" {v_pages.shape}" # pytype: disable=attribute-error ) - if num_heads % num_kv_heads != 0: + if num_q_heads % num_kv_heads != 0: raise ValueError( "Number of Q heads must be divisible by number of KV heads. Got" - f" {num_heads} and {num_kv_heads}." + f" {num_q_heads} and {num_kv_heads}." ) if head_dim_k != head_dim: raise ValueError( @@ -477,40 +477,41 @@ def paged_attention( else: raise ValueError("megacore_mode must be one of ['kv_head', 'batch', None]") - if (num_heads // num_kv_heads) % 8 != 0: + num_groups = num_q_heads // num_kv_heads + if (num_groups) % 8 != 0: # Reshape q to hint XLA to pick a <1x128> layout otherwise it will pick a # <8x128> layout for a <1x128> memref inside the kernel and error out. - q = q.reshape(batch_size, num_heads, 1, head_dim) + q = q.reshape(batch_size, num_q_heads, 1, head_dim) if megacore_mode == "kv_head": q_block_spec = pl.BlockSpec( - (None, num_heads // num_kv_heads, None, head_dim), + (None, num_groups, None, head_dim), lambda core_index, b, h, *_: (b, h * num_cores + core_index, 0, 0), ) elif megacore_mode == "batch": q_block_spec = pl.BlockSpec( - (None, num_heads // num_kv_heads, None, head_dim), + (None, num_groups, None, head_dim), lambda core_index, b, h, *_: (b * num_cores + core_index, h, 0, 0), ) else: q_block_spec = pl.BlockSpec( - (None, num_heads // num_kv_heads, None, head_dim), + (None, num_groups, None, head_dim), lambda core_index, b, h, *_: (b, h, 0, 0), ) q_dtype_for_kernel_launch = jnp.float32 else: if megacore_mode == "kv_head": q_block_spec = pl.BlockSpec( - (None, num_heads // num_kv_heads, head_dim), + (None, num_groups, head_dim), lambda core_index, b, h, *_: (b, h * num_cores + core_index, 0), ) elif megacore_mode == "batch": q_block_spec = pl.BlockSpec( - (None, num_heads // num_kv_heads, head_dim), + (None, num_groups, head_dim), lambda core_index, b, h, *_: (b * num_cores + core_index, h, 0), ) else: q_block_spec = pl.BlockSpec( - (None, num_heads // num_kv_heads, head_dim), + (None, num_groups, head_dim), lambda core_index, b, h, *_: (b, h, 0), ) q_dtype_for_kernel_launch = q.dtype @@ -659,4 +660,4 @@ def paged_attention( v_pages, v_scales_pages, ) - return out.reshape(batch_size, num_heads, head_dim).astype(q.dtype) + return out.reshape(batch_size, num_q_heads, head_dim).astype(q.dtype) diff --git a/tests/pallas/tpu_paged_attention_kernel_test.py b/tests/pallas/tpu_paged_attention_kernel_test.py index 7fbccdb338d4..e778c72a8278 100644 --- a/tests/pallas/tpu_paged_attention_kernel_test.py +++ b/tests/pallas/tpu_paged_attention_kernel_test.py @@ -22,15 +22,12 @@ import numpy as np -jax.config.parse_flags_with_absl() - - -def _generate_qkv( +def _generate_random_qkv( seq_lens, page_size, max_seq_len, num_kv_heads, - num_heads, + num_q_heads, head_dim, prng_key, dtype=jnp.float32, @@ -55,7 +52,7 @@ def _generate_qkv( page_indices = jnp.arange(batch_size * pages_per_sequence, dtype=jnp.int32) page_indices = jax.random.permutation(k3, page_indices, independent=True) page_indices = page_indices.reshape(batch_size, pages_per_sequence) - q = jax.random.normal(k4, (batch_size, num_heads, head_dim), dtype=dtype) + q = jax.random.normal(k4, (batch_size, num_q_heads, head_dim), dtype=dtype) return q, k_pages, v_pages, page_indices @@ -64,7 +61,7 @@ def _reconstruct_kv(page_indices, pages): pages = quantization_utils.unquantize_from_int8(pages, dtype=jnp.float32) batch_size = page_indices.shape[0] - num_heads, _, _, head_dim = pages.shape + num_kv_heads, _, _, head_dim = pages.shape def per_sequence_page_gather(pages, page_indices): return jnp.take(pages, page_indices, 1) @@ -72,15 +69,16 @@ def per_sequence_page_gather(pages, page_indices): gathered = jax.vmap(per_sequence_page_gather, in_axes=(None, 0))( pages, page_indices ) - return gathered.reshape(batch_size, num_heads, -1, head_dim) + return gathered.reshape(batch_size, num_kv_heads, -1, head_dim) def _grouped_query_attention_reference(q, k, v, lengths, attn_logits_soft_cap): - batch_size, num_heads, head_dim = q.shape + batch_size, num_q_heads, head_dim = q.shape _, num_kv_heads, max_seq_len, _ = k.shape assert k.shape == v.shape - assert num_heads % num_kv_heads == 0 - q = q.reshape(batch_size, num_kv_heads, num_heads // num_kv_heads, head_dim) + assert num_q_heads % num_kv_heads == 0 + num_groups = num_q_heads // num_kv_heads + q = q.reshape(batch_size, num_kv_heads, num_groups, head_dim) if isinstance(k, quantization_utils.QuantizedTensor): k = quantization_utils.unquantize_from_int8(k, dtype=jnp.float32) @@ -97,7 +95,7 @@ def _grouped_query_attention_reference(q, k, v, lengths, attn_logits_soft_cap): logits = logits + jnp.where(mask, 0.0, mask_value)[:, None, None, :] weights = jax.nn.softmax(logits, axis=-1) o = jnp.einsum("bhgt,bhtd->bhgd", weights.astype(v.dtype), v) - return o.reshape(batch_size, num_heads, head_dim) + return o.reshape(batch_size, num_q_heads, head_dim) def _megacore_enabled(): @@ -149,7 +147,7 @@ def test_paged_attention( max_kv_len = 2048 block_size = 512 seq_lens = np.asarray([0, 3, 256, 513, 1023, 2048]) - q, k_pages, v_pages, page_indices = _generate_qkv( + q, k_pages, v_pages, page_indices = _generate_random_qkv( seq_lens, page_size, max_kv_len, @@ -188,4 +186,5 @@ def test_paged_attention( if __name__ == "__main__": + jax.config.config_with_absl() absltest.main(testLoader=jtu.JaxTestLoader()) From 829deb68f62a9c5e51fac12f9f824d21a8f379be Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 28 Mar 2025 17:19:34 +0000 Subject: [PATCH 263/483] Set NB_DOMAIN=jax This is a precautionary measure to prevent conflicts with other packages using nanobind and registering the same types. We don't want JAX's nanobind registrations to conflict on, say, XLA types with other projects. --- .bazelrc | 1 + 1 file changed, 1 insertion(+) diff --git a/.bazelrc b/.bazelrc index 2d38dcc87044..422363644578 100644 --- a/.bazelrc +++ b/.bazelrc @@ -31,6 +31,7 @@ build -c opt build --output_filter=DONT_MATCH_ANYTHING build --copt=-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir. +build --copt=-DNB_DOMAIN=jax # ############################################################################# # Platform Specific configs below. These are automatically picked up by Bazel From ecd9f5ded81eede59986d90c10b52ca852b4325e Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 28 Mar 2025 10:27:01 -0700 Subject: [PATCH 264/483] Move aval_to_xla_shape into callback.py, which is its only user. Specialize it to one shape per aval, since that's the only case that exists. Remove some pointless assertions using this code. PiperOrigin-RevId: 741569024 --- jax/_src/api.py | 2 -- jax/_src/callback.py | 32 ++++++++++++++++++++++++-------- jax/_src/interpreters/xla.py | 22 +--------------------- jax/_src/prng.py | 1 - 4 files changed, 25 insertions(+), 32 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index e01bdd4a9d81..55e2b2126a68 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -82,7 +82,6 @@ from jax._src.interpreters import batching from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import pxla -from jax._src.interpreters import xla traceback_util.register_exclusion(__file__) @@ -2591,7 +2590,6 @@ def _device_put_replicated(x): sharding = PmapSharding(np.array(devices), sharding_spec) if dtypes.issubdtype(aval.dtype, dtypes.extended): return aval.dtype._rules.device_put_replicated(buf, aval, sharding, devices) - assert len(xla.aval_to_xla_shapes(aval)) == 1 return pxla.batched_device_put(aval, sharding, [buf] * len(devices), devices) with config.explicit_device_put_scope(): diff --git a/jax/_src/callback.py b/jax/_src/callback.py index 683da66638e6..20334b6cd269 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -670,6 +670,25 @@ def receive_from_host( return token, result + +def _aval_to_xla_shape(aval: core.AbstractValue) -> xc.Shape: + try: + return _xla_shape_handlers[type(aval)](aval) + except KeyError as err: + raise TypeError(f"No xla_shape_handler for type: {type(aval)}") from err + +_xla_shape_handlers: dict[type[core.AbstractValue], + Callable[[Any], xc.Shape]] = {} + +def _make_array_shape(aval: core.ShapedArray) -> xc.Shape: + aval = core.physical_aval(aval) + dtype = np.dtype('bool') if aval.dtype == dtypes.float0 else aval.dtype + return xc.Shape.array_shape(dtype, aval.shape) +_xla_shape_handlers[core.ShapedArray] = _make_array_shape + +_xla_shape_handlers[core.AbstractToken] = lambda _: xc.Shape.token_shape() + + def _emit_tpu_python_callback( backend: xb.XlaBackend, ctx: mlir.LoweringRuleContext, @@ -699,8 +718,7 @@ def _wrapped_callback(*args): # pylint: disable=function-redefined send_channel = ctx.module_context.new_channel() dummy_send_aval = core.ShapedArray((1,), np.float32) dummy_send_val = mlir.ir_constant(np.zeros(1, np.float32)) - operand_shapes = [*operand_shapes, - xla.aval_to_xla_shapes(dummy_send_aval)[0]] + operand_shapes = [*operand_shapes, _aval_to_xla_shape(dummy_send_aval)] token = send_to_host(send_channel, token, dummy_send_val, callback.__name__, sharding=sharding) send_channels.append(send_channel) @@ -763,10 +781,8 @@ def emit_python_callback( raise ValueError( f"`EmitPythonCallback` not supported on {platform} backend.") backend = ctx.module_context.get_backend() - result_shapes = util.flatten( - [xla.aval_to_xla_shapes(result_aval) for result_aval in result_avals]) - operand_shapes = util.flatten( - [xla.aval_to_xla_shapes(op_aval) for op_aval in operand_avals]) + result_shapes = [_aval_to_xla_shape(aval) for aval in result_avals] + operand_shapes = [_aval_to_xla_shape(aval) for aval in operand_avals] # Handling layouts if operand_layouts is None: operand_layouts = util.concatenate( @@ -836,10 +852,10 @@ def _wrapped_callback(token, *args): # type: ignore # pylint: disable=function return (token, *callback_without_token(*args)) operand_shapes = [ - xla.aval_to_xla_shapes(core.abstract_token)[0], *operand_shapes + _aval_to_xla_shape(core.abstract_token), *operand_shapes ] result_shapes = [ - xla.aval_to_xla_shapes(core.abstract_token)[0], *result_shapes + _aval_to_xla_shape(core.abstract_token), *result_shapes ] operands = [token, *operands] result_types = [mlir.token_type(), *result_types] diff --git a/jax/_src/interpreters/xla.py b/jax/_src/interpreters/xla.py index 33a8992a8be4..7fbb22923e0f 100644 --- a/jax/_src/interpreters/xla.py +++ b/jax/_src/interpreters/xla.py @@ -16,7 +16,7 @@ from __future__ import annotations -from collections.abc import Callable, Sequence +from collections.abc import Callable from functools import partial from typing import Any, Union @@ -25,7 +25,6 @@ from jax._src import core from jax._src import dtypes from jax._src.abstract_arrays import numpy_scalar_types -from jax._src.core import ShapedArray from jax._src.util import safe_zip, safe_map from jax._src.typing import Shape @@ -41,11 +40,6 @@ def identity(x): return x _scalar_types = dtypes.python_scalar_dtypes.keys() -def _make_array_shape(aval: ShapedArray) -> Sequence[xc.Shape]: - aval = core.physical_aval(aval) - dtype = np.dtype('bool') if aval.dtype == dtypes.float0 else aval.dtype - return (xc.Shape.array_shape(dtype, aval.shape),) - # Utilities # HLO instructions optionally can be annotated to say how the output should be @@ -90,20 +84,6 @@ def tuple_sharding_proto(elems): ### handlers -# JAX abstract values -> XLA shapes - -def aval_to_xla_shapes(aval: core.AbstractValue) -> Sequence[xc.Shape]: - try: - return _xla_shape_handlers[type(aval)](aval) - except KeyError as err: - raise TypeError(f"No xla_shape_handler for type: {type(aval)}") from err - -_xla_shape_handlers: dict[type[core.AbstractValue], - Callable[[Any], Sequence[xc.Shape]]] = { - ShapedArray: _make_array_shape, -} -_xla_shape_handlers[core.AbstractToken] = lambda _: (xc.Shape.token_shape(),) - # IR constants diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 926a57446f5b..0106aa310383 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -425,7 +425,6 @@ def device_put_sharded(vals, aval, sharding, devices): @staticmethod def device_put_replicated(val, aval, sharding, devices): physical_aval = core.physical_aval(aval) - assert len(xla.aval_to_xla_shapes(physical_aval)) == 1 physical_buf = random_unwrap(val) phys_sharding = physical_sharding(aval, sharding) physical_result = pxla.batched_device_put( From d4c42d7199f39a0b4639a32350abf3e8fb8a6043 Mon Sep 17 00:00:00 2001 From: Zac Cranko Date: Fri, 28 Mar 2025 10:54:48 -0700 Subject: [PATCH 265/483] implement nbytes for PRNGKeyArray --- jax/_src/prng.py | 4 ++++ tests/random_test.py | 5 +++++ 2 files changed, 9 insertions(+) diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 2fa9b2b37aa4..ad96d9409083 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -188,6 +188,10 @@ def ndim(self): def dtype(self): return KeyTy(self._impl) + @property + def nbytes(self): + return self.itemsize * self.size + @property def itemsize(self): return self.dtype.itemsize diff --git a/tests/random_test.py b/tests/random_test.py index a51e387dca76..22df8b0b0649 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -657,6 +657,11 @@ def test_non_integer_seed(self): with self.assertRaisesRegex(TypeError, "PRNG key seed must be an integer"): random.key(seed) + def test_nbytes_property(self): + key = self.make_keys() + self.assertEqual(key.nbytes, key._base_array.nbytes) + self.assertEqual(key.nbytes, key.itemsize * key.size) + def test_dtype_property(self): k1, k2 = self.make_keys(), self.make_keys() self.assertEqual(k1.dtype, k2.dtype) From fbff338a8ef92f99b896ee8d1f0ac65d830edcfa Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 28 Mar 2025 11:00:30 -0700 Subject: [PATCH 266/483] [pallas:mosaic_gpu] `GPUMesh` now accepts axis names in a more structured way This is hopefully less confusing then bunching them together in a single argument. PiperOrigin-RevId: 741580827 --- jax/_src/pallas/mosaic_gpu/core.py | 36 +++++++++++-------- jax/_src/pallas/mosaic_gpu/lowering.py | 23 +++--------- .../pallas/ops/gpu/attention_mgpu.py | 6 ++-- tests/pallas/mosaic_gpu_test.py | 34 ++++++++++-------- 4 files changed, 50 insertions(+), 49 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 857daaefe38f..f8c1ebf442b0 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -575,22 +575,29 @@ def _as_accum(ref) -> WGMMAAbstractAccumulatorRef: @dataclasses.dataclass(frozen=True, kw_only=True) class GPUMesh: - grid: tuple[int, ...] = () - cluster: tuple[int, ...] = () + grid: Sequence[int] = () + grid_names: Sequence[str] = () + cluster: Sequence[int] = () + cluster_names: Sequence[str] = () # Those are NOT CUDA threads. On Hopper they correspond to warpgroups. num_threads: int | None = None - axis_names: tuple[str, ...] = () + thread_name: str | None = None def __post_init__(self): if len(self.cluster) > 3: raise ValueError(f"cluster= must be at most 3D, got {self}.") - num_axis_names = ( - len(self.grid) + len(self.cluster) + (self.num_threads is not None) - ) - if len(self.axis_names) != num_axis_names: + if len(self.grid_names) != len(self.grid): + raise ValueError( + f"grid_names must have the same length as grid, got {self}." + ) + if len(self.cluster_names) != len(self.cluster): raise ValueError( - "Need an axis name for each grid and cluster dimension plus " - f" an additional axis name when num_threads= is given, got {self}." + f"cluster_names must have the same length as cluster, got {self}." + ) + if (self.thread_name is None) != (self.num_threads is None): + raise ValueError( + "num_threads and thread_name must be either both set or both None," + f" got {self}" ) if self.num_threads is not None and self.num_threads > 2048 // 128: raise ValueError( @@ -607,14 +614,13 @@ def shape(self) -> collections.OrderedDict[object, int]: pairs: Iterable[tuple[object, int]] if self.num_threads is not None: pairs = zip( - self.axis_names, (*self.grid, *self.cluster, self.num_threads) + (*self.grid_names, *self.cluster_names, self.thread_name), + (*self.grid, *self.cluster, self.num_threads), ) else: - pairs = tuple( - zip( - (*self.axis_names, _WARPGROUP_AXIS_NAME), - (*self.grid, *self.cluster, 1), - ) + pairs = zip( + (*self.grid_names, *self.cluster_names), + (*self.grid, *self.cluster), ) return collections.OrderedDict(pairs) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index baac1e6eb316..42914c95085a 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -274,17 +274,6 @@ def __iter__(self) -> Iterable[Hashable]: self.grid, self.cluster, [self.wg] if self.wg is not None else [] ) - @classmethod - def from_mesh( - cls, mesh: gpu_core.GPUMesh, axis_names: Sequence[str] - ) -> "_AxisNames": - wg_name = None - if mesh.num_threads is not None: - wg_name = axis_names[-1] - axis_names = axis_names[:-1] - grid_names, cluster_names = util.split_list(axis_names, [len(mesh.grid)]) - return cls(grid_names, cluster_names, wg_name) - @dataclasses.dataclass class ModuleContext: @@ -552,12 +541,10 @@ def lower_pipelined_jaxpr_to_module( block_mappings, [grid_mapping.num_inputs] ) - if mesh is not None: + if mesh: assert isinstance(mesh, gpu_core.GPUMesh) - if mesh and mesh.num_threads is not None: - # Last dim corresponds to the warpgroup count. - block = (128 * grid_mapping.grid[-1], 1, 1) - grid = grid_mapping.grid[:-1] + block = (128 * (mesh.num_threads or 1), 1, 1) + grid = mesh.grid else: block = (128, 1, 1) grid = grid_mapping.grid @@ -665,9 +652,9 @@ def body_fn(*refs): assert not new_consts axis_names = ( - _AxisNames.from_mesh(mesh, grid_mapping.grid_names) + _AxisNames(mesh.grid_names, mesh.cluster_names, mesh.thread_name) if mesh is not None - else _AxisNames(grid_mapping.grid_names) + else _AxisNames(grid_mapping.grid_names or ()) ) with grid_mapping.trace_env(): return lower_jaxpr_to_module( diff --git a/jax/experimental/pallas/ops/gpu/attention_mgpu.py b/jax/experimental/pallas/ops/gpu/attention_mgpu.py index 8883878f5f0e..534da419ed3b 100644 --- a/jax/experimental/pallas/ops/gpu/attention_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/attention_mgpu.py @@ -227,8 +227,9 @@ def entry(q_ref, k_ref, v_ref, out_ref): entry, out_shape=q, grid=(batch_size, num_q_tiles, num_q_heads), + grid_names=("batch", "q_seq", "heads"), num_threads=3, - axis_names=("batch", "q_seq", "heads", "wg"), + thread_name="wg", compiler_params=plgpu.GPUCompilerParams(approx_math=True), )(q, k, v) @@ -366,8 +367,9 @@ def compute_pv(acc_ref): pipeline(k_ref, v_ref) mesh = plgpu.GPUMesh( grid=(batch_size, num_q_tiles, num_q_heads), + grid_names=("batch", "q_seq", "heads"), num_threads=3, - axis_names=("batch", "q_seq", "heads", "wg"), + thread_name="wg", ) def run(refs): q_ref, k_ref, v_ref, out_ref = refs diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 874ecae93f3f..b6c0652c13fe 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1707,7 +1707,7 @@ def test_tmem_alloc(self): plgpu.SMEM((128, 128), jnp.float32), ], num_threads=1, - axis_names=("x",), + thread_name="x", ) def kernel(y_ref, tmem_ref, smem_ref): # Issue a write so the TMEM load is not DCE'd. @@ -2096,8 +2096,9 @@ def copy_kernel(x_smem, o_smem, o_last_block_smem, *consumed_barriers): ), compiler_params=plgpu.GPUCompilerParams(approx_math=True), grid=(1,), + grid_names=("_",), num_threads=3, - axis_names=("_", "wg"), + thread_name="wg", ) out, out_last_block = kernel(x) np.testing.assert_array_equal(out, x) @@ -2130,8 +2131,9 @@ def tiled_add_kernel(x_smem, y_smem, o_smem): out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32), compiler_params=plgpu.GPUCompilerParams(approx_math=True), grid=(1,), + grid_names=("_",), num_threads=num_compute_wgs + 1, - axis_names=("_", "wg"), + thread_name="wg", ) x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float32) y = jax.random.uniform(jax.random.key(1), (m, n), dtype=jnp.float32) @@ -2148,8 +2150,9 @@ def test_carry_accumulate(self, m=256, n=256, num_compute_wgs=2): ], compiler_params=plgpu.GPUCompilerParams(approx_math=True), grid=(1,), + grid_names=("_",), num_threads=num_compute_wgs + 1, - axis_names=("_", "wg"), + thread_name="wg", ) def kernel(x_gmem, acc_gmem, acc_smem): def _compute_thread(): @@ -2204,7 +2207,7 @@ def test_multiple_wg(self): plgpu.kernel, out_shape=jnp.zeros((2, 128), np.int32), num_threads=2, - axis_names=("wg",), + thread_name="wg", ) def kernel(o_ref): wg_idx = jax.lax.axis_index("wg") @@ -2219,8 +2222,9 @@ def test_multiple_wg_with_grid(self): plgpu.kernel, out_shape=jnp.zeros((4, 2, 128), np.int32), grid=(2, 2), + grid_names=("x", "y"), num_threads=2, - axis_names=("x", "y", "wg"), + thread_name="wg", ) def kernel(o_ref): xy_idx = jax.lax.axis_index(("x", "y")) @@ -2250,8 +2254,9 @@ def test_multiple_wg_with_squashed_grid(self): (b, x_dim, y_dim, z_dim, num_threads, 128), np.int32 ), grid=(b, x_dim, y_dim, z_dim), + grid_names=("b", "x", "y", "z"), num_threads=num_threads, - axis_names=("b", "x", "y", "z", "wg"), + thread_name="wg", ) def kernel(o_ref): b_idx = jax.lax.axis_index("b") @@ -2277,7 +2282,7 @@ def test_cross_wg_barrier(self): # Each warpgroup is a single logical thread! scratch_shapes=[plgpu.Barrier(num_arrivals=2)], num_threads=2, - axis_names=("wg",), + thread_name="wg", ) def kernel(o_ref, barrier): plgpu.barrier_arrive(barrier) @@ -2294,8 +2299,9 @@ def test_cluster(self): plgpu.kernel, out_shape=jnp.zeros(128, np.int32), grid=(2,), + grid_names=("x",), cluster=(2,), - axis_names=("x", "cluster"), + cluster_names=("cluster",), ) def kernel(ref): block_idx = jax.lax.axis_index("x") @@ -2336,7 +2342,7 @@ def body(l_ref, r_ref, o_ref): o_ref[my_slice] = l_ref[my_slice] + r_ref[my_slice] x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) - out = plgpu.kernel(body, out_shape=x, grid=(2,), axis_names=("rows",))(x, x) + out = plgpu.kernel(body, out_shape=x, grid=(2,), grid_names=("rows",))(x, x) np.testing.assert_allclose(out, x + x) # Async copies @@ -2351,7 +2357,7 @@ def test_stage3(self): plgpu.Barrier(num_arrivals=2), ], grid=(2,), - axis_names=("rows",), + grid_names=("rows",), ) def kernel(l_ref, r_ref, o_ref, l_smem, r_smem, o_smem, barrier): my_slice = pl.ds(lax.axis_index("rows") * row_block, row_block) @@ -2382,7 +2388,7 @@ def compute(l_smem, r_smem, o_smem): )(l_ref, r_ref, o_ref) x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) - out = plgpu.kernel(body, out_shape=x, grid=(2,), axis_names=("rows",))(x, x) + out = plgpu.kernel(body, out_shape=x, grid=(2,), grid_names=("rows",))(x, x) np.testing.assert_allclose(out, x + x) # Transforms @@ -2404,7 +2410,7 @@ def compute(l_smem, r_smem, o_smem): )(l_ref, r_ref, o_ref) x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) - out = plgpu.kernel(body, out_shape=x, grid=(2,), axis_names=("rows",))(x, x) + out = plgpu.kernel(body, out_shape=x, grid=(2,), grid_names=("rows",))(x, x) np.testing.assert_allclose(out, x + x) def test_semaphore_lowering(self): @@ -2456,7 +2462,7 @@ def do_wgmma(acc_ref): )(l_ref, r_ref, o_ref) x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) - out = plgpu.kernel(body, out_shape=x, grid=(2, 2), axis_names=("m", "n"))(x, x) + out = plgpu.kernel(body, out_shape=x, grid=(2, 2), grid_names=("m", "n"))(x, x) np.testing.assert_allclose(out, x @ x) # TODO(apaszke): Clusters and multicast From e838fe19d3b2a7b41c5ba0a8b7d98d7b9ea9e477 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 28 Mar 2025 13:00:24 -0700 Subject: [PATCH 267/483] [pallas:mosaic_gpu] Added support for collective GMEM->SMEM copies to lane-level lowering More work is needed to support these in the WG lowering. PiperOrigin-RevId: 741622096 --- jax/_src/pallas/core.py | 20 ++-- jax/_src/pallas/mosaic_gpu/core.py | 24 +++++ jax/_src/pallas/mosaic_gpu/lowering.py | 80 ++++++++++++--- jax/_src/pallas/mosaic_gpu/primitives.py | 19 +++- jax/experimental/pallas/mosaic_gpu.py | 1 + tests/pallas/mosaic_gpu_test.py | 121 +++++++++++++++++++++++ 6 files changed, 243 insertions(+), 22 deletions(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 8602205eef22..a74206c46ce7 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -1089,7 +1089,10 @@ def wrapped(f): debug_info=api_util.debug_info("pallas_core_map", f, (), {})), in_tree) - with jax_core.extend_axis_env_nd(mesh.shape.items()): + with ( + tracing_grid_env(tuple(mesh.shape.values()), mapped_dims=()), + jax_core.extend_axis_env_nd(mesh.shape.items()), + ): jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, flat_args) out = core_map_p.bind(*consts, jaxpr=jaxpr, mesh=mesh, compiler_params=compiler_params, @@ -1144,6 +1147,7 @@ def default_mesh_discharge_rule( interpret, cost_estimate, name, + memory_space=MemorySpace.ANY, ): """Discharges a ``core_map`` over a mesh to a ``pallas_call``.""" del out_avals # Unused. @@ -1160,13 +1164,9 @@ def body(*args): for eff in jaxpr.effects if isinstance(eff, state_types.WriteEffect) ) - any_spec = BlockSpec(memory_space=MemorySpace.ANY) - grid_spec = GridSpec( - grid=tuple(mesh.shape.items()), - in_specs=[any_spec] * len(in_avals), - out_specs=[any_spec] * len(modified_idxs), - ) + spec = BlockSpec(memory_space=memory_space) from jax._src.pallas import pallas_call # Avoid circular dependency. + outs = pallas_call._pallas_call( body, name=name, @@ -1174,7 +1174,11 @@ def body(*args): input_output_aliases={ in_idx: out_idx for out_idx, in_idx in enumerate(modified_idxs) }, - grid_spec=grid_spec, + grid_spec=GridSpec( + grid=tuple(mesh.shape.items()), + in_specs=[spec] * len(in_avals), + out_specs=[spec] * len(modified_idxs), + ), mesh=mesh, compiler_params=compiler_params, interpret=interpret, diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index f8c1ebf442b0..8522bdf651f4 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -502,6 +502,17 @@ def __str__(self): return self.name +@dataclasses.dataclass(frozen=True) +class ClusterBarrierType(dtypes.ExtendedDType): + type: ClassVar[Any] = barrier_dtype + name: ClassVar[str] = "cluster_barrier" + + collective_axes: tuple[str | tuple[str, ...], ...] + + def __str__(self): + return self.name + + @dataclasses.dataclass(frozen=True) class Barrier: num_arrivals: int @@ -514,6 +525,18 @@ def get_ref_aval(self) -> AbstractMemoryRef: return AbstractMemoryRef(aval, SMEM) +@dataclasses.dataclass(frozen=True) +class ClusterBarrier: + collective_axes: tuple[str | tuple[str, ...], ...] + num_barriers: int = 1 + + def get_ref_aval(self) -> AbstractMemoryRef: + aval = jax_core.ShapedArray( + [self.num_barriers], ClusterBarrierType(self.collective_axes) + ) + return AbstractMemoryRef(aval, SMEM) + + @dataclasses.dataclass(frozen=True) class WGMMAAccumulatorRef: shape: tuple[int, int] @@ -660,6 +683,7 @@ def _gpu_mesh_discharge_rule( interpret=interpret, cost_estimate=cost_estimate, name=name, + memory_space=GMEM, ) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 42914c95085a..e99feb4dc144 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -85,8 +85,9 @@ def _align_to(x: int, alignment: int): return x -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, kw_only=True) class ResourceEstimatorContext: + axis_names: _AxisNames thread_semantics: mgpu.ThreadSemantics @property @@ -98,11 +99,14 @@ def arrival_multiplier(self) -> int: ) +AnyBarrier = mgpu.Barrier | mgpu.ClusterBarrier + + @dataclasses.dataclass(kw_only=True, frozen=True) class Resources: smem_scratch_bytes: int = 0 tmem_scratch_cols: int = 0 - barrier_counts: collections.Counter[mgpu.Barrier] = dataclasses.field( + barrier_counts: collections.Counter[AnyBarrier] = dataclasses.field( default_factory=collections.Counter ) @@ -120,7 +124,7 @@ def __post_init__(self): ) @property - def barriers(self) -> Sequence[mgpu.Barrier]: + def barriers(self) -> Sequence[AnyBarrier]: return list(self.barrier_counts.elements()) def __add__(self, other: Resources) -> Resources: @@ -230,6 +234,16 @@ def _run_scoped_resource_estimator( ) ]) ) + elif isinstance(aval.dtype, gpu_core.ClusterBarrierType): + collective_dims = jax.tree.map( + lambda axis: _resolve_cluster_axis(ctx.axis_names, axis), + aval.dtype.collective_axes, + ) + rs += Resources( + barrier_counts=collections.Counter( + [mgpu.ClusterBarrier(collective_dims, *aval.shape)] + ) + ) elif aval.memory_space == gpu_core.TMEM: if aval.dtype.itemsize != 4: raise ValueError("TMEM only supports 32-bit types.") @@ -275,6 +289,9 @@ def __iter__(self) -> Iterable[Hashable]: ) +AnyBarrierRef = mgpu.BarrierRef | mgpu.CollectiveBarrierRef + + @dataclasses.dataclass class ModuleContext: name: str @@ -287,9 +304,7 @@ class ModuleContext: tmem_requested_cols: int tmem_used_cols: int tmem_base_ptr: ir.Value - runtime_barriers: MutableMapping[ - mgpu.Barrier, MutableSequence[mgpu.BarrierRef] - ] + runtime_barriers: MutableMapping[AnyBarrier, MutableSequence[AnyBarrierRef]] name_stack: source_info_util.NameStack traceback_caches: mlir.TracebackCaches squashed_dims: tuple[int, ...] @@ -399,7 +414,10 @@ class LoweringRuleContext: @property def estimator_ctx(self) -> ResourceEstimatorContext: - return ResourceEstimatorContext(thread_semantics=self.module_ctx.thread_semantics) + return ResourceEstimatorContext( + axis_names=self.module_ctx.axis_names, + thread_semantics=self.module_ctx.thread_semantics, + ) @dataclasses.dataclass(frozen=True) @@ -746,7 +764,12 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): module_ctx, launch_ctx, jaxpr, buffers_gmem, consts ) - rs = _estimate_resources(ResourceEstimatorContext(thread_semantics), jaxpr) + rs = _estimate_resources( + ResourceEstimatorContext( + axis_names=axis_names, thread_semantics=thread_semantics + ), + jaxpr, + ) smem_scratch_bytes = params.get("smem_scratch_bytes") if smem_scratch_bytes is None: smem_scratch_bytes = rs.smem_scratch_bytes @@ -1784,23 +1807,43 @@ def _block_id(ctx: LoweringRuleContext, dim: gpu_dialect.Dimension) -> ir.Value: return arith_dialect.divui(result, _as_index(cluster_size[dim.value])) +def _resolve_cluster_axis(axis_names: _AxisNames | None, axis_name: str): + if not axis_names: + raise LookupError( + "No axis names are available. Make sure you are using `pl.core_map`" + " with a `plgpu.GPUMesh`." + ) + if not axis_names or axis_name not in axis_names.cluster: + raise LookupError( + f"Unknown cluster axis {axis_name}, available axes:" + f" {[*axis_names.cluster]}" + ) + return gpu_dialect.Dimension(axis_names.cluster.index(axis_name)) + + @register_lowering_rule(lax.axis_index_p, mgpu.ThreadSemantics.Lane) @register_lowering_rule(lax.axis_index_p, mgpu.ThreadSemantics.Warpgroup) def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable): axis_names = ctx.module_ctx.axis_names - if not axis_names or axis_name not in axis_names: - raise ValueError( - "Named axes can only refer to GPUMesh axes in Mosaic GPU kernels" + if not axis_names: + raise LookupError( + "No axis names are available. Make sure you are using `pl.core_map`" + " with a `plgpu.GPUMesh`." + ) + if axis_name not in axis_names: + raise LookupError( + f"Unknown axis {axis_name}, available axes: {[*axis_names]}" ) if axis_names.wg is not None and axis_name == axis_names.wg: return mgpu.warpgroup_idx(sync=True) if axis_name in axis_names.cluster: - idx = axis_names.cluster.index(axis_name) return arith_dialect.index_cast( ir.IntegerType.get_signless(32), - gpu_dialect.cluster_block_id(gpu_dialect.Dimension(idx)), + gpu_dialect.cluster_block_id( + gpu_dialect.Dimension(axis_names.cluster.index(axis_name)) + ), ) squashed_dims = ctx.module_ctx.squashed_dims @@ -1913,6 +1956,17 @@ def _run_scoped_lowering_rule( ) ) should_discharge.append(False) + elif isinstance(aval.dtype, gpu_core.ClusterBarrierType): + collective_dims = jax.tree.map( + lambda axis: _resolve_cluster_axis(ctx.module_ctx.axis_names, axis), + aval.dtype.collective_axes, + ) + input_refs.append( + ctx.module_ctx.reserve_barrier( + mgpu.ClusterBarrier(collective_dims, *aval.shape) + ) + ) + should_discharge.append(False) elif aval.memory_space == gpu_core.SMEM: [input_ref] = alloc_stack.enter_context( ctx.module_ctx.scratch_view( diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 9dc65c1bef88..a9bd91b26622 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -361,6 +361,7 @@ def _copy_gmem_to_smem_lowering( src_transforms_treedef, dst_transforms_treedef, barrier_transforms_treedef, + collective_axes, ): flat_src_transforms, flat_dst_transforms, flat_barrier_transforms = ( util.split_list( @@ -382,6 +383,12 @@ def _copy_gmem_to_smem_lowering( barrier = barrier.__getitem__( *map(lowering._as_index, barrier_indexer.indices) ) + collective = None + if collective_axes is not None: + collective = tuple( + lowering._resolve_cluster_axis(ctx.module_ctx.axis_names, axis) + for axis in collective_axes + ) dst_ty = ir.MemRefType(dst.type) bytes = math.prod(dst_ty.shape) * mgpu.bytewidth(dst_ty.element_type) if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: @@ -400,6 +407,7 @@ def _copy_gmem_to_smem_lowering( barrier=barrier, arrive=False, predicate=ctx.module_ctx.single_wg_lane_predicate, + collective=collective, **copy_params, ) return () @@ -425,7 +433,13 @@ def _copy_gmem_to_smem_lowering( return () -def copy_gmem_to_smem(src: _Ref, dst: _Ref, barrier: _Ref) -> None: +def copy_gmem_to_smem( + src: _Ref, + dst: _Ref, + barrier: _Ref, + *, + collective_axes: str | tuple[str, ...] | None = None, +) -> None: """Asynchronously copies a GMEM reference to a SMEM reference. See also: @@ -450,6 +464,8 @@ def copy_gmem_to_smem(src: _Ref, dst: _Ref, barrier: _Ref) -> None: flat_barrier_transforms, barrier_transforms_treedef = tree_util.tree_flatten( barrier_transforms ) + if isinstance(collective_axes, str): + collective_axes = (collective_axes,) copy_gmem_to_smem_p.bind( src, dst, @@ -460,6 +476,7 @@ def copy_gmem_to_smem(src: _Ref, dst: _Ref, barrier: _Ref) -> None: src_transforms_treedef=src_transforms_treedef, dst_transforms_treedef=dst_transforms_treedef, barrier_transforms_treedef=barrier_transforms_treedef, + collective_axes=collective_axes, ) return None diff --git a/jax/experimental/pallas/mosaic_gpu.py b/jax/experimental/pallas/mosaic_gpu.py index b791fbb8b573..e4c5ffe04093 100644 --- a/jax/experimental/pallas/mosaic_gpu.py +++ b/jax/experimental/pallas/mosaic_gpu.py @@ -18,6 +18,7 @@ """ from jax._src.pallas.mosaic_gpu.core import Barrier as Barrier +from jax._src.pallas.mosaic_gpu.core import ClusterBarrier as ClusterBarrier from jax._src.pallas.mosaic_gpu.core import GPUBlockSpec as GPUBlockSpec from jax._src.pallas.mosaic_gpu.core import GPUCompilerParams as GPUCompilerParams from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace as GPUMemorySpace diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index b6c0652c13fe..c8013f634c67 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -2322,6 +2322,127 @@ def kernel(ref): }, ) + def test_realistic_matmul_with_cluster(self): + dtype = jnp.float16 + swizzle = 128 + elems_128b = swizzle // jnp.dtype(dtype).itemsize + grid_m, grid_k, grid_n = 132, 10, 32 + # TODO(slebedev): Remove ``grid_tile_n`` to simplify the test. + grid_tile_n = 4 + assert grid_n % grid_tile_n == 0 + cluster_m = 2 + cluster_n = 2 + cluster_tile_n = min(cluster_n, grid_tile_n) + tile_m = tile_n = 128 + assert tile_m % elems_128b == 0 + tile_k = elems_128b + m, k, n = grid_m * tile_m, grid_k * tile_k, grid_n * tile_n + + transforms = ( + plgpu.TilingTransform((64, elems_128b)), + plgpu.SwizzleTransform(128), + ) + + max_concurrent_steps = 2 + delay_release = 1 + + @functools.partial( + plgpu.kernel, + out_shape=jax.ShapeDtypeStruct((m, n), dtype), + scratch_shapes=[ + plgpu.SMEM( + (max_concurrent_steps, tile_m, tile_k), + dtype, + transforms=transforms, + ), + plgpu.SMEM( + (max_concurrent_steps, tile_k, tile_n), + dtype, + transforms=transforms, + ), + plgpu.SMEM((tile_m, tile_n), dtype, transforms=transforms), + plgpu.ACC((tile_m, tile_n), jnp.float32), + plgpu.Barrier(num_arrivals=2, num_barriers=max_concurrent_steps), + plgpu.ClusterBarrier( + collective_axes=(("x", "z"), "y"), + num_barriers=max_concurrent_steps, + ), + ], + grid=(grid_tile_n, grid_m, grid_n // grid_tile_n), + grid_names=("tile_n", "m", "n"), + cluster=(cluster_tile_n, cluster_m, cluster_n // cluster_tile_n), + cluster_names=("x", "y", "z"), + ) + def kernel( + a_gmem, + b_gmem, + o_gmem, + a_smem, + b_smem, + o_smem, + acc, + barrier, + cluster_barrier, + ): + m_slice = pl.ds(lax.axis_index("m") * tile_m, tile_m) + n_slice = pl.ds( + (lax.axis_index("tile_n") + lax.axis_index("n") * grid_tile_n) + * tile_n, + tile_n, + ) + + def fetch(step, slot): + if not isinstance(slot, int): # Skip in initialization. + plgpu.barrier_arrive(cluster_barrier.at[slot]) + plgpu.barrier_wait(cluster_barrier.at[slot]) + + k_slice = pl.ds(step * tile_k, tile_k) + plgpu.copy_gmem_to_smem( + a_gmem.at[m_slice, k_slice], + a_smem.at[slot], + barrier.at[slot], + collective_axes=("x", "z"), + ) + plgpu.copy_gmem_to_smem( + b_gmem.at[k_slice, n_slice], + b_smem.at[slot], + barrier.at[slot], + collective_axes="y", + ) + + # Initialize the pipeline. + for slot in range(min(max_concurrent_steps, grid_k)): + fetch(slot, slot) + + def body(step, _): + slot = step % max_concurrent_steps + plgpu.barrier_wait(barrier.at[slot]) + + plgpu.wgmma(acc, a_smem.at[slot], b_smem.at[slot]) + plgpu.wgmma_wait(delay_release) + + fetch_step = step + (max_concurrent_steps - delay_release) + fetch_slot = lax.rem(fetch_step, max_concurrent_steps) + jax.lax.cond( + lax.bitwise_and(step >= delay_release, fetch_step < grid_k), + lambda: fetch(fetch_step, fetch_slot), + lambda: None, + ) + return () + + jax.lax.fori_loop(0, grid_k, body, ()) + + # Finalize the pipeline. + o_smem[...] = acc[...].astype(dtype) + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(o_smem, o_gmem.at[m_slice, n_slice]) + plgpu.wait_smem_to_gmem(0) + + key1, key2 = jax.random.split(jax.random.key(42), 2) + a = jax.random.uniform(key1, shape=(m, k), dtype=dtype) + b = jax.random.uniform(key2, shape=(k, n), dtype=dtype) + np.testing.assert_array_equal(kernel(a, b), a @ b) + class ExamplesTest(PallasTest): From b3a2c5341db9ad04c464b878ec4e59ffe9498918 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 28 Mar 2025 14:14:07 -0700 Subject: [PATCH 268/483] [NFC] Fix linter errors in pipeline file PiperOrigin-RevId: 741644574 --- jax/_src/pallas/mosaic/pipeline.py | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index 184b1497adf9..9b0a9322c94d 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -213,8 +213,8 @@ class BufferedRef: spec: pl.BlockSpec # static metadata dtype: Any # static metadata buffer_type: BufferType # static metadata - window_ref: REF | None - accum_ref: REF | None + window_ref: ArrayRef | None + accum_ref: ArrayRef | None current_slot: ArrayRef | None # TODO(ramiroleal): Unused by class. Remove argument from # BufferedRef instantiations. @@ -337,6 +337,7 @@ def memory_space(self): def current_ref(self): buffer_slice = tuple( 0 if x is None else slice(None) for x in self.block_shape) + assert not (self.window_ref is None or isinstance(self.window_ref, REF)) if self.memory_space == VMEM: return self.window_ref.at[buffer_slice] else: @@ -368,10 +369,12 @@ def is_input_output(self): @property def current_slot_index(self): + """Index in double buffer corresponding to the current slot.""" return self.current_slot[0] @property def next_slot_index(self): + """Index in double buffer corresponding to the next slot.""" return lax.rem(self.current_slot_index + 1, 2) def bind_existing_ref(self, window_ref, indices): @@ -463,6 +466,8 @@ def copy_in(self, src_ref, grid_indices): """Starts copy of HBM dma slice into the current slot.""" assert self.is_input if self.memory_space == VMEM: return + assert not (self.window_ref is None or isinstance(self.window_ref, REF)) + assert self.sem_recvs is not None if self.swap is not None: self.swap[0] = True next_slot = self.next_slot_index @@ -470,7 +475,7 @@ def copy_in(self, src_ref, grid_indices): dst_slice = tuple(pl.ds(0, s.size) for s in src_slice) tpu_primitives.make_async_copy( src_ref.at[src_slice], - self.window_ref.at[next_slot].at[dst_slice], + self.window_ref.at[(next_slot, *dst_slice)], self.sem_recvs.at[next_slot], ).start() @@ -478,13 +483,15 @@ def copy_out(self, dst_ref, grid_indices): """Starts copy of HBM dma slice from the current slot.""" assert self.is_output if self.memory_space == VMEM: return + assert not (self.window_ref is None or isinstance(self.window_ref, REF)) + assert self.sem_sends is not None if self.swap is not None: self.swap[0] = True slot = self.current_slot_index dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices) src_slice = tuple(pl.ds(0, s.size) for s in dst_slice) tpu_primitives.make_async_copy( - self.window_ref.at[slot].at[src_slice], + self.window_ref.at[(slot, *src_slice)], dst_ref.at[dst_slice], self.sem_sends.at[slot], ).start() @@ -493,13 +500,15 @@ def wait_in(self, src_ref, grid_indices): """Waits for input copy to finish.""" assert self.is_input if self.memory_space == VMEM: return + assert not (self.window_ref is None or isinstance(self.window_ref, REF)) + assert self.sem_recvs is not None src_slice = self.get_dma_slice(src_ref.shape, src_ref.dtype, grid_indices) dst_slice = tuple(pl.ds(0, s.size) for s in src_slice) current_slot = self.current_slot_index tpu_primitives.make_async_copy( src_ref.at[src_slice], # nb: doesn't matter - self.window_ref.at[current_slot].at[ - dst_slice + self.window_ref.at[ + (current_slot, *dst_slice) ], # only dst shape is important self.sem_recvs.at[current_slot], ).wait() @@ -508,12 +517,14 @@ def wait_out(self, dst_ref, grid_indices): """Waits for output copy to finish.""" assert self.is_output if self.memory_space == VMEM: return + assert not (self.window_ref is None or isinstance(self.window_ref, REF)) + assert self.sem_sends is not None # In a double buffer, previous slot is the same as next slot. prev_slot = self.next_slot_index dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices) src_slice = tuple(pl.ds(0, s.size) for s in dst_slice) tpu_primitives.make_async_copy( - self.window_ref.at[prev_slot].at[src_slice], # nb: doesn't matter + self.window_ref.at[(prev_slot, *src_slice)], # nb: doesn't matter dst_ref.at[dst_slice], # only dst shape is important self.sem_sends.at[prev_slot], ).wait() @@ -533,16 +544,18 @@ def set_accumulator(self, init=False): """Set accumulator or zero it out to initialize.""" assert self.is_accumulator if self.accum_ref is not None: + accum_dtype = self.accum_ref.dtype def _init(): self.accum_ref[...] = jnp.zeros_like(self.accum_ref[...]) def _set(): - self.accum_ref[...] = self.current_ref[...].astype(self.accum_ref.dtype) + self.accum_ref[...] = self.current_ref[...].astype(accum_dtype) lax.cond(init, _init, _set) def accumulate(self): """Add into the current slot.""" assert self.is_accumulator if self.accum_ref is not None: + assert self.window_ref is not None accum_dtype = jnp.float32 if self.window_ref.dtype == jnp.int32: accum_dtype = jnp.int32 From 91dac631fb79297a947d4742fb79e5898ece31c5 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 28 Mar 2025 14:15:25 -0700 Subject: [PATCH 269/483] scan: improve docs & errors around dynamic length --- jax/_src/lax/control_flow/loops.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index c7bcb1cf6b09..0362c139570a 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -178,6 +178,11 @@ def scan(f, init, xs, length=None): :py:func:`scan` compiles ``f``, so while it can be combined with :py:func:`jit`, it's usually unnecessary. + .. note:: + :func:`scan` is designed for iterating with a static number of iterations. + For iteration with a dynamic number of iterations, use :func:`fori_loop` + or :func:`while_loop`. + Args: f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning that ``f`` accepts two arguments where the first is a value of the loop @@ -239,7 +244,9 @@ def scan(f, init, xs, length=None): try: length = int(length) except core.ConcretizationTypeError as err: - msg = 'The `length` argument to `scan` expects a concrete `int` value.' + msg = ('The `length` argument to `scan` expects a concrete `int` value.' + ' For scan-like iteration with a dynamic length, use `while_loop`' + ' or `fori_loop`.') raise core.ConcretizationTypeError(length, msg) from None # type: ignore[arg-type] if not all(length == l for l in lengths): msg = ("scan got `length` argument of {} which disagrees with " From b719ac00c63ebb74766e7be3b142c046213a18ec Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 28 Mar 2025 15:12:41 -0700 Subject: [PATCH 270/483] Use f32 scratch for output so we only need to transfer output with desired dtype back to HBM. We use f32 as the dtype inside the kernel. Before we write the result from vmem to hbm, we convert to the desired dtype (eg bf16). So we can save memory bandwidth. Also, made minor change by checking sliding window and logit soft capping in the function that checks the static value. PiperOrigin-RevId: 741660728 --- .../pallas/ops/tpu/ragged_paged_attention.py | 54 +++++++++++-------- .../pallas/tpu_ragged_paged_attention_test.py | 8 +-- 2 files changed, 35 insertions(+), 27 deletions(-) diff --git a/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py b/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py index 255670c22e90..e1eacee550a7 100644 --- a/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py +++ b/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py @@ -83,8 +83,8 @@ def ref_ragged_paged_attention( soft_cap: float | None = None, mask_value: float | None = DEFAULT_MASK_VALUE, ): - check_inputs_shapes( - queries, kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs + validate_static_inputs( + queries, kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs, sliding_window, soft_cap ) if mask_value is None: mask_value = DEFAULT_MASK_VALUE @@ -130,7 +130,7 @@ def ref_ragged_paged_attention( # Expect to run these checkes during runtime. -def validate_inputs_on_runtime( +def validate_dynamic_inputs( q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] kv_lens: jax.Array, # i32[max_num_seqs] @@ -140,7 +140,7 @@ def validate_inputs_on_runtime( sliding_window: int | None = None, soft_cap: float | None = None, ): - check_inputs_shapes(q, kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs) + validate_static_inputs(q, kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs, sliding_window, soft_cap) max_num_batched_tokens = q.shape[0] page_size = kv_pages.shape[1] max_num_seqs, pages_per_seq = page_indices.shape @@ -165,20 +165,18 @@ def validate_inputs_on_runtime( raise ValueError( f"{q_len=} must be less or equal to {kv_len=} at sequence {i}." ) - if sliding_window is not None and sliding_window <= 0: - raise ValueError(f"{sliding_window=} must be positive.") - if soft_cap is not None and soft_cap == 0.0: - raise ValueError(f"{soft_cap=} must not be 0.0.") # Expect to run these checks during compile time. -def check_inputs_shapes( +def validate_static_inputs( q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] kv_lens: jax.Array, # i32[max_num_seqs] page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] cu_q_lens: jax.Array, # i32[max_num_seqs + 1] num_seqs, # i32[1] + sliding_window: int | None = None, + soft_cap: float | None = None, ): _, num_q_heads, head_dim = q.shape _, _, num_combined_kv_heads, head_dim_k = kv_pages.shape @@ -213,6 +211,10 @@ def check_inputs_shapes( ) if num_q_heads % num_kv_heads != 0: raise ValueError(f"{num_q_heads=} must be divisible by {num_kv_heads=}") + if sliding_window is not None and sliding_window <= 0: + raise ValueError(f"{sliding_window=} must be positive.") + if soft_cap is not None and soft_cap == 0.0: + raise ValueError(f"{soft_cap=} must not be 0.0.") def ragged_paged_attention_kernel( @@ -233,6 +235,7 @@ def ragged_paged_attention_kernel( sems, # [2, 2] l_ref, # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128] m_ref, # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128] + acc_ref, # [num_q_per_blk, num_q_heads_per_blk, head_dim] *, sm_scale: float, sliding_window: int | None = None, @@ -357,7 +360,7 @@ def flash_attention( v, # [num_kv_per_blk, head_dim] head_l_ref, # [num_q_per_blk * num_q_heads_per_kv_head, 128] head_m_ref, # [num_q_per_blk * num_q_heads_per_kv_head, 128] - head_o_ref, # [num_q_per_blk, num_q_heads_per_kv_head, head_dim] + head_acc_ref, # [num_q_per_blk, num_q_heads_per_kv_head, head_dim] *, kv_blk_idx, ): @@ -378,7 +381,7 @@ def flash_attention( num_q_per_blk * num_q_heads_per_kv_head, 128, ) - assert head_o_ref.shape == ( + assert head_acc_ref.shape == ( num_q_per_blk, num_q_heads_per_kv_head, head_dim, @@ -414,8 +417,8 @@ def init_scratch_ref(): num_q_heads_per_kv_head, ) masked_store( - head_o_ref, - jnp.zeros_like(head_o_ref), + head_acc_ref, + jnp.zeros_like(head_acc_ref), store_start, store_end, ) @@ -481,17 +484,17 @@ def broadcast_to_shape(arr, shape): [arr for _ in range(shape[1] // arr.shape[1])], axis=1 ) - o_curr = head_o_ref[...].reshape(-1, head_dim) + o_curr = head_acc_ref[...].reshape(-1, head_dim) l_alpha = broadcast_to_shape(l_alpha, qkv.shape) beta = broadcast_to_shape(beta, qkv.shape) l_next_safe = broadcast_to_shape(l_next_safe, qkv.shape) out = lax.div( l_alpha * o_curr + beta * qkv, l_next_safe, - ).astype(head_o_ref.dtype) + ) masked_store( - head_o_ref, - out.reshape(head_o_ref.shape), + head_acc_ref, + out.reshape(head_acc_ref.shape), store_start, store_end, ) @@ -544,7 +547,7 @@ def prefetch_next_kv_blk(): v, l_ref.at[kv_head_idx], m_ref.at[kv_head_idx], - o_ref.at[:, q_head_idx : q_head_idx + num_q_heads_per_kv_head, :], + acc_ref.at[:, q_head_idx : q_head_idx + num_q_heads_per_kv_head, :], kv_blk_idx=kv_blk_idx, ) return kv_blk_idx + 1, next_buf_idx @@ -566,6 +569,7 @@ def prefetch_next_kv_blk(): # Reset seq_idx for next kv_heads_blk if run out of seqs! seq_buf_idx_ref[0] = lax.select(seq_idx < num_seqs, seq_idx, 0) seq_buf_idx_ref[1] = buf_idx + o_ref[...] = acc_ref[...].astype(q_ref.dtype) def cdiv(a, b): @@ -662,6 +666,7 @@ def ragged_paged_attention( num_seqs: the dynamic number of sequences. sm_scale: the softmax scale which will be applied to the Q@K^T. sliding_window: the sliding window size for the attention. + soft_cap: the logit soft cap for the attention. mask_value: mask value for causal mask. num_kv_pages_per_block: number of kv pages to be processed in one flash attention block in the pallas kernel. @@ -672,7 +677,7 @@ def ragged_paged_attention( Returns: The output of the attention. """ - check_inputs_shapes(q, kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs) + validate_static_inputs(q, kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs, sliding_window, soft_cap) if mask_value is None: mask_value = DEFAULT_MASK_VALUE _, num_q_heads, head_dim = q.shape @@ -710,6 +715,10 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_): (num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128), jnp.float32, ) + acc_scratch = pltpu.VMEM( + (num_q_per_blk, num_q_heads_per_blk, head_dim), + jnp.float32, + ) double_buf_scratch = pltpu.VMEM( ( 2, # For double buffering during DMA copies. @@ -725,6 +734,7 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_): pltpu.SemaphoreType.DMA((2,)), # Semaphores for double buffers. lm_scratch, # l_ref lm_scratch, # m_ref + acc_scratch, ] scalar_prefetches = ( kv_lens, @@ -755,10 +765,8 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_): ), vmem_limit_bytes=vmem_limit_bytes, ), - out_shape=jax.ShapeDtypeStruct(shape=q.shape, dtype=jnp.float32), + out_shape=jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype), name="ragged_paged_attention_kernel", ) - # TODO(jevinjiang): Use f32 acc scratch for output! So we only need - # to transfer output with desired dtype back to HBM. - return kernel(*scalar_prefetches, q, kv_pages).astype(q.dtype) + return kernel(*scalar_prefetches, q, kv_pages) diff --git a/tests/pallas/tpu_ragged_paged_attention_test.py b/tests/pallas/tpu_ragged_paged_attention_test.py index b76d30bd1dcf..8d48bc281400 100644 --- a/tests/pallas/tpu_ragged_paged_attention_test.py +++ b/tests/pallas/tpu_ragged_paged_attention_test.py @@ -21,7 +21,7 @@ from jax.experimental.pallas.ops.tpu.ragged_paged_attention import ( ragged_paged_attention, ref_ragged_paged_attention, - validate_inputs_on_runtime, + validate_dynamic_inputs, ) import jax.numpy as jnp @@ -91,15 +91,15 @@ def _test_ragged_paged_attention( num_seqs = jnp.array([len(seq_lens)], dtype=jnp.int32) - validate_inputs_on_runtime( + validate_dynamic_inputs( q, kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs, - sliding_window=sliding_window, - soft_cap=soft_cap, + sliding_window, + soft_cap, ) actual_num_q_tokens = cu_q_lens[num_seqs[0]] From 177193662cba6a228fc26cc5a08efb073ec775ab Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 28 Mar 2025 15:15:22 -0700 Subject: [PATCH 271/483] Add vma rules for all_gather, all_to_all, ppermute and reduce_scatter primitives PiperOrigin-RevId: 741661360 --- jax/_src/lax/parallel.py | 60 +++++++++++++++++++++++++++++++--------- tests/shard_map_test.py | 17 ++++++++++++ 2 files changed, 64 insertions(+), 13 deletions(-) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 28e6dbef4a2c..3ef0a2520378 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -25,6 +25,7 @@ import jax from jax import tree_util from jax._src import core +from jax._src import config from jax._src import dispatch from jax._src import dtypes from jax._src.sharding_impls import (SPMDAxisContext, ShardingContext, @@ -325,9 +326,10 @@ def ppermute(x, axis_name, perm): """ if not isinstance(axis_name, (list, tuple)): axis_name = (axis_name,) - return tree_util.tree_map( - partial(ppermute_p.bind, axis_name=axis_name, - perm=tuple(map(tuple, perm))), x) + def bind(leaf): + leaf = insert_collective_pbroadcast(axis_name, leaf) + return ppermute_p.bind(leaf, axis_name=axis_name, perm=tuple(map(tuple, perm))) + return tree_util.tree_map(bind, x) def pshuffle(x, axis_name, perm): """Convenience wrapper of jax.lax.ppermute with alternate permutation encoding @@ -447,6 +449,7 @@ def bind(x, split_axis=split_axis, concat_axis=concat_axis): else: # concat_axis < split_axis x = lax.expand_dims(x, (concat_axis,)) # insert the new axis split_axis += 1 # we have a new axis before split_axis now + x = insert_collective_pbroadcast(axis_name, x) result = all_to_all_p.bind(x, split_axis=split_axis, concat_axis=concat_axis, axis_name=axis_name, axis_index_groups=axis_index_groups, @@ -975,6 +978,7 @@ def _ppermute_batcher(axis_data, vals_in, dims_in, axis_name, perm): def _raise_to_shaped_abstract_eval(x, *, axis_name, **params): _check_axis_names(axis_name) + collective_vma_rule('ppermute', axis_name, x) return x ppermute_p = core.Primitive('ppermute') @@ -1189,7 +1193,8 @@ def _all_to_all_effectful_abstract_eval( assert shape[split_axis] % axis_size == 0, (shape[split_axis], axis_size) shape[split_axis] //= axis_size shape[concat_axis] *= axis_size - out_aval = input_aval.update(shape=tuple(shape), weak_type=False) + vma = collective_vma_rule('all_to_all', axis_name, input_aval) + out_aval = input_aval.update(shape=tuple(shape), weak_type=False, vma=vma) effects = {*map(core.NamedAxisEffect, axis_name)} return out_aval, effects @@ -1313,6 +1318,19 @@ def _ragged_all_to_all_transpose( mlir.register_lowering(ragged_all_to_all_p, _ragged_all_to_all_lowering) batching.skippable_batchers[ragged_all_to_all_p] = partial(_names_in_param, 'axis_name') +def insert_collective_pbroadcast(axis_name, x): + if not config.varying_axes_in_types.value: + return x + + from jax.experimental import shard_map + axis_name = (axis_name,) if not isinstance(axis_name, tuple) else axis_name + aval = core.get_aval(x) + names_union = set(axis_name) | aval.vma + pbroadcast_axis_name = tuple(n for n in names_union if n not in aval.vma) + if pbroadcast_axis_name: + x = shard_map.pbroadcast(x, pbroadcast_axis_name) + return x + def all_gather(x, axis_name, *, axis_index_groups=None, axis=0, tiled=False): """Gather values of x across all replicas. @@ -1382,6 +1400,7 @@ def all_gather(x, axis_name, *, axis_index_groups=None, axis=0, tiled=False): axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups) axis_size = psum(1, axis_name, axis_index_groups=axis_index_groups) def bind(leaf): + leaf = insert_collective_pbroadcast(axis_name, leaf) return all_gather_p.bind( leaf, all_gather_dimension=canonicalize_axis( @@ -1434,6 +1453,19 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name, **other_args).results +def collective_vma_rule(prim_name, axis_name, x_aval): + if not config.varying_axes_in_types.value: + return frozenset() + axis_name = (axis_name,) if not isinstance(axis_name, tuple) else axis_name + if any(a not in x_aval.vma for a in axis_name): + raise ValueError( + f"Collective {prim_name} must be applied to a device-varying " + f" type, but got {x_aval.vma} for collective acting " + f"over axis name {axis_name}. Please open an issue at " + "https://github.com/jax-ml/jax/issues and as a temporary " + "workaround pass the check_rep=False argument to shard_map") + return x_aval.vma + def _all_gather_effectful_abstract_eval( x_aval, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled ): @@ -1445,7 +1477,9 @@ def _all_gather_effectful_abstract_eval( new_shape[all_gather_dimension] *= axis_size else: new_shape.insert(all_gather_dimension, axis_size) - return x_aval.update(shape=new_shape), {*map(core.NamedAxisEffect, axis_name)} + out_vma = collective_vma_rule('all_gather', axis_name, x_aval) + return (x_aval.update(shape=new_shape, vma=out_vma), + {*map(core.NamedAxisEffect, axis_name)}) def _all_gather_transpose_rule(cts, x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled): return (psum_scatter(cts, axis_name=axis_name, @@ -1582,7 +1616,9 @@ def _reduce_scatter_effectful_abstract_eval( f"{scatter_dim_input_size} must match shard count " f"{axis_size}") del new_shape[scatter_dimension] - return x_aval.update(shape=new_shape), {*map(core.NamedAxisEffect, axis_name)} + vma = collective_vma_rule('reduce_scatter', axis_name, x_aval) + return (x_aval.update(shape=new_shape, vma=vma), + {*map(core.NamedAxisEffect, axis_name)}) def _reduce_scatter_transpose_rule(cts, x, *, axis_name, scatter_dimension, @@ -1726,13 +1762,11 @@ def psum_scatter(x, axis_name, *, scatter_dimension=0, axis_index_groups=None, axis_name = axis_name, axis_size = psum(1, axis_name, axis_index_groups=axis_index_groups) axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups) - bind = partial( - reduce_scatter_p.bind, - axis_name=axis_name, - scatter_dimension=scatter_dimension, - axis_index_groups=axis_index_groups, - axis_size=axis_size, - tiled=tiled) + def bind(leaf): + leaf = insert_collective_pbroadcast(axis_name, leaf) + return reduce_scatter_p.bind( + leaf, axis_name=axis_name, scatter_dimension=scatter_dimension, + axis_index_groups=axis_index_groups, axis_size=axis_size, tiled=tiled) return tree_util.tree_map(bind, x) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 1ffb3e1d137a..c1923f5b0ae3 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -2707,6 +2707,23 @@ def f(x): # return jnp.sum(f(x, y)) # print(jax.jit(jax.grad(g)).trace(x, y).jaxpr) + @config.varying_axes_in_types(True) + def test_all_gather_with_vma_in_types(self): + mesh = jtu.create_mesh((2,), ('x',)) + x = np.arange(8.) + + def f(x): + self.assertEqual(x.aval.vma, frozenset()) + out = jax.lax.all_gather(x, 'x') + self.assertEqual(out.aval.vma, frozenset({'x'})) + return out + + f = jax.jit(shard_map(f, mesh=mesh, in_specs=P(), out_specs=P('x'))) + jaxpr = f.trace(x).jaxpr + self.assertIn("pbroadcast[axes=('x',)", str(jaxpr)) + + f(x) # doesn't crash + class FunSpec(NamedTuple): name: str From dafebd0d7f2c79dafb3fa2e6f358bdb67d0dfaa9 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 28 Mar 2025 15:20:58 -0700 Subject: [PATCH 272/483] DOC: add documentation note about default dtypes --- docs/default_dtypes.md | 82 ++++++++++++++++++++++++++++++++ docs/notes.rst | 8 +++- docs/user_guides.rst | 1 - jax/_src/numpy/array_creation.py | 9 ++-- 4 files changed, 95 insertions(+), 5 deletions(-) create mode 100644 docs/default_dtypes.md diff --git a/docs/default_dtypes.md b/docs/default_dtypes.md new file mode 100644 index 000000000000..629f7fb5c314 --- /dev/null +++ b/docs/default_dtypes.md @@ -0,0 +1,82 @@ +(default-dtypes)= +# Default dtypes and the X64 flag +JAX strives to meet the needs of a range of numerical computing practitioners, who +sometimes have conflicting preferences. When it comes to default dtypes, there are +two different camps: + +- Classic scientific computing practitioners (i.e. users of tools like {mod}`numpy` or + {mod}`scipy`) tend to value accuracy of computations foremost: such users would + prefer that computations default to the **widest available representation**: e.g. + floating point values should default to `float64`, integers to `int64`, etc. +- AI researchers (i.e. folks implementing and training neural networks) tend to value + speed over accuracy, to the point where they have developed special data types like + [bfloat16](https://en.wikipedia.org/wiki/Bfloat16_floating-point_format) and others + which deliberately discard the least significant bits in order to speed up computation. + For these users, the mere presence of a float64 value in their computation can lead + to programs that are slow at best, and incompatible with their hardware at worst! + These users would prefer that computations default to `float32` or `int32`. + +The main mechanism JAX offers for this is the `jax_enable_x64` flag, which controls +whether 64-bit values can be created at all. By default this flag is set to `False` +(serving the needs of AI researchers and practitioners), but can be set to `True` +by users who value accuracy over computational speed. + +## Default setting: 32-bits everywhere +By default `jax_enable_x64` is set to False, and so {mod}`jax.numpy` array creation +functions will default to returning 32-bit values. + +For example: +```python +>>> import jax.numpy as jnp + +>>> jnp.arange(5) +Array([0, 1, 2, 3, 4], dtype=int32) + +>>> jnp.zeros(5) +Array([0., 0., 0., 0., 0.], dtype=float32) + +>>> jnp.ones(5, dtype=int) +Array([1, 1, 1, 1, 1], dtype=int32) + +``` + +Beyond defaults, because 64-bit values can be so poisonous to AI workflows, having +this flag set to False prevents you from creating 64-bit arrays at all! For example: +``` +>>> jnp.arange(5, dtype='float64') # doctest: +SKIP +UserWarning: Explicitly requested dtype float64 requested in arange is not available, and will be +truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the +JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more. +Array([0., 1., 2., 3., 4.], dtype=float32) +``` + +## The X64 flag: enabling 64-bit values +To work in the "other mode" where functions default to producing 64-bit values, you can set the +`jax_enable_x64` flag to `True`: +```python +import jax +import jax.numpy as jnp + +jax.config.update('jax_enable_x64', True) + +print(repr(jnp.arange(5))) +print(repr(jnp.zeros(5))) +print(repr(jnp.ones(5, dtype=int))) +``` +``` +Array([0, 1, 2, 3, 4], dtype=int64) +Array([0., 0., 0., 0., 0.], dtype=float64) +Array([1, 1, 1, 1, 1], dtype=int64) +``` + +The X64 configuration can also be set via the `JAX_ENABLE_X64` shell environment variable, +for example: +```bash +$ JAX_ENABLE_X64=1 python main.py +``` +The X64 flag is intended as a **global setting** that should have one value for your whole +program, set at the top of your main file. A common feature request is for the flag to +be contextually configurable (e.g. enabling X64 just for one section of a long program): +this turns out to be difficult to implement within JAX's programming model, where code +execution may happen in a different context than code compilation. There is ongoing work +exploring the feasibility of relaxing this constraint, so stay tuned! diff --git a/docs/notes.rst b/docs/notes.rst index 24a9dc8594cd..502385142b16 100644 --- a/docs/notes.rst +++ b/docs/notes.rst @@ -17,6 +17,10 @@ Memory and computation usage: Programmer guardrails: - :doc:`rank_promotion_warning` describes how to configure :mod:`jax.numpy` to avoid implicit rank promotion. +Arrays and data types: + - :doc:`type_promotion` describes JAX's implicit type promotion for functions of two or more values. + - :doc:`default_dtypes` describes how JAX determines the default dtype for array creation functions. + .. toctree:: :hidden: @@ -27,4 +31,6 @@ Programmer guardrails: async_dispatch concurrency gpu_memory_allocation - rank_promotion_warning \ No newline at end of file + rank_promotion_warning + type_promotion + default_dtypes diff --git a/docs/user_guides.rst b/docs/user_guides.rst index 6481da7a31dd..47984fc493f4 100644 --- a/docs/user_guides.rst +++ b/docs/user_guides.rst @@ -26,7 +26,6 @@ or deployed codebases. errors aot export/index - type_promotion transfer_guard .. toctree:: diff --git a/jax/_src/numpy/array_creation.py b/jax/_src/numpy/array_creation.py index a0495986fcd1..4f07f94fe8b4 100644 --- a/jax/_src/numpy/array_creation.py +++ b/jax/_src/numpy/array_creation.py @@ -50,7 +50,8 @@ def zeros(shape: Any, dtype: DTypeLike | None = None, *, Args: shape: int or sequence of ints specifying the shape of the created array. - dtype: optional dtype for the created array; defaults to floating point. + dtype: optional dtype for the created array; defaults to float32 or float64 + depending on the X64 configuration (see :ref:`default-dtypes`). device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` to which the created array will be committed. @@ -87,7 +88,8 @@ def ones(shape: Any, dtype: DTypeLike | None = None, *, Args: shape: int or sequence of ints specifying the shape of the created array. - dtype: optional dtype for the created array; defaults to floating point. + dtype: optional dtype for the created array; defaults to float32 or float64 + depending on the X64 configuration (see :ref:`default-dtypes`). device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` to which the created array will be committed. @@ -126,7 +128,8 @@ def empty(shape: Any, dtype: DTypeLike | None = None, *, Args: shape: int or sequence of ints specifying the shape of the created array. - dtype: optional dtype for the created array; defaults to floating point. + dtype: optional dtype for the created array; defaults to float32 or float64 + depending on the X64 configuration (see :ref:`default-dtypes`). device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` to which the created array will be committed. From 6fba4ecc58b21a478f223aeba3b8dfff6cef39c7 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 28 Mar 2025 15:20:27 -0700 Subject: [PATCH 273/483] PR #27576: [attrs] experimental appendattr Imported from GitHub PR https://github.com/jax-ml/jax/pull/27576 This is an experimental extension to attrs. Attrs should be considered both experimental and deprecated. This PR also includes some fixes for getattr/setattr. Copybara import of the project: -- 3b1ea1a5f90b28744522670d0498ce5a6b194274 by Matthew Johnson : [attrs] experimental appendattr Merging this change closes #27576 COPYBARA_INTEGRATE_REVIEW=https://github.com/jax-ml/jax/pull/27576 from mattjj:appendattr b93795201b39b8f75890c9228368c994ae1e38e8 PiperOrigin-RevId: 741662724 --- jax/_src/interpreters/partial_eval.py | 68 +++++--- jax/_src/lax/control_flow/loops.py | 92 +++++++---- jax/_src/pjit.py | 82 +++++----- jax/experimental/attrs.py | 63 +++++++- tests/attrs_test.py | 215 +++++++++++++++++++++++++- 5 files changed, 430 insertions(+), 90 deletions(-) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 58b97ce2f3da..0a8e3b7824ff 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -58,6 +58,13 @@ def identity(x): return x AvalId = int ConstId = int +AttrKind = Any +PyTree = Any + +# Attrs flavors, see jax/experimental/attrs.py +ReadWrite = type('ReadWrite', (), {})() +Append = type('Append', (), {})() + def _update_annotation_known( f: lu.WrappedFun, orig_type: InputType | None, @@ -1553,6 +1560,17 @@ def move_binders_to_back(closed_jaxpr: ClosedJaxpr, to_move: Sequence[bool] """Reorder `invars` by moving those indicated in `to_move` to the back.""" return move_binders_to_front(closed_jaxpr, map(op.not_, to_move)) +def move_outvars_to_back(jaxpr: ClosedJaxpr, to_move: Sequence[bool]) -> ClosedJaxpr: + return _move_outvars_to_back(jaxpr, tuple(to_move)) + +@weakref_lru_cache +def _move_outvars_to_back(jaxpr, to_move): + new_outvars = ([e for e, m in zip(jaxpr.jaxpr.outvars, to_move) if not m] + + [e for e, m in zip(jaxpr.jaxpr.outvars, to_move) if m]) + return jaxpr.replace(jaxpr=jaxpr.jaxpr.replace(outvars=new_outvars)) + + + class DynamicJaxprTracer(core.Tracer): __slots__ = ['aval', '_debug_info'] @@ -1657,7 +1675,7 @@ class JaxprStackFrame: eqns: list[JaxprEqn] invars: list[Var] effects: core.Effects - attrs_tracked: list[tuple[Any, str]] + attrs_tracked: list[tuple[Any, str, AttrKind]] attrs_inits: list attrs_vars: list[Var] debug_info: core.DebugInfo @@ -1679,10 +1697,14 @@ def __init__(self, debug_info: core.DebugInfo): def add_eqn(self, eqn: core.JaxprEqn): self.eqns.append(eqn) - def to_jaxpr(self, trace: DynamicJaxprTrace, - out_tracers: Sequence[Tracer], - debug_info: core.DebugInfo, - ) -> tuple[Jaxpr, list[Any], list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]: + def reset_states(self): + reset_states(self.attrs_tracked, self.attrs_inits) + + def to_jaxpr( + self, trace: DynamicJaxprTrace, + out_tracers: Sequence[Tracer], + debug_info: core.DebugInfo, + ) -> tuple[Jaxpr, list[Any], list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str, AttrKind]]]]: # It's not necessary, but we keep the tracer-to-var mapping injective: assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values())) invars = self.attrs_vars + self.invars @@ -1699,7 +1721,6 @@ def to_jaxpr(self, trace: DynamicJaxprTrace, jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals) jaxpr, constvals = _inline_literals(jaxpr, constvals) init_trees = [tree_structure(init_val) for init_val in self.attrs_inits] - set_states(self.attrs_tracked, self.attrs_inits) # reset to initial values return jaxpr, list(constvals), zip(init_trees, end_trees, self.attrs_tracked) def to_jaxpr2(self, out_tracers: Sequence[core.Tracer], @@ -1840,10 +1861,9 @@ def vars_in_shape(aval: AbstractValue) -> Sequence[Var]: outvars = [var(v) if v in used else dropvar(v.aval) for v in eqn.outvars] new_eqns.append(eqn.replace(invars=invars, outvars=outvars)) new_outvars = [lit_or_var(v) for v in jaxpr.outvars] - jaxpr_effects = make_jaxpr_effects(new_constvars, new_invars, new_outvars, - new_eqns) - new_jaxpr = Jaxpr(new_constvars, new_invars, new_outvars, new_eqns, - jaxpr_effects, jaxpr.debug_info) + effs = make_jaxpr_effects(new_constvars, new_invars, new_outvars, new_eqns) + new_jaxpr = Jaxpr(new_constvars, new_invars, new_outvars, new_eqns, effs, + jaxpr.debug_info) return new_jaxpr, new_constvals @@ -2172,19 +2192,23 @@ def trace_to_jaxpr_dynamic( *, keep_inputs: list[bool] | None = None, ) -> tuple[Jaxpr, list[AbstractValue], list[Any], - list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]: + list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str, AttrKind]]]]: keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs trace = DynamicJaxprTrace(fun.debug_info) with core.ensure_no_leaks(trace), source_info_util.reset_name_stack(): in_tracers = _input_type_to_tracers(trace.new_arg, in_avals) in_tracers = [t for t, keep in zip(in_tracers, keep_inputs) if keep] - with core.set_current_trace(trace): - ans = fun.call_wrapped(*in_tracers) + try: + with core.set_current_trace(trace): + ans = fun.call_wrapped(*in_tracers) - out_tracers = map(trace.to_jaxpr_tracer, ans) - _check_no_returned_refs(fun.debug_info, out_tracers) - jaxpr, consts, attrs_tracked = trace.to_jaxpr(out_tracers, fun.debug_info) - del trace, fun, in_tracers, out_tracers, ans + out_tracers = map(trace.to_jaxpr_tracer, ans) + _check_no_returned_refs(fun.debug_info, out_tracers) + jaxpr, consts, attrs_tracked = trace.to_jaxpr(out_tracers, fun.debug_info) + del fun, in_tracers, out_tracers, ans + finally: + trace.frame.reset_states() + del trace config.enable_checks.value and core.check_jaxpr(jaxpr) return jaxpr, [v.aval for v in jaxpr.outvars], consts, attrs_tracked @@ -2242,14 +2266,14 @@ def trace_to_jaxpr_dynamic2( tuple[AbstractedAxisName, ...], ] -AttrsTracked = list[tuple[Any, str]] +AttrsTracked = list[tuple[Any, str, AttrKind]] AttrStates = list -def set_states(attrs_tracked: AttrsTracked, vals: AttrStates): - for ((obj, attr), val) in zip(attrs_tracked, vals): +def reset_states(attrs_tracked: AttrsTracked, init_vals: AttrStates) -> None: + for ((obj, attr, _), val) in zip(attrs_tracked, init_vals): setattr(obj, attr, val) if val is not dne_sentinel else delattr(obj, attr) -def get_states(attrs_tracked: AttrsTracked): - return [getattr(obj, attr) for (obj, attr) in attrs_tracked] +def get_states(attrs_tracked: AttrsTracked) -> list[PyTree]: + return [getattr(obj, attr) for (obj, attr, kind) in attrs_tracked] @register_static class DoesNotExist: ... diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 0362c139570a..56323949a607 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -298,8 +298,17 @@ def _create_jaxpr(init): if len(out_tree_children) != 2: msg = "scan body output must be a pair, got {}." raise TypeError(msg.format(tree_unflatten(out_tree, jaxpr.out_avals))) - _, carry_avals_out, _ = split_list( - jaxpr.out_avals, [len(attrs_tracked), out_tree_children[0].num_leaves]) + + if attrs_tracked: + appends_out = [kind is pe.Append for *_, (_, _, kind) in attrs_tracked] + jaxpr = pe.move_outvars_to_back( + jaxpr, appends_out + [False] * (len(jaxpr.out_avals) - len(appends_out))) + num_attr_carry = sum(init_tree.num_leaves for init_tree, _, (_, _, kind) + in attrs_tracked if kind is pe.ReadWrite) + _, carry_avals_out, _ = split_list( + jaxpr.out_avals, [num_attr_carry, out_tree_children[0].num_leaves]) + else: + carry_avals_out, _ = split_list(jaxpr.out_avals, [out_tree_children[0].num_leaves]) return (init_flat, carry_avals, carry_avals_out, init_tree, in_flat, jaxpr, consts, out_tree, out_tree_children, attrs_tracked) @@ -332,9 +341,8 @@ def _create_jaxpr(init): raise ValueError("`unroll` must be a `bool` or a positive `int`.") if attrs_tracked: in_state = _get_states(attrs_tracked) - in_carry, in_ext = split_list(in_flat, [num_carry]) - in_flat = [*in_state, *in_carry, *in_ext] - num_carry += len(attrs_tracked) + in_flat = [*in_state, *in_flat] + num_carry += len(in_state) out = scan_p.bind(*consts, *in_flat, reverse=reverse, length=length, jaxpr=jaxpr, num_consts=len(consts), num_carry=num_carry, @@ -342,27 +350,50 @@ def _create_jaxpr(init): unroll=unroll, _split_transpose=_split_transpose) if attrs_tracked: - out_state, out = split_list(out, [len(attrs_tracked)]) - _set_states(attrs_tracked, out_state) + num_ext = (len(out) - len(in_state) + - sum(k is pe.Append for *_, (_, _, k) in attrs_tracked)) + out_state, out, out_append = split_list(out, [len(in_state), num_ext]) + out_attrs = _merge_attrs_out(attrs_tracked, out_state, out_append) + _set_states(attrs_tracked, out_attrs) return tree_unflatten(out_tree, out) def _set_states(attrs_tracked, vals): - from jax.experimental.attrs import jax_setattr + from jax.experimental.attrs import jax_setattr, jax_extendattr valss = split_list_checked(vals, [td.num_leaves for _, td, _ in attrs_tracked]) - for ((_, treedef, (obj, attr)), leaves) in zip(attrs_tracked, valss): - val = tree_unflatten(treedef, leaves) - jax_setattr(obj, attr, val) + for ((_, treedef, (obj, attr, kind)), leaves) in zip(attrs_tracked, valss): + if kind is pe.ReadWrite: + val = tree_unflatten(treedef, leaves) + jax_setattr(obj, attr, val) + elif kind is pe.Append: + val, = leaves + jax_extendattr(obj, attr, val.reshape(-1, *val.shape[2:])) + else: + assert False def _get_states(attrs_tracked): from jax.experimental.attrs import jax_getattr vals = [] - for treedef, _, (obj, attr) in attrs_tracked: - tree = jax_getattr(obj, attr) - leaves, treedef_ = tree_flatten(tree) - assert treedef == treedef_ - vals.extend(leaves) + for treedef, _, (obj, attr, kind) in attrs_tracked: + if kind is pe.ReadWrite: + tree = jax_getattr(obj, attr) + leaves, treedef_ = tree_flatten(tree) + assert treedef == treedef_ + vals.extend(leaves) + elif kind is pe.Append: + pass + else: + assert False return vals +def _merge_attrs_out(attrs_tracked, out_state, out_append): + out_state_, out_append_ = iter(out_state), iter(out_append) + out_attrs = [item for _, out_tree, (_, _, k) in attrs_tracked for item in + (itertools.islice(out_state_, out_tree.num_leaves) + if k is pe.ReadWrite else [next(out_append_)])] + assert next(out_state_, None) is next(out_append_, None) is None + return out_attrs + + def _capitalize(s): # s.capitalize() converts s[1:] to lowercase which we don't want. return s[0].capitalize() + s[1:] @@ -662,7 +693,7 @@ def _scan_partial_eval(trace, *tracers, reverse: bool, # The above trace_to_jaxpr_nounits call computed loop-invariant residuals # (known values in invar_pvals_out) and also computed loop-invariant values # needed by the new jaxpr_known (in jaxpr_known_consts, which replace the - # previous consts). We need to collect the computed inteisive residuals, and + # previous consts). We need to collect the computed intensive residuals, and # move corresponding intensive residual binders in jaxpr_unknown to the front. res_pvals = invar_pvals_out[len(invar_pvals_out) - num_res:] intensive_res = [pval.get_known() for pval in res_pvals if pval.is_known()] @@ -785,16 +816,21 @@ def _scan_transpose(cts, *args, reverse, length, num_consts, ct_consts = _map(ad_util.zeros_like_aval, jaxpr.in_avals[num_ires:num_consts]) # jaxpr :: [ires, T d] -> [T c] -> [T a, eres] -> ([T c], [T b]) - # jaxpr_trans :: [ires] -> [CT d, CT c] -> [CT b, eres] -> ([CT d, CT c], [CT a]) + # jaxpr_trans :: [ires] -> [CT d, CT c] -> [CT b, eres] -> ([CT d, CT c], [CT a, e]) jaxpr_trans, attrs_tracked = _transpose_scan_jaxpr( jaxpr, num_ires, num_consts - num_ires, num_eres, ct_ys_is_zeros) - linear_trans = ([False] * num_ires + [False] * len(attrs_tracked) + + appends_out = [kind is pe.Append for *_, (_, _, kind) in attrs_tracked] + jaxpr_trans = pe.move_outvars_to_back( + jaxpr_trans, appends_out + [False] * (len(jaxpr_trans.out_avals) - len(appends_out))) + num_attr_carry = sum(init_tree.num_leaves for init_tree, _, (_, _, kind) + in attrs_tracked if kind is pe.ReadWrite) + linear_trans = ([False] * num_ires + [False] * num_attr_carry + [True] * (len(ct_consts) + len(ct_carry) + len(ct_ys)) + [False] * num_eres) in_state = _get_states(attrs_tracked) transpose_inputs = *ires, *in_state, *ct_consts, *ct_carry, *ct_ys, *eres - transpose_num_out_carry = num_consts-num_ires+num_carry+len(attrs_tracked) + transpose_num_out_carry = num_consts-num_ires+num_carry+num_attr_carry if not _split_transpose: outs = scan_p.bind( @@ -889,8 +925,10 @@ def _scan_transpose(cts, *args, reverse, length, num_consts, for mask in outs_mask ] - out_state, outs = split_list(outs, [len(attrs_tracked)]) - _set_states(attrs_tracked, out_state) + num_outs = len(outs) - num_attr_carry - sum(appends_out) + out_state, outs, out_append = split_list(outs, [num_attr_carry, num_outs]) + out_attrs = _merge_attrs_out(attrs_tracked, out_state, out_append) + _set_states(attrs_tracked, out_attrs) ct_consts, ct_init, ct_xs = split_list(outs, [num_consts - num_ires, num_carry]) return [None] * num_ires + ct_consts + ct_init + ct_xs + [None] * num_eres @@ -935,12 +973,10 @@ def transposed(*res1_cbar_bbar_res2): return c_bar + a_bar # TODO(necula): fix arg names and results for transposed - transposed_wrapped = lu.wrap_init(transposed, - debug_info=jaxpr.jaxpr.debug_info) - return _make_closed_jaxpr_attrs( - transposed_wrapped, - tuple(res1_avals + c_avals + b_carry_avals + - b_ys_avals_stripped + res2_avals)) + transposed_wrapped = lu.wrap_init(transposed, debug_info=jaxpr.jaxpr.debug_info) + trans_avals = (*res1_avals, *c_avals, *b_carry_avals, *b_ys_avals_stripped, *res2_avals) + trans_jaxpr, attrs_tracked = _make_closed_jaxpr_attrs(transposed_wrapped, trans_avals) + return trans_jaxpr, attrs_tracked def _scan_batching_rule(axis_data, args, diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 03eb6835cb06..5727c36a646b 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -233,20 +233,30 @@ def _python_pjit_helper(fun: Callable, jit_info: PjitInfo, *args, **kwargs): def _set_states(attrs_tracked, vals): - from jax.experimental.attrs import jax_setattr + from jax.experimental.attrs import jax_setattr, jax_extendattr valss = split_list(vals, [td.num_leaves for _, td, _ in attrs_tracked[:-1]]) - for ((_, treedef, (obj, attr)), leaves) in zip(attrs_tracked, valss): - val = tree_unflatten(treedef, leaves) - jax_setattr(obj, attr, val) + for ((_, treedef, (obj, attr, kind)), leaves) in zip(attrs_tracked, valss): + if kind is pe.ReadWrite: + val = tree_unflatten(treedef, leaves) + jax_setattr(obj, attr, val) + elif kind is pe.Append: + del treedef + val, = leaves + jax_extendattr(obj, attr, val) def _get_states(attrs_tracked): from jax.experimental.attrs import jax_getattr, dne_sentinel vals = [] - for treedef, _, (obj, attr) in attrs_tracked: - tree = jax_getattr(obj, attr) if hasattr(obj, attr) else dne_sentinel - leaves, treedef_ = tree_flatten(tree) - assert treedef == treedef_ - vals.extend(leaves) + for treedef, _, (obj, attr, kind) in attrs_tracked: + if kind is pe.ReadWrite: + tree = jax_getattr(obj, attr) if hasattr(obj, attr) else dne_sentinel + leaves, treedef_ = tree_flatten(tree) + assert treedef == treedef_ + vals.extend(leaves) + elif kind is pe.Append: + pass + else: + assert False return vals def _need_to_rebuild_with_fdo(pgle_profiler): @@ -537,7 +547,7 @@ class PjitParams(NamedTuple): donated_invars: tuple[bool, ...] arg_names: tuple[str, ...] num_consts: int - attrs_tracked: list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]] + attrs_tracked: list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str, Any]]] def _infer_params_impl( @@ -613,14 +623,14 @@ def _infer_params_impl( ji.in_layouts_treedef, ji.in_layouts_leaves, in_avals, in_tree, flat_fun.debug_info, device_or_backend_set, have_kwargs) - attr_token = _attr_token(flat_fun, in_type) + attr_token = _attr_cache_index(flat_fun, in_type) jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr( flat_fun, in_type, attr_token, IgnoreKey(ji.inline)) if config.mutable_array_checks.value: _check_no_aliased_closed_over_refs(dbg, (*jaxpr.consts, *consts), explicit_args) - _attr_update(flat_fun, in_type, attr_token, attrs_tracked) + _attr_cachedata_update(flat_fun, in_type, attr_token, attrs_tracked) out_shardings_flat, out_layouts_flat = _check_and_canonicalize_out_shardings( out_shardings_treedef, out_shardings_leaves, ji.out_layouts_treedef, @@ -636,13 +646,14 @@ def _infer_params_impl( implicit_args = [] args_flat = [*implicit_args, *explicit_args] - num_states_in = sum(init_tree.num_leaves for init_tree, _, _ in attrs_tracked) - num_extra_args = len(implicit_args) + num_states_in + len(consts) + num_attrs_in = sum(init_tree.num_leaves for init_tree, _, (_, _, kind) + in attrs_tracked if kind is pe.ReadWrite) + num_extra_args = len(implicit_args) + num_attrs_in + len(consts) in_shardings_flat = (UNSPECIFIED,) * num_extra_args + in_shardings_flat in_layouts_flat = (None,) * num_extra_args + in_layouts_flat donated_invars = (False,) * num_extra_args + donated_invars assert (len(in_shardings_flat) == len(in_layouts_flat) == - len(donated_invars) == num_states_in + len(consts) + len(args_flat)) + len(donated_invars) == num_attrs_in + len(consts) + len(args_flat)) params = dict( jaxpr=jaxpr, @@ -1274,7 +1285,7 @@ def _create_pjit_jaxpr( attr_data: int, ignored_inline: IgnoreKey ) -> tuple[core.ClosedJaxpr, list[Any], list[core.AbstractValue], - list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]: + list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str, Any]]]]: util.test_event("create_pjit_jaxpr") del ignored_inline # just for explain_cache_miss if config.no_tracing.value: @@ -1350,32 +1361,31 @@ def seen_attrs_get( assert fun.in_type is None or fun.in_type == in_type return cache[(fun.transforms, fun.params, in_type)] -def _attr_token( +def _attr_cache_index( fun: lu.WrappedFun, in_type: core.InputType | tuple[core.AbstractValue, ...] ) -> int: - from jax.experimental.attrs import jax_getattr, dne_sentinel + from jax.experimental.attrs import dne_sentinel cases = seen_attrs_get(fun, in_type) for i, records in enumerate(cases): - for obj, attr, treedef, avals in records: - val = jax_getattr(obj, attr) if hasattr(obj, attr) else dne_sentinel - vals, treedef_ = tree_flatten(val) - avals_ = map(core.shaped_abstractify, vals) - if treedef != treedef_ or avals != avals_: break + for obj, attr, kind, treedef, avals in records: + if kind is pe.ReadWrite: + val = getattr(obj, attr, dne_sentinel) + vals, treedef_ = tree_flatten(val) + avals_ = map(core.shaped_abstractify, vals) + if treedef != treedef_ or avals != avals_: break else: return i return len(cases) -def _attr_update(fun, in_type, i, attrs_tracked): - from jax.experimental.attrs import jax_getattr, dne_sentinel - leaves = lambda obj, attr: tree_leaves(jax_getattr(obj, attr) if hasattr(obj, attr) else dne_sentinel) - records = [(obj, attr, init_tree, map(core.shaped_abstractify, leaves(obj, attr))) - for init_tree, _, (obj, attr) in attrs_tracked] +def _attr_cachedata_update(fun, in_type, i, attrs_tracked): + from jax.experimental.attrs import dne_sentinel + leaves = lambda obj, attr: tree_leaves(getattr(obj, attr, dne_sentinel)) + records = [(obj, attr, kind, init_tree, map(core.typeof, leaves(obj, attr))) + for init_tree, _, (obj, attr, kind) in attrs_tracked] cases = seen_attrs_get(fun, in_type) if i == len(cases): cases.append(records) - else: - assert i < len(cases) and cases[i] == records @dataclasses.dataclass(frozen=True) @@ -1540,6 +1550,7 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding] committed_arg_shardings.append((arg_s, pxla.MismatchType.ARG_SHARDING, None)) resolved_in_shardings: list[PjitSharding] = [] + assert len(args) == len(pjit_in_shardings) for arg, pjit_in_s in zip(args, pjit_in_shardings): # arg sharding can be None in case of ShapeDtypeStruct. jax.Array does # not allow None as the sharding. @@ -2337,11 +2348,12 @@ def prune_type(ty, xs, maybe_zeros): if attrs_tracked: init_states = _get_states(attrs_tracked) + num_attr_outs = sum(final_tree.num_leaves for _, final_tree, _ in attrs_tracked) primals_and_nz_cts_in = [*init_states, *primals_and_nz_cts_in] - transpose_in_shardings = (UNSPECIFIED,) * len(attrs_tracked) + transpose_in_shardings - transpose_out_shardings = (UNSPECIFIED,) * len(attrs_tracked) + transpose_out_shardings - transpose_in_layouts = (None,) * len(attrs_tracked) + transpose_in_layouts - transpose_out_layouts = (None,) * len(attrs_tracked) + transpose_out_layouts + transpose_in_shardings = (UNSPECIFIED,) * len(init_states) + transpose_in_shardings + transpose_out_shardings = (UNSPECIFIED,) * num_attr_outs + transpose_out_shardings + transpose_in_layouts = (None,) * len(init_states) + transpose_in_layouts + transpose_out_layouts = (None,) * num_attr_outs + transpose_out_layouts try: nz_cts_out = pjit_p.bind( @@ -2370,7 +2382,7 @@ def prune_type(ty, xs, maybe_zeros): dispatch._raise_no_nan_in_deoptimized(e) if attrs_tracked: - final_states, nz_cts_out = split_list(nz_cts_out, [len(init_states)]) + final_states, nz_cts_out = split_list(nz_cts_out, [num_attr_outs]) _set_states(attrs_tracked, final_states) return tree_unflatten(cts_out_treedef, nz_cts_out) diff --git a/jax/experimental/attrs.py b/jax/experimental/attrs.py index bb4c7bf83b3f..0d40938a85c4 100644 --- a/jax/experimental/attrs.py +++ b/jax/experimental/attrs.py @@ -16,6 +16,7 @@ from typing import Any, Callable +import jax from jax._src import core from jax._src import source_info_util from jax._src import api_util @@ -32,20 +33,31 @@ map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip +Array = Any JaxVal = Any Pytree = Any +ReadWrite = pe.ReadWrite +Append = pe.Append + register = api_util.register_class_with_attrs dne_sentinel = pe.dne_sentinel -def jax_getattr(obj: Any, attr: str): +def jax_getattr(obj: Any, attr: str) -> Pytree: with core.take_current_trace() as t: return t.process_getattr(obj, attr) -def jax_setattr(obj: Any, attr: str, val: Pytree): +def jax_setattr(obj: Any, attr: str, val: Pytree) -> None: with core.take_current_trace() as t: return t.process_setattr(obj, attr, val) +def jax_appendattr(obj: Any, attr: str, val: Array) -> None: + return jax_extendattr(obj, attr, jax.numpy.expand_dims(val, 0)) + +def jax_extendattr(obj: Any, attr: str, val: Array) -> None: + with core.take_current_trace() as t: + return t.process_extendattr(obj, attr, val) + def _getattr_impl(_, obj, attr): return getattr(obj, attr) core.EvalTrace.process_getattr = _getattr_impl @@ -54,6 +66,25 @@ def _setattr_impl(_, obj, attr, val): setattr(obj, attr, val) core.EvalTrace.process_setattr = _setattr_impl +def _extendattr_impl(_, obj, attr, val): + cur = getattr(obj, attr, dne_sentinel) + if cur is dne_sentinel: + new = val + else: + _check_append_type_agreement(obj, attr, core.typeof(cur), core.typeof(val)) + new = jax.numpy.concatenate([cur, val]) + setattr(obj, attr, new) +core.EvalTrace.process_extendattr = _extendattr_impl + +def _check_append_type_agreement(_, attr, curtype, valtype): + expected = core.mapped_aval(curtype.shape[0], 0, curtype) + got = core.mapped_aval(valtype.shape[0], 0, valtype) + if not core.typematch(expected, got): + raise TypeError( + f"can only append to attr {attr} with values of trailing shape " + f"{expected.str_short()}, but appendattr got value of type " + f"{valtype.str_short()} which has trailing shape {got.str_short()}.") + def _ensure_tracked(trace: pe.DynamicJaxprTrace, obj: Any, attr: str): frame = trace.frame @@ -65,13 +96,16 @@ def new_tracer(x): frame.tracers.append(tracer) return tracer - if (obj, attr) not in frame.attrs_tracked: + if (obj, attr, Append) in frame.attrs_tracked: + raise TypeError(f"can't read/write to append-only attr {attr}") + + if (obj, attr, ReadWrite) not in frame.attrs_tracked: init_val = getattr(obj, attr, dne_sentinel) frame.attrs_inits.append(init_val) init_vals, init_tree = tree_flatten(init_val) tracers = map(new_tracer, init_vals) setattr(obj, attr, tree_unflatten(init_tree, tracers)) - frame.attrs_tracked.append((obj, attr)) + frame.attrs_tracked.append((obj, attr, ReadWrite)) pe.DynamicJaxprTrace._ensure_tracked = _ensure_tracked def _getattr_staging(trace, obj, attr): @@ -84,6 +118,27 @@ def _setattr_staging(trace, obj, attr, val): setattr(obj, attr, val) pe.DynamicJaxprTrace.process_setattr = _setattr_staging +def _extendattr_staging(trace, obj, attr, val): + frame = trace.frame + + if (obj, attr, ReadWrite) in frame.attrs_tracked: + raise TypeError("can't append to read/write-only attr {attr}") + + first_write = (obj, attr, Append) not in frame.attrs_tracked + init_val = getattr(obj, attr, dne_sentinel) + if init_val is not dne_sentinel: + _check_append_type_agreement(obj, attr, core.typeof(init_val), core.typeof(val)) + if first_write: + frame.attrs_inits.append(init_val) + frame.attrs_tracked.append((obj, attr, Append)) + tracer = val + else: + assert init_val is not dne_sentinel + with core.set_current_trace(trace): + tracer = jax.numpy.concatenate([init_val, val]) + setattr(obj, attr, tracer) +pe.DynamicJaxprTrace.process_extendattr = _extendattr_staging + def jvp(f, primals, tangents, attr_tangents): attrs, attr_tangents = unzip2(((o, a), t) for o, a, t in attr_tangents) diff --git a/tests/attrs_test.py b/tests/attrs_test.py index 169df3712899..8cf64790311b 100644 --- a/tests/attrs_test.py +++ b/tests/attrs_test.py @@ -15,6 +15,7 @@ from __future__ import annotations from dataclasses import dataclass +import itertools as it from absl.testing import absltest from absl.testing import parameterized @@ -28,7 +29,7 @@ from jax._src.util import safe_zip, safe_map from jax.experimental import attrs -from jax.experimental.attrs import jax_setattr, jax_getattr +from jax.experimental.attrs import jax_setattr, jax_getattr, jax_appendattr config.parse_flags_with_absl() @@ -66,6 +67,19 @@ def double_it() -> None: double_it() self.assertEqual(thing.x, 16.0) + def test_setattr_doesnt_leak(self): + thing = Thing(1.0) + + @jax.jit + def f(x): + jax_setattr(thing, 'x', x) + raise Exception + + try: f(1.) + except: pass + self.assertNotIsInstance(thing.x, jax.core.Tracer) + + @parameterized.parameters([True, False]) def test_jit_basic_tree(self, jit: bool): thing = Thing((1.0, 2.0)) @@ -260,6 +274,26 @@ def body(_, __): double_it_10() self.assertAllClose(thing.x, 1024., check_dtypes=False) + @parameterized.parameters([True, False]) + def test_scan_basic_pytree(self, jit): + class Thing: ... + thing = Thing() + thing.x = (1.0, 1.0) + + def double_it_10(): + def body(_, __): + cur_x, _ = jax_getattr(thing ,"x") + jax_setattr(thing, "x", (cur_x * 2.0, 3.0)) + return None, None + _, _ = jax.lax.scan(body, None, None, length=10) + + if jit: + double_it_10 = jax.jit(double_it_10) + + double_it_10() + self.assertAllClose(thing.x[0], 1024., check_dtypes=False) + self.assertAllClose(thing.x[1], 3., check_dtypes=False) + def test_scan_basic_consts_and_args(self): thing = Thing(1.0) @@ -402,6 +436,184 @@ def f(x): jax.make_jaxpr(f)(3.) self.assertFalse(hasattr(thing, 'x')) + tracing_ok = True + f(0.0) + self.assertAllClose(thing.x, 0.) + tracing_ok = False + f(1.0) + self.assertAllClose(thing.x, 1.) + + @parameterized.parameters(it.product([False, True], repeat=2)) + def test_appendattr_basic(self, jit, initialized): + class Thing: + ... + thing = Thing() + + if initialized: + thing.x = jnp.arange(0.) + + def f(x): + assert (not jit) or tracing_ok + jax_appendattr(thing, 'x', x) + jax_appendattr(thing, 'x', x + 1) + + if jit: + f = jax.jit(f) + + tracing_ok = True + f(0.0) + self.assertAllClose(thing.x, jnp.array([0., 1.])) + tracing_ok = False + f(2.0) + self.assertAllClose(thing.x, jnp.array([0., 1., 2., 3.])) + f(4.0) + self.assertAllClose(thing.x, jnp.array([0., 1., 2., 3., 4., 5.])) + + @parameterized.parameters(it.product([False, True], repeat=2)) + def test_appendattr_constant(self, jit, initialized): + class Thing: ... + thing = Thing() + + if initialized: + thing.x = jnp.arange(0.) + + def f(): + assert (not jit) or tracing_ok + jax_appendattr(thing, 'x', 0.0) + jax_appendattr(thing, 'x', 1.0) + + if jit: + f = jax.jit(f) + + tracing_ok = True + f() + self.assertAllClose(thing.x, jnp.array([0., 1.])) + tracing_ok = False + f() + self.assertAllClose(thing.x, jnp.array([0., 1., 0., 1.])) + + @parameterized.parameters([True, False]) + def test_appendattr_getattr_errors(self, initialized): + class Thing: ... + thing = Thing() + + if initialized: + thing.x = jnp.arange(0.) + + @jax.jit + def f(x): + jax_appendattr(thing, 'x', x) + jax_getattr(thing, 'x') + + with self.assertRaisesRegex(TypeError, "can't read/write"): + f(1.0) + + @jax.jit + def g(x): + jax_setattr(thing, 'x', x) + jax_appendattr(thing, 'x', x) + + with self.assertRaisesRegex(TypeError, "can't append"): + g(1.0) + + if initialized: + self.assertNotIsInstance(thing.x, jax.core.Tracer) + else: + self.assertFalse(hasattr(thing, 'x')) + + @parameterized.parameters(it.product([False, True], repeat=2)) + def test_appendattr_dtype_disagreement(self, jit, initialized): + class Thing: ... + thing = Thing() + + if initialized: + thing.x = jnp.array([], 'float32') + + def f(x): + jax_appendattr(thing, 'x', x) + jax_appendattr(thing, 'x', x.astype('complex64')) + + if jit: + f = jax.jit(f) + + msg = "can only append to attr x with values of trailing shape " + msg += "float32" if initialized else "int32" + with self.assertRaisesRegex(TypeError, msg): + f(jnp.array(1, 'int32')) + + @parameterized.parameters(it.product([False, True], repeat=2)) + def test_appendattr_shape_disagreement(self, jit, initialized): + class Thing: ... + thing = Thing() + + if initialized: + thing.x = jnp.array([]) + + def f(x): + jax_appendattr(thing, 'x', x) + jax_appendattr(thing, 'x', jnp.stack([x, x])) + + if jit: + f = jax.jit(f) + + msg = "can only append to attr x with values of trailing shape" + with self.assertRaisesRegex(TypeError, msg): + f(1) + + @parameterized.parameters(it.product([False, True], repeat=2)) + def test_appendattr_scan(self, jit, initialized): + class Thing: ... + thing = Thing() + + if initialized: + thing.x = jnp.array([]) + + def f(): + def body(c, x): + jax_appendattr(thing, 'x', 2 * x) + jax_appendattr(thing, 'x', 2 * x + 1) + return c, () + _, () = jax.lax.scan(body, 0, jnp.arange(3.)) + + if jit: + f = jax.jit(f) + + f() + + self.assertAllClose(thing.x, jnp.array([0., 1., 2., 3., 4., 5.])) + + @parameterized.parameters(it.product([False, True], repeat=2)) + def test_appendattr_scan_vjp(self, jit, initialized): + class Thing: ... + thing = Thing() + + if initialized: + thing.y_bar = jnp.array([]) + + def f(x): + def body(c, _): + return 0.5 * g(2 * c), () + y, _ = jax.lax.scan(body, x, (), length=5) + return y + + if jit: + f = jax.jit(f) + + @jax.custom_vjp + def g(x): + return x + + def g_fwd(x): + return g(x), None + + def g_bwd(_, y_bar): + jax_appendattr(thing, 'y_bar', y_bar) + return y_bar, + + g.defvjp(g_fwd, g_bwd) + jax.grad(f)(3.) + + self.assertAllClose(thing.y_bar, jnp.array([0.5] * 5)) class AttrsJVPTest(jtu.JaxTestCase): @@ -543,6 +755,7 @@ def g_ref(x, x_dot, y, y_dot): self.assertAllClose(w_ddot, w_ddot_, check_dtypes=False) self.assertAllClose(z_ddot, z_ddot_, check_dtypes=False) + class AttrsLinTest(jtu.JaxTestCase): @parameterized.parameters([True, False]) From eb54cd2c6109fafb52894cc1f2d687cb3a25fb4d Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 28 Mar 2025 15:22:48 -0700 Subject: [PATCH 274/483] Remove GPU-specific dependencies from backend-independent tests. The GPU-specific deps were added to the backend-independent tests by mistake [here](https://github.com/jax-ml/jax/pull/27113). These tests should pass using `jax` and `jaxlib` wheels only. PiperOrigin-RevId: 741663266 --- jaxlib/jax.bzl | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 560db85d6a1e..1cc4fab12591 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -255,8 +255,8 @@ def if_building_jaxlib( "//conditions:default": [], }) -def _get_test_deps(deps): - jaxlib_build_deps = [ +def _get_test_deps(deps, backend_independent): + gpu_build_deps = [ "//jaxlib/cuda:gpu_only_test_deps", "//jaxlib/rocm:gpu_only_test_deps", "//jax_plugins:gpu_plugin_only_test_deps", @@ -273,12 +273,21 @@ def _get_test_deps(deps): "//jaxlib/tools:jaxlib_py_import", ] + if backend_independent: + jaxlib_build_deps = deps + gpu_pypi_wheel_deps = _CPU_PYPI_WHEEL_DEPS + gpu_py_import_deps = cpu_py_imports + else: + jaxlib_build_deps = gpu_build_deps + deps + gpu_pypi_wheel_deps = _GPU_PYPI_WHEEL_DEPS + gpu_py_import_deps = gpu_py_imports + return select({ - "//jax:enable_jaxlib_build": jaxlib_build_deps + deps, + "//jax:enable_jaxlib_build": jaxlib_build_deps, "//jax_plugins/cuda:disable_jaxlib_for_cpu_build": _CPU_PYPI_WHEEL_DEPS, - "//jax_plugins/cuda:disable_jaxlib_for_cuda12_build": _GPU_PYPI_WHEEL_DEPS, + "//jax_plugins/cuda:disable_jaxlib_for_cuda12_build": gpu_pypi_wheel_deps, "//jax_plugins/cuda:enable_py_import_for_cpu_build": cpu_py_imports, - "//jax_plugins/cuda:enable_py_import_for_cuda12_build": gpu_py_imports, + "//jax_plugins/cuda:enable_py_import_for_cuda12_build": gpu_py_import_deps, }) # buildifier: disable=function-docstring @@ -334,7 +343,7 @@ def jax_multiplatform_test( deps = _get_test_deps([ "//jax", "//jax:test_util", - ] + deps), + ] + deps, backend_independent = False), data = data, shard_count = test_shards, tags = test_tags, @@ -629,15 +638,15 @@ def jax_py_test( if "PYTHONWARNINGS" not in env: env["PYTHONWARNINGS"] = "error" deps = kwargs.get("deps", []) - kwargs.pop("deps") - test_deps = _get_test_deps(deps) - py_test(name = name, env = env, deps = test_deps, **kwargs) + test_deps = _get_test_deps(deps, backend_independent = True) + kwargs["deps"] = test_deps + py_test(name = name, env = env, **kwargs) def pytype_test(name, **kwargs): deps = kwargs.get("deps", []) - kwargs.pop("deps") - test_deps = _get_test_deps(deps) - native.py_test(name = name, deps = test_deps, **kwargs) + test_deps = _get_test_deps(deps, backend_independent = True) + kwargs["deps"] = test_deps + native.py_test(name = name, **kwargs) def if_oss(oss_value, google_value = []): """Returns one of the arguments based on the non-configurable build env. From 93c6bb72d3f550991969bb7bd13c4d5e0fbc46ae Mon Sep 17 00:00:00 2001 From: Zac Cranko Date: Thu, 20 Mar 2025 15:31:16 -0700 Subject: [PATCH 275/483] add discord release action Update community_release_actions.yml --- .../workflows/community_release_actions.yml | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 .github/workflows/community_release_actions.yml diff --git a/.github/workflows/community_release_actions.yml b/.github/workflows/community_release_actions.yml new file mode 100644 index 000000000000..d61bea3d7e4d --- /dev/null +++ b/.github/workflows/community_release_actions.yml @@ -0,0 +1,31 @@ +name: Release Actions + +on: + release: + types: [published] + +jobs: + discord_release: + if: github.repository_owner == 'jax-ml' + runs-on: ubuntu-latest + steps: + - name: Get release URL + id: get-release-url + run: | + URL="https://docs.jax.dev/en/latest/changelog.html" + echo "::set-output name=URL::$URL" + - name: Get content + uses: 2428392/gh-truncate-string-action@b3ff790d21cf42af3ca7579146eedb93c8fb0757 # v1.4.1 + id: get-content + with: + stringToTruncate: | + JAX [${{ github.event.release.tag_name }}](<${{ steps.get-release-url.outputs.URL }}>) was just released! + + ${{ github.event.release.body }} + maxLength: 2000 + truncationSymbol: "..." + - name: Discord Webhook Action + uses: tsickert/discord-webhook@c840d45a03a323fbc3f7507ac7769dbd91bfb164 # v5.3.0 + with: + webhook-url: ${{ secrets.DISCORD_WEBHOOK_URL }} + content: ${{ steps.get-content.outputs.string }} From 123ce5221b298d551e168029e12e4b53147b206c Mon Sep 17 00:00:00 2001 From: jeffcarp Date: Thu, 27 Feb 2025 15:51:00 -0800 Subject: [PATCH 276/483] Add scalar event logging function --- jax/_src/monitoring.py | 43 ++++++++++++++++++++++++++++++++++++++++ jax/monitoring.py | 2 ++ tests/monitoring_test.py | 28 ++++++++++++++++++++++++-- 3 files changed, 71 insertions(+), 2 deletions(-) diff --git a/jax/_src/monitoring.py b/jax/_src/monitoring.py index 99e957733ba2..de706ccbaef5 100644 --- a/jax/_src/monitoring.py +++ b/jax/_src/monitoring.py @@ -46,10 +46,18 @@ def __call__( ) -> None: ... +class ScalarListenerWithMetadata(Protocol): + + def __call__( + self, event: str, value: float | int, **kwargs: str | int, + ) -> None: + ... + _event_listeners: list[EventListenerWithMetadata] = [] _event_duration_secs_listeners: list[EventDurationListenerWithMetadata] = [] _event_time_span_listeners: list[EventTimeSpanListenerWithMetadata] = [] +_scalar_listeners: list[ScalarListenerWithMetadata] = [] def record_event(event: str, **kwargs: str | int) -> None: @@ -81,6 +89,14 @@ def record_event_time_span( callback(event, start_time, end_time, **kwargs) +def record_scalar( + event: str, value: float | int, **kwargs: str | int +) -> None: + """Record a scalar summary value.""" + for callback in _scalar_listeners: + callback(event, value, **kwargs) + + def register_event_listener( callback: EventListenerWithMetadata, ) -> None: @@ -100,6 +116,14 @@ def register_event_duration_secs_listener( """Register a callback to be invoked during record_event_duration_secs().""" _event_duration_secs_listeners.append(callback) + +def register_scalar_listener( + callback : ScalarListenerWithMetadata, +) -> None: + """Register a callback to be invoked during record_scalar().""" + _scalar_listeners.append(callback) + + def get_event_duration_listeners() -> list[EventDurationListenerWithMetadata]: """Get event duration listeners.""" return list(_event_duration_secs_listeners) @@ -114,12 +138,20 @@ def get_event_listeners() -> list[EventListenerWithMetadata]: """Get event listeners.""" return list(_event_listeners) + +def get_scalar_listeners() -> list[ScalarListenerWithMetadata]: + """Get scalar event listeners.""" + return list(_scalar_listeners) + + def clear_event_listeners(): """Clear event listeners.""" global _event_listeners, _event_duration_secs_listeners, _event_time_span_listeners _event_listeners = [] _event_duration_secs_listeners = [] _event_time_span_listeners = [] + _scalar_listeners = [] + def _unregister_event_duration_listener_by_callback( callback: EventDurationListenerWithMetadata) -> None: @@ -159,3 +191,14 @@ def _unregister_event_listener_by_callback( """ assert callback in _event_listeners _event_listeners.remove(callback) + + +def _unregister_scalar_listener_by_callback( + callback: ScalarListenerWithMetadata, +) -> None: + """Unregister a scalar event listener by callback. + + This function is supposed to be called for testing only. + """ + assert callback in _scalar_listeners + _scalar_listeners.remove(callback) diff --git a/jax/monitoring.py b/jax/monitoring.py index 4c9996da582c..f4ab8124f219 100644 --- a/jax/monitoring.py +++ b/jax/monitoring.py @@ -26,7 +26,9 @@ record_event_duration_secs as record_event_duration_secs, record_event_time_span as record_event_time_span, record_event as record_event, + record_scalar as record_scalar, register_event_duration_secs_listener as register_event_duration_secs_listener, register_event_listener as register_event_listener, register_event_time_span_listener as register_event_time_span_listener, + register_scalar_listener as register_scalar_listener, ) diff --git a/tests/monitoring_test.py b/tests/monitoring_test.py index 52b53895c2cc..89c7148a2a42 100644 --- a/tests/monitoring_test.py +++ b/tests/monitoring_test.py @@ -29,7 +29,7 @@ def tearDown(self): def test_record_event(self): events = [] - counters = {} # Map event names to frequency. + counters = {} # Map event names to frequency. def increment_event_counter(event): if event not in counters: counters[event] = 0 @@ -48,7 +48,7 @@ def increment_event_counter(event): "test_common_event": 2}) def test_record_event_durations(self): - durations = {} # Map event names to frequency. + durations = {} # Map event names to frequency. def increment_event_duration(event, duration): if event not in durations: durations[event] = 0. @@ -62,6 +62,30 @@ def increment_event_duration(event, duration): self.assertDictEqual(durations, {"test_short_event": 3, "test_long_event": 10}) + def test_record_scalar(self): + observed_keys = [] + observed_values = [] + + monitoring.register_scalar_listener( + lambda key, _: observed_keys.append(key), + ) + monitoring.register_scalar_listener( + lambda _, value: observed_values.append(value), + ) + + monitoring.record_scalar("test_unique_event", 1) + monitoring.record_scalar("test_common_event", 2.5) + monitoring.record_scalar("test_common_event", 5e5) + + self.assertListEqual( + observed_keys, + ["test_unique_event", "test_common_event", "test_common_event"], + ) + self.assertListEqual( + observed_values, + [1, 2.5, 5e5], + ) + def test_unregister_exist_callback_success(self): original_duration_listeners = jax_src_monitoring.get_event_duration_listeners() callback = lambda event, durations: None From 80061ad4c433e419961c7c6d40e3d0e5bc4d24b4 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 28 Mar 2025 16:54:23 -0700 Subject: [PATCH 277/483] Add vma rules for pmin and pmax PiperOrigin-RevId: 741685454 --- jax/_src/lax/parallel.py | 48 +++++++++++++++++++++++++++++++++-- jax/experimental/shard_map.py | 37 +++------------------------ tests/shard_map_test.py | 10 ++++++++ 3 files changed, 59 insertions(+), 36 deletions(-) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 3ef0a2520378..8fc8c336d61a 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -203,6 +203,7 @@ def pmax(x, axis_name, *, axis_index_groups=None): _validate_reduce_axis_index_groups(axis_index_groups) leaves, treedef = tree_util.tree_flatten(x) axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups) + leaves = map(partial(insert_collective_pbroadcast, axis_name), leaves) out_flat = pmax_p.bind(*leaves, axes=axis_name, axis_index_groups=axis_index_groups) return tree_util.tree_unflatten(treedef, out_flat) @@ -233,6 +234,7 @@ def pmin(x, axis_name, *, axis_index_groups=None): _validate_reduce_axis_index_groups(axis_index_groups) leaves, treedef = tree_util.tree_flatten(x) axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups) + leaves = map(partial(insert_collective_pbroadcast, axis_name), leaves) out_flat = pmin_p.bind(*leaves, axes=axis_name, axis_index_groups=axis_index_groups) return tree_util.tree_unflatten(treedef, out_flat) @@ -803,6 +805,48 @@ def _allreduce_effectful_abstract_eval(*args, axes, axis_index_groups): ] return out_avals, {core.NamedAxisEffect(axis) for axis in named_axes} +def _psum2_abstract_eval(name, *args, axes, axis_index_groups): + if not config.varying_axes_in_types.value: + return psum_p.abstract_eval( + *args, axes=axes, axis_index_groups=axis_index_groups) + + assert isinstance(axes, tuple) + _check_axis_names(axes) + arg_vma = [a.vma for a in args] + # If intersection between arg_vma and axes is empty, error + if any(not set(axes) & a for a in arg_vma): + raise ValueError( + f"Collective {name} must be applied to a device-varying " + f"type, but got {arg_vma} for collective acting " + f"over axis name {axes}. Please open an issue at " + "https://github.com/jax-ml/jax/issues, and as a temporary " + "workaround pass the check_rep=False argument to shard_map") + + named_axes = tuple(axis for axis in axes if not isinstance(axis, int)) + pos_axes = tuple(axis for axis in axes if isinstance(axis, int)) + if axis_index_groups is not None: + if len(pos_axes) != 0: + raise ValueError( + "axis_index_groups can only be used with reductions over " + f"named axes, but got: {axes}") + core.check_avals_context_mesh(args, 'all_reduce') + out_avals = [ + core.ShapedArray( + lax._reduce_op_shape_rule(arg, axes=pos_axes), arg.dtype, + sharding=lax._reduce_op_sharding_rule(arg, axes=pos_axes), + vma=frozenset(a for a in arg.vma if a not in named_axes)) + for arg in args + ] + return out_avals, {core.NamedAxisEffect(axis) for axis in named_axes} + +# TODO(yashkatariya): Replace this with _psum2_abstract_eval +def _pmin_pmax_abstract_eval(name, *args, axes, axis_index_groups): + if not config.varying_axes_in_types.value: + return _allreduce_effectful_abstract_eval( + *args, axes=axes, axis_index_groups=axis_index_groups) + return _psum2_abstract_eval(name, *args, axes=axes, + axis_index_groups=axis_index_groups) + def _check_axis_names(axes): named_axes = tuple(axis for axis in axes if not isinstance(axis, int)) axis_env = core.get_axis_env() @@ -902,7 +946,7 @@ def broadcast_positional(ct, arg): pmax_p = core.Primitive('pmax') pmax_p.multiple_results = True pmax_p.def_impl(partial(_allreduce_impl, pmax_p, lax.reduce_max)) -pmax_p.def_effectful_abstract_eval(_allreduce_effectful_abstract_eval) +pmax_p.def_effectful_abstract_eval(partial(_pmin_pmax_abstract_eval, 'pmax')) mlir.register_lowering( pmax_p, partial(_allreduce_lowering, lax.max_p, lax.reduce_max)) batching.fancy_primitive_batchers[pmax_p] = \ @@ -913,7 +957,7 @@ def broadcast_positional(ct, arg): pmin_p = core.Primitive('pmin') pmin_p.multiple_results = True pmin_p.def_impl(partial(_allreduce_impl, pmin_p, lax.reduce_min)) -pmin_p.def_effectful_abstract_eval(_allreduce_effectful_abstract_eval) +pmin_p.def_effectful_abstract_eval(partial(_pmin_pmax_abstract_eval, 'pmin')) mlir.register_lowering( pmin_p, partial(_allreduce_lowering, lax.min_p, lax.reduce_min)) batching.fancy_primitive_batchers[pmin_p] = \ diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 44c2b569f947..8e2d93af2639 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -1072,40 +1072,8 @@ def _device_put_eager_rule(mesh, *xs, srcs, devices, copy_semantics): psum2_p = core.Primitive('psum2') psum2_p.multiple_results = True psum2_p.def_impl(lax_parallel.psum_p.impl) - -def _psum2_abstract_eval(*args, axes, axis_index_groups): - if not config.varying_axes_in_types.value: - return lax_parallel.psum_p.abstract_eval( - *args, axes=axes, axis_index_groups=axis_index_groups) - - assert isinstance(axes, tuple) - lax_parallel._check_axis_names(axes) - arg_vma = [a.vma for a in args] - if any(not set(axes) & a for a in arg_vma): - raise ValueError( - "Collective psum must be applied to a device-varying " - f"type, but got {arg_vma} for collective acting " - f"over axis name {axes}. Please open an issue at " - "https://github.com/jax-ml/jax/issues, and as a temporary " - "workaround pass the check_rep=False argument to shard_map") - - named_axes = tuple(axis for axis in axes if not isinstance(axis, int)) - pos_axes = tuple(axis for axis in axes if isinstance(axis, int)) - if axis_index_groups is not None: - if len(pos_axes) != 0: - raise ValueError( - "axis_index_groups can only be used with reductions over " - f"named axes, but got: {axes}") - core.check_avals_context_mesh(args, 'all_reduce') - out_avals = [ - core.ShapedArray( - lax._reduce_op_shape_rule(arg, axes=pos_axes), arg.dtype, - sharding=lax._reduce_op_sharding_rule(arg, axes=pos_axes), - vma=frozenset(a for a in arg.vma if a not in named_axes)) - for arg in args - ] - return out_avals, {core.NamedAxisEffect(axis) for axis in named_axes} -psum2_p.def_effectful_abstract_eval(_psum2_abstract_eval) +psum2_p.def_effectful_abstract_eval( + partial(lax_parallel._psum2_abstract_eval, psum2_p.name)) mlir.register_lowering(psum2_p, mlir._lowerings[lax_parallel.psum_p]) batching.fancy_primitive_batchers[psum2_p] = \ @@ -1135,6 +1103,7 @@ def _pbroadcast_abstract_eval(*args, axes, axis_index_groups): return args assert isinstance(axes, tuple) arg_vma = [a.vma for a in args] + # If there is intersection between arg_vma and axes, error if any(set(axes) & a for a in arg_vma): raise ValueError( "Collective pbroadcast must be applied to a " diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index c1923f5b0ae3..36966fde2a90 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -2685,6 +2685,16 @@ def test_pmax(self): )(x) # don't crash self.assertArraysEqual(y, np.array([6, 7], dtype=np.float32)) + @config.varying_axes_in_types(True) + def test_pmax_vma_in_types(self): + mesh = jtu.create_mesh((4,), ('i',)) + x = jnp.arange(8., dtype=np.float32) + f = jax.jit(shard_map(lambda x: jax.lax.pmax(x, 'i'), mesh=mesh, + in_specs=P(), out_specs=P())) + jaxpr = f.trace(x).jaxpr + self.assertIn("pbroadcast[axes=('i',)", str(jaxpr)) + f(x) # doesn't crash + @config.varying_axes_in_types(True) def test_mul_with_vma_in_types(self): mesh = jtu.create_mesh((2,), ('x',)) From 7ca50844f3d66ab2b158e22b76fcc62e4406f867 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 28 Mar 2025 21:42:42 -0700 Subject: [PATCH 278/483] Fix an edge-case in reshape sharding rule where the last splitting/merging dim was `1`. PiperOrigin-RevId: 741740811 --- jax/_src/lax/lax.py | 7 ++++++- tests/pjit_test.py | 14 ++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index c2f0876ce932..fd956136ccd3 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -6977,11 +6977,16 @@ def _split_on_one_axis(op_shape, new_sizes, name): ' the sharding of the output via the `sharding` argument of' f' jax.lax.reshape. Got operand.shape={op_shape} and {new_sizes=}') temp = [new_sizes[j]] - while math.prod(temp) != op_shape[i]: + next_j = j + 1 + while (math.prod(temp) != op_shape[i] or + (next_j < len(new_sizes) and new_sizes[next_j] == 1)): if math.prod(temp) > op_shape[i]: return False, [] j += 1 + if j >= len(new_sizes): + return False, [] temp.append(new_sizes[j]) + next_j += 1 out.append(temp) i += 1 j += 1 diff --git a/tests/pjit_test.py b/tests/pjit_test.py index b49ba19c72dc..0b2daee8ccff 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -5508,6 +5508,18 @@ def h2(x, y): ('4', (1, 4, 1, 6, 1), (1, 4, 6), P(None, 'x', None, None, None), P(None, 'x', None), False), ('5', (4, 6), (4, 6), P(None, 'x'), P(None, 'x'), False), + ('6', (1024, 4096), (1024, 2048, 2, 1, 1, 1, 1), + P('x', None), P('x', None, None, None, None, None, None), False), + ('7', (1024, 4096, 32), (1024, 2048, 2, 1, 1, 32), + P('x', None, None), P('x', None, None, None, None, None), False), + ('8', (1024, 4096), (1024, 1, 1, 4096), + P('x', None), P('x', None, None, None), False), + ('9', (1024, 4096), (1024, 1, 1, 4096), + P(None, 'x'), P(None, None, None, 'x'), False), + ('10', (1024, 2048, 2, 1, 1, 1), (1024, 4096), + P('x', None, None, None, None, None), P('x', None), False), + ('11', (1024, 2048, 2, 1, 1, 1), (1024, 4096), + P(None, 'x', None, None, None, None), P(None, 'x'), False), ) @jtu.with_user_mesh((2,), ('x',)) def test_reshape(self, src_shape, dst_shape, src_spec, dst_spec, @@ -5519,6 +5531,8 @@ def test_reshape(self, src_shape, dst_shape, src_spec, dst_spec, @partial(jax.jit, static_argnums=1) def f(x, new_sharding): y = lax.reshape(x, dst_shape, out_sharding=new_sharding) + self.assertEqual(y.aval.sharding.spec, dst_spec) + self.assertEqual(y.shape, dst_shape) y = y * 2 self.assertEqual(y.aval.sharding.spec, dst_spec) return y From e7ec418eba9ada336f755613948cbdf4a9e97d59 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sat, 29 Mar 2025 05:19:04 -0700 Subject: [PATCH 279/483] Update XLA dependency to use revision http://github.com/openxla/xla/commit/f50746ab3144d0bf59c8e5c2dcfb2e09e56338d0. PiperOrigin-RevId: 741809075 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 43bba2fcc903..f0a33d4c5e55 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "edfd919df316d687b2d3b08bbc8d9c32f4bcc1c4" -XLA_SHA256 = "d82a7174a8a129180b180b08f5eedfa5fe6ff19fbd46dc11dae8cf64d87dfbf9" +XLA_COMMIT = "f50746ab3144d0bf59c8e5c2dcfb2e09e56338d0" +XLA_SHA256 = "e4935a201c105a705d2a26c718663f9a7073f8a1d337c0e7eb885e2e2480797d" def repo(): tf_http_archive( From 5fda4c1b0e49761d71ce4addec80cdc6b479d2e7 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 30 Mar 2025 04:43:56 -0700 Subject: [PATCH 280/483] Update XLA dependency to use revision http://github.com/openxla/xla/commit/8df9390dc9444d900c7c7f2c123f23b549adf8e3. PiperOrigin-RevId: 741998725 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index f0a33d4c5e55..8b3ddfde019b 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "f50746ab3144d0bf59c8e5c2dcfb2e09e56338d0" -XLA_SHA256 = "e4935a201c105a705d2a26c718663f9a7073f8a1d337c0e7eb885e2e2480797d" +XLA_COMMIT = "8df9390dc9444d900c7c7f2c123f23b549adf8e3" +XLA_SHA256 = "8e97c395d1e50a49fab386ccc7da1f78dc86bf670b20a892656e2e75bbf64f0e" def repo(): tf_http_archive( From a865b4e4370d1301325db64005b92aacbf4c8c7a Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Sun, 30 Mar 2025 10:50:05 -0700 Subject: [PATCH 281/483] [mgpu] Register the mosaic_gpu dialect regardless of warpgroup/lane lowering. In `mgpu.bitwidth()` mosaic_gpu types are being checked even in Lane lowering which fails. PiperOrigin-RevId: 742044332 --- jax/_src/pallas/mosaic_gpu/pallas_call_registration.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py index 6dc958edbc53..1d4be26187ce 100644 --- a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py @@ -61,8 +61,7 @@ def pallas_call_lowering( thread_semantics = compiler_params.get("mosaic_gpu", {}).get( "thread_semantics", mgpu.ThreadSemantics.Lane ) - if thread_semantics == mgpu.ThreadSemantics.Warpgroup: - mgpu.dialect.register_dialect(ctx.module_context.context) # pytype: disable=attribute-error + mgpu.dialect.register_dialect(ctx.module_context.context) # pytype: disable=attribute-error lowering_result = lowering.lower_pipelined_jaxpr_to_module( grid_mapping, From 0edd715e96850f0b2fd2fc13685fde1e426b603a Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Sun, 30 Mar 2025 16:11:49 -0700 Subject: [PATCH 282/483] [mgpu/pallas] Expose WGMMA_TRANSPOSED layout PiperOrigin-RevId: 742084936 --- jax/_src/pallas/mosaic_gpu/primitives.py | 4 ++++ jax/experimental/mosaic/gpu/__init__.py | 1 + 2 files changed, 5 insertions(+) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index a9bd91b26622..48a4cae62824 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -911,6 +911,7 @@ class Layout(enum.Enum): WGMMA_ROW = enum.auto() #: [n] matrix, where n % 8 == 0. WGMMA_COL = enum.auto() + WGMMA_TRANSPOSED = enum.auto() WG_SPLAT = enum.auto() WG_STRIDED = enum.auto() @@ -924,6 +925,9 @@ def check_no_args(): raise ValueError(f"Can't instantiate {self} with arguments.") match self: + case Layout.WGMMA_TRANSPOSED: + check_no_args() + return mgpu.WGMMA_TRANSPOSED_LAYOUT case Layout.WGMMA: check_no_args() return mgpu.WGMMA_LAYOUT diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index 867fd84b8b3c..afc87b5d96fa 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -55,6 +55,7 @@ WGMMA_LAYOUT as WGMMA_LAYOUT, WGMMA_ROW_LAYOUT as WGMMA_ROW_LAYOUT, WGMMA_COL_LAYOUT as WGMMA_COL_LAYOUT, + WGMMA_TRANSPOSED_LAYOUT as WGMMA_TRANSPOSED_LAYOUT, WGMMARowFragLayout as WGMMARowFragLayout, WGMMAColFragLayout as WGMMAColFragLayout, WGSplatFragLayout as WGSplatFragLayout, From 29bd01f8307449205a4894048927e6669dba2929 Mon Sep 17 00:00:00 2001 From: Amir Samani Date: Tue, 18 Mar 2025 17:21:40 -0700 Subject: [PATCH 283/483] add reduction support in copy_smem_to_gmem --- jax/_src/pallas/mosaic_gpu/primitives.py | 7 ++ jax/experimental/mosaic/gpu/launch_context.py | 77 ++++++++++++++++--- jaxlib/mosaic/gpu/runtime.cc | 48 ++++++++---- tests/pallas/mosaic_gpu_test.py | 23 ++++++ 4 files changed, 132 insertions(+), 23 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 48a4cae62824..e3f8e4c03f75 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -185,6 +185,7 @@ def _copy_smem_to_gmem_lowering( dst_transforms_treedef, has_user_predicate, commit_group, + reduction_op: Literal["add", "min", "max", "inc", "dec", "and", "or", "xor"] | None, ): if has_user_predicate: flat_args, user_predicate = flat_args[:-1], flat_args[-1] @@ -215,6 +216,7 @@ def _copy_smem_to_gmem_lowering( dst_ref=dst, predicate=predicate, arrive=commit_group, + reduction_op=reduction_op, **copy_params, ) return () @@ -293,6 +295,9 @@ def copy_smem_to_gmem( predicate: jax.Array | None = None, *, commit_group: bool = True, + reduction_op: Literal[ + "add","min","max","inc","dec","and","or","xor" + ] | None = None, ) -> None: """Asynchronously copies a SMEM reference to a GMEM reference. @@ -304,6 +309,7 @@ def copy_smem_to_gmem( commit_group: If ``True``, this and any previously uncommitted copies are committed to a group and can be awaited jointly via :func:`jax.experimental.mosaic.gpu.wait_smem_to_gmem`. + reduction_op: if set, perform the specified reduction op when copy to gmem See also: :func:`jax.experimental.mosaic.gpu.wait_smem_to_gmem` @@ -331,6 +337,7 @@ def copy_smem_to_gmem( dst_transforms_treedef=dst_transforms_treedef, has_user_predicate=predicate is not None, commit_group=commit_group, + reduction_op=reduction_op, ) return None diff --git a/jax/experimental/mosaic/gpu/launch_context.py b/jax/experimental/mosaic/gpu/launch_context.py index ce432f26dac2..41c15bc5492e 100644 --- a/jax/experimental/mosaic/gpu/launch_context.py +++ b/jax/experimental/mosaic/gpu/launch_context.py @@ -19,9 +19,10 @@ import enum import functools import math -from typing import Any +from typing import Any, Literal from jax._src.lib import mosaic_gpu_dialect as mgpu_dialect +from jax._src import lib as jaxlib from jaxlib.mlir import ir from jaxlib.mlir.dialects import arith from jaxlib.mlir.dialects import func @@ -309,6 +310,9 @@ def _get_tma_desc( gmem_transform: tuple[MemRefTransform, ...], transformed_slice_shape: tuple[int, ...], swizzle: int | None, + reduction_op: Literal[ + "add","min","max","inc","dec","and","or","xor" + ] | None, ): tma_desc_key = (gmem_ref, transformed_slice_shape, swizzle, gmem_transform) if (tma_desc := self.tma_descriptors.get(tma_desc_key, None)) is None: @@ -337,10 +341,38 @@ def init_tma_desc(host_ptr): ) # TODO(apaszke): Better verification (e.g. slice is non-zero) # TODO(apaszke): We always know strides statically. + if jaxlib.version < (0, 5, 4): + dtype_or_bitwidth = c(utils.bitwidth(ref_ty.element_type), i64) + else: + if isinstance(ref_ty.element_type, ir.IntegerType): + if reduction_op is not None: + raise ValueError( + f"TMA with reduction_op={reduction_op} is not supported with Integers" + ) + bitwidth = utils.bitwidth_impl(ref_ty.element_type) + if bitwidth == 4: + tma_dtype = 0 + elif bitwidth == 8: + tma_dtype = 1 + elif bitwidth == 16: + tma_dtype = 2 + elif bitwidth == 32: + tma_dtype = 3 + elif bitwidth == 64: + tma_dtype = 4 + elif ir.F16Type.isinstance(ref_ty.element_type): + tma_dtype = 5 + elif ir.F32Type.isinstance(ref_ty.element_type): + tma_dtype = 6 + elif ir.BF16Type.isinstance(ref_ty.element_type): + tma_dtype = 7 + else: + raise ValueError(f"unsupported TMA dtype {ref_ty.element_type}") + dtype_or_bitwidth = c(tma_dtype, i64) args = [ host_ptr, base_ptr, - c(utils.bitwidth(ref_ty.element_type), i64), + dtype_or_bitwidth, c(rank, i64), utils.pack_array([as_i64(i) for i in sizes_and_strides[:rank]]), utils.pack_array([as_i64(i) for i in sizes_and_strides[rank:]]), @@ -375,6 +407,9 @@ def async_copy( collective: Sequence[gpu.Dimension] | gpu.Dimension | None = None, partitioned: int | None = None, predicate: ir.Value | None = None, # Should select 0 or 1 threads from the WG. + reduction_op: Literal[ + "add","min","max","inc","dec","and","or","xor" + ] | None = None, ): """Initiates an async copy between GMEM and SMEM. @@ -453,6 +488,13 @@ def async_copy( " multiple of 16 bytes" ) + if reduction_op is not None and jaxlib.version < (0, 5, 4): + raise ValueError("TMA with reduction is only supported with jaxlib >= 0.5.4") + if reduction_op is not None and not isinstance(gmem_ref_ty.element_type, ir.FloatType): + raise ValueError("TMA with reduction is only supported with float dtype") + if reduction_op is not None and reduction_op != "add": + raise ValueError("TMA with reduction is only supported with add operation") + # NOTE: TMA supports OOB indices, so we skip the check. base_indices, slice_shape, is_squeezed = utils.parse_indices( gmem_slice, ir.MemRefType(gmem_ref.type).shape, check_oob=False @@ -597,7 +639,7 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int): multicast_mask = None tma_desc = self._get_tma_desc( - gmem_ref, gmem_transform, tuple(slice_shape), swizzle, + gmem_ref, gmem_transform, tuple(slice_shape), swizzle, reduction_op, ) # We constuct TMA descriptors in column-major order. @@ -641,6 +683,7 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int): ) barrier_ptr = barrier.get_ptr() with uniform_ctx(): + assert reduction_op is None if collective_size > 1 and partitioned is not None: if predicate is None: predicate = c(1, ir.IntegerType.get_signless(1)) @@ -679,12 +722,28 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int): ) else: assert multicast_mask is None - with uniform_ctx(): - nvvm.cp_async_bulk_tensor_global_shared_cta( - tma_desc, smem_ptr, rev_dyn_base_indices, predicate=predicate - ) - if arrive: - nvvm.cp_async_bulk_commit_group() + if reduction_op is not None: + with uniform_ctx(): + if predicate is None: + predicate = c(1, ir.IntegerType.get_signless(1)) + rank = len(slice_shape) + idx_operands = ",".join(f"${i}" for i in range(3, 3 + rank)) + llvm.inline_asm( + ir.Type.parse("!llvm.void"), + [predicate,smem_ptr,tma_desc,*rev_dyn_base_indices], + f"@$0 cp.reduce.async.bulk.tensor.{rank}d.global.shared::cta.{reduction_op}.tile.bulk_group [$2,{{{idx_operands}}}], [$1];", + "b,r,l" + ",r" * rank, + has_side_effects=True, + ) + if arrive: + nvvm.cp_async_bulk_commit_group() + else: + with uniform_ctx(): + nvvm.cp_async_bulk_tensor_global_shared_cta( + tma_desc, smem_ptr, rev_dyn_base_indices, predicate=predicate + ) + if arrive: + nvvm.cp_async_bulk_commit_group() def await_async_copy( self, allow_groups: int, await_read_only: bool = False diff --git a/jaxlib/mosaic/gpu/runtime.cc b/jaxlib/mosaic/gpu/runtime.cc index ad3cd0e19644..fd452e781c72 100644 --- a/jaxlib/mosaic/gpu/runtime.cc +++ b/jaxlib/mosaic/gpu/runtime.cc @@ -22,7 +22,7 @@ limitations under the License. extern "C" { void mosaic_gpu_init_tma_desc(CUtensorMap *tma_desc, void *base_addr, - int64_t elem_bitwidth, int64_t rank, + int64_t elem_type, int64_t rank, int64_t *sizes, int64_t *strides, int64_t swizzle_bytes, int64_t *window_shape) { if (((uintptr_t)tma_desc) % 64 != 0) { @@ -32,6 +32,39 @@ void mosaic_gpu_init_tma_desc(CUtensorMap *tma_desc, void *base_addr, abort(); } + CUtensorMapDataType data_type; + int64_t elem_bitwidth; + // types are defined in: LaunchContext._get_tma_desc() + if (elem_type == 0){ + // this is for int4s + data_type = CU_TENSOR_MAP_DATA_TYPE_UINT8; + elem_bitwidth = 4; + } else if (elem_type == 1){ + data_type = CU_TENSOR_MAP_DATA_TYPE_UINT8; + elem_bitwidth = 8; + } else if (elem_type == 2){ + data_type = CU_TENSOR_MAP_DATA_TYPE_UINT16; + elem_bitwidth = 16; + } else if (elem_type == 3){ + data_type = CU_TENSOR_MAP_DATA_TYPE_UINT32; + elem_bitwidth = 32; + } else if (elem_type == 4){ + data_type = CU_TENSOR_MAP_DATA_TYPE_UINT64; + elem_bitwidth = 64; + } else if (elem_type == 5){ + data_type = CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + elem_bitwidth = 16; + } else if (elem_type == 6){ + data_type = CU_TENSOR_MAP_DATA_TYPE_FLOAT32; + elem_bitwidth = 32; + } else if (elem_type == 7){ + data_type = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; + elem_bitwidth = 16; + } else{ + fprintf(stderr, "Unsupported element type: %ld \n", elem_type); + abort(); + } + // Pack 4 bit types in 8 bit pairs. int64_t elem_bytewidth; if (elem_bitwidth < 8) { @@ -54,19 +87,6 @@ void mosaic_gpu_init_tma_desc(CUtensorMap *tma_desc, void *base_addr, elem_bytewidth = elem_bitwidth / 8; } - CUtensorMapDataType data_type; - if (elem_bytewidth == 1) { - data_type = CU_TENSOR_MAP_DATA_TYPE_UINT8; - } else if (elem_bytewidth == 2) { - data_type = CU_TENSOR_MAP_DATA_TYPE_UINT16; - } else if (elem_bytewidth == 4) { - data_type = CU_TENSOR_MAP_DATA_TYPE_UINT32; - } else if (elem_bytewidth == 8) { - data_type = CU_TENSOR_MAP_DATA_TYPE_UINT64; - } else { - fprintf(stderr, "Unsupported element size: %ld\n", elem_bytewidth); - abort(); - } if (rank < 1 || rank > 5) { fprintf(stderr, "Rank must be in [1, 5], but got %ld\n", rank); abort(); diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index c8013f634c67..040f994d0b2b 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -379,6 +379,29 @@ def kernel(x_ref, o_ref_gmem, scratch_ref): x = jnp.arange(256).astype(jnp.float32) np.testing.assert_array_equal(kernel(x)[indexer], x[indexer] + 1.0) + @parameterized.parameters(jnp.bfloat16, jnp.float16, jnp.float32) + def test_copy_smem_to_gmem_reduction(self, dtype): + @functools.partial( + pl.pallas_call, + grid=(200,), + in_specs=[pl.BlockSpec((128,), lambda *i: i), pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct([128], dtype), + scratch_shapes=[plgpu.SMEM((128,), dtype)], + input_output_aliases={1:0} + ) + def kernel(x_ref, o_ref_gmem, o_ref_gmem_alias, scratch_ref): + del o_ref_gmem_alias + scratch_ref[...] = x_ref[...] + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(scratch_ref.at[...], o_ref_gmem.at[...], reduction_op="add") + plgpu.wait_smem_to_gmem(0) + x = jnp.ones(200 * 128).astype(dtype) # 200 blocks + output = jnp.zeros(128).astype(dtype) + output = kernel(x, output) + output_val = x.reshape(-1, 128).sum(axis=0) + np.testing.assert_array_equal(output, output_val) + @parameterized.named_parameters( {"testcase_name": "1d_none", "shape": (256,), "indexers": (slice(0, 128), slice(None, 32))}, From 10425ae6a9ebd77ecd0de775f3f758b8978e18bd Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 28 Mar 2025 14:32:59 -0700 Subject: [PATCH 284/483] jax.core: finalize a number of deprecations for JAX v0.6.0 --- CHANGELOG.md | 8 ++++ jax/core.py | 125 ++++++++++++++++----------------------------------- 2 files changed, 47 insertions(+), 86 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5785f6193065..68450dca4057 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,14 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. `jax.tree_leaves`, `jax.tree_structure`, `jax.tree_transpose`, and `jax.tree_unflatten`. Replacements can be found in {mod}`jax.tree` or {mod}`jax.tree_util`. + * From `jax.core`: `AxisSize`, `ClosedJaxpr`, `EvalTrace`, `InDBIdx`, `InputType`, + `Jaxpr`, `JaxprEqn`, `Literal`, `MapPrimitive`, `OpaqueTraceState`, `OutDBIdx`, + `Primitive`, `Token`, `TRACER_LEAK_DEBUGGER_WARNING`, `Var`, `concrete_aval`, + `dedup_referents`, `escaped_tracer_error`, `extend_axis_env_nd`, `get_referent`, + `join_effects`, `leaked_tracer_error`, `maybe_find_leaked_tracers`, `raise_to_shaped`, + `raise_to_shaped_mappings`, `reset_trace_state`, `str_eqn_compact`, + `substitute_vars_in_output_ty`, `typecompat`, and `used_axis_names_jaxpr`. Most + have no public replacement, though a few are available at {mod}`jax.extend.core`. ## jax 0.5.3 (Mar 19, 2025) diff --git a/jax/core.py b/jax/core.py index b404e66c2691..688fa14d9ccf 100644 --- a/jax/core.py +++ b/jax/core.py @@ -81,75 +81,21 @@ from jax._src import core as _src_core _deprecations = { - # Added 2024-12-16 - "ClosedJaxpr": ("jax.core.ClosedJaxpr is deprecated. Use jax.extend.core.ClosedJaxpr instead, " - "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", - _src_core.ClosedJaxpr), - "Jaxpr": ("jax.core.Jaxpr is deprecated. Use jax.extend.core.Jaxpr instead, " - "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", - _src_core.Jaxpr), - "JaxprEqn": ("jax.core.JaxprEqn is deprecated. Use jax.extend.core.JaxprEqn instead, " - "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", - _src_core.JaxprEqn), - "Literal": ("jax.core.Literal is deprecated. Use jax.extend.core.Literal instead, " - "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", - _src_core.Literal), - "Primitive": ("jax.core.Primitive is deprecated. Use jax.extend.core.Primitive instead, " - "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", - _src_core.Primitive), - "Token": ("jax.core.Token is deprecated. Use jax.extend.core.Token instead, " - "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", - _src_core.Token), - "Var": ("jax.core.Var is deprecated. Use jax.extend.core.Var instead, " - "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", - _src_core.Var), # Added 2024-12-11 "axis_frame": ("jax.core.axis_frame is deprecated.", _src_core.axis_frame), "AxisName": ("jax.core.AxisName is deprecated.", _src_core.AxisName), - "AxisSize": ("jax.core.AxisSize is deprecated.", _src_core.AxisSize), "ConcretizationTypeError": ("jax.core.ConcretizationTypeError is deprecated; " "use jax.errors.ConcretizationTypeError.", _src_core.ConcretizationTypeError), - "EvalTrace": ("jax.core.EvalTrace is deprecated.", _src_core.EvalTrace), - "InDBIdx": ("jax.core.InDBIdx is deprecated.", _src_core.InDBIdx), - "InputType": ("jax.core.InputType is deprecated.", _src_core.InputType), - "MapPrimitive": ("jax.core.MapPrimitive is deprecated.", _src_core.MapPrimitive), - "OpaqueTraceState": ("jax.core.OpaqueTraceState is deprecated.", _src_core.OpaqueTraceState), - "OutDBIdx": ("jax.core.OutDBIdx is deprecated.", _src_core.OutDBIdx), - "TRACER_LEAK_DEBUGGER_WARNING": ("jax.core.TRACER_LEAK_DEBUGGER_WARNING is deprecated.", - _src_core.TRACER_LEAK_DEBUGGER_WARNING), "call_p": ("jax.core.call_p is deprecated. Use jax.extend.core.primitives.call_p", _src_core.call_p), "closed_call_p": ("jax.core.closed_call_p is deprecated. Use jax.extend.core.primitives.closed_call_p", _src_core.closed_call_p), - "concrete_aval": ("jax.core.concrete_aval is deprecated.", _src_core.abstractify), - "dedup_referents": ("jax.core.dedup_referents is deprecated.", _src_core.dedup_referents), - "escaped_tracer_error": ("jax.core.escaped_tracer_error is deprecated.", - _src_core.escaped_tracer_error), - "extend_axis_env_nd": ("jax.core.extend_axis_env_nd is deprecated.", - _src_core.extend_axis_env_nd), "get_type": ("jax.core.get_type is deprecated.", _src_core.get_aval), - "get_referent": ("jax.core.get_referent is deprecated.", _src_core.get_referent), - "join_effects": ("jax.core.join_effects is deprecated.", _src_core.join_effects), - "leaked_tracer_error": ("jax.core.leaked_tracer_error is deprecated.", - _src_core.leaked_tracer_error), - "maybe_find_leaked_tracers": ("jax.core.maybe_find_leaked_tracers is deprecated.", - _src_core.maybe_find_leaked_tracers), - "raise_to_shaped_mappings": ("jax.core.raise_to_shaped_mappings is deprecated." - " It is unused as of jax v0.4.36.", - _src_core.raise_to_shaped_mappings), - "reset_trace_state": ("jax.core.reset_trace_state is deprecated.", - _src_core.reset_trace_state), - "str_eqn_compact": ("jax.core.str_eqn_compact is deprecated.", _src_core.str_eqn_compact), - "substitute_vars_in_output_ty": ("jax.core.substitute_vars_in_output_ty is deprecated.", - _src_core.substitute_vars_in_output_ty), "trace_state_clean": ("jax.core.trace_state_clean is deprecated.", _src_core.trace_state_clean), "typecheck": ("jax.core.typecheck is deprecated.", _src_core.typecheck), - "typecompat": ("jax.core.typecompat is deprecated.", _src_core.typecompat), "typematch": ("jax.core.typematch is deprecated.", _src_core.typematch), - "used_axis_names_jaxpr": ("jax.core.used_axis_names_jaxpr is deprecated.", - _src_core.used_axis_names_jaxpr), # Added 2024-12-10 "full_lower": ("jax.core.full_lower is deprecated. It is a no-op as of JAX v0.4.36.", _src_core.full_lower), @@ -158,54 +104,61 @@ _src_core.jaxpr_as_fun), "lattice_join": ("jax.core.lattice_join is deprecated. It is a no-op as of JAX v0.4.36.", _src_core.lattice_join), - "raise_to_shaped": ("jax.core.raise_to_shaped is deprecated. It is a no-op as of JAX v0.4.36.", - _src_core.raise_to_shaped), + # Finalized 2025-03-25 for JAX v0.6.0; remove after 2025-06-25 + "AxisSize": ("jax.core.AxisSize was removed in JAX v0.6.0.", None), + "ClosedJaxpr": ("jax.core.ClosedJaxpr was removed in JAX v0.6.0. Use jax.extend.core.ClosedJaxpr instead, " + "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", None), + "EvalTrace": ("jax.core.EvalTrace was removed in JAX v0.6.0.", None), + "InDBIdx": ("jax.core.InDBIdx was removed in JAX v0.6.0.", None), + "InputType": ("jax.core.InputType was removed in JAX v0.6.0.", None), + "Jaxpr": ("jax.core.Jaxpr was removed in JAX v0.6.0. Use jax.extend.core.Jaxpr instead, " + "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", None), + "JaxprEqn": ("jax.core.JaxprEqn was removed in JAX v0.6.0. Use jax.extend.core.JaxprEqn instead, " + "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", None), + "Literal": ("jax.core.Literal was removed in JAX v0.6.0. Use jax.extend.core.Literal instead, " + "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", None), + "MapPrimitive": ("jax.core.MapPrimitive was removed in JAX v0.6.0.", None), + "OpaqueTraceState": ("jax.core.OpaqueTraceState was removed in JAX v0.6.0.", None), + "OutDBIdx": ("jax.core.OutDBIdx was removed in JAX v0.6.0.", None), + "Primitive": ("jax.core.Primitive was removed in JAX v0.6.0. Use jax.extend.core.Primitive instead, " + "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", None), + "Token": ("jax.core.Token was removed in JAX v0.6.0. Use jax.extend.core.Token instead, " + "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", None), + "TRACER_LEAK_DEBUGGER_WARNING": ("jax.core.TRACER_LEAK_DEBUGGER_WARNING was removed in JAX v0.6.0.", None), + "Var": ("jax.core.Var was removed in JAX v0.6.0. Use jax.extend.core.Var instead, " + "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", None), + "concrete_aval": ("jax.core.concrete_aval was removed in JAX v0.6.0.", None), + "dedup_referents": ("jax.core.dedup_referents was removed in JAX v0.6.0.", None), + "escaped_tracer_error": ("jax.core.escaped_tracer_error was removed in JAX v0.6.0.", None), + "extend_axis_env_nd": ("jax.core.extend_axis_env_nd was removed in JAX v0.6.0.", None), + "get_referent": ("jax.core.get_referent was removed in JAX v0.6.0.", None), + "join_effects": ("jax.core.join_effects was removed in JAX v0.6.0.", None), + "leaked_tracer_error": ("jax.core.leaked_tracer_error was removed in JAX v0.6.0.", None), + "maybe_find_leaked_tracers": ("jax.core.maybe_find_leaked_tracers was removed in JAX v0.6.0.", None), + "raise_to_shaped": ("jax.core.raise_to_shaped was removed in JAX v0.6.0. It is a no-op as of JAX v0.4.36.", None), + "raise_to_shaped_mappings": ("jax.core.raise_to_shaped_mappings was removed in JAX v0.6.0." + " It is unused as of jax v0.4.36.", None), + "reset_trace_state": ("jax.core.reset_trace_state was removed in JAX v0.6.0.", None), + "str_eqn_compact": ("jax.core.str_eqn_compact was removed in JAX v0.6.0.", None), + "substitute_vars_in_output_ty": ("jax.core.substitute_vars_in_output_ty was removed in JAX v0.6.0.", None), + "typecompat": ("jax.core.typecompat was removed in JAX v0.6.0.", None), + "used_axis_names_jaxpr": ("jax.core.used_axis_names_jaxpr was removed in JAX v0.6.0.", None), } import typing if typing.TYPE_CHECKING: AxisName = _src_core.AxisName - AxisSize = _src_core.AxisSize - ClosedJaxpr = _src_core.ClosedJaxpr ConcretizationTypeError = _src_core.ConcretizationTypeError - EvalTrace = _src_core.EvalTrace - InDBIdx = _src_core.InDBIdx - InputType = _src_core.InputType - Jaxpr = _src_core.Jaxpr - JaxprEqn = _src_core.JaxprEqn - Literal = _src_core.Literal - MapPrimitive = _src_core.MapPrimitive - OpaqueTraceState = _src_core.OpaqueTraceState - OutDBIdx = _src_core.OutDBIdx - Primitive = _src_core.Primitive - Token = _src_core.Token - TRACER_LEAK_DEBUGGER_WARNING = _src_core.TRACER_LEAK_DEBUGGER_WARNING - Var = _src_core.Var axis_frame = _src_core.axis_frame call_p = _src_core.call_p closed_call_p = _src_core.closed_call_p - concrete_aval = _src_core.abstractify - dedup_referents = _src_core.dedup_referents - escaped_tracer_error = _src_core.escaped_tracer_error - extend_axis_env_nd = _src_core.extend_axis_env_nd full_lower = _src_core.full_lower get_type = _src_core.get_aval - get_referent = _src_core.get_referent jaxpr_as_fun = _src_core.jaxpr_as_fun - join_effects = _src_core.join_effects lattice_join = _src_core.lattice_join - leaked_tracer_error = _src_core.leaked_tracer_error - maybe_find_leaked_tracers = _src_core.maybe_find_leaked_tracers - raise_to_shaped = _src_core.raise_to_shaped - raise_to_shaped_mappings = _src_core.raise_to_shaped_mappings - reset_trace_state = _src_core.reset_trace_state - str_eqn_compact = _src_core.str_eqn_compact - substitute_vars_in_output_ty = _src_core.substitute_vars_in_output_ty trace_state_clean = _src_core.trace_state_clean typecheck = _src_core.typecheck - typecompat = _src_core.typecompat typematch = _src_core.typematch - used_axis_names_jaxpr = _src_core.used_axis_names_jaxpr else: from jax._src.deprecations import deprecation_getattr as _deprecation_getattr __getattr__ = _deprecation_getattr(__name__, _deprecations) From aee27854f056bdabdb25d2af0eac5a2e5b35f63b Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 31 Mar 2025 00:53:30 -0700 Subject: [PATCH 285/483] [Pallas:MGPU] Only allow small tiling in Pallas programs This is part of the removal of support for large MMA tiling in Mosaic GPU. It should also let us simplify some of the transpose transforms that are no longer necessary, but I decided to separate this. PiperOrigin-RevId: 742168801 --- jax/_src/pallas/mosaic_gpu/lowering.py | 4 +- jax/_src/pallas/mosaic_gpu/primitives.py | 6 +- .../pallas/ops/gpu/attention_mgpu.py | 4 +- tests/pallas/mosaic_gpu_test.py | 62 +++++++++---------- 4 files changed, 36 insertions(+), 40 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index e99feb4dc144..daa718ff1ff2 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1129,7 +1129,7 @@ def _get_lowering_rule(ctx: LoweringRuleContext, x_ref, *leaves, tree): match transforms: case (gpu_core.UnswizzleRef(swizzle), gpu_core.UntileRef(tiling)): - if tiling != (64, swizzle // x_aval.dtype.itemsize): + if tiling != (8, swizzle // x_aval.dtype.itemsize): raise NotImplementedError("Tiling does not fit swizzle") return mgpu.FragmentedArray.load_tiled( x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype), swizzle=swizzle @@ -1188,7 +1188,7 @@ def _swap_lowering_rule( x_smem, transforms = _handle_indexing(x_smem, transforms) match transforms: case (gpu_core.UnswizzleRef(swizzle), gpu_core.UntileRef(tiling)): - if tiling != (64, swizzle // x_aval.dtype.itemsize): + if tiling != (8, swizzle // x_aval.dtype.itemsize): raise NotImplementedError("Tiling does not fit swizzle") old_value = mgpu.FragmentedArray.load_tiled( x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype), swizzle=swizzle diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 48a4cae62824..d0632728f2b6 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -93,7 +93,7 @@ def _load_p_lowering_rule( match transforms: case (gpu_core.UnswizzleRef(swizzle), gpu_core.UntileRef(tiling)): - if tiling != (64, swizzle // x_aval.dtype.itemsize): + if tiling != (8, swizzle // x_aval.dtype.itemsize): raise NotImplementedError("Tiling does not fit swizzle") return mgpu.FragmentedArray.load_tiled( x_ref, is_signed=mgpu_utils.is_signed(x_aval.dtype), swizzle=swizzle, @@ -739,7 +739,7 @@ def _wgmma_lowering( match a_transforms: case (gpu_core.UnswizzleRef(lhs_swizzle), gpu_core.UntileRef(tiling)): swizzle_elems = lhs_swizzle // a_aval.dtype.itemsize - if tiling != (64, swizzle_elems): + if tiling != (8, swizzle_elems): raise NotImplementedError("WGMMA lhs tiling does not fit swizzle") case _: raise ValueError(f"WGMMA lhs has unsupported transforms: {a_transforms}.") @@ -790,7 +790,7 @@ def _wgmma_lowering( swizzle_elems = rhs_swizzle // a_aval.dtype.itemsize if rhs_swizzle != lhs_swizzle: raise NotImplementedError("WGMMA rhs swizzle must match lhs swizzle") - if rhs_tiling != (swizzle_elems, swizzle_elems): + if rhs_tiling != (8, swizzle_elems): raise NotImplementedError("WGMMA rhs tiling does not fit swizzle") if rhs_transpose: diff --git a/jax/experimental/pallas/ops/gpu/attention_mgpu.py b/jax/experimental/pallas/ops/gpu/attention_mgpu.py index 534da419ed3b..48a0d18459cb 100644 --- a/jax/experimental/pallas/ops/gpu/attention_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/attention_mgpu.py @@ -193,7 +193,7 @@ def kv_loop(kv_step, _): def entry(q_ref, k_ref, v_ref, out_ref): compute_wgs = 2 - tiling = plgpu.TilingTransform((64, 64)) + tiling = plgpu.TilingTransform((8, 64)) swizzle = plgpu.SwizzleTransform(128) qo_scratch = plgpu.SMEM( (compute_wgs, block_q, head_dim), jnp.float16, @@ -263,7 +263,7 @@ def attention_with_pipeline_emitter(q, k, v, config: TuningConfig): if rem: raise NotImplementedError(f"{q_seq_len=} must be a multiple of {block_q * 2=}") - tiling = plgpu.TilingTransform((64, 64)) + tiling = plgpu.TilingTransform((8, 64)) swizzle = plgpu.SwizzleTransform(128) transpose = plgpu.TransposeTransform((0, 2, 1, 3, 4)) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index c8013f634c67..9c5795e49da7 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -523,7 +523,7 @@ def test_gmem_to_smem_with_multiple_smem_indexers_and_transforms(self): index_map=lambda i, j: (i, j), memory_space=plgpu.SMEM, transforms=( - plgpu.TilingTransform((64, 32)), + plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(128), ), ), @@ -584,7 +584,7 @@ def kernel(x_ref, o_ref, barrier_ref): (128, 128), lambda: (0, 0), transforms=( - plgpu.TilingTransform((64, 32)), + plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(128), ), memory_space=plgpu.SMEM, @@ -604,7 +604,7 @@ def kernel(x_ref, o_ref, barrier_ref): def test_scoped_copy_with_transforms(self): self.skip_if_wg_semantics() - ts = (plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(128)) + ts = (plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(128)) def kernel(x_ref, o_ref, barrier_ref): def body(tmp_ref): plgpu.copy_gmem_to_smem(x_ref, tmp_ref, barrier_ref) @@ -639,7 +639,7 @@ def kernel(x_ref, o_ref, barrier_ref): (2, 128, 128), lambda: (0, 0, 0), transforms=( - plgpu.TilingTransform((64, 32)), + plgpu.TilingTransform((8, 32)), plgpu.TransposeTransform((0, 2, 1, 3, 4)), plgpu.SwizzleTransform(128), ), @@ -749,8 +749,7 @@ def compute(acc_ref): (k, n), lambda: (0, 0), transforms=( - plgpu.TilingTransform((64, 64)), - plgpu.SwizzleTransform(128), + plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128), ), ), ), @@ -881,8 +880,7 @@ def test_print_wgmma_tiled_layout(self): shape, lambda: (0, 0), transforms=( - plgpu.TilingTransform((64, 32)), - plgpu.SwizzleTransform(128), + plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(128), ), ) ], @@ -1061,8 +1059,7 @@ def test_swizzled_blockspec_shapes(self): (128, 64), lambda *i: i, transforms=( - plgpu.TilingTransform((64, 64)), - plgpu.SwizzleTransform(128), + plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128), ), ) @functools.partial( @@ -1243,8 +1240,7 @@ def test_tile_slicing(self): shape = (256, 128) block_spec = plgpu.GPUBlockSpec( transforms=( - plgpu.TilingTransform((64, 64)), - plgpu.SwizzleTransform(128), + plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128), ) ) @functools.partial( @@ -1297,7 +1293,7 @@ def rotate(src, dst): (128, 128), lambda: (0, 0), transforms=( - plgpu.TilingTransform((64, 64)), + plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128), ), ) @@ -1431,7 +1427,7 @@ class PallasCallSm90ATest(PallasSm90ATest): @parameterized.parameters(False, True) def test_fori_loop_accumulator(self, force_while): - transforms = (plgpu.TilingTransform((64, 64)), plgpu.SwizzleTransform(128)) + transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) @functools.partial( pl.pallas_call, in_specs=[plgpu.GPUBlockSpec((64, 64), lambda: (0, 0), transforms=transforms)], @@ -1488,21 +1484,21 @@ def _epilogue(): lhs_spec = plgpu.GPUBlockSpec( lhs_spec.block_shape, lhs_spec.index_map, transforms=( - plgpu.TilingTransform((64, elems_128b)), + plgpu.TilingTransform((8, elems_128b)), plgpu.SwizzleTransform(128), ) ) rhs_spec = plgpu.GPUBlockSpec( rhs_spec.block_shape, rhs_spec.index_map, transforms=( - plgpu.TilingTransform((elems_128b, elems_128b)), + plgpu.TilingTransform((8, elems_128b)), plgpu.SwizzleTransform(128), ) ) out_spec = plgpu.GPUBlockSpec( out_spec.block_shape, out_spec.index_map, transforms=( - plgpu.TilingTransform((64, elems_128b)), + plgpu.TilingTransform((8, elems_128b)), plgpu.SwizzleTransform(128), ) ) @@ -1546,7 +1542,7 @@ def scope(acc_ref): b_shape = b_shape[::-1] b = jax.random.uniform(key2, shape=b_shape, dtype=dtype) - rhs_transforms = (plgpu.TilingTransform((elems_128b, elems_128b)),) + rhs_transforms = (plgpu.TilingTransform((8, elems_128b)),) if rhs_transpose: rhs_transforms += (plgpu.TransposeTransform((1, 0, 2, 3)),) res = pl.pallas_call( @@ -1556,7 +1552,7 @@ def scope(acc_ref): (64, 128), lambda i, j: (i, j), transforms=( - plgpu.TilingTransform((64, elems_128b)), + plgpu.TilingTransform((8, elems_128b)), plgpu.SwizzleTransform(128), ), ), @@ -1585,7 +1581,7 @@ def scope(acc_ref): a = jax.random.uniform(key1, shape=(64, 128), dtype=jnp.float16) b = jax.random.uniform(key2, shape=(128, 192), dtype=jnp.float16) - transforms = (plgpu.TilingTransform((64, 64)), plgpu.SwizzleTransform(128)) + transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) res = pl.pallas_call( kernel, in_specs=[ @@ -1608,7 +1604,7 @@ def scope(acc_ref): b = jax.random.uniform(key2, shape=(128, 192), dtype=jnp.float16) i = jax.random.uniform(key3, shape=(64, 192), dtype=jnp.float16) * 10 - transforms = (plgpu.TilingTransform((64, 64)), plgpu.SwizzleTransform(128)) + transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) res = pl.pallas_call( kernel, in_specs=[ @@ -1639,14 +1635,14 @@ def scope(acc_ref): plgpu.GPUBlockSpec( (2, 64, 128), lambda: (0, 0, 0), transforms=( - plgpu.TilingTransform((64, 64)), + plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128), ), ), plgpu.GPUBlockSpec( (2, 128, 192), lambda: (0, 0, 0), transforms=( - plgpu.TilingTransform((64, 64)), + plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128), ), ), @@ -1676,7 +1672,7 @@ def scope(acc_ref): (64, 128), lambda i, j: (i, j), transforms=( - plgpu.TilingTransform((64, elems_128b)), + plgpu.TilingTransform((8, elems_128b)), plgpu.SwizzleTransform(128), ), ), @@ -1684,7 +1680,7 @@ def scope(acc_ref): (128, 128), lambda *i: i, transforms=( - plgpu.TilingTransform((elems_128b, elems_128b)), + plgpu.TilingTransform((8, elems_128b)), plgpu.SwizzleTransform(128), ), ), @@ -1820,7 +1816,7 @@ def body(step, _): @parameterized.parameters( ((),), - ((plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(128)),), + ((plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(128)),), ) def test_emit(self, transforms): num_steps = 4 @@ -2005,7 +2001,7 @@ def kernel_body(a_smem, b_smem): (tile_m, tile_k), lambda k: (pid_m, k), transforms=( - plgpu.TilingTransform((64, elems_128b)), + plgpu.TilingTransform((8, elems_128b)), plgpu.SwizzleTransform(128), ), ), @@ -2013,7 +2009,7 @@ def kernel_body(a_smem, b_smem): (tile_k, tile_n), lambda k: (k, pid_n), transforms=( - plgpu.TilingTransform((elems_128b, elems_128b)), + plgpu.TilingTransform((8, elems_128b)), plgpu.SwizzleTransform(128), ), ), @@ -2039,7 +2035,7 @@ def kernel_body(a_smem, b_smem): (tile_m, tile_n), lambda m, n: (m, n), transforms=( - plgpu.TilingTransform((64, elems_128b)), + plgpu.TilingTransform((8, elems_128b)), plgpu.SwizzleTransform(128), ), ), @@ -2339,7 +2335,7 @@ def test_realistic_matmul_with_cluster(self): m, k, n = grid_m * tile_m, grid_k * tile_k, grid_n * tile_n transforms = ( - plgpu.TilingTransform((64, elems_128b)), + plgpu.TilingTransform((8, elems_128b)), plgpu.SwizzleTransform(128), ) @@ -2521,7 +2517,7 @@ def compute(l_smem, r_smem, o_smem): r = lax.axis_index("rows") block = plgpu.GPUBlockSpec( (row_block, col_block), lambda c: (r, c), - transforms=(plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(64)), + transforms=(plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(64)), ) plgpu.emit_pipeline( compute, @@ -2572,8 +2568,8 @@ def do_wgmma(acc_ref): return acc_ref[...] o_smem[...] += pl.run_scoped(do_wgmma, plgpu.ACC((m_block, n_block), jnp.float16)) m, n = lax.axis_index("m"), lax.axis_index("n") - lo_transforms = (plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(64)) - r_transforms = (plgpu.TilingTransform((32, 32)), plgpu.SwizzleTransform(64)) + lo_transforms = (plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(64)) + r_transforms = (plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(64)) plgpu.emit_pipeline( compute, grid=(l_ref.shape[1] // k_block,), From 05e15ba032841b13bd95f684a2f4f0b57bd75ada Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Mon, 31 Mar 2025 02:49:03 -0700 Subject: [PATCH 286/483] [pallas:mgpu] Allow more freedom for the user to transform references. Imlpemented untile_ref and unswizzle_ref in order to allow patterns where we need different transform stacks over the same memref. For example we may want to reg->smem transposed, then smem->gmem sliced and maybe load strided/print in between for sanity checking: ``` # Store registers transposed o_smem_swizzled = plgpu.unswizzle_ref(o_smem_raw, swizzle_out) o_smem_t = o_smem_swizzled.reshape(1, 1, config.block_n, config.block_m) o_smem_t = plgpu.untile_ref(o_smem_t, (n, m)) o_smem_t = plgpu.transpose_ref(o_smem_t, (1, 0)) o_smem_t[...] = plgpu.layout_cast((regs, plgpu.Layout.WGMMA_TRANSPOSED) plgpu.commit_smem() del o_smem_t # Now we need different transforms on the same smem to slice and async-store to gmem o_smem = o_smem_raw.reshape(n, m // swizzle_elems, swizzle_elems,) o_smem = plgpu.unswizzle_ref(o_smem, swizzle_out) o_smem = plgpu.tile_ref(o_smem, swizzle_out) o_smem = o_smem.at[...] plgpu.copy_smem_to_gmem(o_smem, o_ref.at[...],) ``` Which in turn lets us write PiperOrigin-RevId: 742194519 --- jax/_src/pallas/mosaic_gpu/core.py | 22 +++++++++++++++++----- jax/experimental/pallas/mosaic_gpu.py | 3 +++ tests/pallas/mosaic_gpu_test.py | 20 ++++++++++++++++++++ 3 files changed, 40 insertions(+), 5 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 8522bdf651f4..b0d4f23c792e 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -342,7 +342,7 @@ def to_gpu_transform(self) -> mgpu.MemRefTransform: @tree_util.register_dataclass @dataclasses.dataclass(frozen=True) class TransposeRef(state_types.Transform): - permutation: tuple[int, ...] + permutation: tuple[int, ...] = dataclasses.field(metadata=dict(static=True)) def transform_shape(self, shape): if shape is None: @@ -370,18 +370,30 @@ def undo_to_gpu_transform(self) -> mgpu.MemRefTransform: return mgpu.TransposeTransform(_perm_inverse(self.permutation)) -def transpose_ref( - ref: pallas_core.TransformedRef | Any, - permutation: tuple[int, ...], +def transform_ref( + ref: pallas_core.TransformedRef, + transform: state_types.Transform ) -> pallas_core.TransformedRef: if not isinstance(ref, pallas_core.TransformedRef): if not isinstance(jax_core.get_aval(ref), pallas_core.AbstractMemoryRef): raise TypeError("ref must be a reference") ref = pallas_core.TransformedRef(ref, transforms=()) return pallas_core.TransformedRef( - ref.ref, (*ref.transforms, TransposeRef(permutation)), + ref.ref, (*ref.transforms, transform), ) +def transpose_ref( + ref: pallas_core.TransformedRef | Any, + permutation: tuple[int, ...], +) -> pallas_core.TransformedRef: + return transform_ref(ref, TransposeRef(permutation)) + +def untile_ref(ref, tiling: tuple[int, ...]) -> pallas_core.TransformedRef: + return transform_ref(ref, UntileRef(tiling)) + +def unswizzle_ref(ref, swizzle: int) -> pallas_core.TransformedRef: + return transform_ref(ref, UnswizzleRef(swizzle)) + @dataclasses.dataclass(frozen=True) class SwizzleTransform(MemoryRefTransform): diff --git a/jax/experimental/pallas/mosaic_gpu.py b/jax/experimental/pallas/mosaic_gpu.py index e4c5ffe04093..b44c86ea7a4c 100644 --- a/jax/experimental/pallas/mosaic_gpu.py +++ b/jax/experimental/pallas/mosaic_gpu.py @@ -27,7 +27,10 @@ from jax._src.pallas.mosaic_gpu.core import SemaphoreType as SemaphoreType from jax._src.pallas.mosaic_gpu.core import SwizzleTransform as SwizzleTransform from jax._src.pallas.mosaic_gpu.core import TilingTransform as TilingTransform +from jax._src.pallas.mosaic_gpu.core import transform_ref as transform_ref from jax._src.pallas.mosaic_gpu.core import transpose_ref as transpose_ref +from jax._src.pallas.mosaic_gpu.core import untile_ref as untile_ref +from jax._src.pallas.mosaic_gpu.core import unswizzle_ref as unswizzle_ref from jax._src.pallas.mosaic_gpu.core import TransposeTransform as TransposeTransform from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as ACC # noqa: F401 from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as WGMMAAccumulatorRef diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 9c5795e49da7..96532488a648 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -626,6 +626,26 @@ def body(tmp_ref): x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128) np.testing.assert_array_equal(f(x), x * 2) + def test_scoped_copy_with_user_transforms(self): + def kernel(x_ref, o_ref, barrier_ref): + def body(tmp_ref): + tmp_ref = plgpu.unswizzle_ref(tmp_ref, 128) + tmp_ref = plgpu.untile_ref(tmp_ref, (8, 32)) + plgpu.copy_gmem_to_smem(x_ref, tmp_ref, barrier_ref) + plgpu.barrier_wait(barrier_ref) + o_ref[...] = tmp_ref[...] * 2 + pl.run_scoped(body, plgpu.SMEM((16, 4, 8, 32), jnp.float32)) + + in_spec = pl.BlockSpec(memory_space=plgpu.GMEM) + f = pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct([128, 128], jnp.float32), + in_specs=(in_spec,), + scratch_shapes=[plgpu.Barrier(num_arrivals=1)], + ) + x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128) + np.testing.assert_array_equal(f(x), x * 2) + def test_copy_with_transforms_and_indexing(self): self.skip_if_wg_semantics() From d3ed327572e4075ee5b7b0ba3b4b9633a737a39d Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 31 Mar 2025 03:47:57 -0700 Subject: [PATCH 287/483] [Pallas:MGPU] Remove (now) unnecessary TransposeTransforms Now that we always use small tiles, we can lay out the tiled dimension in arbitrary order so there's no need to swap them during the TMA. PiperOrigin-RevId: 742206980 --- jax/_src/pallas/mosaic_gpu/primitives.py | 5 ++--- jax/experimental/pallas/ops/gpu/attention_mgpu.py | 5 ++--- tests/pallas/mosaic_gpu_test.py | 2 -- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index d0632728f2b6..07235c2fc830 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -759,9 +759,8 @@ def _wgmma_lowering( rhs_transpose = False case ( gpu_core.UnswizzleRef(rhs_swizzle), - gpu_core.TransposeRef((1, 0, 2, 3)), # Only transpose between tiles gpu_core.UntileRef(rhs_tiling), - gpu_core.TransposeRef((1, 0)), # Transpose the two logical dims + gpu_core.TransposeRef((1, 0)), ): rhs_transpose = True case ( @@ -794,7 +793,7 @@ def _wgmma_lowering( raise NotImplementedError("WGMMA rhs tiling does not fit swizzle") if rhs_transpose: - b = mgpu.memref_transpose(b, (0, 1, 3, 2)) + b = mgpu.memref_transpose(b, (1, 0, 3, 2)) new_acc = mgpu.wgmma(acc, a, b, swizzle=rhs_swizzle) nvvm_dialect.wgmma_commit_group_sync_aligned() return new_acc diff --git a/jax/experimental/pallas/ops/gpu/attention_mgpu.py b/jax/experimental/pallas/ops/gpu/attention_mgpu.py index 48a0d18459cb..b19e371a1eb8 100644 --- a/jax/experimental/pallas/ops/gpu/attention_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/attention_mgpu.py @@ -201,7 +201,7 @@ def entry(q_ref, k_ref, v_ref, out_ref): ) k_scratch = plgpu.SMEM( (max_concurrent_steps, block_kv, head_dim), jnp.float16, - transforms=(tiling, plgpu.TransposeTransform((0, 2, 1, 3, 4)), swizzle), + transforms=(tiling, swizzle), ) v_scratch = plgpu.SMEM( (max_concurrent_steps, block_kv, head_dim), jnp.float16, @@ -265,7 +265,6 @@ def attention_with_pipeline_emitter(q, k, v, config: TuningConfig): tiling = plgpu.TilingTransform((8, 64)) swizzle = plgpu.SwizzleTransform(128) - transpose = plgpu.TransposeTransform((0, 2, 1, 3, 4)) def fa3_kernel(q_ref, k_ref, v_ref, out_ref, scoped): batch = lax.axis_index("batch") @@ -354,7 +353,7 @@ def compute_pv(acc_ref): plgpu.GPUBlockSpec( # k block_shape=(block_kv, head_dim), index_map=lambda i: (i, 0), - transforms=[tiling, transpose, swizzle]), + transforms=[tiling, swizzle]), plgpu.GPUBlockSpec( # v block_shape=(block_kv, head_dim), index_map=lambda i: (i, 0), diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 96532488a648..965539af52ab 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1563,8 +1563,6 @@ def scope(acc_ref): b = jax.random.uniform(key2, shape=b_shape, dtype=dtype) rhs_transforms = (plgpu.TilingTransform((8, elems_128b)),) - if rhs_transpose: - rhs_transforms += (plgpu.TransposeTransform((1, 0, 2, 3)),) res = pl.pallas_call( kernel, in_specs=[ From fc01058ee42cefc2502a5674eebfef79f2749ebe Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 31 Mar 2025 05:15:12 -0700 Subject: [PATCH 288/483] Update XLA dependency to use revision http://github.com/openxla/xla/commit/f4a53456b04acf9b63b3b30bd828cec29c4aa7de. PiperOrigin-RevId: 742228024 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 8b3ddfde019b..d078359af86a 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "8df9390dc9444d900c7c7f2c123f23b549adf8e3" -XLA_SHA256 = "8e97c395d1e50a49fab386ccc7da1f78dc86bf670b20a892656e2e75bbf64f0e" +XLA_COMMIT = "f4a53456b04acf9b63b3b30bd828cec29c4aa7de" +XLA_SHA256 = "2ee32b70af547fd13ce404d75c3fa9834bc8be46a488cd8f0caa10e9a6ec7ede" def repo(): tf_http_archive( From cb5168269119ca098d678869c74303982ec84b17 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 31 Mar 2025 07:07:12 -0700 Subject: [PATCH 289/483] [pallas:mosaic_gpu] Run all Mosaic GPU-specific tests under WG semantics We do skip quite a few due to missing features. I tried to make the reason for skipping clear in each case. PiperOrigin-RevId: 742252858 --- tests/pallas/mosaic_gpu_test.py | 367 +++++++++++++++++++++----------- 1 file changed, 241 insertions(+), 126 deletions(-) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 965539af52ab..6b1839a64580 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -85,6 +85,13 @@ def skip_if_wg_semantics(self): if self.THREAD_SEMANTICS == plgpu.ThreadSemantics.Warpgroup: self.skipTest("Not supported under WG semantics") + def kernel(self, *args, **kwargs): + compiler_params = dataclasses.replace( + kwargs.pop("compiler_params", plgpu.GPUCompilerParams()), + thread_semantics=self.THREAD_SEMANTICS, + ) + return plgpu.kernel(*args, compiler_params=compiler_params, **kwargs) + def pallas_call(self, *args, **kwargs): compiler_params = dataclasses.replace( kwargs.pop("compiler_params", plgpu.GPUCompilerParams()), @@ -975,7 +982,7 @@ def kernel(x_ref, o_ref): def test_load_scalar(self): @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct((128,), jnp.int32), in_specs=[plgpu.GPUBlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM)], ) @@ -987,7 +994,7 @@ def kernel(x_ref, o_ref): def test_run_scoped(self): @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), ) def kernel(x_ref, o_ref): @@ -1005,7 +1012,7 @@ def body(tmp_ref): def test_program_id(self): @functools.partial( - pl.pallas_call, + self.pallas_call, in_specs=(), out_specs=pl.BlockSpec((128,), lambda *i: i), out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.int32), @@ -1024,7 +1031,7 @@ def test_program_id_in_squashed_grid(self): # 3 CUDA grid dimensions. grid = (2, 3, 4, 5) @functools.partial( - pl.pallas_call, + self.pallas_call, in_specs=(), out_specs=pl.BlockSpec((1,) * len(grid) + (128,), lambda *i: (*i, 0)), out_shape=jax.ShapeDtypeStruct([*grid, 128], jnp.int32), @@ -1045,7 +1052,7 @@ def kernel(o_ref): def test_program_id_in_block_spec(self): @functools.partial( - pl.pallas_call, + self.pallas_call, in_specs=(pl.BlockSpec((2, 128), lambda i: (pl.program_id(0), i)),), out_specs=pl.BlockSpec((2, 128), lambda i: (pl.program_id(0), i)), out_shape=jax.ShapeDtypeStruct([2, 128], jnp.int32), @@ -1059,7 +1066,7 @@ def kernel(x_ref, o_ref): def test_num_programs(self): @functools.partial( - pl.pallas_call, + self.pallas_call, in_specs=(), out_specs=pl.BlockSpec((128,), lambda *i: i), out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.int32), @@ -1074,6 +1081,7 @@ def kernel(o_ref): ) def test_swizzled_blockspec_shapes(self): + self.skip_if_wg_semantics() spec = plgpu.GPUBlockSpec( (128, 64), @@ -1083,7 +1091,7 @@ def test_swizzled_blockspec_shapes(self): ), ) @functools.partial( - pl.pallas_call, + self.pallas_call, in_specs=[spec], out_specs=spec, out_shape=jax.ShapeDtypeStruct((128, 128), jnp.float16), @@ -1124,7 +1132,7 @@ def kernel(o_ref): def test_fori_loop_dynamic_bounds(self): @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32), grid=(1,) ) @@ -1201,8 +1209,10 @@ def body(acc): ) def test_while_loop_layout_mismatch(self): + self.skip_if_wg_semantics() # `plgpu.layout_cast` is not supported. + @functools.partial( - pl.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.int32) + self.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.int32) ) def kernel(o_ref): def cond(acc): @@ -1255,8 +1265,10 @@ def kernel(x_ref, o_ref): x = jnp.arange(256, dtype=jnp.int32) np.testing.assert_array_equal(kernel(x), jnp.broadcast_to(jnp.sum(x) * 3, [256])) - # Not testing with warpgroup semantics, because we want to enforce a layout. def test_tile_slicing(self): + # Not testing with warpgroup semantics, because we want to enforce a layout. + self.skip_if_wg_semantics() + shape = (256, 128) block_spec = plgpu.GPUBlockSpec( transforms=( @@ -1264,7 +1276,7 @@ def test_tile_slicing(self): ) ) @functools.partial( - pl.pallas_call, + self.pallas_call, in_specs=[block_spec], out_specs=block_spec, out_shape=jax.ShapeDtypeStruct((64, 64), jnp.uint16), @@ -1289,7 +1301,7 @@ def kernel(a_ref, b_ref): a_ref[...] = jnp.ones_like(a_ref) a = np.zeros((64, 64), dtype=jnp.float32) - b = pl.pallas_call( + b = self.pallas_call( kernel, in_specs=[plgpu.GPUBlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM)], out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM), @@ -1299,6 +1311,8 @@ def kernel(a_ref, b_ref): np.testing.assert_array_equal(b, np.ones_like(a)) def test_slicing(self): + self.skip_if_wg_semantics() + left = upper = slice(None, 64) right = lower = slice(64, None) # We rotate the four quadrants of the input clockwise. @@ -1317,14 +1331,16 @@ def rotate(src, dst): plgpu.SwizzleTransform(128), ), ) - f = pl.pallas_call(rotate, out_shape=x, in_specs=[spec], out_specs=spec) + f = self.pallas_call(rotate, out_shape=x, in_specs=[spec], out_specs=spec) expected = np.empty_like(x) rotate(x, expected) np.testing.assert_array_equal(f(x), expected) def test_layout_cast(self, shape=(256, 64)): + self.skip_if_wg_semantics() # `plgpu.layout_cast` is not supported. + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct(shape, jnp.float32), ) def kernel(o_ref): @@ -1334,6 +1350,8 @@ def kernel(o_ref): np.testing.assert_array_equal(kernel(), x) def test_profiler(self): + self.skip_if_wg_semantics() # Transform inference fails. + def kernel(x_ref, o_ref): with jax.named_scope("add"): with jax.named_scope("load"): @@ -1343,7 +1361,7 @@ def kernel(x_ref, o_ref): o_ref[...] = o with tempfile.TemporaryDirectory() as tmpdir: x = jnp.arange(256).astype(jnp.float32) - y = pl.pallas_call( + y = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct([256], jnp.float32), compiler_params=plgpu.GPUCompilerParams( @@ -1447,10 +1465,15 @@ class PallasCallSm90ATest(PallasSm90ATest): @parameterized.parameters(False, True) def test_fori_loop_accumulator(self, force_while): + # ``pl.run_state`` is not supported in WG semantics. + self.skip_if_wg_semantics() + transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) @functools.partial( - pl.pallas_call, - in_specs=[plgpu.GPUBlockSpec((64, 64), lambda: (0, 0), transforms=transforms)], + self.pallas_call, + in_specs=[ + plgpu.GPUBlockSpec((64, 64), lambda: (0, 0), transforms=transforms) + ], out_shape=jax.ShapeDtypeStruct((64, 64), jnp.float16), out_specs=plgpu.GPUBlockSpec((64, 64), lambda: (0, 0)), ) @@ -1462,8 +1485,7 @@ def scope(acc_ref): acc_ini = jnp.ones((64, 64), dtype=jnp.float16) np.testing.assert_array_equal(kernel(acc_ini), jnp.full((64, 64), 5, dtype=jnp.float16)) - @parameterized.parameters([*plgpu.ThreadSemantics]) - def test_realistic_matmul(self, thread_semantics): + def test_realistic_matmul(self): dtype = jnp.float16 swizzle = 128 elems_128b = swizzle // jnp.dtype(dtype).itemsize @@ -1500,7 +1522,7 @@ def _epilogue(): lambda m, n, k: (m, n), ) - if thread_semantics == plgpu.ThreadSemantics.Lane: + if self.THREAD_SEMANTICS == plgpu.ThreadSemantics.Lane: lhs_spec = plgpu.GPUBlockSpec( lhs_spec.block_shape, lhs_spec.index_map, transforms=( @@ -1523,7 +1545,7 @@ def _epilogue(): ) ) - res = pl.pallas_call( + res = self.pallas_call( kernel, in_specs=[lhs_spec, rhs_spec], out_specs=out_spec, @@ -1534,13 +1556,14 @@ def _epilogue(): dimension_semantics=["parallel", "parallel", "sequential"], max_concurrent_steps=2, delay_release=1, - thread_semantics=thread_semantics, ), )(a, b) np.testing.assert_allclose(res, a @ b, rtol=1e-3) @parameterized.parameters(jnp.float16, jnp.float32) def test_wgmma(self, dtype): + self.skip_if_wg_semantics() + # TensorCores can only fuse transposes of 16-bit values, and RHS # is expected to be column major by default. rhs_transpose = jnp.dtype(dtype).itemsize != 2 @@ -1563,7 +1586,7 @@ def scope(acc_ref): b = jax.random.uniform(key2, shape=b_shape, dtype=dtype) rhs_transforms = (plgpu.TilingTransform((8, elems_128b)),) - res = pl.pallas_call( + res = self.pallas_call( kernel, in_specs=[ plgpu.GPUBlockSpec( @@ -1599,12 +1622,18 @@ def scope(acc_ref): a = jax.random.uniform(key1, shape=(64, 128), dtype=jnp.float16) b = jax.random.uniform(key2, shape=(128, 192), dtype=jnp.float16) - transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) - res = pl.pallas_call( + transforms = () + if self.THREAD_SEMANTICS == plgpu.ThreadSemantics.Lane: + transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) + res = self.pallas_call( kernel, in_specs=[ - plgpu.GPUBlockSpec((64, 128), lambda: (0, 0), transforms=transforms), - plgpu.GPUBlockSpec((128, 192), lambda: (0, 0), transforms=transforms), + plgpu.GPUBlockSpec( + (64, 128), lambda: (0, 0), transforms=transforms + ), + plgpu.GPUBlockSpec( + (128, 192), lambda: (0, 0), transforms=transforms + ), ], out_specs=plgpu.GPUBlockSpec((64, 192), lambda: (0, 0)), out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float32), @@ -1612,6 +1641,9 @@ def scope(acc_ref): np.testing.assert_allclose(res, a @ b, rtol=1e-3) def test_wgmma_registers_init(self): + # ``pl.run_state`` is not supported in WG semantics. + self.skip_if_wg_semantics() + def kernel(a_ref, b_ref, i_ref, o_ref): def scope(acc_ref): plgpu.wgmma(acc_ref, a_ref[...], b_ref) @@ -1623,12 +1655,18 @@ def scope(acc_ref): i = jax.random.uniform(key3, shape=(64, 192), dtype=jnp.float16) * 10 transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) - res = pl.pallas_call( + res = self.pallas_call( kernel, in_specs=[ - plgpu.GPUBlockSpec((64, 128), lambda: (0, 0), transforms=transforms), - plgpu.GPUBlockSpec((128, 192), lambda: (0, 0), transforms=transforms), - plgpu.GPUBlockSpec((64, 192), lambda: (0, 0), transforms=transforms), + plgpu.GPUBlockSpec( + (64, 128), lambda: (0, 0), transforms=transforms + ), + plgpu.GPUBlockSpec( + (128, 192), lambda: (0, 0), transforms=transforms + ), + plgpu.GPUBlockSpec( + (64, 192), lambda: (0, 0), transforms=transforms + ), ], out_specs=plgpu.GPUBlockSpec((64, 192), lambda: (0, 0)), out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float16), @@ -1636,6 +1674,8 @@ def scope(acc_ref): np.testing.assert_allclose(res, i + a @ b, rtol=2e-3) def test_wgmma_sliced_ref(self): + self.skip_if_wg_semantics() # Needs WGMMA to support slices. + def kernel(a_ref, b_ref, o_ref): def scope(acc_ref): plgpu.wgmma(acc_ref, a_ref.at[0], b_ref.at[0]) @@ -1647,22 +1687,18 @@ def scope(acc_ref): a = jax.random.uniform(key1, shape=(2, 64, 128), dtype=jnp.float16) b = jax.random.uniform(key2, shape=(2, 128, 192), dtype=jnp.float16) - res = pl.pallas_call( + transforms = () + if self.THREAD_SEMANTICS == plgpu.ThreadSemantics.Lane: + transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) + + res = self.pallas_call( kernel, in_specs=[ plgpu.GPUBlockSpec( - (2, 64, 128), lambda: (0, 0, 0), - transforms=( - plgpu.TilingTransform((8, 64)), - plgpu.SwizzleTransform(128), - ), + (2, 64, 128), lambda: (0, 0, 0), transforms=transforms ), plgpu.GPUBlockSpec( - (2, 128, 192), lambda: (0, 0, 0), - transforms=( - plgpu.TilingTransform((8, 64)), - plgpu.SwizzleTransform(128), - ), + (2, 128, 192), lambda: (0, 0, 0), transforms=transforms ), ], out_specs=plgpu.GPUBlockSpec((64, 192), lambda: (0, 0)), @@ -1671,6 +1707,8 @@ def scope(acc_ref): np.testing.assert_allclose(res, a[0] @ b[0], rtol=1e-3) def test_wgmma_sliced_acc(self): + self.skip_if_wg_semantics() # Needs WGMMA to support slices. + swizzle = 128 elems_128b = swizzle // jnp.dtype(jnp.float16).itemsize def kernel(a_ref, b_ref, o_ref): @@ -1683,38 +1721,41 @@ def scope(acc_ref): key1, key2 = jax.random.split(jax.random.key(42), 2) a = jax.random.uniform(key1, shape=(64, 128), dtype=jnp.float16) b = jax.random.uniform(key2, shape=(128, 128), dtype=jnp.float16) - res = pl.pallas_call( + transforms = () + if self.THREAD_SEMANTICS == plgpu.ThreadSemantics.Lane: + transforms = ( + plgpu.TilingTransform((8, elems_128b)), + plgpu.SwizzleTransform(128), + ) + res = self.pallas_call( kernel, in_specs=[ plgpu.GPUBlockSpec( - (64, 128), - lambda i, j: (i, j), - transforms=( - plgpu.TilingTransform((8, elems_128b)), - plgpu.SwizzleTransform(128), - ), + (64, 128), lambda *ij: ij, transforms=transforms ), plgpu.GPUBlockSpec( - (128, 128), - lambda *i: i, - transforms=( - plgpu.TilingTransform((8, elems_128b)), - plgpu.SwizzleTransform(128), - ), + (128, 128), lambda *ij: ij, transforms=transforms ), ], - out_specs=plgpu.GPUBlockSpec((64, 128), lambda *i: i), + out_specs=plgpu.GPUBlockSpec((64, 128), lambda *ij: ij), out_shape=jax.ShapeDtypeStruct((64, 128), jnp.float32), grid=(1, 1), )(a, b) np.testing.assert_allclose(res, a @ b, rtol=1e-3) +class PallasCallSm90AWGTest( + PallasCallSm90ATest, thread_semantics=plgpu.ThreadSemantics.Warpgroup +): + ... + + class PallasCallSm100ATest(PallasSm100ATest): def test_tmem_alloc(self): + @functools.partial( - plgpu.kernel, + self.kernel, out_shape=jnp.zeros((128, 128), jnp.float32), scratch_shapes=[ plgpu.TMEM((128, 128), jnp.float32), @@ -1734,6 +1775,12 @@ def kernel(y_ref, tmem_ref, smem_ref): jax.block_until_ready(kernel()) +class PallasCallSm100AWGTest( + PallasCallSm100ATest, thread_semantics=plgpu.ThreadSemantics.Warpgroup +): + ... + + class PipelineTest(PallasTest): def test_pipeline_mode(self): @@ -1755,13 +1802,13 @@ def body(x_ref, y_ref, o_ref): @jax.jit def vadd(x, y): - return pl.pallas_call( - body, - out_shape=jax.ShapeDtypeStruct(x.shape, jnp.float32), - in_specs=in_specs, - out_specs=out_specs, - grid=data_size // block_size, - )(x, y) + return self.pallas_call( + body, + out_shape=jax.ShapeDtypeStruct(x.shape, jnp.float32), + in_specs=in_specs, + out_specs=out_specs, + grid=data_size // block_size, + )(x, y) with self.assertRaisesRegex(Exception, "Pipeline mode is not supported"): vadd(x, y) @@ -1823,7 +1870,7 @@ def body(step, _): plgpu.wait_smem_to_gmem(0) x = jnp.arange(32 * 4 * 64).reshape(32 * 4, 64).astype(jnp.float32) - kernel_fn = pl.pallas_call( + kernel_fn = self.pallas_call( kernel, in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), @@ -1837,6 +1884,9 @@ def body(step, _): ((plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(128)),), ) def test_emit(self, transforms): + if transforms: + self.skip_if_wg_semantics() + num_steps = 4 def kernel(x_gmem, o_gmem): @@ -1863,7 +1913,7 @@ def kernel_body(x_smem, o_smem): x = jnp.arange(64 * num_steps * 64) x = x.reshape(-1, num_steps * 64).astype(jnp.float32) - kernel_fn = pl.pallas_call( + kernel_fn = self.pallas_call( kernel, in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), @@ -1896,7 +1946,7 @@ def nested_kernel_body(x_smem, o_smem): x = jnp.arange(32 * num_steps * 16) x = x.reshape(-1, num_steps * 16).astype(jnp.float32) - kernel_fn = pl.pallas_call( + kernel_fn = self.pallas_call( kernel, in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), @@ -1921,7 +1971,7 @@ def kernel_body(x_smem, o_smem): x = jnp.arange(32 * num_steps * 16) x = x.reshape(-1, num_steps * 16).astype(jnp.float32) - kernel_fn = pl.pallas_call( + kernel_fn = self.pallas_call( kernel, in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), @@ -1954,7 +2004,7 @@ def kernel_body(x_smem, o_smem): x = jnp.arange(num_steps1 * 32 * num_steps2 * 16) x = x.reshape(-1, num_steps2 * 16).astype(jnp.float32) - kernel_fn = pl.pallas_call( + kernel_fn = self.pallas_call( kernel, in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), @@ -1982,7 +2032,7 @@ def kernel_body(x_smem, o_smem): x = jnp.arange(32 * num_steps1 * 16 * num_steps2 * 8) x = x.reshape(-1, num_steps1 * 16, num_steps2 * 8).astype(jnp.float32) - kernel_fn = pl.pallas_call( + kernel_fn = self.pallas_call( kernel, in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), @@ -1991,9 +2041,17 @@ def kernel_body(x_smem, o_smem): np.testing.assert_array_equal(kernel_fn(x), x + 1.0) +class PipelineWGTest( + PipelineTest, thread_semantics=plgpu.ThreadSemantics.Warpgroup +): + ... + + class PipelineSm90ATest(PallasSm90ATest): def test_realistic_matmul(self): + self.skip_if_wg_semantics() # Needs WGMMA to support slices. + dtype = jnp.float16 swizzle = 128 elems_128b = swizzle // jnp.dtype(dtype).itemsize @@ -2003,6 +2061,13 @@ def test_realistic_matmul(self): tile_k = elems_128b m, k, n = grid_m * tile_m, grid_k * tile_k, grid_n * tile_n + transforms = () + if self.THREAD_SEMANTICS == plgpu.ThreadSemantics.Lane: + transforms = ( + plgpu.TilingTransform((8, elems_128b)), + plgpu.SwizzleTransform(128), + ) + def kernel(a_gmem, b_gmem, o_smem, acc): def kernel_body(a_smem, b_smem): assert a_smem.shape == (tile_m, tile_k) @@ -2016,21 +2081,11 @@ def kernel_body(a_smem, b_smem): kernel_body, in_specs=[ plgpu.GPUBlockSpec( - (tile_m, tile_k), - lambda k: (pid_m, k), - transforms=( - plgpu.TilingTransform((8, elems_128b)), - plgpu.SwizzleTransform(128), - ), - ), - plgpu.GPUBlockSpec( - (tile_k, tile_n), - lambda k: (k, pid_n), - transforms=( - plgpu.TilingTransform((8, elems_128b)), - plgpu.SwizzleTransform(128), - ), - ), + (tile_m, tile_k), lambda k: (pid_m, k), transforms=transforms + ), + plgpu.GPUBlockSpec( + (tile_k, tile_n), lambda k: (k, pid_n), transforms=transforms + ), ], grid=(grid_k,), max_concurrent_steps=2, @@ -2043,19 +2098,14 @@ def kernel_body(a_smem, b_smem): a = jax.random.uniform(key1, shape=(m, k), dtype=dtype) b = jax.random.uniform(key2, shape=(k, n), dtype=dtype) - res = pl.pallas_call( + res = self.pallas_call( kernel, in_specs=[ pl.BlockSpec(memory_space=plgpu.GMEM), - pl.BlockSpec(memory_space=plgpu.GMEM) + pl.BlockSpec(memory_space=plgpu.GMEM), ], out_specs=plgpu.GPUBlockSpec( - (tile_m, tile_n), - lambda m, n: (m, n), - transforms=( - plgpu.TilingTransform((8, elems_128b)), - plgpu.SwizzleTransform(128), - ), + (tile_m, tile_n), lambda m, n: (m, n), transforms=transforms ), out_shape=jax.ShapeDtypeStruct((m, n), jnp.float16), scratch_shapes=[plgpu.ACC((tile_m, tile_n), jnp.float32)], @@ -2064,11 +2114,19 @@ def kernel_body(a_smem, b_smem): np.testing.assert_array_equal(res, a @ b) +class PipelineSm90AWGTest( + PipelineSm90ATest, thread_semantics=plgpu.ThreadSemantics.Warpgroup +): + ... + + class WarpSpecializedPipelineTest(PallasTest): @parameterized.product(m=[512], n=[512], manual_consumed_barriers=[False, True]) def test_pipelined_copy(self, m, n, manual_consumed_barriers): + self.skip_if_wg_semantics() # Times out! + x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float16) blk_m = blk_n = 64 @@ -2102,7 +2160,7 @@ def copy_kernel(x_smem, o_smem, o_last_block_smem, *consumed_barriers): ), ], ) - kernel = plgpu.kernel( + kernel = self.kernel( pipeline, out_shape=( jax.ShapeDtypeStruct((m, n), jnp.float16), @@ -2119,6 +2177,8 @@ def copy_kernel(x_smem, o_smem, o_last_block_smem, *consumed_barriers): np.testing.assert_array_equal(out_last_block, x[-blk_m:, -blk_n:]) def test_elementwise_add(self, m=256, n=256, num_compute_wgs=2): + self.skip_if_wg_semantics() # Crashes! + blk_m = blk_n = 64 spec = pl.BlockSpec( block_shape=(blk_m, blk_n), index_map=lambda i, j: (i, j) @@ -2140,7 +2200,7 @@ def tiled_add_kernel(x_smem, y_smem, o_smem): in_specs=[spec, spec], out_specs=[spec], ) - kernel = plgpu.kernel( + kernel = self.kernel( pipeline, out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32), compiler_params=plgpu.GPUCompilerParams(approx_math=True), @@ -2154,10 +2214,12 @@ def tiled_add_kernel(x_smem, y_smem, o_smem): np.testing.assert_allclose(kernel(x, y), x + y, atol=1e-4) def test_carry_accumulate(self, m=256, n=256, num_compute_wgs=2): + self.skip_if_wg_semantics() # `plgpu.layout_cast` is not supported. + blk_m = blk_n = 64 @functools.partial( - plgpu.kernel, + self.kernel, out_shape=jax.ShapeDtypeStruct((blk_m, blk_n), jnp.float32), scratch_shapes=[ plgpu.SMEM((blk_m, blk_n), jnp.float32), @@ -2214,11 +2276,19 @@ def tiled_acc_kernel(x_smem, carry): np.testing.assert_allclose(kernel(x), ref, atol=1e-4) +class WarpSpecializedPipelineWGTest( + WarpSpecializedPipelineTest, + thread_semantics=plgpu.ThreadSemantics.Warpgroup, +): + ... + + class CoreMapTest(PallasTest): def test_multiple_wg(self): + @functools.partial( - plgpu.kernel, + self.kernel, out_shape=jnp.zeros((2, 128), np.int32), num_threads=2, thread_name="wg", @@ -2232,8 +2302,9 @@ def kernel(o_ref): ) def test_multiple_wg_with_grid(self): + @functools.partial( - plgpu.kernel, + self.kernel, out_shape=jnp.zeros((4, 2, 128), np.int32), grid=(2, 2), grid_names=("x", "y"), @@ -2263,7 +2334,7 @@ def test_multiple_wg_with_squashed_grid(self): num_threads = 2 @functools.partial( - plgpu.kernel, + self.kernel, out_shape=jnp.zeros( (b, x_dim, y_dim, z_dim, num_threads, 128), np.int32 ), @@ -2290,8 +2361,10 @@ def kernel(o_ref): np.testing.assert_array_equal(result, ref) def test_cross_wg_barrier(self): + self.skip_if_wg_semantics() # Times out! + @functools.partial( - plgpu.kernel, + self.kernel, out_shape=jnp.zeros((2, 128), np.int32), # Each warpgroup is a single logical thread! scratch_shapes=[plgpu.Barrier(num_arrivals=2)], @@ -2309,8 +2382,10 @@ def kernel(o_ref, barrier): ) def test_cluster(self): + self.skip_if_wg_semantics() # Needs debug_print in the MGPU dialect. + @functools.partial( - plgpu.kernel, + self.kernel, out_shape=jnp.zeros(128, np.int32), grid=(2,), grid_names=("x",), @@ -2337,6 +2412,8 @@ def kernel(ref): ) def test_realistic_matmul_with_cluster(self): + self.skip_if_wg_semantics() # Needs WGMMA to support slices. + dtype = jnp.float16 swizzle = 128 elems_128b = swizzle // jnp.dtype(dtype).itemsize @@ -2361,7 +2438,7 @@ def test_realistic_matmul_with_cluster(self): delay_release = 1 @functools.partial( - plgpu.kernel, + self.kernel, out_shape=jax.ShapeDtypeStruct((m, n), dtype), scratch_shapes=[ plgpu.SMEM( @@ -2458,34 +2535,44 @@ def body(step, _): np.testing.assert_array_equal(kernel(a, b), a @ b) +class CoreMapWGTest( + CoreMapTest, thread_semantics=plgpu.ThreadSemantics.Warpgroup +): + ... + + class ExamplesTest(PallasTest): # Basic def test_stage0(self): - def body(l_ref, r_ref, o_ref): + x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) + + @functools.partial(self.kernel, out_shape=x) + def kernel(l_ref, r_ref, o_ref): o_ref[...] = l_ref[...] + r_ref[...] - x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) - out = plgpu.kernel(body, out_shape=x)(x, x) - np.testing.assert_allclose(out, x + x) + np.testing.assert_allclose(kernel(x, x), x + x) # Multi-block kernels def test_stage1(self): row_block = 64 - def body(l_ref, r_ref, o_ref): + x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) + + @functools.partial( + self.kernel, out_shape=x, grid=(2,), grid_names=("rows",) + ) + def kernel(l_ref, r_ref, o_ref): my_slice = pl.ds(lax.axis_index("rows") * row_block, row_block) o_ref[my_slice] = l_ref[my_slice] + r_ref[my_slice] - x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) - out = plgpu.kernel(body, out_shape=x, grid=(2,), grid_names=("rows",))(x, x) - np.testing.assert_allclose(out, x + x) + np.testing.assert_allclose(kernel(x, x), x + x) # Async copies def test_stage3(self): row_block, col_block = 64, 128 @functools.partial( - plgpu.kernel, + self.kernel, out_shape=jax.ShapeDtypeStruct((128, 128), jnp.float16), scratch_shapes=[ *([plgpu.SMEM((row_block, col_block), jnp.float16)] * 3), @@ -2510,7 +2597,12 @@ def kernel(l_ref, r_ref, o_ref, l_smem, r_smem, o_smem, barrier): # Pipelining def test_stage4(self): row_block, col_block = 64, 32 - def body(l_ref, r_ref, o_ref): + x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) + + @functools.partial( + self.kernel, out_shape=x, grid=(2,), grid_names=("rows",) + ) + def kernel(l_ref, r_ref, o_ref): def compute(l_smem, r_smem, o_smem): o_smem[...] = l_smem[...] + r_smem[...] r = lax.axis_index("rows") @@ -2522,14 +2614,19 @@ def compute(l_smem, r_smem, o_smem): out_specs=[block], )(l_ref, r_ref, o_ref) - x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) - out = plgpu.kernel(body, out_shape=x, grid=(2,), grid_names=("rows",))(x, x) - np.testing.assert_allclose(out, x + x) + np.testing.assert_allclose(kernel(x, x), x + x) # Transforms def test_stage5(self): + self.skip_if_wg_semantics() # Needs WGMMA to support slices. + row_block, col_block = 64, 32 - def body(l_ref, r_ref, o_ref): + x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) + + @functools.partial( + self.kernel, out_shape=x, grid=(2,), grid_names=("rows",) + ) + def kernel(l_ref, r_ref, o_ref): def compute(l_smem, r_smem, o_smem): o_smem[...] = l_smem[...] + r_smem[...] r = lax.axis_index("rows") @@ -2544,9 +2641,7 @@ def compute(l_smem, r_smem, o_smem): out_specs=[block], )(l_ref, r_ref, o_ref) - x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) - out = plgpu.kernel(body, out_shape=x, grid=(2,), grid_names=("rows",))(x, x) - np.testing.assert_allclose(out, x + x) + np.testing.assert_allclose(kernel(x, x), x + x) def test_semaphore_lowering(self): # This is a smoke test until we add support for lowering of semaphore ops. @@ -2556,8 +2651,10 @@ def body(i_ref1, i_ref2, o_ref, sem_ref): assert jnp.issubdtype(sem_ref.dtype, pl.semaphore) o_ref[...] = i_ref1[...] x = jnp.arange(128, dtype=jnp.float32).reshape((128,)) - kernel = pl.pallas_call( - body, out_shape=x, scratch_shapes=[plgpu.SemaphoreType.REGULAR((4,))], + kernel = self.pallas_call( + body, + out_shape=x, + scratch_shapes=[plgpu.SemaphoreType.REGULAR((4,))], ) text = jax.jit(kernel).lower(x, x).as_text() self.assertIn( @@ -2573,19 +2670,33 @@ def body(i_ref1, i_ref2, o_ref, sem_ref): ) +class ExamplesWGTest( + ExamplesTest, thread_semantics=plgpu.ThreadSemantics.Warpgroup +): + ... + + class ExamplesSm90ATest(PallasSm90ATest): # WGMMA def test_stage6(self): + self.skip_if_wg_semantics() # Needs WGMMA to support slices. + m_block = n_block = 64 k_block = 32 - def body(l_ref, r_ref, o_ref): + x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) + + @functools.partial( + self.kernel, out_shape=x, grid=(2, 2), grid_names=("m", "n") + ) + def kernel(l_ref, r_ref, o_ref): def compute(l_smem, r_smem, o_smem): def do_wgmma(acc_ref): plgpu.wgmma(acc_ref, l_smem, r_smem) return acc_ref[...] o_smem[...] += pl.run_scoped(do_wgmma, plgpu.ACC((m_block, n_block), jnp.float16)) - m, n = lax.axis_index("m"), lax.axis_index("n") + m = lax.axis_index("m") + n = lax.axis_index("n") lo_transforms = (plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(64)) r_transforms = (plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(64)) plgpu.emit_pipeline( @@ -2596,12 +2707,16 @@ def do_wgmma(acc_ref): out_specs=[plgpu.GPUBlockSpec((m_block, n_block), lambda k: (m, n), transforms=lo_transforms)], )(l_ref, r_ref, o_ref) - x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) - out = plgpu.kernel(body, out_shape=x, grid=(2, 2), grid_names=("m", "n"))(x, x) - np.testing.assert_allclose(out, x @ x) + np.testing.assert_allclose(kernel(x, x), x @ x) # TODO(apaszke): Clusters and multicast +class ExamplesSm90AWGTest( + ExamplesSm90ATest, thread_semantics=plgpu.ThreadSemantics.Warpgroup +): + ... + + if __name__ == "__main__": absltest.main() From 12526ea11646a75fac201e26c1a2e901f94a4c76 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Mon, 31 Mar 2025 07:08:48 -0700 Subject: [PATCH 290/483] [jaxlib] Pack/unpack subbyte types to/from numpy arrays to support int2, uint2, int4, uint4, float4_e2m1fn subbyte types in CPU/GPU callbacks. PiperOrigin-RevId: 742253272 --- jaxlib/cuda/BUILD | 1 + jaxlib/gpu/py_client_gpu.cc | 89 +++++++++++++++++++++------------ jaxlib/rocm/BUILD | 1 + jaxlib/xla/BUILD | 1 + jaxlib/xla/py_client_cpu.cc | 81 ++++++++++++++++++++---------- tests/python_callback_test.py | 94 ++++++++++++++++++++--------------- 6 files changed, 168 insertions(+), 99 deletions(-) diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index fac62c81dee7..d35e421ef904 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -689,6 +689,7 @@ cc_library( "@xla//third_party/python_runtime:headers", # buildcleaner: keep "@xla//xla:comparison_util", "@xla//xla:shape_util", + "@xla//xla:util", "@xla//xla:xla_data_proto_cc", "@xla//xla/ffi:ffi_api", "@xla//xla/ffi/api:ffi", diff --git a/jaxlib/gpu/py_client_gpu.cc b/jaxlib/gpu/py_client_gpu.cc index 861ffce3e749..38f2ac1896e7 100644 --- a/jaxlib/gpu/py_client_gpu.cc +++ b/jaxlib/gpu/py_client_gpu.cc @@ -43,6 +43,7 @@ limitations under the License. #include "xla/python/nb_numpy.h" #include "xla/python/types.h" #include "xla/shape_util.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" namespace nb = nanobind; @@ -80,13 +81,14 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, for (size_t i = 0; i < arity; ++i) { auto arg = args.get(i); auto ptype = static_cast(arg->element_type()); - // TODO(b/395428868): Remove this check once we support subbyte types. - if (ptype == xla::S1 || ptype == xla::S2 || ptype == xla::S4 || - ptype == xla::U1 || ptype == xla::U2 || ptype == xla::U4) { + // TODO(b/402422886): Remove this once we form Jax arrays directly instead + // of packing/unpacking to/from numpy arrays. + if (ptype == xla::S1 || ptype == xla::U1) { return xla::ffi::Error(xla::ffi::ErrorCode::kUnimplemented, absl::StrFormat("Unsupported primitive type: %s", PrimitiveType_Name(ptype))); } + if (ptype == xla::TOKEN) { host_input_buffers[i] = nullptr; continue; @@ -112,9 +114,6 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, PyTuple_SET_ITEM(host_input_arrays.ptr(), i, nb::none().inc_ref().ptr()); continue; } - nb::capsule base(host_input_buffers[i], [](void* ptr) noexcept { - delete[] static_cast(ptr); - }); auto maybe_dtype = PrimitiveTypeToNbDtype(ptype); if (!maybe_dtype.ok()) { return xla::ffi::Error::Internal(maybe_dtype.status().ToString()); @@ -122,8 +121,22 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, auto dtype = maybe_dtype.value(); auto dims = absl::Span(arg->dimensions().begin(), arg->dimensions().size()); + // TODO(b/402422886): Remove this once we form Jax arrays directly instead + // of packing/unpacking to/from numpy arrays. + // We pass in data using default numpy layout i.e., std::nullopt. + size_t bits_per_element = xla::primitive_util::BitWidth(ptype); + if (bits_per_element == 2 || bits_per_element == 4) { + auto buffer = xla::UnpackIntN( + bits_per_element, static_cast(host_input_buffers[i]), + arg->size_bytes()); + delete[] static_cast(host_input_buffers[i]); + host_input_buffers[i] = buffer.release(); + } + nb::capsule base(host_input_buffers[i], [](void* ptr) noexcept { + delete[] static_cast(ptr); + }); auto array = xla::nb_numpy_ndarray(dtype, dims, std::nullopt, - host_input_buffers[i], base); + host_input_buffers[i], /*base=*/base); array.attr("flags").attr("writeable") = nb::bool_(false); PyTuple_SET_ITEM(host_input_arrays.ptr(), i, array.inc_ref().ptr()); } @@ -146,8 +159,7 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, auto ret = rets.get(i).value(); auto ptype = static_cast(ret->element_type()); // TODO(b/395428868): Remove this check once we support subbyte types. - if (ptype == xla::S1 || ptype == xla::S2 || ptype == xla::S4 || - ptype == xla::U1 || ptype == xla::U2 || ptype == xla::U4) { + if (ptype == xla::S1 || ptype == xla::U1) { return xla::ffi::Error(xla::ffi::ErrorCode::kUnimplemented, absl::StrFormat("Unsupported primitive type: %s", PrimitiveType_Name(ptype))); @@ -168,32 +180,43 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, } auto expected_shape = maybe_expected_shape.value(); auto expected_strides = xla::ByteStridesForShape(expected_shape); - if (strides == expected_strides) { - auto gpu_res = - gpuMemcpyAsync(ret->untyped_data(), array.data(), ret->size_bytes(), - gpuMemcpyHostToDevice, stream); - CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; - continue; + + const void* data = array.data(); + if (strides != expected_strides) { + xla::TransposePlan::Options options; + options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); + options.dims = absl::Span( + reinterpret_cast(array.shape()), array.ndim()); + absl::InlinedVector reversed_layout; + reversed_layout.resize(expected_shape.dimensions().size()); + absl::c_reverse_copy(expected_shape.layout().minor_to_major(), + reversed_layout.begin()); + options.permutation = reversed_layout; + options.input_layout = xla::TransposePlan::Striding{strides}; + auto maybe_plan = transpose_cache->cache.GetOrCreate(options); + if (!maybe_plan.ok()) { + return xla::ffi::Error::Internal(maybe_plan.status().ToString()); + } + auto plan = maybe_plan.value(); + void* temp = new char[ret->size_bytes()]; + temp_buffers.push_back(temp); + plan->Execute(data, temp); + data = temp; } - void* temp = new char[ret->size_bytes()]; - temp_buffers.push_back(temp); - xla::TransposePlan::Options options; - options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); - options.dims = absl::Span( - reinterpret_cast(array.shape()), array.ndim()); - absl::InlinedVector reversed_layout; - reversed_layout.resize(expected_shape.dimensions().size()); - absl::c_reverse_copy(expected_shape.layout().minor_to_major(), - reversed_layout.begin()); - options.permutation = reversed_layout; - options.input_layout = xla::TransposePlan::Striding{strides}; - auto maybe_plan = transpose_cache->cache.GetOrCreate(options); - if (!maybe_plan.ok()) { - return xla::ffi::Error::Internal(maybe_plan.status().ToString()); + + // TODO(b/402422886): Remove this once we form Jax arrays directly instead + // of packing/unpacking to/from numpy arrays. + std::unique_ptr buffer; + size_t bits_per_element = xla::primitive_util::BitWidth(ptype); + if (bits_per_element == 2 || bits_per_element == 4) { + // NOTE(dsuo): FFI arguments and return buffers are sized assuming + // minimum 1-byte element sizes, even if the data itself is packed. + buffer = xla::PackIntN(bits_per_element, static_cast(data), + ret->size_bytes()); + data = buffer.get(); } - auto plan = maybe_plan.value(); - plan->Execute(array.data(), temp); - auto gpu_res = gpuMemcpyAsync(ret->untyped_data(), temp, ret->size_bytes(), + + auto gpu_res = gpuMemcpyAsync(ret->untyped_data(), data, ret->size_bytes(), gpuMemcpyHostToDevice, stream); CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; } diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index d0c0c798abb8..358a6d1cc9aa 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -588,6 +588,7 @@ cc_library( "@xla//third_party/python_runtime:headers", # buildcleaner: keep "@xla//xla:comparison_util", "@xla//xla:shape_util", + "@xla//xla:util", "@xla//xla:xla_data_proto_cc", "@xla//xla/ffi:ffi_api", "@xla//xla/ffi/api:ffi", diff --git a/jaxlib/xla/BUILD b/jaxlib/xla/BUILD index 2ca18afda13d..5b532c1dc501 100644 --- a/jaxlib/xla/BUILD +++ b/jaxlib/xla/BUILD @@ -637,6 +637,7 @@ cc_library( "@nanobind", "@xla//third_party/python_runtime:headers", # buildcleaner: keep "@xla//xla:shape_util", + "@xla//xla:util", "@xla//xla:xla_data_proto_cc", "@xla//xla/ffi:ffi_api", "@xla//xla/ffi/api:ffi", diff --git a/jaxlib/xla/py_client_cpu.cc b/jaxlib/xla/py_client_cpu.cc index ac4e7bee5680..fc4f895af6aa 100644 --- a/jaxlib/xla/py_client_cpu.cc +++ b/jaxlib/xla/py_client_cpu.cc @@ -41,6 +41,7 @@ limitations under the License. #include "xla/python/nb_numpy.h" #include "xla/python/types.h" #include "xla/shape_util.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" namespace nb = nanobind; @@ -78,9 +79,9 @@ ffi::Error XlaFfiPythonCpuCallback(FfiLoadedHostCallbacks* callbacks, for (size_t i = 0; i < args.size(); ++i) { auto arg = args.get(i); auto ptype = static_cast(arg->element_type()); - // TODO(b/395428868): Remove this check once we support subbyte types. - if (ptype == S1 || ptype == S2 || ptype == S4 || ptype == U1 || - ptype == U2 || ptype == U4) { + // TODO(b/402422886): Remove this once we form Jax arrays directly instead + // of packing/unpacking to/from numpy arrays. + if (ptype == S1 || ptype == U1) { return ffi::Error(ffi::ErrorCode::kUnimplemented, absl::StrFormat("Unsupported primitive type: %s", PrimitiveType_Name(ptype))); @@ -96,9 +97,18 @@ ffi::Error XlaFfiPythonCpuCallback(FfiLoadedHostCallbacks* callbacks, auto dtype = maybe_dtype.value(); auto dims = absl::Span(arg->dimensions().begin(), arg->dimensions().size()); + // TODO(b/402422886): Remove this once we form Jax arrays directly instead + // of packing/unpacking to/from numpy arrays. + std::unique_ptr buffer; + const void* data = arg->untyped_data(); + size_t bits_per_element = xla::primitive_util::BitWidth(ptype); + if (bits_per_element == 2 || bits_per_element == 4) { + buffer = xla::UnpackIntN(bits_per_element, static_cast(data), + arg->size_bytes()); + data = buffer.get(); + } // We pass in data using default numpy layout i.e., std::nullopt. - auto array = - nb_numpy_ndarray(dtype, dims, std::nullopt, arg.value().untyped_data()); + auto array = nb_numpy_ndarray(dtype, dims, std::nullopt, data); array.attr("flags").attr("writeable") = nb::bool_(false); PyTuple_SET_ITEM(nb_args.ptr(), i, array.release().ptr()); } @@ -119,9 +129,9 @@ ffi::Error XlaFfiPythonCpuCallback(FfiLoadedHostCallbacks* callbacks, for (size_t i = 0; i < rets.size(); ++i) { auto ret = rets.get(i).value(); auto ptype = static_cast(ret->element_type()); - // TODO(b/395428868): Remove this check once we support subbyte types. - if (ptype == S1 || ptype == S2 || ptype == S4 || ptype == U1 || - ptype == U2 || ptype == U4) { + // TODO(b/402422886): Remove this once we form Jax arrays directly instead + // of packing/unpacking to/from numpy arrays. + if (ptype == S1 || ptype == U1) { return ffi::Error(ffi::ErrorCode::kUnimplemented, absl::StrFormat("Unsupported primitive type: %s", PrimitiveType_Name(ptype))); @@ -141,26 +151,45 @@ ffi::Error XlaFfiPythonCpuCallback(FfiLoadedHostCallbacks* callbacks, } auto expected_shape = maybe_expected_shape.value(); auto expected_strides = ByteStridesForShape(expected_shape); - if (strides == expected_strides) { - std::memcpy(ret->untyped_data(), array.data(), ret->size_bytes()); - continue; + + const void* data = array.data(); + if (strides != expected_strides) { + xla::TransposePlan::Options options; + options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); + options.dims = absl::Span( + reinterpret_cast(array.shape()), array.ndim()); + absl::InlinedVector reversed_layout; + reversed_layout.resize(expected_shape.dimensions().size()); + absl::c_reverse_copy(expected_shape.layout().minor_to_major(), + reversed_layout.begin()); + options.permutation = reversed_layout; + options.input_layout = xla::TransposePlan::Striding{strides}; + auto maybe_plan = transpose_cache->cache.GetOrCreate(options); + if (!maybe_plan.ok()) { + return ffi::Error::Internal(maybe_plan.status().ToString()); + } + auto plan = maybe_plan.value(); + plan->Execute(data, ret->untyped_data()); + data = ret->untyped_data(); } - xla::TransposePlan::Options options; - options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); - options.dims = absl::Span( - reinterpret_cast(array.shape()), array.ndim()); - absl::InlinedVector reversed_layout; - reversed_layout.resize(expected_shape.dimensions_size()); - absl::c_reverse_copy(expected_shape.layout().minor_to_major(), - reversed_layout.begin()); - options.permutation = reversed_layout; - options.input_layout = xla::TransposePlan::Striding{strides}; - auto maybe_plan = transpose_cache->cache.GetOrCreate(options); - if (!maybe_plan.ok()) { - return ffi::Error::Internal(maybe_plan.status().ToString()); + + // TODO(b/402422886): Remove this once we form Jax arrays directly instead + // of packing/unpacking to/from numpy arrays. + std::unique_ptr buffer; + size_t bits_per_element = xla::primitive_util::BitWidth(ptype); + if (bits_per_element == 2 || bits_per_element == 4) { + // NOTE(dsuo): FFI arguments and return buffers are sized assuming + // minimum 1-byte element sizes, even if the data itself is packed. + buffer = xla::PackIntN(bits_per_element, static_cast(data), + ret->size_bytes()); + data = buffer.get(); + } + + // Copy data to output buffer if haven't already or modified the data to + // write back. + if (data != ret->untyped_data()) { + std::memcpy(ret->untyped_data(), data, ret->size_bytes()); } - auto plan = maybe_plan.value(); - plan->Execute(array.data(), ret->untyped_data()); } return ffi::Error::Success(); diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index a8442b4a1356..34ab20c05644 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -586,10 +586,15 @@ def fun(x): self.assertAllClose(2 * x, fun(x)) self.assertEqual(count(), 1) - @parameterized.parameters("int2", "int4", "uint2", "uint4") + @parameterized.parameters("int2", "int4", "uint2", "uint4", "float4_e2m1fn") def test_subbyte_operands(self, dtype: str): - if jaxlib_extension_version <= 321: - self.skipTest("Requires jaxlib_extension_version >= 322.") + if jaxlib_extension_version < 323: + self.skipTest("Requires jaxlib_extension_version >= 323.") + if "2" in dtype and jtu.test_device_matches(["tpu"]): + self.skipTest( + "TODO(dsuo): TPU callbacks send SIGABRT for int2, uint2, and" + " float4_e2m1fn." + ) def get(x): return x def f(x): @@ -600,21 +605,17 @@ def f(x): ) return y x = np.arange(8, dtype=dtype) - # TODO(b/395428868): Remove this check once we support subbyte types. - if jtu.test_device_matches(["tpu"]): - if "2" in dtype: - self.skipTest("TODO(dsuo): TPU callbacks send SIGABRT for int2/uint2.") - np.testing.assert_array_equal(jax.jit(f)(x), np.arange(8, dtype=dtype)) - else: - with self.assertRaisesRegex( - Exception, "Unsupported primitive type" - ): - _ = jax.jit(f)(x) + np.testing.assert_array_equal(jax.jit(f)(x), np.arange(8, dtype=dtype)) - @parameterized.parameters("int2", "int4", "uint2", "uint4") + @parameterized.parameters("int2", "int4", "uint2", "uint4", "float4_e2m1fn") def test_subbyte_results(self, dtype: str): - if jaxlib_extension_version <= 321: - self.skipTest("Requires jaxlib_extension_version >= 322.") + if jaxlib_extension_version < 323: + self.skipTest("Requires jaxlib_extension_version >= 323.") + if "2" in dtype and jtu.test_device_matches(["tpu"]): + self.skipTest( + "TODO(dsuo): TPU callbacks send SIGABRT for int2, uint2, and" + " float4_e2m1fn." + ) def get(): return np.arange(8, dtype=dtype) @@ -625,16 +626,43 @@ def f(): ) return y - # TODO(b/395428868): Remove this check once we support subbyte types. - if jtu.test_device_matches(["tpu"]): - if "2" in dtype: - self.skipTest("TODO(dsuo): TPU callbacks send SIGABRT for int2/uint2.") - np.testing.assert_array_equal(jax.jit(f)(), np.arange(8, dtype=dtype)) - else: - with self.assertRaisesRegex( - Exception, "Unsupported primitive type" - ): - _ = jax.jit(f)() + np.testing.assert_array_equal(jax.jit(f)(), np.arange(8, dtype=dtype)) + + @parameterized.parameters("int2", "int4", "uint2", "uint4", "float4_e2m1fn") + def test_non_default_stride_subbyte_results(self, dtype: str): + if jaxlib_extension_version < 323: + self.skipTest("Requires jaxlib_extension_version >= 323.") + if "2" in dtype and jtu.test_device_matches(["tpu"]): + self.skipTest( + "TODO(dsuo): TPU callbacks send SIGABRT for int2, uint2, and" + " float4_e2m1fn." + ) + x = jnp.arange(24, dtype=dtype).reshape(2, 3, 4) + def callback(x): + return np.asfortranarray(x) + + @jax.jit + def f(x): + return jax.pure_callback( + callback, jax.ShapeDtypeStruct(x.shape, x.dtype), x + ) + + result = f(x) + np.testing.assert_array_equal(x, result) + + def test_non_default_stride(self): + x = jnp.arange(24, dtype=jnp.float32).reshape(2, 3, 4) + def callback(x): + return np.asfortranarray(x) + + @jax.jit + def f(x): + return jax.pure_callback( + callback, jax.ShapeDtypeStruct(x.shape, x.dtype), x + ) + + result = f(x) + np.testing.assert_array_equal(x, result) class PureCallbackTest(jtu.JaxTestCase): @@ -1108,20 +1136,6 @@ def fun(x): result += fun(jnp.ones((500, 500), jnp.complex64))[1] jax.block_until_ready(result) # doesn't deadlock - def test_non_default_stride(self): - x = jnp.arange(24, dtype=jnp.float32).reshape(2, 3, 4) - def callback(x): - return np.asfortranarray(x) - - @jax.jit - def f(x): - return jax.pure_callback( - callback, jax.ShapeDtypeStruct(x.shape, x.dtype), x - ) - - result = f(x) - np.testing.assert_array_equal(x, result) - class IOCallbackTest(jtu.JaxTestCase): From b3d851d722ea5efb893d96b3c03a739ba9763bd0 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Mon, 31 Mar 2025 07:34:32 -0700 Subject: [PATCH 291/483] Add Jax tracing micro benchmarks. Add a first benchmark for tracing/lowering pallas splash attention. Sample results below taken on a GCP n2d-standard-128 instance with 512GB Ram and 128 vCPU AMD EPYC Milan. --------------------------------------------------------------------------------- Benchmark Time CPU Iterations --------------------------------------------------------------------------------- test_pallas_mqa_splash_attention_trace 39.8 ms 39.8 ms 19 test_pallas_mqa_splash_attention_lower 42.1 ms 41.9 ms 18 PiperOrigin-RevId: 742259409 --- benchmarks/tracing_benchmark.py | 76 +++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 benchmarks/tracing_benchmark.py diff --git a/benchmarks/tracing_benchmark.py b/benchmarks/tracing_benchmark.py new file mode 100644 index 000000000000..e06ad538d476 --- /dev/null +++ b/benchmarks/tracing_benchmark.py @@ -0,0 +1,76 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Benchmarks for Jax tracing.""" + +import google_benchmark +import jax +from jax import random +from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel as splash +from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask as mask_lib +import numpy as np + + +def make_mqa_splash_attention_fn_and_args(): + seed = 0 + key = random.key(seed) + k1, k2, k3 = random.split(key, 3) + + q_seq_len = 1024 + kv_seq_len = 1024 + num_q_heads = 2 + head_dim_qk = 128 + head_dim_v = 128 + dtype = np.dtype("float32") + + q = random.uniform(k1, (num_q_heads, q_seq_len, head_dim_qk), dtype=dtype) + k = random.uniform(k2, (kv_seq_len, head_dim_qk), dtype=dtype) + v = random.uniform(k3, (kv_seq_len, head_dim_v), dtype=dtype) + + mask = mask_lib.NumpyMask( + mask_lib.make_random_mask((q_seq_len, kv_seq_len), sparsity=0.5, seed=0) + ) + mask = mask_lib.MultiHeadMask(tuple(mask for _ in range(num_q_heads))) + block_sizes = splash.BlockSizes.get_default() + + return ( + jax.jit( + splash.make_splash_mqa_single_device(mask, block_sizes=block_sizes) + ) + ), (q, k, v) + + +@google_benchmark.register +@google_benchmark.option.unit(google_benchmark.kMillisecond) +def test_pallas_mqa_splash_attention_trace(state): + attn, (q, k, v) = make_mqa_splash_attention_fn_and_args() + + while state: + _ = attn.trace(q, k, v) + jax.clear_caches() + + +@google_benchmark.register +@google_benchmark.option.unit(google_benchmark.kMillisecond) +def test_pallas_mqa_splash_attention_lower(state): + attn, (q, k, v) = make_mqa_splash_attention_fn_and_args() + traced = attn.trace(q, k, v) + + while state: + _ = traced.lower(lowering_platforms=("tpu",)) + jax.clear_caches() + + +if __name__ == "__main__": + google_benchmark.main() From 95497ca2f0d41af0ca97af408932982fa3fa7160 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 31 Mar 2025 07:42:16 -0700 Subject: [PATCH 292/483] Remove legacy GPU kernel for LU decomposition. Following the compatibility timeline described here: https://docs.jax.dev/en/latest/export/export.html#ensuring-forward-and-backward-compatibility It has been 6 months since the release of 0.4.33 which is the relevant release for this kernel. PiperOrigin-RevId: 742261532 --- jaxlib/gpu/blas.cc | 10 ----- jaxlib/gpu/blas_kernels.cc | 60 -------------------------- jaxlib/gpu/blas_kernels.h | 10 ----- jaxlib/gpu/gpu_kernels.cc | 3 -- jaxlib/gpu/solver.cc | 41 ------------------ jaxlib/gpu/solver_kernels.cc | 83 ------------------------------------ jaxlib/gpu/solver_kernels.h | 10 ----- 7 files changed, 217 deletions(-) diff --git a/jaxlib/gpu/blas.cc b/jaxlib/gpu/blas.cc index cf391e07e31e..59bf2c4603f6 100644 --- a/jaxlib/gpu/blas.cc +++ b/jaxlib/gpu/blas.cc @@ -49,14 +49,6 @@ BlasType DtypeToBlasType(const dtype& np_type) { return it->second; } -// Returns the descriptor for a GetrfBatched operation. -std::pair BuildGetrfBatchedDescriptor(const dtype& dtype, - int b, int n) { - BlasType type = DtypeToBlasType(dtype); - size_t size = b * sizeof(void*); - return {size, PackDescriptor(GetrfBatchedDescriptor{type, b, n})}; -} - // Returns the descriptor for a GetrfBatched operation. std::pair BuildGeqrfBatchedDescriptor(const dtype& dtype, int b, int m, int n) { @@ -67,7 +59,6 @@ std::pair BuildGeqrfBatchedDescriptor(const dtype& dtype, nb::dict Registrations() { nb::dict dict; - dict[JAX_GPU_PREFIX "blas_getrf_batched"] = EncapsulateFunction(GetrfBatched); dict[JAX_GPU_PREFIX "blas_geqrf_batched"] = EncapsulateFunction(GeqrfBatched); return dict; } @@ -76,7 +67,6 @@ NB_MODULE(_blas, m) { tsl::ImportNumpy(); m.def("registrations", &Registrations); - m.def("build_getrf_batched_descriptor", &BuildGetrfBatchedDescriptor); m.def("build_geqrf_batched_descriptor", &BuildGeqrfBatchedDescriptor); } diff --git a/jaxlib/gpu/blas_kernels.cc b/jaxlib/gpu/blas_kernels.cc index ac30aa9cc520..cdcc154d026d 100644 --- a/jaxlib/gpu/blas_kernels.cc +++ b/jaxlib/gpu/blas_kernels.cc @@ -52,66 +52,6 @@ int SizeOfBlasType(BlasType type) { } // namespace -// Batched LU decomposition: getrfbatched - -static absl::Status GetrfBatched_(gpuStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const GetrfBatchedDescriptor& d = **s; - auto h = BlasHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - if (buffers[0] != buffers[1]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( - buffers[1], buffers[0], SizeOfBlasType(d.type) * d.batch * d.n * d.n, - gpuMemcpyDeviceToDevice, stream))); - } - - int* ipiv = static_cast(buffers[2]); - int* info = static_cast(buffers[3]); - MakeBatchPointersAsync(stream, buffers[1], buffers[4], d.batch, - SizeOfBlasType(d.type) * d.n * d.n); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuGetLastError())); - switch (d.type) { - case BlasType::F32: { - float** batch_ptrs = static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasSgetrfBatched( - handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch))); - break; - } - case BlasType::F64: { - double** batch_ptrs = static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasDgetrfBatched( - handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch))); - break; - } - case BlasType::C64: { - gpublasComplex** batch_ptrs = static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasCgetrfBatched( - handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch))); - break; - } - case BlasType::C128: { - gpublasDoubleComplex** batch_ptrs = - static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasZgetrfBatched( - handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch))); - break; - } - } - return absl::OkStatus(); -} - -void GetrfBatched(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = GetrfBatched_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - // Batched QR decomposition: geqrfbatched static absl::Status GeqrfBatched_(gpuStream_t stream, void** buffers, diff --git a/jaxlib/gpu/blas_kernels.h b/jaxlib/gpu/blas_kernels.h index 724565ea73d1..8ca7b4db4668 100644 --- a/jaxlib/gpu/blas_kernels.h +++ b/jaxlib/gpu/blas_kernels.h @@ -32,16 +32,6 @@ enum class BlasType { C128, }; -// Batched LU decomposition: getrfbatched - -struct GetrfBatchedDescriptor { - BlasType type; - int batch, n; -}; - -void GetrfBatched(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - // Batched QR decomposition: geqrfbatched struct GeqrfBatchedDescriptor { diff --git a/jaxlib/gpu/gpu_kernels.cc b/jaxlib/gpu/gpu_kernels.cc index 242078357254..840c313f2fa3 100644 --- a/jaxlib/gpu/gpu_kernels.cc +++ b/jaxlib/gpu/gpu_kernels.cc @@ -33,13 +33,10 @@ namespace jax { namespace JAX_GPU_NAMESPACE { namespace { -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cublas_getrf_batched", GetrfBatched, - "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cublas_geqrf_batched", GeqrfBatched, "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cudnn_rnn", RNNForward, "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cudnn_rnn_bwd", RNNBackward, "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_getrf", Getrf, "CUDA"); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_getrf_ffi", "CUDA", GetrfFfi); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_syrk_ffi", "CUDA", diff --git a/jaxlib/gpu/solver.cc b/jaxlib/gpu/solver.cc index 20fc308100c4..8013d9877ed5 100644 --- a/jaxlib/gpu/solver.cc +++ b/jaxlib/gpu/solver.cc @@ -54,45 +54,6 @@ SolverType DtypeToSolverType(const dtype& np_type) { return it->second; } -// getrf: LU decomposition - -// Returns the workspace size and a descriptor for a getrf operation. -std::pair BuildGetrfDescriptor(const dtype& dtype, int b, int m, - int n) { - SolverType type = DtypeToSolverType(dtype); - auto h = SolverHandlePool::Borrow(/*stream=*/nullptr); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - int lwork; - switch (type) { - case SolverType::F32: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnSgetrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); - break; - case SolverType::F64: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnDgetrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); - break; - case SolverType::C64: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnCgetrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); - break; - case SolverType::C128: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnZgetrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); - break; - } - return {lwork, PackDescriptor(GetrfDescriptor{type, b, m, n, lwork})}; -} - // geqrf: QR decomposition // Returns the workspace size and a descriptor for a geqrf operation. @@ -462,7 +423,6 @@ std::pair BuildSytrdDescriptor(const dtype& dtype, bool lower, nb::dict Registrations() { nb::dict dict; - dict[JAX_GPU_PREFIX "solver_getrf"] = EncapsulateFunction(Getrf); dict[JAX_GPU_PREFIX "solver_geqrf"] = EncapsulateFunction(Geqrf); dict[JAX_GPU_PREFIX "solver_orgqr"] = EncapsulateFunction(Orgqr); dict[JAX_GPU_PREFIX "solver_syevd"] = EncapsulateFunction(Syevd); @@ -496,7 +456,6 @@ nb::dict Registrations() { NB_MODULE(_solver, m) { tsl::ImportNumpy(); m.def("registrations", &Registrations); - m.def("build_getrf_descriptor", &BuildGetrfDescriptor); m.def("build_geqrf_descriptor", &BuildGeqrfDescriptor); m.def("build_orgqr_descriptor", &BuildOrgqrDescriptor); m.def("build_syevd_descriptor", &BuildSyevdDescriptor); diff --git a/jaxlib/gpu/solver_kernels.cc b/jaxlib/gpu/solver_kernels.cc index 8c22dfcdbca7..8971619d7f34 100644 --- a/jaxlib/gpu/solver_kernels.cc +++ b/jaxlib/gpu/solver_kernels.cc @@ -50,89 +50,6 @@ static int SizeOfSolverType(SolverType type) { } } -// getrf: LU decomposition - -static absl::Status Getrf_(gpuStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const GetrfDescriptor& d = **s; - auto h = SolverHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - if (buffers[1] != buffers[0]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( - buffers[1], buffers[0], - SizeOfSolverType(d.type) * static_cast(d.batch) * - static_cast(d.m) * static_cast(d.n), - gpuMemcpyDeviceToDevice, stream))); - } - - int* ipiv = static_cast(buffers[2]); - int* info = static_cast(buffers[3]); - void* workspace = buffers[4]; - switch (d.type) { - case SolverType::F32: { - float* a = static_cast(buffers[1]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnSgetrf( - handle.get(), d.m, d.n, a, d.m, static_cast(workspace), - d.lwork, ipiv, info))); - a += d.m * d.n; - ipiv += std::min(d.m, d.n); - ++info; - } - break; - } - case SolverType::F64: { - double* a = static_cast(buffers[1]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnDgetrf( - handle.get(), d.m, d.n, a, d.m, static_cast(workspace), - d.lwork, ipiv, info))); - a += d.m * d.n; - ipiv += std::min(d.m, d.n); - ++info; - } - break; - } - case SolverType::C64: { - gpuComplex* a = static_cast(buffers[1]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnCgetrf( - handle.get(), d.m, d.n, a, d.m, static_cast(workspace), - d.lwork, ipiv, info))); - a += d.m * d.n; - ipiv += std::min(d.m, d.n); - ++info; - } - break; - } - case SolverType::C128: { - gpuDoubleComplex* a = static_cast(buffers[1]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnZgetrf( - handle.get(), d.m, d.n, a, d.m, - static_cast(workspace), d.lwork, ipiv, info))); - a += d.m * d.n; - ipiv += std::min(d.m, d.n); - ++info; - } - break; - } - } - return absl::OkStatus(); -} - -void Getrf(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = Getrf_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - // geqrf: QR decomposition static absl::Status Geqrf_(gpuStream_t stream, void** buffers, diff --git a/jaxlib/gpu/solver_kernels.h b/jaxlib/gpu/solver_kernels.h index 51082f2fe812..6372e55b930d 100644 --- a/jaxlib/gpu/solver_kernels.h +++ b/jaxlib/gpu/solver_kernels.h @@ -33,16 +33,6 @@ enum class SolverType { C128, }; -// getrf: LU decomposition - -struct GetrfDescriptor { - SolverType type; - int batch, m, n, lwork; -}; - -void Getrf(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - // geqrf: QR decomposition struct GeqrfDescriptor { From 6b719496ed83f3ca18e0e42f32892eb63102af3b Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 31 Mar 2025 07:59:05 -0700 Subject: [PATCH 293/483] [pallas:mosaic_gpu] Fixed lane-level lowering of `lax.optimization_barrier` PiperOrigin-RevId: 742265860 --- jax/_src/pallas/mosaic_gpu/lowering.py | 14 +++++++------- tests/pallas/mosaic_gpu_test.py | 6 ------ 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index daa718ff1ff2..f027d5bcb76d 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -2369,20 +2369,20 @@ def _bitcast_convert_type_lowering_rule( @register_lowering_rule(lax.optimization_barrier_p, mgpu.ThreadSemantics.Lane) def _optimization_barrier_lowering(ctx: LoweringRuleContext, *args): - args = (_ensure_fa(arg, aval.dtype) for arg, aval in zip(args, ctx.avals_in)) - return mgpu.optimization_barrier(*args) + result = mgpu.optimization_barrier( + *(_ensure_fa(arg, aval.dtype) for arg, aval in zip(args, ctx.avals_in)) + ) + return (result,) if len(ctx.avals_in) == 1 else result @register_lowering_rule( lax.optimization_barrier_p, mgpu.ThreadSemantics.Warpgroup ) def _optimization_barrier_lowering_wg(ctx: LoweringRuleContext, *args): - args = [ + result = mgpu.dialect.optimization_barrier([ _ensure_ir_value(arg, aval.dtype) for arg, aval in zip(args, ctx.avals_in) - ] - result = mgpu.dialect.optimization_barrier(args) - - return (result,) if len(args) == 1 else result + ]) + return (result,) if len(ctx.avals_in) == 1 else result def _bcast( diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 6b1839a64580..77e934f656e7 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1405,9 +1405,6 @@ def convert(x_ref, y_ref): ) def test_optimization_barrier(self): - if self.THREAD_SEMANTICS == plgpu.ThreadSemantics.Lane: - self.skipTest("This test crashes with lane semantics") - @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((128,), jnp.float32), @@ -1419,9 +1416,6 @@ def kernel(x_ref, o_ref): np.testing.assert_array_equal(kernel(x), x) def test_optimization_barrier_multiple_inputs(self): - if self.THREAD_SEMANTICS == plgpu.ThreadSemantics.Lane: - self.skipTest("This test crashes with lane semantics") - @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((128,), jnp.float32), From 200f8263980bb1346c15f4616e28f129cf0b4f85 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 31 Mar 2025 08:50:39 -0700 Subject: [PATCH 294/483] [array api] return all devices in devices() --- jax/_src/numpy/array_api_metadata.py | 9 +++++++-- tests/array_api_test.py | 6 +++++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/jax/_src/numpy/array_api_metadata.py b/jax/_src/numpy/array_api_metadata.py index d8d2c2d1a2a4..5267e51215ee 100644 --- a/jax/_src/numpy/array_api_metadata.py +++ b/jax/_src/numpy/array_api_metadata.py @@ -24,7 +24,9 @@ import jax from jax._src.sharding import Sharding from jax._src.lib import xla_client as xc -from jax._src import dtypes as _dtypes, config +from jax._src import config +from jax._src import dtypes as _dtypes +from jax._src import xla_bridge as xb __array_api_version__ = '2023.12' @@ -73,7 +75,10 @@ def default_device(self): return None def devices(self): - return jax.devices() + out = [None] # None indicates "uncommitted" + for backend in xb.backends(): + out.extend(jax.devices(backend)) + return out def capabilities(self): return self._capabilities diff --git a/tests/array_api_test.py b/tests/array_api_test.py index d509fe78c35f..8e4ba275fdd3 100644 --- a/tests/array_api_test.py +++ b/tests/array_api_test.py @@ -26,6 +26,7 @@ import jax.numpy as jnp from jax._src import config, test_util as jtu from jax._src.dtypes import _default_types, canonicalize_dtype +from jax._src import xla_bridge as xb ARRAY_API_NAMESPACE = jnp @@ -283,7 +284,10 @@ def test_default_device_info(self): assert self.info.default_device() is None def test_devices_info(self): - assert self.info.devices() == jax.devices() + devices = set(self.info.devices()) + assert None in devices + for backend in xb.backends(): + assert devices.issuperset(jax.devices(backend)) def test_default_dtypes_info(self): _default_dtypes = { From aaa3ebfb8a135e4c82c08e551fd756ad5db85716 Mon Sep 17 00:00:00 2001 From: Shu Wang Date: Mon, 31 Mar 2025 12:05:30 -0500 Subject: [PATCH 295/483] Add optimization barrier. --- jax/_src/cudnn/scaled_matmul_stablehlo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/cudnn/scaled_matmul_stablehlo.py b/jax/_src/cudnn/scaled_matmul_stablehlo.py index 60cdbee7fa20..6766e3992202 100644 --- a/jax/_src/cudnn/scaled_matmul_stablehlo.py +++ b/jax/_src/cudnn/scaled_matmul_stablehlo.py @@ -495,7 +495,7 @@ def quantize(x, config): SCALE_MAX = jnp.finfo(config.scale_type).max.astype(x.dtype) scales_q = jnp.clip(scales / config.global_scale, 0, SCALE_MAX) - scales_q = scales_q.astype(config.scale_type) + scales_q = jax.lax.optimization_barrier(scales_q.astype(config.scale_type)) scaled_x = x / scales_q.astype(jnp.float32) else: raise ValueError(f"Unrecognized mode: {config.mode}.") From 05039fe520906a2cd9562593406dd4544828515e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 31 Mar 2025 17:49:43 +0000 Subject: [PATCH 296/483] Bump tsickert/discord-webhook from 5.3.0 to 7.0.0 Bumps [tsickert/discord-webhook](https://github.com/tsickert/discord-webhook) from 5.3.0 to 7.0.0. - [Release notes](https://github.com/tsickert/discord-webhook/releases) - [Commits](https://github.com/tsickert/discord-webhook/compare/c840d45a03a323fbc3f7507ac7769dbd91bfb164...b217a69502f52803de774ded2b1ab7c282e99645) --- updated-dependencies: - dependency-name: tsickert/discord-webhook dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- .github/workflows/community_release_actions.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/community_release_actions.yml b/.github/workflows/community_release_actions.yml index d61bea3d7e4d..1980e803ba9b 100644 --- a/.github/workflows/community_release_actions.yml +++ b/.github/workflows/community_release_actions.yml @@ -25,7 +25,7 @@ jobs: maxLength: 2000 truncationSymbol: "..." - name: Discord Webhook Action - uses: tsickert/discord-webhook@c840d45a03a323fbc3f7507ac7769dbd91bfb164 # v5.3.0 + uses: tsickert/discord-webhook@b217a69502f52803de774ded2b1ab7c282e99645 # v7.0.0 with: webhook-url: ${{ secrets.DISCORD_WEBHOOK_URL }} content: ${{ steps.get-content.outputs.string }} From 5d69e6b64dca1ef4590a942e8de173fc40e46a34 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 31 Mar 2025 17:49:47 +0000 Subject: [PATCH 297/483] Bump actions/setup-python from 5.4.0 to 5.5.0 Bumps [actions/setup-python](https://github.com/actions/setup-python) from 5.4.0 to 5.5.0. - [Release notes](https://github.com/actions/setup-python/releases) - [Commits](https://github.com/actions/setup-python/compare/42375524e23c412d93fb67b49958b491fce71c38...8d9ed9ac5c53483de85588cdf95a591a75ab9f55) --- updated-dependencies: - dependency-name: actions/setup-python dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- .github/workflows/ci-build.yaml | 12 ++++++------ .github/workflows/jax-array-api.yml | 2 +- .github/workflows/upstream-nightly.yml | 2 +- .github/workflows/wheel_win_x64.yml | 2 +- .github/workflows/windows_ci.yml | 2 +- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index f43407af2ed9..c575c84cd422 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -31,7 +31,7 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python 3.11 - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 with: python-version: 3.11 - run: python -m pip install pre-commit @@ -70,7 +70,7 @@ jobs: apt update apt install -y libssl-dev - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -108,7 +108,7 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -142,7 +142,7 @@ jobs: apt update apt install -y libssl-dev libsqlite3-dev - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -168,7 +168,7 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -201,7 +201,7 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 with: python-version: 3.12 - name: Install JAX diff --git a/.github/workflows/jax-array-api.yml b/.github/workflows/jax-array-api.yml index 2b97c5a05c1c..c91ab6b8b7da 100644 --- a/.github/workflows/jax-array-api.yml +++ b/.github/workflows/jax-array-api.yml @@ -32,7 +32,7 @@ jobs: submodules: 'true' path: 'array-api-tests' - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 with: python-version: ${{ matrix.python-version }} - name: Install dependencies diff --git a/.github/workflows/upstream-nightly.yml b/.github/workflows/upstream-nightly.yml index 5132a12cf16f..ba2c750f8a8a 100644 --- a/.github/workflows/upstream-nightly.yml +++ b/.github/workflows/upstream-nightly.yml @@ -33,7 +33,7 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 with: python-version: ${{ matrix.python-version }} - name: Install JAX test requirements diff --git a/.github/workflows/wheel_win_x64.yml b/.github/workflows/wheel_win_x64.yml index 912088428fd5..a2b3aeddc24a 100644 --- a/.github/workflows/wheel_win_x64.yml +++ b/.github/workflows/wheel_win_x64.yml @@ -27,7 +27,7 @@ jobs: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + - uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 with: python-version: ${{ matrix.pyver }} cache: 'pip' diff --git a/.github/workflows/windows_ci.yml b/.github/workflows/windows_ci.yml index fc2b63396f56..5a435023ffda 100644 --- a/.github/workflows/windows_ci.yml +++ b/.github/workflows/windows_ci.yml @@ -35,7 +35,7 @@ jobs: with: path: jax - - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + - uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 with: python-version: ${{ matrix.pyver }} cache: 'pip' From 1355e7c65003428c5922df306cef77cef48412ed Mon Sep 17 00:00:00 2001 From: Olli Lupton Date: Wed, 5 Mar 2025 15:32:25 +0000 Subject: [PATCH 298/483] AutoPGLE: force-disable graphs less Previously, XLA's command buffers (CUDA graphs) would be disabled both for PGLE profile collection and when re-compiling using the profile data. With this change, they are only disabled when collecting the profile data. --- jax/_src/compiler.py | 239 ++++++++++++++++++++++--------------------- tests/pgle_test.py | 52 +++++++++- 2 files changed, 174 insertions(+), 117 deletions(-) diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index dea532d13031..9ac47aa4f0ea 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -17,6 +17,8 @@ from __future__ import annotations from collections.abc import Sequence +import copy +from functools import partial import logging import time from typing import Any, Callable @@ -197,15 +199,6 @@ def get_compile_options( config.memory_fitting_level.value ).value - # This is a temporary workaround to simplify the AutoPGLE usage. - # TODO(b/376647494): Remove once the bug is fixed. - if ((config.enable_pgle.value and config.pgle_profiling_runs.value > 0) - or config.compilation_cache_expect_pgle.value): - logger.debug("Explicitly disabling command buffer scheduling for AutoPGLE.") - if env_options_overrides is None: - env_options_overrides = {} - env_options_overrides['xla_gpu_enable_command_buffer'] = '' - if env_options_overrides is not None: # Some overrides are passed directly on build_options. overrides_on_build_options = [ @@ -298,6 +291,8 @@ def backend_compile( options: xc.CompileOptions, host_callbacks: Sequence[Any], ) -> xc.LoadedExecutable: + sym_name = module.operation.attributes['sym_name'] + module_name = ir.StringAttr(sym_name).value # Convert ir.Module to a string representation, unless the backend # explicitly flags the ability to handle a module directly (avoiding the # overhead of back and forth conversions). @@ -308,6 +303,14 @@ def backend_compile( else: built_c = module + if (options.executable_build_options.fdo_profile is not None + and len(options.executable_build_options.fdo_profile)): + logger.debug( + "Compiling module %s with FDO profile of length %d", + module_name, + len(options.executable_build_options.fdo_profile), + ) + try: # we use a separate function call to ensure that XLA compilation appears # separately in Python profiling results @@ -362,72 +365,31 @@ def compile_or_get_cached( if dumped_to := mlir.dump_module_to_file(computation, "compile"): logging.info("Dumped the module to %s.", dumped_to) - use_compilation_cache = compilation_cache.is_cache_used(backend) - is_multi_process = ( len({device.process_index for device in devices.flatten()}) > 1 ) min_device_process_id = min( devices.flatten(), key=lambda device: device.id ).process_index - is_auto_pgle_used = ( - config.enable_pgle.value and config.pgle_profiling_runs.value > 0 - ) - if not use_compilation_cache: - if ( - is_multi_process - and is_auto_pgle_used - and distributed.global_state.client is not None - ): - compile_options.executable_build_options.fdo_profile = ( - _share_fdo_profiles( - computation, - devices, - compile_options, - backend, - distributed.global_state.client, - min_device_process_id, - ) - ) + # cache_key: may be None if compilation caching is disabled + cache_key, compile_options = _resolve_compilation_strategy( + computation, + devices, + compile_options, + backend, + pgle_profiler, + is_multi_process, + module_name, + min_device_process_id, + ) + if cache_key is None: return backend_compile(backend, computation, compile_options, host_callbacks) monitoring.record_event('/jax/compilation_cache/compile_requests_use_cache') - try: - if config.remove_custom_partitioning_ptr_from_cache_key.value: - ignore_callbacks = cache_key_type.IgnoreCallbacks.CUSTOM_PARTITIONING - else: - ignore_callbacks = cache_key_type.IgnoreCallbacks.NO - - cache_key = compilation_cache.get_cache_key( - computation, - devices, - compile_options, - backend, - ignore_callbacks=ignore_callbacks, - ) - except xc._xla.XlaRuntimeError as ex: - logger.error("compile_or_get_cached: unable to generate cache key, " - "skipping the cache: %s", ex) - return backend_compile(backend, computation, compile_options, - host_callbacks) - - if is_auto_pgle_used or config.compilation_cache_expect_pgle.value: - cache_key = _resolve_pgle_module_cache_key( - computation, - devices, - compile_options, - backend, - pgle_profiler, - is_multi_process, - cache_key, - module_name, - min_device_process_id, - ) - cache_retrieval_start = time.monotonic() retrieved_executable, retrieved_compile_time = _cache_read( module_name, cache_key, compile_options, backend) @@ -481,85 +443,130 @@ def compile_or_get_cached( # 1. PGLE optimized module (the one which was recompiled with FDO profile) is # in the persistent cache. In this case the module should be returned from # cache and PGLE should be disabled for this module. Is module is stored in -# the persistent cache under the "pgle_profiled_module_key" which calculated -# with replacing FDO profile with flag which identify that module were PGLE -# profiled. +# the persistent cache under the "pgle_optimized_cache_key", which is +# calculated by replacing the FDO profile with a sentinel value that identifies +# that the module was optimized with PGLE. # 2. PGLE profiled module is not in the persistent cache and the module is -# getting built with an FDO profile. In this case we need to share FDO profile -# with other processes and store the result under the -# "pgle_profiled_module_key" so later in case 1 we will be able to find the +# getting built with an FDO profile. In this case we need to share the FDO +# profile with any other processes and store the result under the +# "pgle_optimized_cache_key" so later in case 1 we will be able to find the # module. # 3. PGLE profiled module is not in the persistent cache and the module is # getting compiled to be PGLEd (FDO profile is empty). In this case we need to -# simply return the non-PGLE profiled module from the persistent cache. +# simply return the non-PGLE profiled module from the persistent cache if it +# exists, and otherwise compile it. # # If the compilation_cache_expect_pgle option is set then in case 1 the PGLE # optimized module will be loaded even if PGLE is not enabled in the current # process. This is useful if we want to combine the use of PGLE with other # profiling tools (e.g. Nsight Systems) that cannot co-exist with PGLE due to # contention for CUPTI resources. -def _resolve_pgle_module_cache_key( +def _resolve_compilation_strategy( computation: ir.Module, devices: np.ndarray, compile_options: xc.CompileOptions, backend: xc.Client, pgle_profiler: profiler.PGLEProfiler | None, is_multi_process: bool, - cache_key: str, module_name: str, min_device_process_id: int, -) -> str: - fdo_profile = compile_options.executable_build_options.fdo_profile - compile_options.executable_build_options.fdo_profile = b"pgle profiled" - - pgle_profiled_module_key = compilation_cache.get_cache_key( - computation, - devices, - compile_options, - backend, - cache_key_type.IgnoreCallbacks.ALL, +) -> tuple[str | None, xc.CompileOptions]: + is_auto_pgle_used = ( + config.enable_pgle.value and config.pgle_profiling_runs.value > 0 ) - compile_options.executable_build_options.fdo_profile = fdo_profile - - result_key = cache_key - if _is_executable_in_cache(backend, pgle_profiled_module_key): - # Load PGLE profiled module from the persistent cache. - result_key = pgle_profiled_module_key - if config.compilation_cache_expect_pgle.value: - logging.info(f"PGLE-optimized {module_name} loaded from compilation cache") - if pgle_profiler is not None: - pgle_profiler.disable() + + get_cache_key = partial(_get_cache_key, backend=backend, + computation=computation, devices=devices) + + if is_auto_pgle_used or config.compilation_cache_expect_pgle.value: + # This can be None if cache key generation fails. + pgle_optimized_cache_key = get_cache_key(compile_options, + override_fdo_profile=b"pgle profiled") + # TODO(b/376647494): remove the workaround when the bug is fixed; the JAX + # profiler cannot collect sufficiently detailed profile data for PGLE if + # command buffers / CUDA graphs are enabled. Therefore disable command + # buffers when compiling for PGLE data collection, but not if AutoPGLE is + # not enabled, and not when re-compiling using PGLE data. This condition + # includes `compilation_cache_expect_pgle` so that slow-to-compile modules + # that are not executed often enough to trigger re-compilation will still + # be cached between an "enable_pgle" run and an "expect_pgle" run. + first_pass_compile_options = copy.deepcopy(compile_options) + first_pass_compile_options.env_option_overrides += [ + ("xla_gpu_enable_command_buffer", ""), + ] else: - # No PGLE-optimised module found in the persistent cache. - if (config.compilation_cache_expect_pgle.value - and _is_executable_in_cache(backend, cache_key)): - # The user asserted this miss was unexpected; emit a warning + pgle_optimized_cache_key = None + first_pass_compile_options = compile_options + + # This can be None if cache key generation fails or caching is disabled + cache_key = get_cache_key(first_pass_compile_options) + + if cache_key is not None and pgle_optimized_cache_key is not None: + # The compilation cache is enabled and AutoPGLE is enabled/expected + if _is_executable_in_cache(backend, pgle_optimized_cache_key): + if config.compilation_cache_expect_pgle.value: + logging.info(f"PGLE-optimized {module_name} loaded from compilation cache") + # No need to record N profiles in this case + if pgle_profiler is not None: + pgle_profiler.disable() + return pgle_optimized_cache_key, compile_options + elif (config.compilation_cache_expect_pgle.value + and _is_executable_in_cache(backend, cache_key)): + # No PGLE-optimized module found in the persistent cache, and the user + # asserted (expect_pgle) that this miss was unexpected warnings.warn(f"PERSISTENT CACHE MISS for PGLE-optimized {module_name} " "despite non-PGLE hit; it may not have been executed " "enough times when the cache was populated") - if fdo_profile is not None and len(fdo_profile) > 0: - # Store module under PGLE profiled module cache key. - result_key = pgle_profiled_module_key - if is_multi_process and distributed.global_state.client is not None: - compile_options.executable_build_options.fdo_profile = ( - _share_fdo_profiles( - computation, - devices, - compile_options, - backend, - distributed.global_state.client, - min_device_process_id, - ) - ) - else: - compile_options.executable_build_options.fdo_profile = fdo_profile - logger.debug( - "Compiling module %s with FDO profile of length %d", - module_name, - len(compile_options.executable_build_options.fdo_profile), + + if (is_auto_pgle_used + and compile_options.executable_build_options.fdo_profile is not None + and len(compile_options.executable_build_options.fdo_profile)): + # Profile data are available to trigger a PGLE-optimized recompilation; + # store under `pgle_optimized_cache_key` if the cache is enabled + if is_multi_process and distributed.global_state.client is not None: + compile_options.executable_build_options.fdo_profile = ( + _share_fdo_profiles( + computation, + devices, + compile_options, + backend, + distributed.global_state.client, + min_device_process_id, ) - return result_key + ) + return pgle_optimized_cache_key, compile_options + else: + # Compile for PGLE collection, store under `cache_key` if the cache is + # enabled. This is also the AutoPGLE-disabled path. + return cache_key, first_pass_compile_options +def _get_cache_key( + options: xc.CompileOptions, + backend: xc.Client, + computation: ir.Module, + devices: np.ndarray, + override_fdo_profile: bytes | None = None) -> str | None: + if not compilation_cache.is_cache_used(backend): + return None + if config.remove_custom_partitioning_ptr_from_cache_key.value: + ignore_callbacks = cache_key_type.IgnoreCallbacks.CUSTOM_PARTITIONING + else: + ignore_callbacks = cache_key_type.IgnoreCallbacks.NO + if override_fdo_profile is not None: + options = copy.deepcopy(options) + options.executable_build_options.fdo_profile = override_fdo_profile + try: + return compilation_cache.get_cache_key( + computation, + devices, + options, + backend, + ignore_callbacks, + ) + except xc._xla.XlaRuntimeError as ex: + logger.error("compile_or_get_cached: unable to generate cache key, " + "skipping the cache: %s", ex) + return None # The process that has the lowest device ID should share FDO profile before # compilation with other processes. diff --git a/tests/pgle_test.py b/tests/pgle_test.py index 7dabd809d95e..2787de4c6e17 100644 --- a/tests/pgle_test.py +++ b/tests/pgle_test.py @@ -21,7 +21,7 @@ import tempfile import warnings -from absl.testing import absltest +from absl.testing import absltest, parameterized import jax from jax._src import api from jax._src import compilation_cache as cc @@ -478,5 +478,55 @@ def check_if_cache_hit(event): self.assertLen(w, 1) self.assertIn("PERSISTENT CACHE WRITE with key jit_h-", str(w[0].message)) + @parameterized.parameters([True, False]) + @jtu.thread_unsafe_test() + def testAutoPgleWithCommandBuffers(self, enable_compilation_cache): + with (config.pgle_profiling_runs(1), + config.enable_compilation_cache(enable_compilation_cache), + config.enable_pgle(True), + tempfile.TemporaryDirectory() as dump_dir, + tempfile.TemporaryDirectory() as cache_dir): + if enable_compilation_cache: + cc.reset_cache() + cc.set_cache_dir(cache_dir) + compiler_options = { + 'xla_dump_to': dump_dir, + # FUSION, see https://github.com/openxla/xla/issues/22459 + 'xla_gpu_enable_command_buffer': 1, + 'xla_gpu_graph_min_graph_size': 1, + } + @partial( + jax.jit, + compiler_options=compiler_options, + ) + def f(x): + return x * 2 + + x = jnp.arange(1) + expected = x * 2 + + # This is ugly, but it does not seem possible to get the AutoPGLE-recompiled + # executable text (.lower(x).compile().as_text() or similar). + def get_new_hlo(): + additions = set(os.listdir(dump_dir)) - get_new_hlo.seen_files + get_new_hlo.seen_files |= additions + new_hlos = list(filter(lambda f: f.endswith("_gpu_after_optimizations.txt"), additions)) + assert len(new_hlos) == 1 + with open(os.path.join(dump_dir, new_hlos[0]), "r") as ifile: + return ifile.read() + + get_new_hlo.seen_files = set() + + # Run 1 + self.assertArraysEqual(f(x), expected) + self.assertNotIn("command_buffer", get_new_hlo()) # b/376647494 workaround + # Run 2 + self.assertArraysEqual(f(x), expected) + self.assertIn("command_buffer", get_new_hlo()) # workaround disabled + + api.clear_caches() + pjit._pgle_profiler_dict.clear() + + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From d6b4fed5ed25432fd5298fe108e797dac734465d Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 31 Mar 2025 11:33:05 -0700 Subject: [PATCH 299/483] Propagate sharding and vma rule for axis_index_p. There's no need for pbroadcast insertion for axis_index_p in the traceable PiperOrigin-RevId: 742334213 --- jax/_src/lax/parallel.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 8fc8c336d61a..ebc6255cb66b 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -35,6 +35,7 @@ from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.interpreters import pxla +from jax._src.mesh import get_abstract_mesh from jax._src.lax import lax from jax._src.lax import slicing from jax._src.lib.mlir import ir @@ -1860,8 +1861,14 @@ def _axis_index_lowering(ctx, *, axis_name): ctx.module_context.axis_env)] def _axis_index_effectful_abstract_eval(*, axis_name): - _check_axis_names([axis_name]) - return ShapedArray((), np.int32), {core.NamedAxisEffect(axis_name)} + effect = {core.NamedAxisEffect(axis_name)} + axis_name = (axis_name,) if not isinstance(axis_name, tuple) else axis_name + _check_axis_names(axis_name) + mesh = get_abstract_mesh() + sharding = NamedSharding(mesh, P()) + vma = ((frozenset(axis_name) if mesh._any_axis_manual else frozenset()) + if config.varying_axes_in_types.value else frozenset()) + return ShapedArray((), np.int32, sharding=sharding, vma=vma), effect def _axis_index_batcher(axis_data, vals_in, dims_in, *, axis_name): return lax.iota(np.int32, axis_data.size), 0 From 8cda2a23dda416bd150b39f1c3580602fe2aa4f5 Mon Sep 17 00:00:00 2001 From: Nicolas Castet Date: Thu, 20 Mar 2025 15:20:14 -0500 Subject: [PATCH 300/483] [Mosaic-GPU] [2/3] Add NVSHMEM support to Mosaic-GPU custom call --- .../mosaic_gpu/pallas_call_registration.py | 1 + jax/experimental/mosaic/gpu/core.py | 48 ++++- jaxlib/mosaic/gpu/BUILD | 15 ++ jaxlib/mosaic/gpu/custom_call.cc | 165 +++++++++++++++--- jaxlib/mosaic/gpu/mosaic_gpu_comm.h | 86 +++++++++ jaxlib/mosaic/gpu/runtime.cc | 13 ++ 6 files changed, 293 insertions(+), 35 deletions(-) create mode 100644 jaxlib/mosaic/gpu/mosaic_gpu_comm.h diff --git a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py index 1d4be26187ce..ff3c4f89d30c 100644 --- a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py @@ -96,6 +96,7 @@ def zero_init_gmem_scratch(): module=module, out_types=(*lowering_result.new_out_shapes, *lowering_result.gmem_scratch_shapes), input_output_aliases=input_output_aliases, + use_custom_barrier=False, # False until we add get_barrier_semaphore() feature ) if lowering_result.gmem_scratch_shapes: # Drop the GMEM scratch. outs = outs[:-len(lowering_result.gmem_scratch_shapes)] diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index fcc5d3db6d60..43b93e7da023 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -83,6 +83,15 @@ # Set this so that the custom call can find it os.environ["MOSAIC_GPU_RUNTIME_LIB_PATH"] = str(RUNTIME_PATH) +if os.environ.get("MOSAIC_GPU_NVSHMEM_LLVM_LIB_PATH") is None: + try: + from nvidia import nvshmem + except ImportError: + pass + else: + os.environ["MOSAIC_GPU_NVSHMEM_LLVM_LIB_PATH"] = ( + os.path.join(nvshmem.__path__[0], 'lib/libnvshmem_device.bc') + ) mosaic_gpu_p = jax._src.core.Primitive("mosaic_gpu_p") mosaic_gpu_p.multiple_results = True @@ -103,6 +112,7 @@ def _mosaic_gpu_lowering_rule( module, out_types, input_output_aliases: tuple[tuple[int, int], ...] = (), + use_custom_barrier: bool = False, ): assert len(args) == len(ctx.avals_in) assert len(out_types) == len(ctx.avals_out) @@ -121,15 +131,35 @@ def _mosaic_gpu_lowering_rule( raise RuntimeError("Hash collision!") else: KNOWN_KERNELS[kernel_id] = module_asm - op = mlir.custom_call( - "mosaic_gpu", - result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], - operands=args, - operand_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_in], - result_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_out], - backend_config=kernel_id + module_asm, - operand_output_aliases=dict(input_output_aliases), - ) + + if ctx.is_forward_compat(): + if use_custom_barrier: + raise ValueError("Barrier semaphore is not supported in forward compatibility mode. " + "Please, use 'export_ignore_forward_compatibility=True'.") + op = mlir.custom_call( + "mosaic_gpu", + result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], + operands=args, + operand_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_in], + result_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_out], + backend_config=kernel_id + module, + operand_output_aliases=dict(input_output_aliases), + ) + else: + op = mlir.custom_call( + "mosaic_gpu_v2", + result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], + operands=args, + operand_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_in], + result_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_out], + backend_config=dict( + kernel_hash=ir.StringAttr.get(kernel_id), + module=ir.StringAttr.get(module_asm), + use_custom_barrier=ir.BoolAttr.get(use_custom_barrier), + ), + operand_output_aliases=dict(input_output_aliases), + api_version=4, + ) return op.results diff --git a/jaxlib/mosaic/gpu/BUILD b/jaxlib/mosaic/gpu/BUILD index 80a8f0e51080..6f9a729688ff 100644 --- a/jaxlib/mosaic/gpu/BUILD +++ b/jaxlib/mosaic/gpu/BUILD @@ -114,11 +114,21 @@ cc_library( # Linker may prune these symbols if they are not explicitly exported. linkopts = ["-Wl,--export-dynamic-symbol='mosaic_gpu_*'"], deps = [ + ":mosaic_gpu_comm", "@local_config_cuda//cuda:cuda_headers", ], alwayslink = True, ) +cc_library( + name = "mosaic_gpu_comm", + hdrs = ["mosaic_gpu_comm.h"], + deps = [ + "@local_config_cuda//cuda:cuda_headers", + "@xla//xla/tsl/cuda:cudart", + ], +) + cc_library( name = "custom_call", srcs = ["custom_call.cc"], @@ -127,9 +137,11 @@ cc_library( ":target", "//jaxlib/cuda:cuda_vendor", "//jaxlib/mosaic/dialect/gpu:mosaic_gpu", + "//jaxlib/mosaic/gpu:mosaic_gpu_comm", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -175,6 +187,8 @@ cc_library( "@llvm-project//mlir:VectorToLLVM", "@xla//xla/service:custom_call_status", "@xla//xla/service:custom_call_target_registry", + "@xla//xla/ffi", + "@xla//xla/ffi:ffi_api", ], alwayslink = True, ) @@ -210,5 +224,6 @@ cc_binary( deps = [ "@local_config_cuda//cuda:cuda_headers", "@xla//xla/tsl/cuda:cudart", + "//jaxlib/mosaic/gpu:mosaic_gpu_comm", ], ) diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index d9a69c57e142..465551e2903b 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -35,6 +35,7 @@ limitations under the License. #include "absl/base/optimization.h" #include "absl/cleanup/cleanup.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -87,14 +88,19 @@ limitations under the License. #include "jaxlib/gpu/vendor.h" #include "jaxlib/mosaic/dialect/gpu/mosaic_gpu.h" #include "jaxlib/mosaic/gpu/launch_lowering.h" +#include "jaxlib/mosaic/gpu/mosaic_gpu_comm.h" #include "jaxlib/mosaic/gpu/passes.h" #include "jaxlib/mosaic/gpu/serde.h" #include "jaxlib/mosaic/gpu/target.h" +#include "xla/ffi/ffi.h" +#include "xla/ffi/ffi_api.h" #include "xla/service/custom_call_status.h" #include "xla/service/custom_call_target_registry.h" namespace { +namespace ffi = xla::ffi; + using MosaicInitFunc = void(void****); using MosaicHostFunc = void(void**); @@ -121,7 +127,7 @@ absl::StatusOr> GetSmAndPtxIsaVersion() { mlir::FailureOr GetPassPipeline( mlir::MLIRContext* ctx, mlir::gpu::CompilationTarget target, - const std::string& sm, const std::string& ptx_isa) { + const std::string& sm, const std::string& ptx_isa, const std::string& nvshmem_path) { static bool register_once = []() { llvm::InitializeNativeTarget(); llvm::InitializeNativeTarget(); @@ -179,8 +185,8 @@ mlir::FailureOr GetPassPipeline( gpu.module(mosaic-byval-insertion), gpu.module(reconcile-unrealized-casts), mosaic-convert-gpu-to-llvm, - gpu-module-to-binary{format=)", - mlir::gpu::stringifyCompilationTarget(target).str(), R"(}, + gpu-module-to-binary{format=)" + + mlir::gpu::stringifyCompilationTarget(target).str() + (!nvshmem_path.empty() ? R"( l=)" + nvshmem_path : "") + R"(}, convert-math-to-llvm{approximate-log1p=true}, canonicalize{max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, @@ -289,7 +295,7 @@ class TemporaryDirectory { }; void DumpCompilationOutput(mlir::ModuleOp module, const std::string& sm, - const std::string& ptx_isa) { + const std::string& ptx_isa, const std::string& nvshmem_path) { bool dump_ptx = getenv("MOSAIC_GPU_DUMP_PTX") != nullptr; bool dump_ptxas = getenv("MOSAIC_GPU_DUMP_PTXAS") != nullptr; bool dump_sass = getenv("MOSAIC_GPU_DUMP_SASS") != nullptr; @@ -300,7 +306,8 @@ void DumpCompilationOutput(mlir::ModuleOp module, const std::string& sm, module = module.clone(); // Prevent accidental modification. absl::Cleanup module_destroyer = [module] { module->erase(); }; auto passes = GetPassPipeline( - module.getContext(), mlir::gpu::CompilationTarget::Assembly, sm, ptx_isa); + module.getContext(), mlir::gpu::CompilationTarget::Assembly, + sm, ptx_isa, nvshmem_path); if (mlir::failed(passes) || mlir::failed(RunPasses(std::move(*passes), module))) { return; @@ -358,7 +365,29 @@ void DumpCompilationOutput(mlir::ModuleOp module, const std::string& sm, } } -absl::StatusOr> Compile( +bool is_nvshmem_used(mlir::ModuleOp module) { + constexpr std::string_view prefix1 = "nvshmem_"; + constexpr std::string_view prefix2 = "nvshmemx_"; + for (mlir::LLVM::LLVMFuncOp llvm_func : module.getOps()) { + const auto& func_name = llvm_func.getName(); + if (!func_name.starts_with(prefix1) && !func_name.starts_with(prefix2)) { + continue; + } + auto uses = mlir::SymbolTable::getSymbolUses(llvm_func, module.getOperation()); + if (uses && !uses->empty()) { + return true; + } + } + return false; +} + +absl::StatusOr get_nvshmem_llvm_lib_path() { + const char * nvshmem_path_ptr = getenv("MOSAIC_GPU_NVSHMEM_LLVM_LIB_PATH"); + if (!nvshmem_path_ptr) return absl::InternalError("Failed to get MOSAIC_GPU_NVSHMEM_LLVM_LIB_PATH"); + return nvshmem_path_ptr; +} + +absl::StatusOr, bool>> Compile( mlir::ModuleOp module) { auto sm_and_ptx_isa = GetSmAndPtxIsaVersion(); if (!sm_and_ptx_isa.ok()) { @@ -366,9 +395,16 @@ absl::StatusOr> Compile( } const std::string sm = sm_and_ptx_isa.value().first; const std::string ptx_isa = sm_and_ptx_isa.value().second; - DumpCompilationOutput(module, sm, ptx_isa); + bool is_comm_used = is_nvshmem_used(module); + std::string nvshmem_path = ""; + if (is_comm_used) { + TF_ASSIGN_OR_RETURN(nvshmem_path, get_nvshmem_llvm_lib_path()); + } + DumpCompilationOutput(module, sm, ptx_isa, nvshmem_path); auto passes = GetPassPipeline( - module.getContext(), mlir::gpu::CompilationTarget::Binary, sm, ptx_isa); + module.getContext(), + mlir::gpu::CompilationTarget::Binary, + sm, ptx_isa, nvshmem_path); if (mlir::failed(passes)) { return absl::InternalError("Failed to construct pass pipeline"); } @@ -392,23 +428,25 @@ absl::StatusOr> Compile( if (!maybe_execution_engine) { return absl::InternalError("Failed to compile kernel"); } - return std::move(*maybe_execution_engine); + return std::make_pair(std::move(*maybe_execution_engine), is_comm_used); } class CompiledKernel { public: CompiledKernel(std::unique_ptr engine, void* ctx, - MosaicHostFunc* host_launch) - : engine_(std::move(engine)), ctx_(ctx), host_launch_(host_launch) {} + MosaicHostFunc* host_launch, bool is_comm_used) + : engine_(std::move(engine)), ctx_(ctx), host_launch_(host_launch), + is_comm_used_(is_comm_used) {} - std::tuple GetHostLaunch() { - return std::make_tuple(ctx_, host_launch_); + std::tuple GetHostLaunch() { + return std::make_tuple(ctx_, host_launch_, is_comm_used_); } private: std::unique_ptr engine_; void* ctx_; // TODO(apaszke): Destroy this properly MosaicHostFunc* host_launch_; + bool is_comm_used_; }; using KernelHash = std::array; @@ -477,7 +515,8 @@ absl::StatusOr CompileAndInit(const char* module) { if (!maybe_engine.ok()) { return maybe_engine.status(); } - mlir::ExecutionEngine* execution_engine = maybe_engine->get(); + mlir::ExecutionEngine* execution_engine = maybe_engine.value().first.get(); + bool is_comm_used = maybe_engine.value().second; auto host_and_init_func_names = GetHostAndInitFuncNames(*module_op); if (!host_and_init_func_names.ok()) { @@ -496,14 +535,15 @@ absl::StatusOr CompileAndInit(const char* module) { void** kernel_ptr_ptr = &kernel_ptr; void*** init_args[2] = {&module_ptr_ptr, &kernel_ptr_ptr}; reinterpret_cast(*init)(init_args); - return CompiledKernel(std::move(*maybe_engine), kernel_ptr, - reinterpret_cast(*host)); + return CompiledKernel(std::move(maybe_engine.value().first), kernel_ptr, + reinterpret_cast(*host), + is_comm_used); } // Each compiled kernel has a unique init func, and each kernel is used from // a single HLO module. So it should be safe to not include the CUDA context // in the key. -absl::StatusOr> CachedCompileAndInit( +absl::StatusOr CachedCompileAndInit( CacheKey key, const char* module) { auto cache_and_mutex = GetKernelCache(); auto* cache = cache_and_mutex.first; @@ -514,7 +554,7 @@ absl::StatusOr> CachedCompileAndInit( absl::ReaderMutexLock lock(mutex); auto it = cache->find(key); if (ABSL_PREDICT_TRUE(it != cache->end())) - return it->second.GetHostLaunch(); + return &it->second; } absl::MutexLock lock(mutex); @@ -526,11 +566,12 @@ absl::StatusOr> CachedCompileAndInit( } cache->insert_or_assign(key, std::move(*compiled)); } - return cache->at(key).GetHostLaunch(); + return &cache->at(key); } void MosaicGPUCustomCall(void* stream, void** buffers, char* opaque, size_t opaque_len, XlaCustomCallStatus* status) { + // Forward-compatible version using the legacy FFI API if (reinterpret_cast(opaque) % alignof(KernelHash)) { fprintf(stderr, "Misaligned opaque pointer\n"); abort(); @@ -542,20 +583,92 @@ void MosaicGPUCustomCall(void* stream, void** buffers, char* opaque, abort(); } CacheKey key(hash, reinterpret_cast(ctx)); - auto ctx_and_kernel = CachedCompileAndInit(key, opaque + sizeof(KernelHash)); - if (!ctx_and_kernel.ok()) { + auto compiled_kernel = CachedCompileAndInit(key, opaque + sizeof(KernelHash)); + if (!compiled_kernel.ok()) { XlaCustomCallStatusSetFailure(status, - ctx_and_kernel.status().message().data(), - ctx_and_kernel.status().message().size()); + compiled_kernel.status().message().data(), + compiled_kernel.status().message().size()); return; } - void* args[4] = {&std::get<0>(*ctx_and_kernel), &stream, &buffers}; - std::get<1>(*ctx_and_kernel)(args); + auto ctx_kernel_comm = (*compiled_kernel)->GetHostLaunch(); + bool is_comm_used = std::get<2>(ctx_kernel_comm); + void* args[4] = {&std::get<0>(ctx_kernel_comm), &stream, &buffers}; + if (is_comm_used) { + mosaic::gpu::NvshmemApi::Default().barrier_all_on_stream( + reinterpret_cast(stream)); + } + std::get<1>(ctx_kernel_comm)(args); } XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("mosaic_gpu", &MosaicGPUCustomCall, "CUDA"); +absl::Status MosaicGpuExecute(gpuStream_t stream, ffi::RemainingArgs inputs, + ffi::RemainingRets results, + absl::string_view kernel_hash, + absl::string_view module, + bool use_custom_barrier, + xla::RunId run_id) { + // Updated version using the new FFI API supporting custom barrier + // for distributed kernels + if (use_custom_barrier) { + fprintf(stderr, "Custom barrier is not supported on GPUs.\n"); + abort(); + } + if (reinterpret_cast(kernel_hash.data()) % + alignof(KernelHash) || + kernel_hash.size() != sizeof(KernelHash)) { + fprintf(stderr, "Misaligned opaque pointer\n"); + abort(); + } + auto hash = *reinterpret_cast(kernel_hash.data()); + CUcontext ctx; + if (cuCtxGetCurrent(&ctx) != CUDA_SUCCESS) { + fprintf(stderr, "Failed to get current CUDA context\n"); + abort(); + } + CacheKey key(hash, reinterpret_cast(ctx)); + TF_ASSIGN_OR_RETURN(auto compiled_kernel, CachedCompileAndInit(key, module.data())); + auto ctx_kernel_comm = compiled_kernel->GetHostLaunch(); + bool is_comm_used = std::get<2>(ctx_kernel_comm); + + std::vector buffers; + buffers.reserve(inputs.size() + results.size()); + for (int i = 0; i < inputs.size(); ++i) { + buffers.push_back(inputs.get(i)->untyped_data()); + } + for (int i = 0; i < results.size(); ++i) { + buffers.push_back((*results.get(i))->untyped_data()); + } + void **buffers_ptr = buffers.data(); + void *args[4] = {&std::get<0>(ctx_kernel_comm), &stream, &buffers_ptr}; + + if (is_comm_used) { + mosaic::gpu::NvshmemApi::Default().barrier_all_on_stream( + reinterpret_cast(stream)); + } + std::get<1>(ctx_kernel_comm)(args); + return absl::OkStatus(); +} + +XLA_FFI_DEFINE_HANDLER(kMosaicGpuExecute, MosaicGpuExecute, + ffi::Ffi::Bind() + .Ctx>() + .RemainingArgs() + .RemainingRets() + .Attr("kernel_hash") + .Attr("module") + .Attr("use_custom_barrier") + .Ctx()); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "mosaic_gpu_v2", "CUDA", + { + /*instantiate=*/nullptr, + /*prepare=*/nullptr, + /*initialize=*/nullptr, + /*execute=*/kMosaicGpuExecute, + }); + } // namespace extern "C" { @@ -566,7 +679,7 @@ void** MosaicGpuCompile(const char* module) { if (!compiled.ok()) { return nullptr; } - auto [ctx, launch] = compiled->GetHostLaunch(); + auto [ctx, launch, is_comm_used] = compiled->GetHostLaunch(); auto tuple_ptr = std::unique_ptr(new void*[3]); if (!tuple_ptr) { return nullptr; diff --git a/jaxlib/mosaic/gpu/mosaic_gpu_comm.h b/jaxlib/mosaic/gpu/mosaic_gpu_comm.h new file mode 100644 index 000000000000..b0bd94883e43 --- /dev/null +++ b/jaxlib/mosaic/gpu/mosaic_gpu_comm.h @@ -0,0 +1,86 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_MOSAIC_GPU_COMM_H_ +#define JAXLIB_MOSAIC_GPU_COMM_H_ + +#include +#include +#include + +#include "third_party/gpus/cuda/include/cuda.h" +#include "third_party/gpus/cuda/include/cuda_runtime_api.h" + +#define NVSHMEM_SUCCESS 0 +#define NVSHMEM_LIB_SONAME "libnvshmem_host.so.3" + +namespace mosaic { +namespace gpu { + +#define NVSHMEM_SET_FN(FnName) \ + FnName = reinterpret_cast(dlsym(library, #FnName)); \ + if (!FnName) { \ + fprintf(stderr, #FnName " not available in this library."); \ + abort(); \ + } + +class NvshmemApi { + public: + // Returns a default NvshmemApi for a current process. + // NvshmemApi follows the Singleton design pattern + static NvshmemApi& Default() { + static NvshmemApi instance; + return instance; + } + + int cumodule_int(CUmodule module) { + std::lock_guard lock(mutex_); + return nvshmemx_cumodule_init(module); + } + + void barrier_all_on_stream(cudaStream_t stream) { + nvshmemx_barrier_all_on_stream(stream); + } + + NvshmemApi(NvshmemApi const&) = delete; + void operator=(NvshmemApi const&) = delete; + + private: + NvshmemApi() { + const char* env_value = getenv("NVSHMEM_LIBRARY_PATH"); + const char* libnvshmem_path = + env_value && *env_value != 0 ? env_value : NVSHMEM_LIB_SONAME; + void* library = dlopen(libnvshmem_path, RTLD_LAZY); + if (library == nullptr) { + fprintf(stderr, "Failed to open %s library: %s", libnvshmem_path, dlerror()); + abort(); + } + + // Initialize supported NVSHMEM host API + NVSHMEM_SET_FN(nvshmemx_cumodule_init) + NVSHMEM_SET_FN(nvshmemx_barrier_all_on_stream) + } + + // Dlopened NVSHMEM API + int (*nvshmemx_cumodule_init)(CUmodule); + int (*nvshmemx_barrier_all_on_stream)(cudaStream_t); + + std::mutex mutex_; +}; + +} // namespace gpu +} // namespace mosaic + +#endif // JAXLIB_MOSAIC_GPU_COMM_H_ diff --git a/jaxlib/mosaic/gpu/runtime.cc b/jaxlib/mosaic/gpu/runtime.cc index ad3cd0e19644..6897bcf350df 100644 --- a/jaxlib/mosaic/gpu/runtime.cc +++ b/jaxlib/mosaic/gpu/runtime.cc @@ -17,8 +17,10 @@ limitations under the License. #include #include +#include "jaxlib/mosaic/gpu/mosaic_gpu_comm.h" #include "third_party/gpus/cuda/include/cuda.h" + extern "C" { void mosaic_gpu_init_tma_desc(CUtensorMap *tma_desc, void *base_addr, @@ -154,6 +156,17 @@ void* mosaic_gpu_module_load(void *data) { fprintf(stderr, "cuModuleLoadData failed: %s\n", ptr); abort(); } + + CUdeviceptr ptr = 0; + size_t size = 0; + // Check if module contains NVSHMEM globals implying NVSHMEM state needs to set + if (cuModuleGetGlobal(&ptr, &size, module, "nvshmemi_device_lib_version_d") == CUDA_SUCCESS) { + if (mosaic::gpu::NvshmemApi::Default().cumodule_int(module) != NVSHMEM_SUCCESS) { + fprintf(stderr, "nvshmemx_cumodule_init failed.\n"); + abort(); + } + } + return module; } From ca36047ac91b4e1b5107cfa55da7a7cd6301716d Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 31 Mar 2025 15:14:47 -0700 Subject: [PATCH 301/483] __jax_array__: add support in jnp.reshape, jnp.transpose, jnp.matrix_transpose --- jax/_src/numpy/lax_numpy.py | 8 ++++---- tests/array_extensibility_test.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 63edaed0adeb..7b900e09068e 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -1203,8 +1203,8 @@ def transpose(a: ArrayLike, axes: Sequence[int] | None = None) -> Array: Array([[1, 3], [2, 4]], dtype=int32) """ - util.check_arraylike("transpose", a) - axes_ = list(range(np.ndim(a))[::-1]) if axes is None else axes + a = util.ensure_arraylike("transpose", a) + axes_ = list(range(a.ndim)[::-1]) if axes is None else axes axes_ = [_canonicalize_axis(i, np.ndim(a)) for i in axes_] return lax.transpose(a, axes_) @@ -1285,8 +1285,8 @@ def matrix_transpose(x: ArrayLike, /) -> Array: [[5, 7], [6, 8]]], dtype=int32) """ - util.check_arraylike("matrix_transpose", x) - ndim = np.ndim(x) + x = util.ensure_arraylike("matrix_transpose", x) + ndim = x.ndim if ndim < 2: raise ValueError(f"x must be at least two-dimensional for matrix_transpose; got {ndim=}") axes = (*range(ndim - 2), ndim - 1, ndim - 2) diff --git a/tests/array_extensibility_test.py b/tests/array_extensibility_test.py index fae9129dd99a..63a8762cd0b0 100644 --- a/tests/array_extensibility_test.py +++ b/tests/array_extensibility_test.py @@ -375,7 +375,7 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: NumPyAPI.sig(jnp.logical_or, Int[5], Int[5]), NumPyAPI.sig(jnp.logical_xor, Int[5], Int[5]), NumPyAPI.sig(jnp.matmul, Float[5, 5], Float[5]), - # NumPyAPI.sig(jnp.matrix_transpose, Float[5, 6]), + NumPyAPI.sig(jnp.matrix_transpose, Float[5, 6]), NumPyAPI.sig(jnp.matvec, Float[5, 5], Float[5]), NumPyAPI.sig(jnp.max, Float[5]), NumPyAPI.sig(jnp.maximum, Float[5], Float[5]), @@ -442,7 +442,7 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: NumPyAPI.sig(jnp.reciprocal, Float[5]), NumPyAPI.sig(jnp.remainder, Float[5], Float[5]), # NumPyAPI.sig(jnp.repeat, Float[5], Int[5]), - # NumPyAPI.sig(jnp.reshape, Float[6], (2, 3)), + NumPyAPI.sig(jnp.reshape, Float[6], shape=(2, 3)), NumPyAPI.sig(jnp.resize, Float[6], new_shape=(2, 3)), NumPyAPI.sig(jnp.right_shift, Int[5], Int[5]), NumPyAPI.sig(jnp.rint, Float[5]), @@ -481,7 +481,7 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: NumPyAPI.sig(jnp.tensordot, Float[2, 3, 4], Float[3, 4, 5]), # NumPyAPI.sig(jnp.tile, Float[5], reps=(2,)), NumPyAPI.sig(jnp.trace, Float[5, 5]), - # NumPyAPI.sig(jnp.transpose, Float[5, 6]), + NumPyAPI.sig(jnp.transpose, Float[5, 6]), NumPyAPI.sig(jnp.trapezoid, Float[5]), NumPyAPI.sig(jnp.tril, Float[5, 6]), NumPyAPI.sig(jnp.tril_indices_from, Float[5, 6]), From f59f615f6f45c4524c9326daa800c95bead6cbf0 Mon Sep 17 00:00:00 2001 From: Ayaka Date: Mon, 31 Mar 2025 15:54:01 -0700 Subject: [PATCH 302/483] Minor docstring updates for AOT wrappers in error checking PiperOrigin-RevId: 742431349 --- jax/_src/error_check.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/jax/_src/error_check.py b/jax/_src/error_check.py index e78b9bc82115..b80def4fd2db 100644 --- a/jax/_src/error_check.py +++ b/jax/_src/error_check.py @@ -289,6 +289,11 @@ def wrap_for_export(f): function scope, making it possible to export the function and later import in other processes. + When the function is later imported, it must be wrapped with + :func:`unwrap_from_import` to integrate the error checking mechanism of the + imported function into the global error checking mechanism of the current + process. + This function should only be applied once to a function; wrapping the same function multiple times is unnecessary. """ @@ -327,6 +332,9 @@ def unwrap_from_import(f): separate from the global error state of the current process. This wrapper ensures that errors detected during execution are correctly integrated into the global error checking mechanism of the current process. + + This function should only be applied to functions that were previously wrapped + with :func:`wrap_for_export` before export. """ if _error_storage.ref is None: with core.eval_context(): From 4003e2d0eec70ab0b5ed4e5c8bad8a1148a2efd8 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 27 Mar 2025 13:29:34 -0700 Subject: [PATCH 303/483] jnp.power: support __jax_array__ on inputs --- jax/_src/numpy/ufuncs.py | 5 +++++ tests/array_extensibility_test.py | 4 ++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index e561b7ae71b6..3902b24b35ac 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -2652,6 +2652,11 @@ def power(x1: ArrayLike, x2: ArrayLike, /) -> Array: [nan, 27., 1.]], dtype=float32) """ check_arraylike("power", x1, x2) + + # Must do __jax_array__ conversion prior to dtype check. + x1 = x1.__jax_array__() if hasattr(x1, "__jax_array__") else x1 + x2 = x2.__jax_array__() if hasattr(x2, "__jax_array__") else x2 + check_no_float0s("power", x1, x2) # We apply special cases, both for algorithmic and autodiff reasons: diff --git a/tests/array_extensibility_test.py b/tests/array_extensibility_test.py index 63a8762cd0b0..45847b6f0f29 100644 --- a/tests/array_extensibility_test.py +++ b/tests/array_extensibility_test.py @@ -427,8 +427,8 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: NumPyAPI.sig(jnp.polysub, Float[5], Float[5]), NumPyAPI.sig(jnp.polyval, Float[5], Float[10]), NumPyAPI.sig(jnp.positive, Float[5]), - # NumPyAPI.sig(jnp.pow, Float[5], Float[5]), - # NumPyAPI.sig(jnp.power, Float[5], Float[5]), + NumPyAPI.sig(jnp.pow, Float[5], Float[5]), + NumPyAPI.sig(jnp.power, Float[5], Float[5]), NumPyAPI.sig(jnp.prod, Float[5]), NumPyAPI.sig(jnp.ptp, Float[5]), NumPyAPI.sig(jnp.put, Float[5], Int[()], Float[()], inplace=False), From 994af3efb85339b69112b9e75c9975d24d90d8b3 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 31 Mar 2025 17:37:13 -0700 Subject: [PATCH 304/483] [Pallas TPU] Remove forward compatibility code for float -> signed conversions This will be submitted automatically once the compatibility window has passed PiperOrigin-RevId: 742464046 --- jax/_src/pallas/mosaic/lowering.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 617324d43bf9..1139630ae602 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -2096,16 +2096,6 @@ def _convert_helper(x, *, to_dtype): # unsigned -> float is unsupported. We fall through and raise at the bottom. if not jnp.issubdtype(to_dtype, jnp.floating): return x.astype(to_dtype) - if jnp.issubdtype(from_dtype, jnp.floating) and jnp.issubdtype( - to_dtype, jnp.signedinteger - ): - if from_dtype.itemsize < 4: - x = x.astype(jnp.float32) - if to_dtype.itemsize < 4: - # Need to clip values to match XLA - minval, maxval = jnp.iinfo(to_dtype).min, jnp.iinfo(to_dtype).max - x = jnp.clip(x, minval, maxval) - return x.astype(jnp.int32).astype(to_dtype) raise NotImplementedError(f"Unsupported cast: {from_dtype} -> {to_dtype}") def _convert_element_type_lowering_rule( @@ -2149,10 +2139,7 @@ def _convert_element_type_lowering_rule( return x # TODO(apaszke): Remove both_32bit constraints using the Mosaic canonicalizer. elif _from(floating) and _to(signed): - # TODO(apaszke): Remove once a month has passed, along with the - # _convert_helper float -> signed conversion above. - if not ctx.forward_compatible or both_32bit: - return arith.fptosi(out_type, x) + return arith.fptosi(out_type, x) elif _from(signed) and _to(floating) and both_32bit: return arith.sitofp(out_type, x) elif old_dtype == jnp.bool_ and _to(integer) and new_dtype.itemsize == 4: From 006a6a63feb64bf9984526030ba008186d69d2b4 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 31 Mar 2025 22:01:48 -0700 Subject: [PATCH 305/483] [Easy] Make pallas mesh grid handling more resilient to tuple names. PiperOrigin-RevId: 742531956 --- jax/_src/lax/parallel.py | 2 +- jax/_src/pallas/mosaic/lowering.py | 16 +++++++++++----- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index ebc6255cb66b..39b6c68679ca 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -1861,8 +1861,8 @@ def _axis_index_lowering(ctx, *, axis_name): ctx.module_context.axis_env)] def _axis_index_effectful_abstract_eval(*, axis_name): - effect = {core.NamedAxisEffect(axis_name)} axis_name = (axis_name,) if not isinstance(axis_name, tuple) else axis_name + effect = {core.NamedAxisEffect(axis_name)} _check_axis_names(axis_name) mesh = get_abstract_mesh() sharding = NamedSharding(mesh, P()) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 1139630ae602..f8f49f3d7aea 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -517,11 +517,17 @@ def has_communication(self) -> bool: nonlocal_axis_names = set() def _get_nonlocal_axis_names(jaxpr: jax_core.Jaxpr): return { - e.name - for e in jaxpr.effects - if isinstance(e, jax_core.NamedAxisEffect) - and (not self.grid_names or e.name not in self.grid_names) - } + e.name + for e in jaxpr.effects + if isinstance(e, jax_core.NamedAxisEffect) + and ( + not self.grid_names + or all( + name not in self.grid_names + for name in tree_util.tree_leaves(e.name) + ) + ) + } nonlocal_axis_names.update(_get_nonlocal_axis_names(self.jaxpr)) for bm in self.block_mappings: if bm is not None: From 6adb7289754edff320a670630dec12fc697100ab Mon Sep 17 00:00:00 2001 From: Louis-Justin TALLOT <72044417+LouisJustinTALLOT@users.noreply.github.com> Date: Tue, 1 Apr 2025 02:46:30 -0400 Subject: [PATCH 306/483] Clarify documentation of jnp.heaviside --- jax/_src/numpy/ufuncs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 3902b24b35ac..60e10b3be048 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -3640,9 +3640,9 @@ def heaviside(x1: ArrayLike, x2: ArrayLike, /) -> Array: .. math:: \mathrm{heaviside}(x1, x2) = \begin{cases} - 0., & x < 0\\ - x2, & x = 0\\ - 1., & x > 0. + 0, & x1 < 0\\ + x2, & x1 = 0\\ + 1, & x1 > 0. \end{cases} Args: From 5d1bc005a00546ece0172d01bd3434f5026c80c6 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 1 Apr 2025 05:02:28 -0700 Subject: [PATCH 307/483] Update XLA dependency to use revision http://github.com/openxla/xla/commit/b1971cc2b3407e87fada2674a057d72897b79acc. PiperOrigin-RevId: 742646393 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index d078359af86a..13223c4a4b88 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "f4a53456b04acf9b63b3b30bd828cec29c4aa7de" -XLA_SHA256 = "2ee32b70af547fd13ce404d75c3fa9834bc8be46a488cd8f0caa10e9a6ec7ede" +XLA_COMMIT = "b1971cc2b3407e87fada2674a057d72897b79acc" +XLA_SHA256 = "3b2feabbcd6adc5721533edfbe3dc2ad6517cb1b059cf41dea63f62874bff12d" def repo(): tf_http_archive( From 40a3d0c78dad1d539180a4f830a3e1c17460ced0 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 1 Apr 2025 09:11:08 -0700 Subject: [PATCH 308/483] Create the test targets for the wheel size verification. Add the tests to the Bazel presubmit RBE jobs (except `arm64`/`aarch64` jobs that use RBE cross-compilation). PiperOrigin-RevId: 742724458 --- BUILD.bazel | 16 ++++++++++ ci/run_bazel_test_cpu_rbe.sh | 4 ++- ci/run_bazel_test_cuda_rbe.sh | 8 ++++- jaxlib/tools/BUILD.bazel | 48 ++++++++++++++++++++++++++++ jaxlib/tools/wheel_size_test.py | 56 +++++++++++++++++++++++++++++++++ 5 files changed, 130 insertions(+), 2 deletions(-) create mode 100644 jaxlib/tools/wheel_size_test.py diff --git a/BUILD.bazel b/BUILD.bazel index 2c10f0d9a748..8dbf2bed0902 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -22,6 +22,7 @@ load( "jax_source_package", "jax_wheel", "py_deps", + "pytype_test", ) collect_data_files( @@ -152,3 +153,18 @@ py_import( wheel_deps = [":wheel_additives"], deps = COMMON_DEPS, ) + +pytype_test( + name = "jax_wheel_size_test", + srcs = ["//jaxlib/tools:wheel_size_test.py"], + args = [ + "--wheel-path=$(location :jax_wheel)", + "--max-size-mib=5", + ], + data = [":jax_wheel"], + main = "wheel_size_test.py", + tags = [ + "manual", + "notap", + ], +) diff --git a/ci/run_bazel_test_cpu_rbe.sh b/ci/run_bazel_test_cpu_rbe.sh index 248111e0247a..d8cb190079e0 100755 --- a/ci/run_bazel_test_cpu_rbe.sh +++ b/ci/run_bazel_test_cpu_rbe.sh @@ -64,5 +64,7 @@ else --action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \ --test_output=errors \ --color=yes \ - //tests:cpu_tests //tests:backend_independent_tests + //tests:cpu_tests //tests:backend_independent_tests \ + //jaxlib/tools:jaxlib_wheel_size_test \ + //:jax_wheel_size_test fi \ No newline at end of file diff --git a/ci/run_bazel_test_cuda_rbe.sh b/ci/run_bazel_test_cuda_rbe.sh index 17bd8d9db4f8..94c6a89fdb8c 100755 --- a/ci/run_bazel_test_cuda_rbe.sh +++ b/ci/run_bazel_test_cuda_rbe.sh @@ -48,4 +48,10 @@ bazel test --config=rbe_linux_x86_64_cuda \ --test_env=JAX_SKIP_SLOW_TESTS=true \ --action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \ --color=yes \ - //tests:gpu_tests //tests:backend_independent_tests //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests \ No newline at end of file + --@local_config_cuda//cuda:override_include_cuda_libs=true \ + //tests:gpu_tests //tests:backend_independent_tests \ + //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests \ + //jaxlib/tools:jax_cuda_plugin_wheel_size_test \ + //jaxlib/tools:jax_cuda_pjrt_wheel_size_test \ + //jaxlib/tools:jaxlib_wheel_size_test \ + //:jax_wheel_size_test \ No newline at end of file diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index 2ddc9e90a702..79a1f7e7089d 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -33,12 +33,15 @@ load( "jax_py_test", "jax_wheel", "pytype_strict_library", + "pytype_test", ) licenses(["notice"]) # Apache 2 package(default_visibility = ["//visibility:public"]) +exports_files(["wheel_size_test.py"]) + genrule( name = "platform_tags_py", srcs = [], @@ -389,3 +392,48 @@ verify_manylinux_compliance_test( wheel = ":jax_cuda_pjrt_wheel", x86_64_compliance_tag = X86_64_MANYLINUX_TAG, ) + +pytype_test( + name = "jaxlib_wheel_size_test", + srcs = [":wheel_size_test.py"], + args = [ + "--wheel-path=$(location :jaxlib_wheel)", + "--max-size-mib=110", + ], + data = [":jaxlib_wheel"], + main = "wheel_size_test.py", + tags = [ + "manual", + "notap", + ], +) + +pytype_test( + name = "jax_cuda_plugin_wheel_size_test", + srcs = [":wheel_size_test.py"], + args = [ + "--wheel-path=$(location :jax_cuda_plugin_wheel)", + "--max-size-mib=20", + ], + data = [":jax_cuda_plugin_wheel"], + main = "wheel_size_test.py", + tags = [ + "manual", + "notap", + ], +) + +pytype_test( + name = "jax_cuda_pjrt_wheel_size_test", + srcs = [":wheel_size_test.py"], + args = [ + "--wheel-path=$(location :jax_cuda_pjrt_wheel)", + "--max-size-mib=120", + ], + data = [":jax_cuda_pjrt_wheel"], + main = "wheel_size_test.py", + tags = [ + "manual", + "notap", + ], +) diff --git a/jaxlib/tools/wheel_size_test.py b/jaxlib/tools/wheel_size_test.py new file mode 100644 index 000000000000..7e9c08ff9797 --- /dev/null +++ b/jaxlib/tools/wheel_size_test.py @@ -0,0 +1,56 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import os + + +def parse_args(): + """Arguments parser.""" + parser = argparse.ArgumentParser( + description="Helper for the wheel size verification", + fromfile_prefix_chars="@", + ) + parser.add_argument( + "--wheel-path", required=True, help="Path of the wheel, mandatory" + ) + parser.add_argument( + "--max-size-mib", + required=True, + help="Maximum size of the wheel in MiB", + ) + return parser.parse_args() + + +def verify_wheel_size(args): + wheel_size_mib = os.path.getsize(args.wheel_path) >> 20 + wheel_name = os.path.basename(args.wheel_path) + if wheel_size_mib > int(args.max_size_mib): + raise RuntimeError( + "The {name} size is {size} MiB, which is larger than the maximum size" + " {max_size} MiB".format( + name=wheel_name, + size=wheel_size_mib, + max_size=args.max_size_mb, + ) + ) + else: + logging.info( + "The %s size is %s MiB, which is less than the maximum size" + " %s MB", wheel_name, wheel_size_mib, args.max_size_mib) + + +if __name__ == "__main__": + verify_wheel_size(parse_args()) From 76271d638ad94f3df854054640f3b35161ee5be4 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 1 Apr 2025 09:50:00 -0700 Subject: [PATCH 309/483] Add scan_p and cond_p vma rule. PiperOrigin-RevId: 742737384 --- jax/_src/core.py | 7 ++++--- jax/_src/lax/control_flow/conditionals.py | 9 +++++++++ jax/_src/lax/control_flow/loops.py | 12 ++++++++++-- jax/_src/state/types.py | 9 +++++++++ jax/experimental/shard_map.py | 4 ++-- 5 files changed, 34 insertions(+), 7 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index ee6537650f20..ae94782ce98a 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -574,7 +574,7 @@ def read(v: Atom) -> Any: def write(v: Var, val: Any) -> None: if config.enable_checks.value and not config.dynamic_shapes.value: - assert typecheck(v.aval, val), (v.aval, val) + assert typecheck(v.aval, val), (v.aval, get_aval(val)) env[v] = val env: dict[Var, Any] = {} @@ -2594,7 +2594,7 @@ def _map_shaped_array( if axis is None: return aval sharding = aval.sharding.with_spec(tuple_delete(aval.sharding.spec, axis)) return ShapedArray(tuple_delete(aval.shape, axis), aval.dtype, - weak_type=aval.weak_type, sharding=sharding) + weak_type=aval.weak_type, sharding=sharding, vma=aval.vma) def _unmap_shaped_array( size: int, axis: int | None, explicit_mesh_axis, aval: ShapedArray @@ -2604,7 +2604,8 @@ def _unmap_shaped_array( sharding = aval.sharding.with_spec(tuple_insert( aval.sharding.spec, axis, explicit_mesh_axis)) return ShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype, - weak_type=aval.weak_type, sharding=sharding) + weak_type=aval.weak_type, sharding=sharding, + vma=aval.vma) else: raise TypeError(axis) def _map_dshaped_array( diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 63896cc2a0bf..b0e1221752bd 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -347,6 +347,15 @@ def _cond_abstract_eval(*avals: core.AbstractValue, if disallowed_effects: raise NotImplementedError( f'Effects not supported in `cond`: {disallowed_effects}') + b0_vma = [o.vma for o in branches[0].out_avals] + for branch in branches[1:]: + b_vma = [o.vma for o in branch.out_avals] + if b0_vma != b_vma: + raise Exception("The branches of cond produced mismatched varying manual " + f"axes. Got {b0_vma} and {b_vma}. Please open an issue " + "at https://github.com/jax-ml/jax/issues, and as a " + "temporary workaround pass the check_rep=False argument " + "to shard_map") return branches[0].out_avals, joined_effects def _bcast_select(pred, on_true, on_false): diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 56323949a607..9a66dd037d3a 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -570,9 +570,17 @@ def _prepend_dim_to_aval(sz, aval): def _scan_abstract_eval(*args, reverse, length, num_consts, num_carry, jaxpr, linear, unroll, _split_transpose): - carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry]) + out_carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry]) + _, in_carry_avals, _ = split_list(args, [num_consts, num_carry]) + if [i.vma for i in in_carry_avals] != [o.vma for o in out_carry_avals]: + raise ValueError( + 'Scan carry input and output got mismatched varying manual axes ' + f'{in_carry_avals} and {out_carry_avals}. Please open an ' + 'issue at https://github.com/jax-ml/jax/issues, and as a ' + 'temporary workaround pass the check_rep=False argument to ' + 'shard_map') ys_avals = _map(partial(_prepend_dim_to_aval, length), y_avals) - return carry_avals + ys_avals, jaxpr.effects + return out_carry_avals + ys_avals, jaxpr.effects def _scan_jvp(primals, tangents, reverse, length, jaxpr, num_consts, num_carry, linear, unroll, _split_transpose): diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py index e926e3a35f80..b9dbaf35c5d2 100644 --- a/jax/_src/state/types.py +++ b/jax/_src/state/types.py @@ -366,6 +366,15 @@ def sharding(self): f"`Ref{{{self.inner_aval.str_short()}}} has no `sharding`." ) from None + @property + def vma(self): + try: + return self.inner_aval.vma # pytype: disable=attribute-error + except AttributeError: + raise AttributeError( + f"`Ref{{{self.inner_aval.str_short()}}} has no `vma`." + ) from None + @core.aval_property def at(self): return RefIndexer(self) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 8e2d93af2639..4b9daf170dce 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -1365,7 +1365,7 @@ def _scan_check(mesh, *in_rep, jaxpr, num_consts, num_carry, **_): _, carry_rep_in, _ = split_list(in_rep, [num_consts, num_carry]) out_rep = _check_rep(mesh, jaxpr.jaxpr, in_rep) carry_rep_out, _ = split_list(out_rep, [num_carry]) - if not carry_rep_in == carry_rep_out: + if carry_rep_in != carry_rep_out: raise Exception("Scan carry input and output got mismatched replication " f"types {carry_rep_in} and {carry_rep_out}. Please open an " "issue at https://github.com/jax-ml/jax/issues, and as a " @@ -1403,7 +1403,7 @@ def _cond_rule(mesh, *in_rep, branches): out_rep = _check_rep(mesh, branches[0].jaxpr, args_rep) for branch in branches[1:]: out_rep_ = _check_rep(mesh, branch.jaxpr, args_rep) - if not out_rep_ == out_rep: + if out_rep_ != out_rep: raise Exception("The branches of cond produced mismatched replication " "types. Please open an issue at " "https://github.com/jax-ml/jax/issues, and as a " From a34c4628755877f431d075cde50d73ee33158b34 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 1 Apr 2025 09:53:29 -0700 Subject: [PATCH 310/483] jnp.select: support __jax_array__ for inputs --- jax/_src/numpy/lax_numpy.py | 6 ++++++ tests/array_extensibility_test.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 7b900e09068e..a47f66e5f621 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -2907,6 +2907,12 @@ def select( raise ValueError(msg.format(len(condlist), len(choicelist))) if len(condlist) == 0: raise ValueError("condlist must be non-empty") + + util.check_arraylike("select", *condlist, *choicelist, default) + condlist = [asarray(cond) for cond in condlist] + choicelist = [asarray(choice) for choice in choicelist] + default = asarray(default) + # Put the default at front with condition False because # argmax returns zero for an array of False values. choicelist = util.promote_dtypes(default, *choicelist) diff --git a/tests/array_extensibility_test.py b/tests/array_extensibility_test.py index 45847b6f0f29..f62491a608c7 100644 --- a/tests/array_extensibility_test.py +++ b/tests/array_extensibility_test.py @@ -452,7 +452,7 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: NumPyAPI.sig(jnp.rot90, Float[5, 3]), NumPyAPI.sig(jnp.round, Float[5]), NumPyAPI.sig(jnp.searchsorted, Float[5], Float[5]), - # NumPyAPI.sig(jnp.select, [Bool[5], Bool[5]], [Float[5], Float[5]], Float[5]), + NumPyAPI.sig(jnp.select, [Bool[5], Bool[5]], [Float[5], Float[5]], Float[()]), NumPyAPI.sig(jnp.setdiff1d, Int[5], Int[5]), NumPyAPI.sig(jnp.setxor1d, Int[5], Int[5]), NumPyAPI.sig(jnp.shape, Float[5, 3]), From a80f6279e9eba6ec0aa1fc2b37e979f883768c31 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Tue, 1 Apr 2025 00:17:30 +0000 Subject: [PATCH 311/483] make random_gamma_grad not a primitive anymore Fixes #16076 Co-authored-by: Roy Frostig --- jax/_src/checkify.py | 2 +- jax/_src/lax/special.py | 47 ++++++++++++++++------------------- jax/extend/core/primitives.py | 1 - jax/lax/__init__.py | 1 - 4 files changed, 23 insertions(+), 28 deletions(-) diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index f80a0cbd1d75..f0abf53b0717 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -600,7 +600,7 @@ def isnan(x): lax.igamma_p, lax.igammac_p, lax.integer_pow_p, lax.lgamma_p, lax.linear_solve_p, lax.log1p_p, lax.log_p, lax.logistic_p, lax.mul_p, lax.pad_p, lax.pow_p, lax.psum_p, - lax.random_gamma_grad_p, lax.reduce_p, lax.reduce_prod_p, + lax.reduce_p, lax.reduce_prod_p, lax.reduce_sum_p, lax.reduce_window_p, lax.reduce_window_sum_p, lax.regularized_incomplete_beta_p, lax.rem_p, lax.rng_uniform_p, lax.rsqrt_p, lax.sin_p, diff --git a/jax/_src/lax/special.py b/jax/_src/lax/special.py index 041205156d58..a59d62523c9f 100644 --- a/jax/_src/lax/special.py +++ b/jax/_src/lax/special.py @@ -38,6 +38,25 @@ from jax._src.lib.mlir.dialects import chlo from jax._src.typing import Array, ArrayLike +# TODO(mattjj): this function sucks, delete it +def _up_and_broadcast(doit): + def up_and_broadcast(*args): + broadcasted_shape = broadcast_shapes(*(a.shape for a in args)) + args = [broadcast_in_dim(a, broadcasted_shape, list(range(a.ndim))) for a in args] + + a_dtype = args[0].dtype + needs_upcast = a_dtype == dtypes.bfloat16 or a_dtype == np.float16 + if needs_upcast: + args = [convert_element_type(a, np.float32) for a in args] + a_x_type = np.float32 + else: + a_x_type = a_dtype + result = doit(*args, dtype=a_x_type) + if needs_upcast: + result = convert_element_type(result, a_dtype) + return result + return up_and_broadcast + def betainc(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array: r"""Elementwise regularized incomplete beta integral.""" a, b, x = core.standard_insert_pbroadcast(a, b, x) @@ -71,10 +90,11 @@ def igamma_grad_a(a: ArrayLike, x: ArrayLike) -> Array: a, x = core.standard_insert_pbroadcast(a, x) return igamma_grad_a_p.bind(a, x) -def random_gamma_grad(a: ArrayLike, x: ArrayLike) -> Array: +@_up_and_broadcast +def random_gamma_grad(a: ArrayLike, x: ArrayLike, *, dtype) -> Array: r"""Elementwise derivative of samples from `Gamma(a, 1)`.""" a, x = core.standard_insert_pbroadcast(a, x) - return random_gamma_grad_p.bind(a, x) + return random_gamma_grad_impl(a, x, dtype=dtype) def zeta(x: ArrayLike, q: ArrayLike) -> Array: r"""Elementwise Hurwitz zeta function: :math:`\zeta(x, q)`""" @@ -531,24 +551,6 @@ def random_gamma_grad_impl(a, x, *, dtype): full_like(a, float('nan')), output) return output -def _up_and_broadcast(doit): - def up_and_broadcast(*args): - broadcasted_shape = broadcast_shapes(*(a.shape for a in args)) - args = [broadcast_in_dim(a, broadcasted_shape, list(range(a.ndim))) for a in args] - - a_dtype = args[0].dtype - needs_upcast = a_dtype == dtypes.bfloat16 or a_dtype == np.float16 - if needs_upcast: - args = [convert_element_type(a, np.float32) for a in args] - a_x_type = np.float32 - else: - a_x_type = a_dtype - result = doit(*args, dtype=a_x_type) - if needs_upcast: - result = convert_element_type(result, a_dtype) - return result - return up_and_broadcast - def evaluate_chebyshev_polynomial(x, coefficients): b0 = full_like(x,0) @@ -694,11 +696,6 @@ def bessel_i0e_impl(x): ad.defjvp(igammac_p, igammac_grada, igammac_gradx) -random_gamma_grad_p = standard_naryop([_float, _float], 'random_gamma_grad') -mlir.register_lowering(random_gamma_grad_p, - mlir.lower_fun(_up_and_broadcast(random_gamma_grad_impl), - multiple_results=False)) - zeta_p = standard_naryop([_float, _float], 'zeta') mlir.register_lowering(zeta_p, partial(_nary_lower_hlo, chlo.zeta)) diff --git a/jax/extend/core/primitives.py b/jax/extend/core/primitives.py index d8a10154cf4a..60d8cd24a949 100644 --- a/jax/extend/core/primitives.py +++ b/jax/extend/core/primitives.py @@ -149,7 +149,6 @@ igamma_p as igamma_p, lgamma_p as lgamma_p, polygamma_p as polygamma_p, - random_gamma_grad_p as random_gamma_grad_p, regularized_incomplete_beta_p as regularized_incomplete_beta_p, zeta_p as zeta_p, ) diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index 6f2163c424a6..43c4cf17e559 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -261,7 +261,6 @@ polygamma as polygamma, polygamma_p as polygamma_p, random_gamma_grad as random_gamma_grad, - random_gamma_grad_p as random_gamma_grad_p, regularized_incomplete_beta_p as regularized_incomplete_beta_p, zeta as zeta, zeta_p as zeta_p, From 2d2be0bbb922c2571d433eeaeb5209c043334e97 Mon Sep 17 00:00:00 2001 From: Zac Cranko Date: Tue, 1 Apr 2025 10:44:52 -0700 Subject: [PATCH 312/483] Update permisisons community_release_actions.yml --- .github/workflows/community_release_actions.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/community_release_actions.yml b/.github/workflows/community_release_actions.yml index 1980e803ba9b..1110cbad9475 100644 --- a/.github/workflows/community_release_actions.yml +++ b/.github/workflows/community_release_actions.yml @@ -4,6 +4,9 @@ on: release: types: [published] +permissions: + contents: read + jobs: discord_release: if: github.repository_owner == 'jax-ml' From 5370ac2ec59c1acb347eb68771beec2487c8de64 Mon Sep 17 00:00:00 2001 From: Vladimir Belitskiy Date: Tue, 1 Apr 2025 11:33:00 -0700 Subject: [PATCH 313/483] Remove the try/except for Shardy imports. Shardy has been been included in JAX for a while now. PiperOrigin-RevId: 742778405 --- jax/_src/interpreters/mlir.py | 4 +--- jax/_src/lib/mlir/dialects/__init__.py | 6 +----- jax/extend/mlir/dialects/sdy.py | 6 +----- tests/pjit_test.py | 7 ------- 4 files changed, 3 insertions(+), 20 deletions(-) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 23d1b5dd9d89..a1b37876f87e 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -616,9 +616,7 @@ def make_ir_context() -> ir.Context: # we don't do any heavy computation on MLIR modules from Python anyway, so we # just disable threading. context.enable_multithreading(False) - # TODO(bartchr): Once JAX is released with SDY, remove the if. - if dialects.sdy: - dialects.sdy.register_dialect(context) + dialects.sdy.register_dialect(context) dialects.mhlo.register_mhlo_dialect(context) dialects.chlo.register_dialect(context) dialects.hlo.register_dialect(context) diff --git a/jax/_src/lib/mlir/dialects/__init__.py b/jax/_src/lib/mlir/dialects/__init__.py index a9bae8821db5..be5317824c36 100644 --- a/jax/_src/lib/mlir/dialects/__init__.py +++ b/jax/_src/lib/mlir/dialects/__init__.py @@ -51,11 +51,7 @@ ]) del _lazy -# TODO(bartchr): Once JAX is released with SDY, remove the try/except. -try: - from jaxlib.mlir.dialects import sdy as sdy -except ImportError: - sdy: Any = None # type: ignore[no-redef] +from jaxlib.mlir.dialects import sdy # Alias that is set up to abstract away the transition from MHLO to StableHLO. from jaxlib.mlir.dialects import stablehlo as hlo diff --git a/jax/extend/mlir/dialects/sdy.py b/jax/extend/mlir/dialects/sdy.py index 48586cc26760..d83fd90ecdf4 100644 --- a/jax/extend/mlir/dialects/sdy.py +++ b/jax/extend/mlir/dialects/sdy.py @@ -14,8 +14,4 @@ # ruff: noqa: F403 -# TODO(bartchr): Once JAX is released with SDY, remove the try/except. -try: - from jaxlib.mlir.dialects.sdy import * -except ImportError: - pass +from jaxlib.mlir.dialects.sdy import * diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 0b2daee8ccff..ee4a8cd3e15e 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -59,7 +59,6 @@ from jax._src import mesh as mesh_lib from jax._src.mesh import AxisType from jax._src.interpreters import pxla -from jax._src.lib.mlir import dialects from jax._src import xla_bridge from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension @@ -8067,12 +8066,6 @@ def f(x, y): @jtu.with_config(jax_use_shardy_partitioner=True) class ShardyTest(jtu.JaxTestCase): - # TODO(bartchr): Once JAX is released with SDY, remove setUp. - def setUp(self): - if not dialects.sdy: - raise unittest.SkipTest('Shardy is not available.') - super().setUp() - def test_lowering_input_output_sharding(self): mesh = jtu.create_mesh((4, 2), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2) From 0b199f48c7e0d4e5837cee34ced7f3fc7065732f Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Tue, 1 Apr 2025 11:49:03 -0700 Subject: [PATCH 314/483] [jaxlib] Roll back subbyte types due to failing asan tests. Reverts 12526ea11646a75fac201e26c1a2e901f94a4c76 PiperOrigin-RevId: 742784183 --- jaxlib/cuda/BUILD | 1 - jaxlib/gpu/py_client_gpu.cc | 89 ++++++++++++--------------------- jaxlib/rocm/BUILD | 1 - jaxlib/xla/BUILD | 1 - jaxlib/xla/py_client_cpu.cc | 81 ++++++++++-------------------- tests/python_callback_test.py | 94 +++++++++++++++-------------------- 6 files changed, 99 insertions(+), 168 deletions(-) diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index d35e421ef904..fac62c81dee7 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -689,7 +689,6 @@ cc_library( "@xla//third_party/python_runtime:headers", # buildcleaner: keep "@xla//xla:comparison_util", "@xla//xla:shape_util", - "@xla//xla:util", "@xla//xla:xla_data_proto_cc", "@xla//xla/ffi:ffi_api", "@xla//xla/ffi/api:ffi", diff --git a/jaxlib/gpu/py_client_gpu.cc b/jaxlib/gpu/py_client_gpu.cc index 38f2ac1896e7..861ffce3e749 100644 --- a/jaxlib/gpu/py_client_gpu.cc +++ b/jaxlib/gpu/py_client_gpu.cc @@ -43,7 +43,6 @@ limitations under the License. #include "xla/python/nb_numpy.h" #include "xla/python/types.h" #include "xla/shape_util.h" -#include "xla/util.h" #include "xla/xla_data.pb.h" namespace nb = nanobind; @@ -81,14 +80,13 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, for (size_t i = 0; i < arity; ++i) { auto arg = args.get(i); auto ptype = static_cast(arg->element_type()); - // TODO(b/402422886): Remove this once we form Jax arrays directly instead - // of packing/unpacking to/from numpy arrays. - if (ptype == xla::S1 || ptype == xla::U1) { + // TODO(b/395428868): Remove this check once we support subbyte types. + if (ptype == xla::S1 || ptype == xla::S2 || ptype == xla::S4 || + ptype == xla::U1 || ptype == xla::U2 || ptype == xla::U4) { return xla::ffi::Error(xla::ffi::ErrorCode::kUnimplemented, absl::StrFormat("Unsupported primitive type: %s", PrimitiveType_Name(ptype))); } - if (ptype == xla::TOKEN) { host_input_buffers[i] = nullptr; continue; @@ -114,6 +112,9 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, PyTuple_SET_ITEM(host_input_arrays.ptr(), i, nb::none().inc_ref().ptr()); continue; } + nb::capsule base(host_input_buffers[i], [](void* ptr) noexcept { + delete[] static_cast(ptr); + }); auto maybe_dtype = PrimitiveTypeToNbDtype(ptype); if (!maybe_dtype.ok()) { return xla::ffi::Error::Internal(maybe_dtype.status().ToString()); @@ -121,22 +122,8 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, auto dtype = maybe_dtype.value(); auto dims = absl::Span(arg->dimensions().begin(), arg->dimensions().size()); - // TODO(b/402422886): Remove this once we form Jax arrays directly instead - // of packing/unpacking to/from numpy arrays. - // We pass in data using default numpy layout i.e., std::nullopt. - size_t bits_per_element = xla::primitive_util::BitWidth(ptype); - if (bits_per_element == 2 || bits_per_element == 4) { - auto buffer = xla::UnpackIntN( - bits_per_element, static_cast(host_input_buffers[i]), - arg->size_bytes()); - delete[] static_cast(host_input_buffers[i]); - host_input_buffers[i] = buffer.release(); - } - nb::capsule base(host_input_buffers[i], [](void* ptr) noexcept { - delete[] static_cast(ptr); - }); auto array = xla::nb_numpy_ndarray(dtype, dims, std::nullopt, - host_input_buffers[i], /*base=*/base); + host_input_buffers[i], base); array.attr("flags").attr("writeable") = nb::bool_(false); PyTuple_SET_ITEM(host_input_arrays.ptr(), i, array.inc_ref().ptr()); } @@ -159,7 +146,8 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, auto ret = rets.get(i).value(); auto ptype = static_cast(ret->element_type()); // TODO(b/395428868): Remove this check once we support subbyte types. - if (ptype == xla::S1 || ptype == xla::U1) { + if (ptype == xla::S1 || ptype == xla::S2 || ptype == xla::S4 || + ptype == xla::U1 || ptype == xla::U2 || ptype == xla::U4) { return xla::ffi::Error(xla::ffi::ErrorCode::kUnimplemented, absl::StrFormat("Unsupported primitive type: %s", PrimitiveType_Name(ptype))); @@ -180,43 +168,32 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, } auto expected_shape = maybe_expected_shape.value(); auto expected_strides = xla::ByteStridesForShape(expected_shape); - - const void* data = array.data(); - if (strides != expected_strides) { - xla::TransposePlan::Options options; - options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); - options.dims = absl::Span( - reinterpret_cast(array.shape()), array.ndim()); - absl::InlinedVector reversed_layout; - reversed_layout.resize(expected_shape.dimensions().size()); - absl::c_reverse_copy(expected_shape.layout().minor_to_major(), - reversed_layout.begin()); - options.permutation = reversed_layout; - options.input_layout = xla::TransposePlan::Striding{strides}; - auto maybe_plan = transpose_cache->cache.GetOrCreate(options); - if (!maybe_plan.ok()) { - return xla::ffi::Error::Internal(maybe_plan.status().ToString()); - } - auto plan = maybe_plan.value(); - void* temp = new char[ret->size_bytes()]; - temp_buffers.push_back(temp); - plan->Execute(data, temp); - data = temp; + if (strides == expected_strides) { + auto gpu_res = + gpuMemcpyAsync(ret->untyped_data(), array.data(), ret->size_bytes(), + gpuMemcpyHostToDevice, stream); + CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; + continue; } - - // TODO(b/402422886): Remove this once we form Jax arrays directly instead - // of packing/unpacking to/from numpy arrays. - std::unique_ptr buffer; - size_t bits_per_element = xla::primitive_util::BitWidth(ptype); - if (bits_per_element == 2 || bits_per_element == 4) { - // NOTE(dsuo): FFI arguments and return buffers are sized assuming - // minimum 1-byte element sizes, even if the data itself is packed. - buffer = xla::PackIntN(bits_per_element, static_cast(data), - ret->size_bytes()); - data = buffer.get(); + void* temp = new char[ret->size_bytes()]; + temp_buffers.push_back(temp); + xla::TransposePlan::Options options; + options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); + options.dims = absl::Span( + reinterpret_cast(array.shape()), array.ndim()); + absl::InlinedVector reversed_layout; + reversed_layout.resize(expected_shape.dimensions().size()); + absl::c_reverse_copy(expected_shape.layout().minor_to_major(), + reversed_layout.begin()); + options.permutation = reversed_layout; + options.input_layout = xla::TransposePlan::Striding{strides}; + auto maybe_plan = transpose_cache->cache.GetOrCreate(options); + if (!maybe_plan.ok()) { + return xla::ffi::Error::Internal(maybe_plan.status().ToString()); } - - auto gpu_res = gpuMemcpyAsync(ret->untyped_data(), data, ret->size_bytes(), + auto plan = maybe_plan.value(); + plan->Execute(array.data(), temp); + auto gpu_res = gpuMemcpyAsync(ret->untyped_data(), temp, ret->size_bytes(), gpuMemcpyHostToDevice, stream); CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; } diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index 358a6d1cc9aa..d0c0c798abb8 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -588,7 +588,6 @@ cc_library( "@xla//third_party/python_runtime:headers", # buildcleaner: keep "@xla//xla:comparison_util", "@xla//xla:shape_util", - "@xla//xla:util", "@xla//xla:xla_data_proto_cc", "@xla//xla/ffi:ffi_api", "@xla//xla/ffi/api:ffi", diff --git a/jaxlib/xla/BUILD b/jaxlib/xla/BUILD index 5b532c1dc501..2ca18afda13d 100644 --- a/jaxlib/xla/BUILD +++ b/jaxlib/xla/BUILD @@ -637,7 +637,6 @@ cc_library( "@nanobind", "@xla//third_party/python_runtime:headers", # buildcleaner: keep "@xla//xla:shape_util", - "@xla//xla:util", "@xla//xla:xla_data_proto_cc", "@xla//xla/ffi:ffi_api", "@xla//xla/ffi/api:ffi", diff --git a/jaxlib/xla/py_client_cpu.cc b/jaxlib/xla/py_client_cpu.cc index fc4f895af6aa..ac4e7bee5680 100644 --- a/jaxlib/xla/py_client_cpu.cc +++ b/jaxlib/xla/py_client_cpu.cc @@ -41,7 +41,6 @@ limitations under the License. #include "xla/python/nb_numpy.h" #include "xla/python/types.h" #include "xla/shape_util.h" -#include "xla/util.h" #include "xla/xla_data.pb.h" namespace nb = nanobind; @@ -79,9 +78,9 @@ ffi::Error XlaFfiPythonCpuCallback(FfiLoadedHostCallbacks* callbacks, for (size_t i = 0; i < args.size(); ++i) { auto arg = args.get(i); auto ptype = static_cast(arg->element_type()); - // TODO(b/402422886): Remove this once we form Jax arrays directly instead - // of packing/unpacking to/from numpy arrays. - if (ptype == S1 || ptype == U1) { + // TODO(b/395428868): Remove this check once we support subbyte types. + if (ptype == S1 || ptype == S2 || ptype == S4 || ptype == U1 || + ptype == U2 || ptype == U4) { return ffi::Error(ffi::ErrorCode::kUnimplemented, absl::StrFormat("Unsupported primitive type: %s", PrimitiveType_Name(ptype))); @@ -97,18 +96,9 @@ ffi::Error XlaFfiPythonCpuCallback(FfiLoadedHostCallbacks* callbacks, auto dtype = maybe_dtype.value(); auto dims = absl::Span(arg->dimensions().begin(), arg->dimensions().size()); - // TODO(b/402422886): Remove this once we form Jax arrays directly instead - // of packing/unpacking to/from numpy arrays. - std::unique_ptr buffer; - const void* data = arg->untyped_data(); - size_t bits_per_element = xla::primitive_util::BitWidth(ptype); - if (bits_per_element == 2 || bits_per_element == 4) { - buffer = xla::UnpackIntN(bits_per_element, static_cast(data), - arg->size_bytes()); - data = buffer.get(); - } // We pass in data using default numpy layout i.e., std::nullopt. - auto array = nb_numpy_ndarray(dtype, dims, std::nullopt, data); + auto array = + nb_numpy_ndarray(dtype, dims, std::nullopt, arg.value().untyped_data()); array.attr("flags").attr("writeable") = nb::bool_(false); PyTuple_SET_ITEM(nb_args.ptr(), i, array.release().ptr()); } @@ -129,9 +119,9 @@ ffi::Error XlaFfiPythonCpuCallback(FfiLoadedHostCallbacks* callbacks, for (size_t i = 0; i < rets.size(); ++i) { auto ret = rets.get(i).value(); auto ptype = static_cast(ret->element_type()); - // TODO(b/402422886): Remove this once we form Jax arrays directly instead - // of packing/unpacking to/from numpy arrays. - if (ptype == S1 || ptype == U1) { + // TODO(b/395428868): Remove this check once we support subbyte types. + if (ptype == S1 || ptype == S2 || ptype == S4 || ptype == U1 || + ptype == U2 || ptype == U4) { return ffi::Error(ffi::ErrorCode::kUnimplemented, absl::StrFormat("Unsupported primitive type: %s", PrimitiveType_Name(ptype))); @@ -151,45 +141,26 @@ ffi::Error XlaFfiPythonCpuCallback(FfiLoadedHostCallbacks* callbacks, } auto expected_shape = maybe_expected_shape.value(); auto expected_strides = ByteStridesForShape(expected_shape); - - const void* data = array.data(); - if (strides != expected_strides) { - xla::TransposePlan::Options options; - options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); - options.dims = absl::Span( - reinterpret_cast(array.shape()), array.ndim()); - absl::InlinedVector reversed_layout; - reversed_layout.resize(expected_shape.dimensions().size()); - absl::c_reverse_copy(expected_shape.layout().minor_to_major(), - reversed_layout.begin()); - options.permutation = reversed_layout; - options.input_layout = xla::TransposePlan::Striding{strides}; - auto maybe_plan = transpose_cache->cache.GetOrCreate(options); - if (!maybe_plan.ok()) { - return ffi::Error::Internal(maybe_plan.status().ToString()); - } - auto plan = maybe_plan.value(); - plan->Execute(data, ret->untyped_data()); - data = ret->untyped_data(); - } - - // TODO(b/402422886): Remove this once we form Jax arrays directly instead - // of packing/unpacking to/from numpy arrays. - std::unique_ptr buffer; - size_t bits_per_element = xla::primitive_util::BitWidth(ptype); - if (bits_per_element == 2 || bits_per_element == 4) { - // NOTE(dsuo): FFI arguments and return buffers are sized assuming - // minimum 1-byte element sizes, even if the data itself is packed. - buffer = xla::PackIntN(bits_per_element, static_cast(data), - ret->size_bytes()); - data = buffer.get(); + if (strides == expected_strides) { + std::memcpy(ret->untyped_data(), array.data(), ret->size_bytes()); + continue; } - - // Copy data to output buffer if haven't already or modified the data to - // write back. - if (data != ret->untyped_data()) { - std::memcpy(ret->untyped_data(), data, ret->size_bytes()); + xla::TransposePlan::Options options; + options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); + options.dims = absl::Span( + reinterpret_cast(array.shape()), array.ndim()); + absl::InlinedVector reversed_layout; + reversed_layout.resize(expected_shape.dimensions_size()); + absl::c_reverse_copy(expected_shape.layout().minor_to_major(), + reversed_layout.begin()); + options.permutation = reversed_layout; + options.input_layout = xla::TransposePlan::Striding{strides}; + auto maybe_plan = transpose_cache->cache.GetOrCreate(options); + if (!maybe_plan.ok()) { + return ffi::Error::Internal(maybe_plan.status().ToString()); } + auto plan = maybe_plan.value(); + plan->Execute(array.data(), ret->untyped_data()); } return ffi::Error::Success(); diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index 34ab20c05644..a8442b4a1356 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -586,15 +586,10 @@ def fun(x): self.assertAllClose(2 * x, fun(x)) self.assertEqual(count(), 1) - @parameterized.parameters("int2", "int4", "uint2", "uint4", "float4_e2m1fn") + @parameterized.parameters("int2", "int4", "uint2", "uint4") def test_subbyte_operands(self, dtype: str): - if jaxlib_extension_version < 323: - self.skipTest("Requires jaxlib_extension_version >= 323.") - if "2" in dtype and jtu.test_device_matches(["tpu"]): - self.skipTest( - "TODO(dsuo): TPU callbacks send SIGABRT for int2, uint2, and" - " float4_e2m1fn." - ) + if jaxlib_extension_version <= 321: + self.skipTest("Requires jaxlib_extension_version >= 322.") def get(x): return x def f(x): @@ -605,17 +600,21 @@ def f(x): ) return y x = np.arange(8, dtype=dtype) - np.testing.assert_array_equal(jax.jit(f)(x), np.arange(8, dtype=dtype)) + # TODO(b/395428868): Remove this check once we support subbyte types. + if jtu.test_device_matches(["tpu"]): + if "2" in dtype: + self.skipTest("TODO(dsuo): TPU callbacks send SIGABRT for int2/uint2.") + np.testing.assert_array_equal(jax.jit(f)(x), np.arange(8, dtype=dtype)) + else: + with self.assertRaisesRegex( + Exception, "Unsupported primitive type" + ): + _ = jax.jit(f)(x) - @parameterized.parameters("int2", "int4", "uint2", "uint4", "float4_e2m1fn") + @parameterized.parameters("int2", "int4", "uint2", "uint4") def test_subbyte_results(self, dtype: str): - if jaxlib_extension_version < 323: - self.skipTest("Requires jaxlib_extension_version >= 323.") - if "2" in dtype and jtu.test_device_matches(["tpu"]): - self.skipTest( - "TODO(dsuo): TPU callbacks send SIGABRT for int2, uint2, and" - " float4_e2m1fn." - ) + if jaxlib_extension_version <= 321: + self.skipTest("Requires jaxlib_extension_version >= 322.") def get(): return np.arange(8, dtype=dtype) @@ -626,43 +625,16 @@ def f(): ) return y - np.testing.assert_array_equal(jax.jit(f)(), np.arange(8, dtype=dtype)) - - @parameterized.parameters("int2", "int4", "uint2", "uint4", "float4_e2m1fn") - def test_non_default_stride_subbyte_results(self, dtype: str): - if jaxlib_extension_version < 323: - self.skipTest("Requires jaxlib_extension_version >= 323.") - if "2" in dtype and jtu.test_device_matches(["tpu"]): - self.skipTest( - "TODO(dsuo): TPU callbacks send SIGABRT for int2, uint2, and" - " float4_e2m1fn." - ) - x = jnp.arange(24, dtype=dtype).reshape(2, 3, 4) - def callback(x): - return np.asfortranarray(x) - - @jax.jit - def f(x): - return jax.pure_callback( - callback, jax.ShapeDtypeStruct(x.shape, x.dtype), x - ) - - result = f(x) - np.testing.assert_array_equal(x, result) - - def test_non_default_stride(self): - x = jnp.arange(24, dtype=jnp.float32).reshape(2, 3, 4) - def callback(x): - return np.asfortranarray(x) - - @jax.jit - def f(x): - return jax.pure_callback( - callback, jax.ShapeDtypeStruct(x.shape, x.dtype), x - ) - - result = f(x) - np.testing.assert_array_equal(x, result) + # TODO(b/395428868): Remove this check once we support subbyte types. + if jtu.test_device_matches(["tpu"]): + if "2" in dtype: + self.skipTest("TODO(dsuo): TPU callbacks send SIGABRT for int2/uint2.") + np.testing.assert_array_equal(jax.jit(f)(), np.arange(8, dtype=dtype)) + else: + with self.assertRaisesRegex( + Exception, "Unsupported primitive type" + ): + _ = jax.jit(f)() class PureCallbackTest(jtu.JaxTestCase): @@ -1136,6 +1108,20 @@ def fun(x): result += fun(jnp.ones((500, 500), jnp.complex64))[1] jax.block_until_ready(result) # doesn't deadlock + def test_non_default_stride(self): + x = jnp.arange(24, dtype=jnp.float32).reshape(2, 3, 4) + def callback(x): + return np.asfortranarray(x) + + @jax.jit + def f(x): + return jax.pure_callback( + callback, jax.ShapeDtypeStruct(x.shape, x.dtype), x + ) + + result = f(x) + np.testing.assert_array_equal(x, result) + class IOCallbackTest(jtu.JaxTestCase): From 7b04a79fbdc0fe7b75e44a77cae8ed7a003a6821 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 1 Apr 2025 11:25:32 -0700 Subject: [PATCH 315/483] jnp.einsum: add support for __jax_array__ --- jax/_src/numpy/einsum.py | 4 ++++ tests/array_extensibility_test.py | 23 +++++++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/jax/_src/numpy/einsum.py b/jax/_src/numpy/einsum.py index 9d745643b596..21333a9e7a0d 100644 --- a/jax/_src/numpy/einsum.py +++ b/jax/_src/numpy/einsum.py @@ -288,6 +288,10 @@ def einsum( spec = operands[0] if isinstance(operands[0], str) else None path_type = 'optimal' if optimize is True else Unoptimized() if optimize is False else optimize + # Extract __jax_array__ before passing to contract_path() + operands = tuple(op.__jax_array__() if hasattr(op, "__jax_array__") else op + for op in operands) + # Allow handling of shape polymorphism non_constant_dim_types = { type(d) for op in operands if not isinstance(op, str) diff --git a/tests/array_extensibility_test.py b/tests/array_extensibility_test.py index f62491a608c7..69e9e1609f86 100644 --- a/tests/array_extensibility_test.py +++ b/tests/array_extensibility_test.py @@ -548,6 +548,29 @@ def test_array_creation_from_duck_typed_array(self, func): self.assertEqual(result.shape, obj.shape) self.assertEqual(result.dtype, obj.dtype) + @parameterized.named_parameters( + {"testcase_name": "subscript-form", "args": ("jk,k->j", Float[5, 3], Float[3])}, + {"testcase_name": "index-form", "args": (Float[5, 3], (0, 1), Float[3], (1,), (0,))}, + ) + def test_einsum(self, args): + rng = jtu.rand_default(self.rng()) + def make_arg(arg): + if isinstance(arg, jax.ShapeDtypeStruct): + return rng(arg.shape, arg.dtype) + return arg + args = jax.tree.map(make_arg, args) + + def wrap_array(arg): + if isinstance(arg, (jax.Array, np.ndarray)): + return JaxArrayWrapper(arg) + return arg + wrapped_args = jax.tree.map(wrap_array, args) + + expected = jnp.einsum(*args) + actual = jnp.einsum(*wrapped_args) + + self.assertAllClose(actual, expected, atol=0, rtol=0) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 4908b2f167a78783e95c1d677a849d0a98a97dc4 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 1 Apr 2025 10:05:42 -0700 Subject: [PATCH 316/483] cumulative reductions: support __jax_array__ on inputs --- jax/_src/numpy/reductions.py | 10 ++++------ tests/array_extensibility_test.py | 8 ++++---- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 985b296bc06f..96b2782edc13 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -30,7 +30,7 @@ from jax._src import deprecations from jax._src import dtypes from jax._src.numpy.util import ( - _broadcast_to, check_arraylike, _complex_elem_type, + _broadcast_to, check_arraylike, _complex_elem_type, ensure_arraylike, promote_dtypes_inexact, promote_dtypes_numeric, _where) from jax._src.lax import lax as lax_internal from jax._src.typing import Array, ArrayLike, DType, DTypeLike, DeprecatedArg @@ -1992,7 +1992,7 @@ def _cumulative_reduction( fill_nan: bool = False, fill_value: ArrayLike = 0, promote_integers: bool = False) -> Array: """Helper function for implementing cumulative reductions.""" - check_arraylike(name, a) + a = ensure_arraylike(name, a) if out is not None: raise NotImplementedError(f"The 'out' argument to jnp.{name} is not supported") dtypes.check_user_dtype_supported(dtype, name) @@ -2242,8 +2242,7 @@ def cumulative_sum( Array([[ 0, 1, 3, 6], [ 0, 4, 9, 15]], dtype=int32) """ - check_arraylike("cumulative_sum", x) - x = lax_internal.asarray(x) + x = ensure_arraylike("cumulative_sum", x) if x.ndim == 0: raise ValueError( "The input must be non-scalar to take a cumulative sum, however a " @@ -2304,8 +2303,7 @@ def cumulative_prod( Array([[ 1, 1, 2, 6], [ 1, 4, 20, 120]], dtype=int32) """ - check_arraylike("cumulative_prod", x) - x = lax_internal.asarray(x) + x = ensure_arraylike("cumulative_prod", x) if x.ndim == 0: raise ValueError( "The input must be non-scalar to take a cumulative product, however a " diff --git a/tests/array_extensibility_test.py b/tests/array_extensibility_test.py index 69e9e1609f86..8f5ea33b5894 100644 --- a/tests/array_extensibility_test.py +++ b/tests/array_extensibility_test.py @@ -283,10 +283,10 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: # NumPyAPI.sig(np.count_nonzero, [float], [(10,)]), # NumPyAPI.sig(np.cov, [float], [(10,)]), # NumPyAPI.sig(np.cross, [float, float], [(3,), (3,)]), - # NumPyAPI.sig(np.cumprod, [float], [(10,)]), - # NumPyAPI.sig(np.cumsum, [float], [(10,)]), - # NumPyAPI.sig(np.cumulative_prod, [float], [(10,)]), - # NumPyAPI.sig(np.cumulative_sum, [float], [(10,)]), + NumPyAPI.sig(jnp.cumprod, Float[5]), + NumPyAPI.sig(jnp.cumsum, Float[5]), + NumPyAPI.sig(jnp.cumulative_prod, Float[5]), + NumPyAPI.sig(jnp.cumulative_sum, Float[5]), NumPyAPI.sig(jnp.deg2rad, Float[5]), NumPyAPI.sig(jnp.degrees, Float[5]), # NumPyAPI.sig(jnp.delete, Float[5], Int[()]), From 05269a8ec90a1e14f89d514a0f4b228525bf906c Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Tue, 1 Apr 2025 20:18:32 +0000 Subject: [PATCH 317/483] [mutable-arrays] add vmap rule for mutable_array_p, very basic test --- jax/_src/interpreters/batching.py | 5 +++++ tests/mutable_array_test.py | 10 ++++++++++ 2 files changed, 15 insertions(+) diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index 03c9a95105d7..a187d42511ac 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -1169,3 +1169,8 @@ def add_batched(batched_args, batch_dims): x = moveaxis(x, bdx, bdy) return add_jaxvals(x, y), bdy primitive_batchers[add_jaxvals_p] = add_batched + + +### mutable arrays + +defvectorized(core.mutable_array_p) diff --git a/tests/mutable_array_test.py b/tests/mutable_array_test.py index 950bddf544d7..a51e1d7841ce 100644 --- a/tests/mutable_array_test.py +++ b/tests/mutable_array_test.py @@ -239,6 +239,16 @@ def f(x_ref): x_ref = core.mutable_array(x) y = f(x_ref) + def test_vmap_basic(self): + @jax.vmap + def f(x): + x_ref = core.mutable_array(x) + x_ref[...] = x_ref[...] * x_ref[...] + return x_ref[...] + xs = jnp.arange(4.) + ys = f(xs) + self.assertAllClose(ys, xs ** 2, check_dtypes=False) + @jtu.with_config(jax_mutable_array_checks=True) class MutableArrayErrorsTest(jtu.JaxTestCase): From ff5a2e8c91c3e32db6a547326d1356023226f83c Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 1 Apr 2025 14:25:40 -0700 Subject: [PATCH 318/483] Enable test_scan_offload in memories_test. PiperOrigin-RevId: 742840628 --- tests/memories_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/memories_test.py b/tests/memories_test.py index 64ee2829873d..570b0c375834 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -1090,7 +1090,6 @@ def f_bwd(res, tx): self.assertArraysEqual(g(arr), all_true) def test_scan_offload(self): - self.skipTest('b/406586554') np_inp = jnp.arange(4096).reshape(16, 16, 16) @jax.jit From f13919220118aacd7eea9d7794a6217eea83066d Mon Sep 17 00:00:00 2001 From: Ayaka Date: Tue, 1 Apr 2025 19:17:02 -0700 Subject: [PATCH 319/483] Add OOB checks to jax.numpy array indexing PiperOrigin-RevId: 742927160 --- jax/_src/numpy/error.py | 53 ++++++++++++++++++++++++++++++++++- jax/_src/numpy/indexing.py | 17 +++++++---- tests/jax_numpy_error_test.py | 36 ++++++++++++++++++++++++ 3 files changed, 99 insertions(+), 7 deletions(-) diff --git a/jax/_src/numpy/error.py b/jax/_src/numpy/error.py index 20dab289d779..e2c23b43bdf8 100644 --- a/jax/_src/numpy/error.py +++ b/jax/_src/numpy/error.py @@ -13,10 +13,11 @@ # limitations under the License. import contextlib -from typing import Literal +from typing import Literal, Sequence import jax from jax._src import config +from jax._src.typing import ArrayLike Category = Literal["nan", "divide", "oob"] @@ -102,6 +103,56 @@ def _set_error_if_divide_by_zero(pred: jax.Array, /): error_check_lib.set_error_if(pred == zero, "Division by zero encountered") +def _check_precondition_oob_gather( + shape: tuple[int, ...], gather_indices: ArrayLike +) -> None: + """Check for out of bounds errors before calling `lax.gather`.""" + if config.error_checking_behavior_oob.value == "ignore": + return + + # TODO(mattjj): fix the circular import issue. + from jax._src import error_check as error_check_lib + import jax.numpy as jnp + + shape = jnp.array(shape, dtype=jnp.int32) + error_check_lib.set_error_if( + jnp.logical_or( + jnp.min(gather_indices) < -shape, + jnp.max(gather_indices) >= shape, + ), + "Out of bounds encountered before calling `lax.gather`", + ) + + +def _check_precondition_oob_dynamic_slice( + shape: tuple[int, ...], + start_indices: Sequence[ArrayLike], + slice_sizes: list[int], + allow_negative_indices: list[bool], +) -> None: + """Check for out of bounds errors before calling `lax.dynamic_slice`.""" + if config.error_checking_behavior_oob.value == "ignore": + return + + # TODO(mattjj): fix the circular import issue. + from jax._src import error_check as error_check_lib + import jax.numpy as jnp + + shape = jnp.array(shape, dtype=jnp.int32) + start_indices = jnp.array(start_indices, dtype=jnp.int32) + slice_sizes = jnp.array(slice_sizes, dtype=jnp.int32) + allow_negative_indices = jnp.array(allow_negative_indices, dtype=jnp.bool_) + + lower_bound = jnp.where(allow_negative_indices, -shape, 0) + error_check_lib.set_error_if( + jnp.logical_or( + jnp.minimum(start_indices, start_indices + slice_sizes) < lower_bound, + jnp.maximum(start_indices, start_indices + slice_sizes) >= shape, + ), + "Out of bounds encountered before calling `lax.dynamic_slice`", + ) + + Behavior = Literal["ignore", "raise"] diff --git a/jax/_src/numpy/indexing.py b/jax/_src/numpy/indexing.py index 5d59bb53b457..863f0c775ec6 100644 --- a/jax/_src/numpy/indexing.py +++ b/jax/_src/numpy/indexing.py @@ -20,8 +20,6 @@ import string from typing import Any, NamedTuple, Sequence -import numpy as np - import jax from jax import lax from jax._src import array @@ -30,17 +28,19 @@ from jax._src import dispatch from jax._src import dtypes from jax._src import errors +from jax._src import mesh as mesh_lib from jax._src.api import jit from jax._src.lax import lax as lax_internal from jax._src.numpy import einsum -from jax._src import mesh as mesh_lib -from jax._src.pjit import auto_axes +from jax._src.numpy import error as jnp_error from jax._src.numpy import lax_numpy from jax._src.numpy import ufuncs from jax._src.numpy import util +from jax._src.pjit import auto_axes from jax._src.tree_util import tree_flatten from jax._src.typing import Array, ArrayLike, StaticScalar -from jax._src.util import canonicalize_axis, set_module, tuple_replace, safe_zip +from jax._src.util import canonicalize_axis, safe_zip, set_module, tuple_replace +import numpy as np export = set_module('jax.numpy') @@ -570,7 +570,7 @@ def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: str | None) -> idx += (arr.ndim - len(idx)) * (slice(None),) start_indices: Sequence[ArrayLike] = [] - slice_sizes: Sequence[int] = [] + slice_sizes: list[int] = [] allow_negative_indices: list[bool] = [] for ind, size in safe_zip(idx, arr.shape): @@ -587,6 +587,7 @@ def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: str | None) -> slice_sizes.append(1) allow_negative_indices.append( not isinstance(ind, (int, np.integer)) or bool(ind < 0)) + # Try to use static slicing when possible. if all(isinstance(i, (int, np.integer)) and i >= 0 for i in start_indices): int_start_indices = [int(i) for i in start_indices] # type: ignore @@ -598,6 +599,9 @@ def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: str | None) -> # start indices to have matching types. if len(start_indices) > 1: start_indices = util.promote_dtypes(*start_indices) + jnp_error._check_precondition_oob_dynamic_slice( + arr.shape, start_indices, slice_sizes, allow_negative_indices + ) arr = lax.dynamic_slice( arr, start_indices=start_indices, slice_sizes=slice_sizes, allow_negative_indices=allow_negative_indices) @@ -640,6 +644,7 @@ def _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, unique_indices, mode, fill_value, out_sharding): idx = merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx) indexer = index_to_gather(np.shape(arr), idx) # shared with _scatter_update + jnp_error._check_precondition_oob_gather(arr.shape, indexer.gather_indices) y = arr if fill_value is not None: diff --git a/tests/jax_numpy_error_test.py b/tests/jax_numpy_error_test.py index f2262d8b5dc0..a38e7d5509f9 100644 --- a/tests/jax_numpy_error_test.py +++ b/tests/jax_numpy_error_test.py @@ -231,6 +231,42 @@ def test_can_raise_divide_by_zero_error(self, jit, div_func, dtype): with self.assertRaisesRegex(JaxValueError, "Division by zero"): error_check.raise_if_error() + @parameterized.product(jit=[True, False]) + def test_can_raise_oob_error_take(self, jit): + def f(x, a): + return x[a] + + if jit: + f = jax.jit(f) + + x = jnp.arange(10) + a = jnp.int32(10) + + with jnp_error.error_checking_behavior(oob="ignore"): + f(x, a) + error_check.raise_if_error() # should not raise error + + with jnp_error.error_checking_behavior(oob="raise"): + f(x, a) + with self.assertRaisesRegex(JaxValueError, "Out of bounds"): + error_check.raise_if_error() + + def test_can_raise_oob_error_dynamic_slice(self): + def f(x, a): + return x[:, a:a+4] # dynamic indices are non-jittable + + x = jnp.arange(10).reshape(2, 5) + a = jnp.array(3, dtype=jnp.int32) + + with jnp_error.error_checking_behavior(oob="ignore"): + f(x, a) + error_check.raise_if_error() # should not raise error + + with jnp_error.error_checking_behavior(oob="raise"): + f(x, a) + with self.assertRaisesRegex(JaxValueError, "Out of bounds"): + error_check.raise_if_error() + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 1875c76bd2944f64967e9c9b7989233502d8da95 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Tue, 1 Apr 2025 19:10:39 -0700 Subject: [PATCH 320/483] let XLA metadata be unset in nested dynamic scopes Treat `None` metadata values as a special instruction not to set (or to unset, if nested) the corresponding entry. In particular, this makes it possible to unset metadata within the sub-computations of higher-order operations (e.g. branches in conditionals, loop bodies, etc.). This can be used, for example, to annotate a conditional but not all the operations in its branches. That is, the HLO for the following function `f` on a scalar float argument: ``` def cos(x): with set_xla_metadata(a=None): return jnp.cos(x) @jax.jit def f(x): with set_xla_metadata(a="b"): return jax.lax.cond(x < 0., jnp.sin, cos, x) ``` produces an attribute `a` on the conditional and on the sine, but not on the cosine. --- jax/_src/core.py | 2 +- jax/_src/xla_metadata.py | 13 ++++++++++--- tests/xla_metadata_test.py | 33 +++++++++++++++++++++++++++++++++ 3 files changed, 44 insertions(+), 4 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index ee6537650f20..bb23a540e526 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -320,7 +320,7 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): config.compute_on_context_manager.set_local(self.prev_compute_type) config.threefry_partitionable.set_local(self.prev_threefry_partitionable) - if self.context.xla_metadata is not None: + if self.context.xla_metadata: config.xla_metadata_context_manager.set_local(self.prev_xla_metadata) config.abstract_mesh_context_manager.set_local(self.prev_abstract_mesh) diff --git a/jax/_src/xla_metadata.py b/jax/_src/xla_metadata.py index 91895b4e7851..77c0e2ff9910 100644 --- a/jax/_src/xla_metadata.py +++ b/jax/_src/xla_metadata.py @@ -24,6 +24,8 @@ class XlaMetadata: __slots__ = ['val', 'hash'] + val: dict[str, Any] + def __init__(self, val): self.val = val self.hash = hash(tuple(sorted(self.val.items()))) @@ -35,14 +37,19 @@ def __eq__(self, other): return other is not None and self.val == other.val +def filter_nones(d: dict) -> dict: + return {k: v for k, v in d.items() if v is not None} + + def update_metadata(a, b: dict[str, Any]): if not b: return a if a is None or a is config_ext.unset: - return XlaMetadata(b) - val = a.val.copy() + val = {} + else: + val = a.val.copy() val.update(b) - return XlaMetadata(val) + return XlaMetadata(filter_nones(val)) def current_xla_metadata(): diff --git a/tests/xla_metadata_test.py b/tests/xla_metadata_test.py index d141bc15c249..33fd7a08b1de 100644 --- a/tests/xla_metadata_test.py +++ b/tests/xla_metadata_test.py @@ -190,6 +190,39 @@ def while_fn(a): if "stablehlo.add" in line: self.assertIn('mhlo.frontend_attributes = {a = "c"}', line) + def test_cond_annotates_branches(self): + sin = jnp.sin + cos = jnp.cos + + @jax.jit + def f(x): + with set_xla_metadata(a="b"): + return jax.lax.cond(x < 0., sin, cos, x) + + hlo_lines = f.lower(1.).as_text().split("\n") + sin_hlo, = [line for line in hlo_lines if "stablehlo.sine" in line] + cos_hlo, = [line for line in hlo_lines if "stablehlo.cosine" in line] + self.assertIn('mhlo.frontend_attributes = {a = "b"}', sin_hlo) + self.assertIn('mhlo.frontend_attributes = {a = "b"}', cos_hlo) + + def test_cond_annotates_branches_and_none_unsets(self): + sin = jnp.sin + + def cos(x): + with set_xla_metadata(a=None): + return jnp.cos(x) + + @jax.jit + def f(x): + with set_xla_metadata(a="b"): + return jax.lax.cond(x < 0., sin, cos, x) + + hlo_lines = f.lower(1.).as_text().split("\n") + sin_hlo, = [line for line in hlo_lines if "stablehlo.sine" in line] + cos_hlo, = [line for line in hlo_lines if "stablehlo.cosine" in line] + self.assertIn( 'mhlo.frontend_attributes = {a = "b"}', sin_hlo) + self.assertNotIn('mhlo.frontend_attributes = {a = "b"}', cos_hlo) + def test_nested_jit(self): @jax.jit def f(x, y): From 6fe6d8050663358a3a4447e4022efe012285f840 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Tue, 1 Apr 2025 22:17:41 -0700 Subject: [PATCH 321/483] upgrade docs from `jax.core` to `jax.extend.core` where needed to fix doc build --- docs/jax-primitives.md | 4 ++-- docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb | 2 +- docs/notebooks/Writing_custom_interpreters_in_Jax.md | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/jax-primitives.md b/docs/jax-primitives.md index abdc8be6d0a8..38a45ef4823e 100644 --- a/docs/jax-primitives.md +++ b/docs/jax-primitives.md @@ -21,7 +21,7 @@ kernelspec: A JAX primitive is the basic computational unit of a JAX program. This document explains the interface that a JAX primitive must support to allow JAX to perform all its transformations (this is not a how-to guide). -For example, the multiply-add operation can be implemented in terms of the low-level `jax.lax.*` primitives (which are like XLA operator wrappers) or `jax.core.Primitive("multiply_add")`, as demonstrated further below. +For example, the multiply-add operation can be implemented in terms of the low-level `jax.lax.*` primitives (which are like XLA operator wrappers) or `jax.extend.core.Primitive("multiply_add")`, as demonstrated further below. And JAX is able to take sequences of such primitive operations, and transform them via its composable transformations of Python functions, such as {func}`jax.jit`, {func}`jax.grad` and {func}`jax.vmap`. JAX implements these transforms in a *JAX-traceable* way. This means that when a Python function is executed, the only operations it applies to the data are either: @@ -171,7 +171,7 @@ The JAX traceability property is satisfied as long as the function is written in The right way to add support for multiply-add is in terms of existing JAX primitives, as shown above. However, to demonstrate how JAX primitives work, pretend that you want to add a new primitive to JAX for the multiply-add functionality. ```{code-cell} -from jax import core +from jax.extend import core multiply_add_p = core.Primitive("multiply_add") # Create the primitive diff --git a/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb b/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb index 00ba9186eeec..56b2d80fc58e 100644 --- a/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb +++ b/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb @@ -215,8 +215,8 @@ "# Importing Jax functions useful for tracing/interpreting.\n", "from functools import wraps\n", "\n", - "from jax import core\n", "from jax import lax\n", + "from jax.extend import core\n", "from jax._src.util import safe_map" ] }, diff --git a/docs/notebooks/Writing_custom_interpreters_in_Jax.md b/docs/notebooks/Writing_custom_interpreters_in_Jax.md index 10c4e7cb6e3b..6b993a630e93 100644 --- a/docs/notebooks/Writing_custom_interpreters_in_Jax.md +++ b/docs/notebooks/Writing_custom_interpreters_in_Jax.md @@ -147,8 +147,8 @@ Let's use `make_jaxpr` to trace a function into a Jaxpr. # Importing Jax functions useful for tracing/interpreting. from functools import wraps -from jax import core from jax import lax +from jax.extend import core from jax._src.util import safe_map ``` From 8e2c1a18c7676a2b481c5a41128ddf191793831b Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 16 Jan 2025 21:52:37 +0000 Subject: [PATCH 322/483] Updates for 3.14 Added tsan ci cpython 3.14 job --- ...essions.txt => tsan-suppressions_3.13.txt} | 4 + .github/workflows/tsan-suppressions_3.14.txt | 26 +++ .github/workflows/tsan.yaml | 166 +++++++++++++++--- WORKSPACE | 1 + build/build.py | 2 + build/requirements_lock_3_13_ft.txt | 2 +- build/requirements_lock_3_14_ft.txt | 107 +++++++++++ build/tools/utils.py | 14 ++ jaxlib/jax.bzl | 4 +- 9 files changed, 296 insertions(+), 30 deletions(-) rename .github/workflows/{tsan-suppressions.txt => tsan-suppressions_3.13.txt} (93%) create mode 100644 .github/workflows/tsan-suppressions_3.14.txt create mode 100644 build/requirements_lock_3_14_ft.txt diff --git a/.github/workflows/tsan-suppressions.txt b/.github/workflows/tsan-suppressions_3.13.txt similarity index 93% rename from .github/workflows/tsan-suppressions.txt rename to .github/workflows/tsan-suppressions_3.13.txt index bdffddc58ca0..833fa856a7d6 100644 --- a/.github/workflows/tsan-suppressions.txt +++ b/.github/workflows/tsan-suppressions_3.13.txt @@ -21,6 +21,10 @@ race:_PyUnicode_InternImmortal # Fixed in Python 3.14, but not backported to 3.13. race_top:PyMember_GetOne +# https://github.com/python/cpython/issues/131680 +# Fixed in Python 3.14, but not backported to 3.13. +race_top: new_reference + # https://github.com/python/cpython/issues/129748 race:mi_block_set_nextx diff --git a/.github/workflows/tsan-suppressions_3.14.txt b/.github/workflows/tsan-suppressions_3.14.txt new file mode 100644 index 000000000000..9cfc68e1ae36 --- /dev/null +++ b/.github/workflows/tsan-suppressions_3.14.txt @@ -0,0 +1,26 @@ +# false-positive caused because we haven't tsan-instrumented libgcc_s. Multiple threads +# are racing on a call to __register_frame_info(), but that function appears to be correctly locked internally. +race:llvm::RuntimeDyldELF::registerEHFrames + +# https://github.com/openxla/xla/issues/20686 +race:dnnl_sgemm + +# https://github.com/python/cpython/issues/128050 +race:partial_vectorcall_fallback + +# Likely only happens when the process is crashing. +race:dump_traceback + +# https://github.com/python/cpython/issues/129748 +race:mi_block_set_nextx + +# https://github.com/python/cpython/issues/128130 +race_top:run_eval_code_obj + +# Races because the LAPACK and BLAS in our scipy isn't TSAN instrumented. +race:heevd_ffi +race:gesdd_ffi +race:dscal_k_ +race:scal_k_ +race:gemm_beta +race:gemm_oncopy diff --git a/.github/workflows/tsan.yaml b/.github/workflows/tsan.yaml index cd59c0bf45e0..4c28608a8257 100644 --- a/.github/workflows/tsan.yaml +++ b/.github/workflows/tsan.yaml @@ -22,6 +22,16 @@ jobs: image: index.docker.io/library/ubuntu@sha256:b359f1067efa76f37863778f7b6d0e8d911e3ee8efa807ad01fbf5dc1ef9006b # ratchet:ubuntu:24.04 strategy: fail-fast: false + matrix: + include: + - name-prefix: "with 3.13" + python-version: "3.13" + github_branch: "3.13" + requirements_lock_name: "requirements_lock_3_13_ft" + - name-prefix: "with 3.14" + python-version: "3.14" + github_branch: "main" + requirements_lock_name: "requirements_lock_3_14_ft" defaults: run: shell: bash -l {0} @@ -44,22 +54,33 @@ jobs: with: repository: python/cpython path: cpython - ref: "3.13" + ref: ${{ matrix.github_branch }} - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: repository: numpy/numpy path: numpy submodules: true + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + if: ${{ matrix.python-version == '3.14' }} + with: + repository: scipy/scipy + path: scipy + submodules: true - - name: Restore cached TSAN CPython + - name: Get year & week number + id: get-date + run: echo "date=$(/bin/date "+%Y-%U")" >> $GITHUB_OUTPUT + shell: bash -l {0} + + - name: Restore cached TSAN CPython ${{ matrix.python-version }} id: cache-cpython-tsan-restore uses: actions/cache/restore@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 with: path: | ./python-tsan.tgz - key: ${{ runner.os }}-cpython-tsan-${{ hashFiles('cpython/configure.ac') }} + key: ${{ runner.os }}-cpython-tsan-${{ matrix.python-version }}-${{ steps.get-date.outputs.date }} - - name: Build CPython with enabled TSAN + - name: Build TSAN CPython ${{ matrix.python-version }} if: steps.cache-cpython-tsan-restore.outputs.cache-hit != 'true' run: | cd cpython @@ -73,19 +94,14 @@ jobs: # Create archive to be used with bazel as hermetic python: cd ${GITHUB_WORKSPACE} && tar -czpf python-tsan.tgz cpython-tsan - - name: Save TSAN CPython + - name: Save TSAN CPython ${{ matrix.python-version }} id: cache-cpython-tsan-save if: steps.cache-cpython-tsan-restore.outputs.cache-hit != 'true' uses: actions/cache/save@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 with: path: | ./python-tsan.tgz - key: ${{ runner.os }}-cpython-tsan-${{ hashFiles('cpython/configure.ac') }} - - - name: Get year & week number - id: get-date - run: echo "date=$(/bin/date "+%Y-%U")" >> $GITHUB_OUTPUT - shell: bash -l {0} + key: ${{ runner.os }}-cpython-tsan-${{ matrix.python-version }}-${{ steps.get-date.outputs.date }} - name: Restore cached TSAN Numpy id: cache-numpy-tsan-restore @@ -93,7 +109,7 @@ jobs: with: path: | ./wheelhouse - key: ${{ runner.os }}-numpy-tsan-${{ hashFiles('numpy/pyproject.toml') }}-${{ steps.get-date.outputs.date }} + key: ${{ runner.os }}-numpy-tsan-${{ matrix.python-version }}-${{ hashFiles('numpy/pyproject.toml') }}-${{ steps.get-date.outputs.date }} - name: Build TSAN Numpy wheel if: steps.cache-numpy-tsan-restore.outputs.cache-hit != 'true' @@ -114,7 +130,8 @@ jobs: python3 -m pip install uv~=0.5.30 # Make sure to install a compatible Cython version (master branch is best for now) - python3 -m uv pip install -r requirements/build_requirements.txt -U git+https://github.com/cython/cython + NO_CYTHON_COMPILE=true python3 -m uv pip install -U git+https://github.com/cython/cython + python3 -m uv pip install -r requirements/build_requirements.txt CC=clang-18 CXX=clang++-18 python3 -m pip wheel --wheel-dir dist -v . --no-build-isolation -Csetup-args=-Db_sanitize=thread -Csetup-args=-Dbuildtype=debugoptimized @@ -147,7 +164,83 @@ jobs: with: path: | ./wheelhouse - key: ${{ runner.os }}-numpy-tsan-${{ hashFiles('numpy/pyproject.toml') }}-${{ steps.get-date.outputs.date }} + key: ${{ runner.os }}-numpy-tsan-${{ matrix.python-version }}-${{ hashFiles('numpy/pyproject.toml') }}-${{ steps.get-date.outputs.date }} + + - name: Restore cached Scipy + if: ${{ matrix.python-version == '3.14' }} + id: cache-scipy-restore + uses: actions/cache/restore@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 + with: + path: | + ./wheelhouse + key: ${{ runner.os }}-scipy-${{ matrix.python-version }}-${{ hashFiles('scipy/pyproject.toml') }}-${{ steps.get-date.outputs.date }} + + - name: Build Scipy wheel + if: ${{ steps.cache-scipy-restore.outputs.cache-hit != 'true' && matrix.python-version == '3.14' }} + run: | + # Install scipy dependencies: + apt-get install -y gfortran libopenblas-dev liblapack-dev pkg-config --no-install-recommends + + cd scipy + + # If we restored cpython from cache, we need to get python interpreter from python-tsan.tgz + if [ ! -d ${GITHUB_WORKSPACE}/cpython-tsan/bin/ ]; then + echo "Extract cpython from python-tsan.tgz" + pushd . + ls ${GITHUB_WORKSPACE}/python-tsan.tgz + cd ${GITHUB_WORKSPACE} && tar -xzf python-tsan.tgz + ls ${GITHUB_WORKSPACE}/cpython-tsan/bin/ + popd + fi + + export PATH=${GITHUB_WORKSPACE}/cpython-tsan/bin/:$PATH + + python3 -m pip install uv~=0.5.30 + # Make sure to install a compatible Cython version (master branch is best for now) + NO_CYTHON_COMPILE=true python3 -m uv pip install -U git+https://github.com/cython/cython + python3 -m uv pip install -U --pre numpy --extra-index-url file://${GITHUB_WORKSPACE}/wheelhouse/ + python3 -m uv pip install pythran pybind11 meson-python ninja + + python3 -m uv pip list | grep -E "(numpy|pythran|cython|pybind11)" + + export CC=clang-18 + export CXX=clang++-18 + python3 -m pip wheel --wheel-dir dist -vvv . --no-build-isolation --no-deps -Csetup-args=-Dbuildtype=debugoptimized + + python3 -m uv pip list | grep -E "(numpy|pythran|cython|pybind11)" + + # Create simple index and copy the wheel + mkdir -p ${GITHUB_WORKSPACE}/wheelhouse/scipy + + scipy_whl_name=($(cd dist && ls scipy*.whl)) + if [ -z "${scipy_whl_name}" ]; then exit 1; fi + + echo "Built TSAN Scipy wheel: ${scipy_whl_name}" + + cp dist/${scipy_whl_name} ${GITHUB_WORKSPACE}/wheelhouse/scipy + + # Recreate wheelhouse index with Numpy and Scipy + cat << EOF > ${GITHUB_WORKSPACE}/wheelhouse/index.html + + numpy>
+ scipy>
+ + EOF + + cat << EOF > ${GITHUB_WORKSPACE}/wheelhouse/scipy/index.html + + ${scipy_whl_name}
+ + EOF + + - name: Save Scipy wheel + id: cache-scipy-save + if: ${{ steps.cache-scipy-restore.outputs.cache-hit != 'true' && matrix.python-version == '3.14' }} + uses: actions/cache/save@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 + with: + path: | + ./wheelhouse + key: ${{ runner.os }}-scipy-${{ matrix.python-version }}-${{ hashFiles('scipy/pyproject.toml') }}-${{ steps.get-date.outputs.date }} - name: Build Jax and run tests timeout-minutes: 120 @@ -164,7 +257,7 @@ jobs: python3 -VV python3 build/build.py build --configure_only \ - --python_version=3.13-ft \ + --python_version=${{ matrix.python-version }}-ft \ --bazel_options=--repo_env=HERMETIC_PYTHON_URL="file://${GITHUB_WORKSPACE}/python-tsan.tgz" \ --bazel_options=--repo_env=HERMETIC_PYTHON_SHA256=${PYTHON_SHA256} \ --bazel_options=--repo_env=HERMETIC_PYTHON_PREFIX="cpython-tsan/" \ @@ -174,18 +267,32 @@ jobs: --bazel_options=--copt=-g \ --clang_path=/usr/bin/clang-18 - # Patch build/requirements_lock_3_13_ft.txt to use TSAN instrumented NumPy - sed -i "s|+--extra-index-url.*|+--extra-index-url file://${GITHUB_WORKSPACE}/wheelhouse/|" .github/workflows/requirements_lock_3_13_ft.patch - cat .github/workflows/requirements_lock_3_13_ft.patch - git apply .github/workflows/requirements_lock_3_13_ft.patch || exit 1 + if [ "${{ matrix.python-version }}" == "3.13" ]; then + # Patch build/requirements_lock_3_13_ft.txt to use TSAN instrumented NumPy - # Display the content for debugging in logs - cat build/requirements_lock_3_13_ft.txt | head -15 - # Check the patch - cat build/requirements_lock_3_13_ft.txt | head -15 | grep -E "(--pre|.*${GITHUB_WORKSPACE}/wheelhouse/|numpy)" - if [ "$?" == "1" ]; then echo "Could not find the patch in the requirements_lock_3_13_ft.txt"; exit 1; fi - cat build/requirements_lock_3_13_ft.txt | grep -E "(numpy==)" - if [ "$?" == "0" ]; then "Found original numpy dependency in the requirements_lock_3_13_ft.txt"; exit 1; fi + sed -i "s|+--extra-index-url.*|+--extra-index-url file://${GITHUB_WORKSPACE}/wheelhouse/|" .github/workflows/${{ matrix.requirements_lock_name }}.patch + cat .github/workflows/${{ matrix.requirements_lock_name }}.patch + git apply .github/workflows/${{ matrix.requirements_lock_name }}.patch || exit 1 + + # Display the content for debugging in logs + cat build/${{ matrix.requirements_lock_name }}.txt | head -15 + # Check the patch + cat build/${{ matrix.requirements_lock_name }}.txt | head -15 | grep -E "(--pre|.*${GITHUB_WORKSPACE}/wheelhouse/|numpy)" + if [ "$?" == "1" ]; then echo "Could not find the patch in the ${{ matrix.requirements_lock_name }}.txt"; exit 1; fi + cat build/${{ matrix.requirements_lock_name }}.txt | grep -E "(numpy==)" + if [ "$?" == "0" ]; then "Found original numpy dependency in the ${{ matrix.requirements_lock_name }}.txt"; exit 1; fi + + else + # Patch build/requirements_lock_3_14_ft.txt to use TSAN instrumented NumPy and Scipy + + sed -i "s|--extra-index-url.*|--extra-index-url file://${GITHUB_WORKSPACE}/wheelhouse/|" build/${{ matrix.requirements_lock_name }}.txt + + # We should install jpeg dev package to be able to build Pillow from source: + apt-get install -y libjpeg-dev --no-install-recommends + + # Install scipy runtime dependencies (in case we restore scipy wheel from cache): + apt-get install -y libopenblas-dev liblapack-dev --no-install-recommends + fi echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES" echo "JAX_ENABLE_X64=$JAX_ENABLE_X64" @@ -201,13 +308,18 @@ jobs: # Check numpy version ./bazel cquery @pypi_numpy//:* | grep whl + if [ "${{ matrix.python-version }}" == "3.14" ]; then + # Check scipy version + ./bazel cquery @pypi_scipy//:* | grep whl + fi + # Build JAX and run tests ./bazel test \ --test_env=JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES \ --test_env=JAX_ENABLE_X64=$JAX_ENABLE_X64 \ --test_env=JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS \ --test_env=PYTHON_GIL=0 \ - --test_env=TSAN_OPTIONS=halt_on_error=1,suppressions=$PWD/.github/workflows/tsan-suppressions.txt \ + --test_env=TSAN_OPTIONS=halt_on_error=1,suppressions=$PWD/.github/workflows/tsan-suppressions_${{ matrix.python-version }}.txt \ --test_env=JAX_TEST_NUM_THREADS=8 \ --test_output=errors \ --local_test_jobs=32 \ diff --git a/WORKSPACE b/WORKSPACE index a6968446a1ec..5c093ec2228f 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -14,6 +14,7 @@ python_init_repositories( "3.12": "//build:requirements_lock_3_12.txt", "3.13": "//build:requirements_lock_3_13.txt", "3.13-ft": "//build:requirements_lock_3_13_ft.txt", + "3.14-ft": "//build:requirements_lock_3_14_ft.txt", }, local_wheel_inclusion_list = [ "jax-*", diff --git a/build/build.py b/build/build.py index f8c0ccbfa6a4..226d984b3d89 100755 --- a/build/build.py +++ b/build/build.py @@ -496,6 +496,7 @@ async def main(): if args.use_clang: clang_path = args.clang_path or utils.get_clang_path_or_exit() clang_major_version = utils.get_clang_major_version(clang_path) + clangpp_path = utils.get_clangpp_path(clang_path) logging.debug( "Using Clang as the compiler, clang path: %s, clang version: %s", clang_path, @@ -505,6 +506,7 @@ async def main(): # Use double quotes around clang path to avoid path issues on Windows. wheel_build_command_base.append(f"--action_env=CLANG_COMPILER_PATH=\"{clang_path}\"") wheel_build_command_base.append(f"--repo_env=CC=\"{clang_path}\"") + wheel_build_command_base.append(f"--repo_env=CXX=\"{clangpp_path}\"") wheel_build_command_base.append(f"--repo_env=BAZEL_COMPILER=\"{clang_path}\"") if clang_major_version >= 16: diff --git a/build/requirements_lock_3_13_ft.txt b/build/requirements_lock_3_13_ft.txt index 5157706c00e8..a96a3e6e489b 100644 --- a/build/requirements_lock_3_13_ft.txt +++ b/build/requirements_lock_3_13_ft.txt @@ -658,7 +658,7 @@ zipp==3.21.0 \ --hash=sha256:2c9958f6430a2040341a52eb608ed6dd93ef4392e02ffe219417c1b28b5dd1f4 \ --hash=sha256:ac1bbe05fd2991f160ebce24ffbac5f6d11d83dc90891255885223d42b3cd931 # via etils -# python 3.13t can compile 0.23.0 +# python 3.13t can't compile 0.23.0 # due to https://github.com/indygreg/python-zstandard/issues/231 # zstandard==0.23.0 \ # --hash=sha256:034b88913ecc1b097f528e42b539453fa82c3557e414b3de9d5632c80439a473 \ diff --git a/build/requirements_lock_3_14_ft.txt b/build/requirements_lock_3_14_ft.txt new file mode 100644 index 000000000000..18e4ef6d576a --- /dev/null +++ b/build/requirements_lock_3_14_ft.txt @@ -0,0 +1,107 @@ +--pre +--extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple +numpy + +--pre +--extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple +scipy + +absl-py==2.1.0 + +attrs==24.3.0 + +auditwheel==6.2.0 + +build==1.2.2.post1 + +cloudpickle==3.1.1 # version 3.1.0 leads to recursion error + +colorama==0.4.6 + +contourpy==1.3.1 + +cycler==0.12.1 + +etils[epath,epy]==1.11.0 + +execnet==2.1.1 + +filelock==3.16.1 + +flatbuffers==24.12.23 + +fonttools==4.56.0 + +fsspec==2024.12.0 + +hypothesis==6.123.9 + +importlib-resources==6.5.2 + +iniconfig==2.0.0 + +kiwisolver==1.4.8 + +markdown-it-py==3.0.0 + +matplotlib==3.10.1 + +mdurl==0.1.2 + +ml-dtypes==0.5.1 + +mpmath==1.3.0 + +nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" + +nvidia-cuda-cupti-cu12==12.8.57 ; sys_platform == "linux" +nvidia-cuda-nvcc-cu12==12.8.61 ; sys_platform == "linux" +nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" +nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" +nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" +nvidia-cusolver-cu12==11.7.2.55 ; sys_platform == "linux" +nvidia-cusparse-cu12==12.5.7.53 ; sys_platform == "linux" +nvidia-nccl-cu12==2.25.1 ; sys_platform == "linux" + +nvidia-nvjitlink-cu12==12.8.61 ; sys_platform == "linux" +opt-einsum==3.4.0 + +packaging==24.2 + +pillow==11.1.0 +pluggy==1.5.0 + +portpicker==1.6.0 + +psutil==6.1.1 +pyelftools==0.31 + +pygments==2.19.1 + +pyparsing==3.2.2 # version 3.2.1 fails with SyntaxError(originally SyntaxWarning): 'return' in a 'finally' block in pyparsing/core.py", line 5716 + +pyproject-hooks==1.2.0 + +pytest==8.3.4 + +pytest-xdist==3.6.1 + +python-dateutil==2.9.0.post0 + +rich==13.9.4 + +six==1.17.0 + +sortedcontainers==2.4.0 + +typing-extensions==4.12.2 + +wheel==0.45.1 + +zipp==3.21.0 + +# python 3.14t can't compile 0.23.0 +# due to https://github.com/indygreg/python-zstandard/issues/231 +# zstandard==0.23.0 + +setuptools==70.3.0 diff --git a/build/tools/utils.py b/build/tools/utils.py index 8b8dc80d1e0f..ccce8aff09cc 100644 --- a/build/tools/utils.py +++ b/build/tools/utils.py @@ -202,6 +202,20 @@ def get_clang_major_version(clang_path): return major_version +def get_clangpp_path(clang_path): + clang_path = pathlib.Path(clang_path) + clang_exec_name = clang_path.stem + clangpp_exec_name = clang_exec_name + if "clang++" not in clang_exec_name: + clangpp_exec_name = clang_exec_name.replace("clang", "clang++") + clangpp_path = clang_path.parent / clangpp_exec_name + if not clangpp_path.exists(): + raise FileNotFoundError( + f"Failed to get clang++ path from clang path: '{clang_path!s}'. " + f"Tried the path: '{clangpp_path!s}'." + ) + return str(clangpp_path) + def get_gcc_major_version(gcc_path: str): gcc_version_proc = subprocess.run( [gcc_path, "-dumpversion"], diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 1cc4fab12591..93e9ebacfa6f 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -76,9 +76,9 @@ _CPU_PYPI_WHEEL_DEPS = [ "@pypi_jaxlib//:pkg", ] -# TODO(vam): remove this once zstandard builds against Python 3.13 +# TODO(vam): remove this once zstandard builds against Python >3.13 def get_zstandard(): - if HERMETIC_PYTHON_VERSION == "3.13" or HERMETIC_PYTHON_VERSION == "3.13-ft": + if HERMETIC_PYTHON_VERSION in ("3.13", "3.13-ft", "3.14", "3.14-ft"): return [] return ["@pypi_zstandard//:pkg"] From 076d021057722aa58d0621d79630ddfab4a64bce Mon Sep 17 00:00:00 2001 From: George Necula Date: Wed, 26 Mar 2025 12:39:23 +0200 Subject: [PATCH 323/483] [better_errors] Fix the handling of kwargs for debug_info. kwargs are passed sorted by the actual kwarg keyword. This order must be accounted for when we construct the `debug_info.arg_names`. Extended the tests to be more precise about not mixing up kwargs, e.g., use different shapes and look for the shape in the HLO. --- jax/_src/api.py | 2 +- jax/_src/api_util.py | 57 ++++++++---- jax/_src/pjit.py | 4 +- tests/debug_info_test.py | 191 +++++++++++++++++++++++++++++++++------ 4 files changed, 207 insertions(+), 47 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 55e2b2126a68..fb10245c30e9 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -192,7 +192,7 @@ def jit( constant). Static arguments should be hashable, meaning both ``__hash__`` and - ``__eq__`` are implemented, and immutable. Otherwise they can be arbitrary + ``__eq__`` are implemented, and immutable. Otherwise, they can be arbitrary Python objects. Calling the jitted function with different values for these constants will trigger recompilation. Arguments that are not array-like or containers thereof must be marked as static. diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index a42141b96fbd..451d2e490a15 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -28,12 +28,12 @@ from jax._src.tree_util import ( PyTreeDef, tree_flatten, tree_unflatten, tree_map, treedef_children, generate_key_paths, broadcast_prefix, - prefix_errors) -from jax._src.tree_util import _replace_nones + prefix_errors, _replace_nones) from jax._src import linear_util as lu from jax._src.util import (safe_map, WrapKwArgs, Hashable, HashableFunction, - Unhashable, safe_zip) + Unhashable, safe_zip as zip) from jax._src import traceback_util + traceback_util.register_exclusion(__file__) map = safe_map @@ -201,9 +201,11 @@ def _validate_argnames( f"in {argnames_name}. Function does not take these args.") -def argnums_partial(f, dyn_argnums, args, require_static_args_hashable=True): +def argnums_partial(f: lu.WrappedFun, dyn_argnums: int | Sequence[int], + args: Sequence, require_static_args_hashable=True): dyn_argnums = _ensure_index_tuple(dyn_argnums) dyn_argnums = _ensure_inbounds(False, len(args), dyn_argnums) + fixed_args: list if require_static_args_hashable: fixed_args = [] for i, arg in enumerate(args): @@ -273,7 +275,9 @@ def argnums_partial_except(f: lu.WrappedFun, static_argnums: tuple[int, ...], return _argnums_partial(f, dyn_argnums, tuple(fixed_args)), dyn_args @lu.transformation2 -def _argnums_partial(_fun, _dyn_argnums, _fixed_args, *dyn_args, **kwargs): +def _argnums_partial(_fun: Callable, + _dyn_argnums: Sequence[int], + _fixed_args: Sequence, *dyn_args, **kwargs): sentinel = object() args = [sentinel] * (len(_fixed_args) + len(dyn_args)) for i, arg in zip(_dyn_argnums, dyn_args): @@ -334,7 +338,7 @@ def donation_vector(donate_argnums, donate_argnames, in_tree, donate = bool(i in donate_argnums) res.extend((donate,) * arg.num_leaves) if kwargs_tree is not None: - for key, val in safe_zip(kwargs_tree.node_data()[1], kwargs_tree.children()): # type: ignore + for key, val in zip(kwargs_tree.node_data()[1], kwargs_tree.children()): # type: ignore donate = key in donate_argnames res.extend((donate,) * val.num_leaves) return tuple(res) @@ -673,28 +677,45 @@ def _non_static_arg_names(fn_signature: inspect.Signature | None, top-level arguments. In other cases, including when the `args` and `kwargs` do not match the signature, we use names like `args[0[]`, `args[1]`, etc. """ + # Use the same argument parsing as jit: positional followed by kwargs + # sorted by keys. static = object() static_argnums_ = _ensure_inbounds(True, len(args), static_argnums) static_argnames_ = set(static_argnames) args_ = [static if i in static_argnums_ else x for i, x in enumerate(args)] - kwargs_ = {k:static if k in static_argnames_ else x for k, x in kwargs.items()} + kwargs_ = {k: static if k in static_argnames_ else x for k, x in kwargs.items()} + ordered_args: Sequence[tuple[str, Any]] | None = None if fn_signature is not None: try: ba = fn_signature.bind(*args_, **kwargs_) except (ValueError, TypeError): pass else: - return tuple(f'{name}{lu._clean_keystr_arg_names(path)}' - for name, x in ba.arguments.items() - for path, l in generate_key_paths(x) if l is not static) - args_arg_names = tuple(f'args{lu._clean_keystr_arg_names(path)}' - for path, l in generate_key_paths(args_) - if l is not static) - kwargs_arg_names = tuple(f'kwargs{lu._clean_keystr_arg_names(path)}' - for path, l in generate_key_paths(kwargs_) - if l is not static) - arg_names = args_arg_names + kwargs_arg_names - return arg_names + # Do we have a **kwargs + kwargs_name = next((name for name, p in fn_signature.parameters.items() + if p.kind == inspect.Parameter.VAR_KEYWORD), None) + # Positional argument are those not passed by keyword and not passed + # by **kwargs. + positional = [(name, x) for name, x in ba.arguments.items() + if name not in kwargs and name != kwargs_name] + # Keyword arguments are passed sorted by actual kwarg keyword + sorted_kwargs = sorted(((name, x) for name, x in kwargs_.items()), + key=lambda name_x: name_x[0]) + sorted_kwargs = [(name if name in ba.arguments else f"{kwargs_name}['{name}']", + x) + for name, x in sorted_kwargs] + ordered_args = positional + sorted_kwargs + + if ordered_args is None: + positional = [("args", args_)] + keyword = sorted([(f"kwargs['{name}']", x) for name, x in kwargs_.items() if x is not static], + key=lambda name_x: name_x[0]) + ordered_args = positional + keyword + + return tuple(f'{name}{lu._clean_keystr_arg_names(path)}' + for name, x in ordered_args + for path, l in generate_key_paths(x) if l is not static) + def hoist_obj_attrs(f, flat_args): idxs, objs, flat_args_ = [], [], [] diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 5727c36a646b..af744ae5db96 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -683,7 +683,7 @@ def __init__(self): # We use an outer cache that is keyed on the signature of the arguments, but # when populating a cache entry using _infer_params_impl, we need to provide -# actual arguments. In principle we could refactor _infer_params_impl to look +# actual arguments. In principle, we could refactor _infer_params_impl to look # only at an argument signature instead of args/kwargs in those cases that we # cache, but this was a more minimal change. @util.weakref_lru_cache @@ -730,7 +730,7 @@ def _infer_params_internal( if entry.pjit_params is None: p, args_flat = _infer_params_impl( fun, ji, ctx_mesh, dbg, args, kwargs, in_avals=avals) - if p.attrs_tracked: # if attrs, don't popoulate the cache + if p.attrs_tracked: # if attrs, don't populate the cache return p, p.consts + args_flat entry.pjit_params = p return entry.pjit_params, entry.pjit_params.consts + dynargs diff --git a/tests/debug_info_test.py b/tests/debug_info_test.py index 1d2935ea34d7..8ec5c42ef24c 100644 --- a/tests/debug_info_test.py +++ b/tests/debug_info_test.py @@ -46,6 +46,7 @@ from jax._src.compilation_cache import is_persistent_cache_enabled from jax._src.lax.control_flow import for_loop from jax._src.interpreters import mlir +from jax._src import util as util import numpy as np @@ -241,7 +242,7 @@ def my_f(x, y, z, w): dbg = api_util.debug_info("jit", my_f, (1, 2), dict(z=3, w=4)) self.assertRegex(dbg.func_src_info, r"^my_f at .*debug_info_test.py:\d+") self.assertEqual(dbg.func_name, "my_f") - self.assertEqual(dbg.arg_names, ("x", "y", "z", "w")) + self.assertEqual(dbg.arg_names, ("x", "y", "w", "z")) self.assertIsNone(dbg.result_paths) def test_debug_info_arg_passed_as_kwarg(self): @@ -261,23 +262,29 @@ def my_f(x_tree, *, y_tree): "y_tree['w']", "y_tree['z']")) def test_debug_info_with_statics(self): - def my_f(x, y, *, z, w): + def my_f(x, z, *, w, y): pass - dbg = api_util.debug_info("jit", my_f, (1, 2), dict(z=3, w=4), + dbg = api_util.debug_info("jit", my_f, (1,), dict(y=2, z=3, w=4), static_argnums=(1,), static_argnames=("w",)) - self.assertEqual(dbg.arg_names, ("x", "z")) + self.assertEqual(dbg.arg_names, ("x", "y", "z")) def test_debug_info_with_pytrees_and_statics(self): - def my_f(x, y, *, z, w): + def my_f(x, y, *, z, w, t): pass dbg = api_util.debug_info("jit", my_f, ((1, 2), (2, 3)), - dict(z=(3, 4), w=(5, 6)), + dict(z=(3, 4), w=(5, 6), t=7), static_argnums=(1,), static_argnames=("w",)) - self.assertEqual(dbg.arg_names, ("x[0]", "x[1]", "z[0]", "z[1]")) + self.assertEqual(dbg.arg_names, ("x[0]", "x[1]", "t", "z[0]", "z[1]")) + + dbg = api_util.debug_info("jit", my_f, ((1, 2),), + dict(z=(3, 4), w=(5, 6), t=7, y=3), + static_argnums=(1,), + static_argnames=("w",)) + self.assertEqual(dbg.arg_names, ("x[0]", "x[1]", "t", "y", "z[0]", "z[1]")) def test_debug_info_too_many_args(self): def my_f(x): @@ -287,7 +294,7 @@ def my_f(x): self.assertEqual(dbg.arg_names, ('args[0]', 'args[1]', 'args[2]', "kwargs['z']")) def test_debug_info_no_source_info_built_in(self): - # built-in function "int" does not have an inspect.Signature + # built-in function "max" does not have an inspect.Signature dbg = api_util.debug_info("jit", max, (1,), {}) self.assertEqual(dbg.func_src_info, "max") self.assertEqual(dbg.arg_names, ("args[0]",)) @@ -761,6 +768,122 @@ def f(x, y, *args, **kwargs): re.compile(r".*func.func public @main\(.*\{jax.result_info = \"result\"\}"), ]) + def test_jit_arg_names_with_out_of_order_kwargs(self): + tracer_spy = TracerSpy() + + # The shapes are different, to differentiate them easily + a1 = (np.float32(0),) # a hashable tuple, can be static + b2 = np.arange(2, dtype=np.float32) # b2 + z3 = np.arange(3, dtype=np.float32) + y4 = (np.float32(0.), np.float32(1.), np.float32(2.), np.float32(3.)) + x5 = np.arange(5, dtype=np.float32) + u6 = np.arange(6, dtype=np.float32) + t7 = np.arange(7, dtype=np.float32) + + def my_f(a1, b2, z3, y4, x5, *, u6, t7): + assert np.shape(a1[0]) == () + assert np.shape(b2) == (2,) + assert np.shape(z3) == (3,) + assert np.shape(y4) == (4,) + assert np.shape(x5) == (5,) + assert np.shape(u6) == (6,) + assert np.shape(t7) == (7,) + tracer_spy.append(b2) + tracer_spy.append(x5) + return a1[0] + b2[0] + z3[0] + y4[0] + x5[0] + u6[0] + t7[0] + + self._check_tracers_and_jaxprs( + jax.jit(my_f, static_argnums=(0,), static_argnames=("y4",)), + # Some positional args passed as keyword + a1, b2, x5=x5, y4=y4, z3=z3, t7=t7, u6=u6, + expected_jaxpr_debug_infos=[ + "traced_for=jit, fun=my_f, arg_names=b2,t7,u6,x5,z3, result_paths=result", + ], + tracer_spy=tracer_spy, + expected_tracer_debug_infos=[ + "traced_for=jit, fun=my_f, arg_names=b2,t7,u6,x5,z3, from b2", + "traced_for=jit, fun=my_f, arg_names=b2,t7,u6,x5,z3, from x5", + ], + expected_lowering_lines=[ + re.compile(r".*func.func public @main\(%arg0: tensor<2xf..> loc\(\"b2\"\)"), + re.compile(r".*func.func public @main\(.*%arg1: tensor<7xf..> loc\(\"t7\"\)"), + re.compile(r".*func.func public @main\(.*%arg2: tensor<6xf..> loc\(\"u6\"\)"), + re.compile(r".*func.func public @main\(.*%arg3: tensor<5xf..> loc\(\"x5\"\)"), + ] + ) + + tracer_spy.tracers = [] + util.clear_all_caches() + self._check_tracers_and_jaxprs( + jax.jit(my_f, static_argnames=("y4",)), + # Positional argument y4 is static and passed by kwarg + a1, b2, z3, x5=x5, y4=y4, t7=t7, u6=u6, + expected_jaxpr_debug_infos=[ + "traced_for=jit, fun=my_f, arg_names=a1[0],b2,z3,t7,u6,x5, result_paths=result", + ], + tracer_spy=tracer_spy, + expected_tracer_debug_infos=[ + "traced_for=jit, fun=my_f, arg_names=a1[0],b2,z3,t7,u6,x5, from b2", + "traced_for=jit, fun=my_f, arg_names=a1[0],b2,z3,t7,u6,x5, from x5", + ], + expected_lowering_lines=[ + re.compile(r".*func.func public @main\(%arg0: tensor loc\(\"a1\[0\]\"\)"), + re.compile(r".*func.func public @main\(.*%arg1: tensor<2xf..> loc\(\"b2\"\)"), + re.compile(r".*func.func public @main\(.*%arg2: tensor<3xf..> loc\(\"z3\"\)"), + re.compile(r".*func.func public @main\(.*%arg3: tensor<7xf..> loc\(\"t7\"\)"), + re.compile(r".*func.func public @main\(.*%arg4: tensor<6xf..> loc\(\"u6\"\)"), + re.compile(r".*func.func public @main\(.*%arg5: tensor<5xf..> loc\(\"x5\"\)"), + ] + ) + + tracer_spy.tracers = [] + util.clear_all_caches() + self._check_tracers_and_jaxprs( + jax.jit(my_f, static_argnames=("y4",)), + # Positional argument y4 is static (declared as static_argnames) + a1, b2, z3, y4, x5=x5, t7=t7, u6=u6, + expected_jaxpr_debug_infos=[ + "traced_for=jit, fun=my_f, arg_names=a1[0],b2,z3,t7,u6,x5, result_paths=result", + ], + tracer_spy=tracer_spy, + expected_tracer_debug_infos=[ + "traced_for=jit, fun=my_f, arg_names=a1[0],b2,z3,t7,u6,x5, from b2", + "traced_for=jit, fun=my_f, arg_names=a1[0],b2,z3,t7,u6,x5, from x5", + ], + expected_lowering_lines=[ + re.compile(r".*func.func public @main\(%arg0: tensor loc\(\"a1\[0\]\"\)"), + re.compile(r".*func.func public @main\(.*%arg1: tensor<2xf..> loc\(\"b2\"\)"), + re.compile(r".*func.func public @main\(.*%arg2: tensor<3xf..> loc\(\"z3\"\)"), + re.compile(r".*func.func public @main\(.*%arg3: tensor<7xf..> loc\(\"t7\"\)"), + re.compile(r".*func.func public @main\(.*%arg4: tensor<6xf..> loc\(\"u6\"\)"), + re.compile(r".*func.func public @main\(.*%arg5: tensor<5xf..> loc\(\"x5\"\)"), + ] + ) + + tracer_spy.tracers = [] + util.clear_all_caches() + self._check_tracers_and_jaxprs( + jax.jit(my_f, static_argnums=(3,)), + # Positional argument y4 is static (declared as static_argnums) + a1, b2, z3, y4, x5=x5, t7=t7, u6=u6, + expected_jaxpr_debug_infos=[ + "traced_for=jit, fun=my_f, arg_names=a1[0],b2,z3,t7,u6,x5, result_paths=result", + ], + tracer_spy=tracer_spy, + expected_tracer_debug_infos=[ + "traced_for=jit, fun=my_f, arg_names=a1[0],b2,z3,t7,u6,x5, from b2", + "traced_for=jit, fun=my_f, arg_names=a1[0],b2,z3,t7,u6,x5, from x5", + ], + expected_lowering_lines=[ + re.compile(r".*func.func public @main\(%arg0: tensor loc\(\"a1\[0\]\"\)"), + re.compile(r".*func.func public @main\(.*%arg1: tensor<2xf..> loc\(\"b2\"\)"), + re.compile(r".*func.func public @main\(.*%arg2: tensor<3xf..> loc\(\"z3\"\)"), + re.compile(r".*func.func public @main\(.*%arg3: tensor<7xf..> loc\(\"t7\"\)"), + re.compile(r".*func.func public @main\(.*%arg4: tensor<6xf..> loc\(\"u6\"\)"), + re.compile(r".*func.func public @main\(.*%arg5: tensor<5xf..> loc\(\"x5\"\)"), + ] + ) + def test_jit_result_info(self): def f(x, y, z): return {'a': x, 'b': [y]} @@ -1493,34 +1616,50 @@ def my_f(x): def test_pmap_with_arg_and_result_names(self): tracer_spy = TracerSpy() - x = np.ones((jax.device_count(),), dtype=np.float32) - def my_f(x, y, *args, a, **kwargs): - # y and kwargs[c] is dead + + # Use different shapes arguments to distinguish them in the HLO + def my_f(x0, y1, *args, b4, **kwargs): + assert np.shape(x0) == () + assert np.shape(y1) == (1,) + assert np.shape(args[0]) == (2,) + assert np.shape(args[1]) == (3,) + assert np.shape(b4) == (4,) + assert np.shape(kwargs["a5"]) == (5,) + assert np.shape(kwargs["c6"]) == (6,) + # kwargs[b5] is dead tracer_spy.append(args[1]) - s = x + a + args[1] + kwargs["d"] - return dict(u=s, v=x) + tracer_spy.append(b4) + tracer_spy.append(kwargs["c6"]) + s0 = x0 + y1[0] + b4[0] + args[1][0] + kwargs["c6"][0] + return dict(v1=jnp.broadcast_to(s0, (1,)), u0=s0) self._check_tracers_and_jaxprs( jax.pmap(my_f, static_broadcasted_argnums=(0,)), - 1., x, x, x, # x, y, args[0], args[1] - d=x, a=x, b=x, # kwargs + 1., # x0 + np.ones((jax.device_count(), 1), dtype=np.float32), # y1 + np.ones((jax.device_count(), 2), dtype=np.float32), # args[0] + np.ones((jax.device_count(), 3), dtype=np.float32), # args[1] + b4=np.ones((jax.device_count(), 4), dtype=np.float32), + a5=np.ones((jax.device_count(), 5), dtype=np.float32), + c6=np.ones((jax.device_count(), 6), dtype=np.float32), expected_jaxpr_debug_infos=[ - "traced_for=pmap, fun=my_f, arg_names=y,args[0],args[1],a,kwargs['b'],kwargs['d'], result_paths=result['u'],result['v']", + "traced_for=pmap, fun=my_f, arg_names=y1,args[0],args[1],kwargs['a5'],b4,kwargs['c6'], result_paths=result['u0'],result['v1']", ], tracer_spy=tracer_spy, expected_tracer_debug_infos=[ - "traced_for=pmap, fun=my_f, arg_names=y,args[0],args[1],a,kwargs['b'],kwargs['d'], from args[1]", + "traced_for=pmap, fun=my_f, arg_names=y1,args[0],args[1],kwargs['a5'],b4,kwargs['c6'], from args[1]", + "traced_for=pmap, fun=my_f, arg_names=y1,args[0],args[1],kwargs['a5'],b4,kwargs['c6'], from b4", + "traced_for=pmap, fun=my_f, arg_names=y1,args[0],args[1],kwargs['a5'],b4,kwargs['c6'], from kwargs['c6']", ], expected_lowering_lines=[ - # TODO(necula): we did not DCE y? - re.compile(r".*func.func public @main\(.*%arg0: tensor<1xf..> loc\(\"y\"\)"), - re.compile(r".*func.func public @main\(.*%arg1: tensor<1xf..> loc\(\"args\[0\]\"\)"), - re.compile(r".*func.func public @main\(.*%arg2: tensor<1xf..> loc\(\"args\[1\]\"\)"), - re.compile(r".*func.func public @main\(.*%arg3: tensor<1xf..> loc\(\"a\"\)"), - re.compile(r".*func.func public @main\(.*%arg4: tensor<1xf..> loc\(\"kwargs\['b'\]\"\)"), - re.compile(r".*func.func public @main\(.*%arg5: tensor<1xf..> loc\(\"kwargs\['d'\]\"\)"), - re.compile(r".*func.func public @main\(.* -> .*\{jax.result_info = \"result\['u'\]\"\}"), - re.compile(r".*func.func public @main\(.* -> .*\{jax.result_info = \"result\['v'\]\"\}"), + re.compile(r".*func.func public @main\(.*%arg0: tensor<1x1xf..> loc\(\"y1\"\)"), + re.compile(r".*func.func public @main\(.*%arg1: tensor<1x2xf..> loc\(\"args\[0\]\"\)"), + re.compile(r".*func.func public @main\(.*%arg2: tensor<1x3xf..> loc\(\"args\[1\]\"\)"), + re.compile(r".*func.func public @main\(.*%arg3: tensor<1x5xf..> loc\(\"kwargs\['a5'\]\"\)"), + re.compile(r".*func.func public @main\(.*%arg4: tensor<1x4xf..> loc\(\"b4\"\)"), + re.compile(r".*func.func public @main\(.*%arg5: tensor<1x6xf..> loc\(\"kwargs\['c6'\]\"\)"), + re.compile(r".*func.func public @main\(.* -> .*\{jax.result_info = \"result\['u0'\]\"\}"), + re.compile(r".*func.func public @main\(.* -> .*\{jax.result_info = \"result\['v1'\]\"\}"), ] ) From 82ec5737ff3e2466a8fa5615e5a49ebfbdbcd99e Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Wed, 2 Apr 2025 03:33:23 -0700 Subject: [PATCH 324/483] Remove nanobind pin now that nanobind fix landed. Reverts 33d306ab4090c17b427908853b314e17cb449661 PiperOrigin-RevId: 743062185 --- examples/ffi/pyproject.toml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/ffi/pyproject.toml b/examples/ffi/pyproject.toml index 6f188ee037da..130dd91bbc70 100644 --- a/examples/ffi/pyproject.toml +++ b/examples/ffi/pyproject.toml @@ -1,7 +1,5 @@ [build-system] -# TODO(dsuo): Remove nanobind pin after -# https://github.com/wjacob/nanobind/pull/980 lands. -requires = ["scikit-build-core", "nanobind==2.5.0", "jax>=0.4.31"] +requires = ["scikit-build-core", "nanobind", "jax>=0.4.31"] build-backend = "scikit_build_core.build" [project] From 735cec18cb2f8dff2aea5e503fd886a37aee094e Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Wed, 2 Apr 2025 03:39:25 -0700 Subject: [PATCH 325/483] [jaxlib] Fix asan tests for subbyte types in CPU/GPU callbacks. Reverts 0b199f48c7e0d4e5837cee34ced7f3fc7065732f PiperOrigin-RevId: 743063615 --- jaxlib/cuda/BUILD | 1 + jaxlib/gpu/py_client_gpu.cc | 88 ++++++++++++++++++++------------ jaxlib/rocm/BUILD | 1 + jaxlib/xla/BUILD | 1 + jaxlib/xla/py_client_cpu.cc | 87 +++++++++++++++++++++++--------- tests/python_callback_test.py | 94 ++++++++++++++++++++--------------- 6 files changed, 177 insertions(+), 95 deletions(-) diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index fac62c81dee7..d35e421ef904 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -689,6 +689,7 @@ cc_library( "@xla//third_party/python_runtime:headers", # buildcleaner: keep "@xla//xla:comparison_util", "@xla//xla:shape_util", + "@xla//xla:util", "@xla//xla:xla_data_proto_cc", "@xla//xla/ffi:ffi_api", "@xla//xla/ffi/api:ffi", diff --git a/jaxlib/gpu/py_client_gpu.cc b/jaxlib/gpu/py_client_gpu.cc index 861ffce3e749..e3aec51d8d25 100644 --- a/jaxlib/gpu/py_client_gpu.cc +++ b/jaxlib/gpu/py_client_gpu.cc @@ -44,6 +44,7 @@ limitations under the License. #include "xla/python/types.h" #include "xla/shape_util.h" #include "xla/xla_data.pb.h" +#include "xla/util.h" namespace nb = nanobind; @@ -81,8 +82,7 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, auto arg = args.get(i); auto ptype = static_cast(arg->element_type()); // TODO(b/395428868): Remove this check once we support subbyte types. - if (ptype == xla::S1 || ptype == xla::S2 || ptype == xla::S4 || - ptype == xla::U1 || ptype == xla::U2 || ptype == xla::U4) { + if (ptype == xla::S1 || ptype == xla::U1) { return xla::ffi::Error(xla::ffi::ErrorCode::kUnimplemented, absl::StrFormat("Unsupported primitive type: %s", PrimitiveType_Name(ptype))); @@ -112,9 +112,6 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, PyTuple_SET_ITEM(host_input_arrays.ptr(), i, nb::none().inc_ref().ptr()); continue; } - nb::capsule base(host_input_buffers[i], [](void* ptr) noexcept { - delete[] static_cast(ptr); - }); auto maybe_dtype = PrimitiveTypeToNbDtype(ptype); if (!maybe_dtype.ok()) { return xla::ffi::Error::Internal(maybe_dtype.status().ToString()); @@ -122,6 +119,23 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, auto dtype = maybe_dtype.value(); auto dims = absl::Span(arg->dimensions().begin(), arg->dimensions().size()); + // TODO(b/402422886): Remove this once we form Jax arrays directly instead + // of packing/unpacking to/from numpy arrays. + // We pass in data using default numpy layout i.e., std::nullopt. + size_t bits_per_element = xla::primitive_util::BitWidth(ptype); + if (bits_per_element == 2 || bits_per_element == 4) { + // NOTE(dsuo): FFI argument and return buffers are sized assuming + // minimum 1-byte element sizes, even if the data itself is packed. + size_t packed_size = arg->size_bytes() * bits_per_element / 8; + auto buffer = xla::UnpackIntN( + bits_per_element, static_cast(host_input_buffers[i]), + packed_size); + delete[] static_cast(host_input_buffers[i]); + host_input_buffers[i] = buffer.release(); + } + nb::capsule base(host_input_buffers[i], [](void* ptr) noexcept { + delete[] static_cast(ptr); + }); auto array = xla::nb_numpy_ndarray(dtype, dims, std::nullopt, host_input_buffers[i], base); array.attr("flags").attr("writeable") = nb::bool_(false); @@ -146,8 +160,7 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, auto ret = rets.get(i).value(); auto ptype = static_cast(ret->element_type()); // TODO(b/395428868): Remove this check once we support subbyte types. - if (ptype == xla::S1 || ptype == xla::S2 || ptype == xla::S4 || - ptype == xla::U1 || ptype == xla::U2 || ptype == xla::U4) { + if (ptype == xla::S1 || ptype == xla::U1) { return xla::ffi::Error(xla::ffi::ErrorCode::kUnimplemented, absl::StrFormat("Unsupported primitive type: %s", PrimitiveType_Name(ptype))); @@ -168,32 +181,45 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, } auto expected_shape = maybe_expected_shape.value(); auto expected_strides = xla::ByteStridesForShape(expected_shape); - if (strides == expected_strides) { - auto gpu_res = - gpuMemcpyAsync(ret->untyped_data(), array.data(), ret->size_bytes(), - gpuMemcpyHostToDevice, stream); - CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; - continue; + + const void* data = array.data(); + size_t size_bytes = array.size() * array.itemsize(); + if (strides != expected_strides) { + xla::TransposePlan::Options options; + options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); + options.dims = absl::Span( + reinterpret_cast(array.shape()), array.ndim()); + absl::InlinedVector reversed_layout; + reversed_layout.resize(expected_shape.dimensions().size()); + absl::c_reverse_copy(expected_shape.layout().minor_to_major(), + reversed_layout.begin()); + options.permutation = reversed_layout; + options.input_layout = xla::TransposePlan::Striding{strides}; + auto maybe_plan = transpose_cache->cache.GetOrCreate(options); + if (!maybe_plan.ok()) { + return xla::ffi::Error::Internal(maybe_plan.status().ToString()); + } + auto plan = maybe_plan.value(); + void* temp = new char[size_bytes]; + temp_buffers.push_back(temp); + plan->Execute(data, temp); + data = temp; } - void* temp = new char[ret->size_bytes()]; - temp_buffers.push_back(temp); - xla::TransposePlan::Options options; - options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); - options.dims = absl::Span( - reinterpret_cast(array.shape()), array.ndim()); - absl::InlinedVector reversed_layout; - reversed_layout.resize(expected_shape.dimensions().size()); - absl::c_reverse_copy(expected_shape.layout().minor_to_major(), - reversed_layout.begin()); - options.permutation = reversed_layout; - options.input_layout = xla::TransposePlan::Striding{strides}; - auto maybe_plan = transpose_cache->cache.GetOrCreate(options); - if (!maybe_plan.ok()) { - return xla::ffi::Error::Internal(maybe_plan.status().ToString()); + + // TODO(b/402422886): Remove this once we form Jax arrays directly instead + // of packing/unpacking to/from numpy arrays. + std::unique_ptr buffer; + size_t bits_per_element = xla::primitive_util::BitWidth(ptype); + if (bits_per_element == 2 || bits_per_element == 4) { + // NOTE(dsuo): FFI arguments and return buffers are sized assuming + // minimum 1-byte element sizes, even if the data itself is packed. + buffer = xla::PackIntN(bits_per_element, static_cast(data), + size_bytes); + data = buffer.get(); + size_bytes = (size_bytes * bits_per_element) / 8; } - auto plan = maybe_plan.value(); - plan->Execute(array.data(), temp); - auto gpu_res = gpuMemcpyAsync(ret->untyped_data(), temp, ret->size_bytes(), + + auto gpu_res = gpuMemcpyAsync(ret->untyped_data(), data, size_bytes, gpuMemcpyHostToDevice, stream); CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; } diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index d0c0c798abb8..358a6d1cc9aa 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -588,6 +588,7 @@ cc_library( "@xla//third_party/python_runtime:headers", # buildcleaner: keep "@xla//xla:comparison_util", "@xla//xla:shape_util", + "@xla//xla:util", "@xla//xla:xla_data_proto_cc", "@xla//xla/ffi:ffi_api", "@xla//xla/ffi/api:ffi", diff --git a/jaxlib/xla/BUILD b/jaxlib/xla/BUILD index 2ca18afda13d..5b532c1dc501 100644 --- a/jaxlib/xla/BUILD +++ b/jaxlib/xla/BUILD @@ -637,6 +637,7 @@ cc_library( "@nanobind", "@xla//third_party/python_runtime:headers", # buildcleaner: keep "@xla//xla:shape_util", + "@xla//xla:util", "@xla//xla:xla_data_proto_cc", "@xla//xla/ffi:ffi_api", "@xla//xla/ffi/api:ffi", diff --git a/jaxlib/xla/py_client_cpu.cc b/jaxlib/xla/py_client_cpu.cc index ac4e7bee5680..fef6a54aab2d 100644 --- a/jaxlib/xla/py_client_cpu.cc +++ b/jaxlib/xla/py_client_cpu.cc @@ -41,6 +41,7 @@ limitations under the License. #include "xla/python/nb_numpy.h" #include "xla/python/types.h" #include "xla/shape_util.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" namespace nb = nanobind; @@ -79,8 +80,7 @@ ffi::Error XlaFfiPythonCpuCallback(FfiLoadedHostCallbacks* callbacks, auto arg = args.get(i); auto ptype = static_cast(arg->element_type()); // TODO(b/395428868): Remove this check once we support subbyte types. - if (ptype == S1 || ptype == S2 || ptype == S4 || ptype == U1 || - ptype == U2 || ptype == U4) { + if (ptype == S1 || ptype == U1) { return ffi::Error(ffi::ErrorCode::kUnimplemented, absl::StrFormat("Unsupported primitive type: %s", PrimitiveType_Name(ptype))); @@ -96,9 +96,20 @@ ffi::Error XlaFfiPythonCpuCallback(FfiLoadedHostCallbacks* callbacks, auto dtype = maybe_dtype.value(); auto dims = absl::Span(arg->dimensions().begin(), arg->dimensions().size()); + // TODO(b/402422886): Remove this once we form Jax arrays directly instead + std::unique_ptr buffer; + const void* data = arg->untyped_data(); + size_t bits_per_element = xla::primitive_util::BitWidth(ptype); + if (bits_per_element == 2 || bits_per_element == 4) { + // NOTE(dsuo): FFI argument and return buffers are sized assuming + size_t packed_size = arg->size_bytes() * bits_per_element / 8; + buffer = xla::UnpackIntN(bits_per_element, static_cast(data), + packed_size); + data = buffer.get(); + } // We pass in data using default numpy layout i.e., std::nullopt. auto array = - nb_numpy_ndarray(dtype, dims, std::nullopt, arg.value().untyped_data()); + nb_numpy_ndarray(dtype, dims, std::nullopt, data); array.attr("flags").attr("writeable") = nb::bool_(false); PyTuple_SET_ITEM(nb_args.ptr(), i, array.release().ptr()); } @@ -119,9 +130,8 @@ ffi::Error XlaFfiPythonCpuCallback(FfiLoadedHostCallbacks* callbacks, for (size_t i = 0; i < rets.size(); ++i) { auto ret = rets.get(i).value(); auto ptype = static_cast(ret->element_type()); - // TODO(b/395428868): Remove this check once we support subbyte types. - if (ptype == S1 || ptype == S2 || ptype == S4 || ptype == U1 || - ptype == U2 || ptype == U4) { + // TODO(b/402422886): Remove this once we form Jax arrays directly instead + if (ptype == S1 || ptype == U1) { return ffi::Error(ffi::ErrorCode::kUnimplemented, absl::StrFormat("Unsupported primitive type: %s", PrimitiveType_Name(ptype))); @@ -141,26 +151,55 @@ ffi::Error XlaFfiPythonCpuCallback(FfiLoadedHostCallbacks* callbacks, } auto expected_shape = maybe_expected_shape.value(); auto expected_strides = ByteStridesForShape(expected_shape); - if (strides == expected_strides) { - std::memcpy(ret->untyped_data(), array.data(), ret->size_bytes()); - continue; + + const void* data = array.data(); + std::unique_ptr buffer; + size_t bits_per_element = xla::primitive_util::BitWidth(ptype); + size_t size_bytes = array.size() * array.itemsize(); + if (strides != expected_strides) { + xla::TransposePlan::Options options; + options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); + options.dims = absl::Span( + reinterpret_cast(array.shape()), array.ndim()); + absl::InlinedVector reversed_layout; + reversed_layout.resize(expected_shape.dimensions().size()); + absl::c_reverse_copy(expected_shape.layout().minor_to_major(), + reversed_layout.begin()); + options.permutation = reversed_layout; + options.input_layout = xla::TransposePlan::Striding{strides}; + auto maybe_plan = transpose_cache->cache.GetOrCreate(options); + if (!maybe_plan.ok()) { + return ffi::Error::Internal(maybe_plan.status().ToString()); + } + auto plan = maybe_plan.value(); + if (bits_per_element == 2 || bits_per_element == 4) { + // NOTE(dsuo): If the data needs to be unpacked, don't use return buffer + // supplied by FFI directly. + buffer = std::make_unique(size_bytes); + plan->Execute(data, buffer.get()); + data = buffer.get(); + } else { + plan->Execute(data, ret->untyped_data()); + data = ret->untyped_data(); + } } - xla::TransposePlan::Options options; - options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); - options.dims = absl::Span( - reinterpret_cast(array.shape()), array.ndim()); - absl::InlinedVector reversed_layout; - reversed_layout.resize(expected_shape.dimensions_size()); - absl::c_reverse_copy(expected_shape.layout().minor_to_major(), - reversed_layout.begin()); - options.permutation = reversed_layout; - options.input_layout = xla::TransposePlan::Striding{strides}; - auto maybe_plan = transpose_cache->cache.GetOrCreate(options); - if (!maybe_plan.ok()) { - return ffi::Error::Internal(maybe_plan.status().ToString()); + + // TODO(b/402422886): Remove this once we form Jax arrays directly instead + // of packing/unpacking to/from numpy arrays. + if (bits_per_element == 2 || bits_per_element == 4) { + // NOTE(dsuo): FFI arguments and return buffers are sized assuming + // minimum 1-byte element sizes, even if the data itself is packed. + buffer = xla::PackIntN(bits_per_element, static_cast(data), + size_bytes); + data = buffer.get(); + size_bytes = (size_bytes * bits_per_element) / 8; + } + + // Copy data to output buffer if haven't already or modified the data to + // write back. + if (data != ret->untyped_data()) { + std::memcpy(ret->untyped_data(), data, size_bytes); } - auto plan = maybe_plan.value(); - plan->Execute(array.data(), ret->untyped_data()); } return ffi::Error::Success(); diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index a8442b4a1356..34ab20c05644 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -586,10 +586,15 @@ def fun(x): self.assertAllClose(2 * x, fun(x)) self.assertEqual(count(), 1) - @parameterized.parameters("int2", "int4", "uint2", "uint4") + @parameterized.parameters("int2", "int4", "uint2", "uint4", "float4_e2m1fn") def test_subbyte_operands(self, dtype: str): - if jaxlib_extension_version <= 321: - self.skipTest("Requires jaxlib_extension_version >= 322.") + if jaxlib_extension_version < 323: + self.skipTest("Requires jaxlib_extension_version >= 323.") + if "2" in dtype and jtu.test_device_matches(["tpu"]): + self.skipTest( + "TODO(dsuo): TPU callbacks send SIGABRT for int2, uint2, and" + " float4_e2m1fn." + ) def get(x): return x def f(x): @@ -600,21 +605,17 @@ def f(x): ) return y x = np.arange(8, dtype=dtype) - # TODO(b/395428868): Remove this check once we support subbyte types. - if jtu.test_device_matches(["tpu"]): - if "2" in dtype: - self.skipTest("TODO(dsuo): TPU callbacks send SIGABRT for int2/uint2.") - np.testing.assert_array_equal(jax.jit(f)(x), np.arange(8, dtype=dtype)) - else: - with self.assertRaisesRegex( - Exception, "Unsupported primitive type" - ): - _ = jax.jit(f)(x) + np.testing.assert_array_equal(jax.jit(f)(x), np.arange(8, dtype=dtype)) - @parameterized.parameters("int2", "int4", "uint2", "uint4") + @parameterized.parameters("int2", "int4", "uint2", "uint4", "float4_e2m1fn") def test_subbyte_results(self, dtype: str): - if jaxlib_extension_version <= 321: - self.skipTest("Requires jaxlib_extension_version >= 322.") + if jaxlib_extension_version < 323: + self.skipTest("Requires jaxlib_extension_version >= 323.") + if "2" in dtype and jtu.test_device_matches(["tpu"]): + self.skipTest( + "TODO(dsuo): TPU callbacks send SIGABRT for int2, uint2, and" + " float4_e2m1fn." + ) def get(): return np.arange(8, dtype=dtype) @@ -625,16 +626,43 @@ def f(): ) return y - # TODO(b/395428868): Remove this check once we support subbyte types. - if jtu.test_device_matches(["tpu"]): - if "2" in dtype: - self.skipTest("TODO(dsuo): TPU callbacks send SIGABRT for int2/uint2.") - np.testing.assert_array_equal(jax.jit(f)(), np.arange(8, dtype=dtype)) - else: - with self.assertRaisesRegex( - Exception, "Unsupported primitive type" - ): - _ = jax.jit(f)() + np.testing.assert_array_equal(jax.jit(f)(), np.arange(8, dtype=dtype)) + + @parameterized.parameters("int2", "int4", "uint2", "uint4", "float4_e2m1fn") + def test_non_default_stride_subbyte_results(self, dtype: str): + if jaxlib_extension_version < 323: + self.skipTest("Requires jaxlib_extension_version >= 323.") + if "2" in dtype and jtu.test_device_matches(["tpu"]): + self.skipTest( + "TODO(dsuo): TPU callbacks send SIGABRT for int2, uint2, and" + " float4_e2m1fn." + ) + x = jnp.arange(24, dtype=dtype).reshape(2, 3, 4) + def callback(x): + return np.asfortranarray(x) + + @jax.jit + def f(x): + return jax.pure_callback( + callback, jax.ShapeDtypeStruct(x.shape, x.dtype), x + ) + + result = f(x) + np.testing.assert_array_equal(x, result) + + def test_non_default_stride(self): + x = jnp.arange(24, dtype=jnp.float32).reshape(2, 3, 4) + def callback(x): + return np.asfortranarray(x) + + @jax.jit + def f(x): + return jax.pure_callback( + callback, jax.ShapeDtypeStruct(x.shape, x.dtype), x + ) + + result = f(x) + np.testing.assert_array_equal(x, result) class PureCallbackTest(jtu.JaxTestCase): @@ -1108,20 +1136,6 @@ def fun(x): result += fun(jnp.ones((500, 500), jnp.complex64))[1] jax.block_until_ready(result) # doesn't deadlock - def test_non_default_stride(self): - x = jnp.arange(24, dtype=jnp.float32).reshape(2, 3, 4) - def callback(x): - return np.asfortranarray(x) - - @jax.jit - def f(x): - return jax.pure_callback( - callback, jax.ShapeDtypeStruct(x.shape, x.dtype), x - ) - - result = f(x) - np.testing.assert_array_equal(x, result) - class IOCallbackTest(jtu.JaxTestCase): From 45d577d3dc12f894416ab1eefb1ef48a15b1f3da Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 2 Apr 2025 04:16:48 -0700 Subject: [PATCH 326/483] Prepare for disallowing `jnp.array(None)` PiperOrigin-RevId: 743074472 --- jax/_src/numpy/lax_numpy.py | 2 +- tests/lax_numpy_test.py | 12 ++++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index a47f66e5f621..ae32703e7113 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -5503,7 +5503,7 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, if any(leaf is None for leaf in leaves): # Added Nov 16 2023 if deprecations.is_accelerated("jax-numpy-array-none"): - raise TypeError("None is not a valid value for jnp.array") + raise ValueError("None is not a valid value for jnp.array") warnings.warn( "None encountered in jnp.array(); this is currently treated as NaN. " "In the future this will result in an error.", diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index c0650441edd7..f94f42f027ce 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -47,6 +47,7 @@ from jax._src import array from jax._src import config from jax._src import core +from jax._src import deprecations from jax._src import dtypes from jax._src import test_util as jtu from jax._src.lax import lax as lax_internal @@ -3796,8 +3797,15 @@ def testArrayFromList(self): jnp.array([0, val]) def testArrayNoneWarning(self): - # TODO(jakevdp): make this an error after the deprecation period. - with self.assertWarnsRegex(FutureWarning, r"None encountered in jnp.array\(\)"): + if deprecations.is_accelerated('jax-numpy-array-none'): + ctx = self.assertRaisesRegex( + ValueError, 'None is not a valid value for jnp.array' + ) + else: + ctx = self.assertWarnsRegex( + FutureWarning, r'None encountered in jnp.array\(\)' + ) + with ctx: jnp.array([0.0, None]) def testIssue121(self): From 0bee42b6cebe844d151ee4047406af0142756998 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 2 Apr 2025 05:26:24 -0700 Subject: [PATCH 327/483] Update XLA dependency to use revision http://github.com/openxla/xla/commit/c3087e022f3c07f7ed1dd4e47024c437a504341b. PiperOrigin-RevId: 743093178 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 13223c4a4b88..90a19ac95e51 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "b1971cc2b3407e87fada2674a057d72897b79acc" -XLA_SHA256 = "3b2feabbcd6adc5721533edfbe3dc2ad6517cb1b059cf41dea63f62874bff12d" +XLA_COMMIT = "c3087e022f3c07f7ed1dd4e47024c437a504341b" +XLA_SHA256 = "66457303ddec4dbbe43accf38a8b6b635d55808938cf2495443b09ee9c95a147" def repo(): tf_http_archive( From 10b2cda90e9066dc9c02ae5a068dc47a1e745a2a Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 2 Apr 2025 06:09:36 -0700 Subject: [PATCH 328/483] Relax the aval check in `select_hlo_lowering_opaque` to only check for shardings if they are not empty. The same thing happens in select_p's sharding rule PiperOrigin-RevId: 743105350 --- jax/_src/lax/lax.py | 6 +++++- tests/pjit_test.py | 20 ++++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index fd956136ccd3..ac6054328f73 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -7357,7 +7357,11 @@ def _select_jvp(primals, tangents): def _select_hlo_lowering_opaque(ctx, which, *cases): avals_in = ctx.avals_in aval_out, = ctx.avals_out - assert all(aval_case == aval_out for aval_case in avals_in[1:]) + assert all((aval_case.shape, aval_case.dtype) == (aval_out.shape, aval_out.dtype) + for aval_case in avals_in[1:]) + assert all( + aval_case == aval_out for aval_case in avals_in[1:] + if not aval_case.sharding.mesh.empty and not aval_out.sharding.mesh.empty) select_lower = _select_hlo_lowering physical_aval_out = core.physical_aval(aval_out) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index ee4a8cd3e15e..38f191302ea1 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7365,6 +7365,26 @@ def h(y): out = h(arr) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + def test_scan_with_random_key_inside_jit(self): + mesh = jtu.create_mesh((2,), ('x',)) + sharding = NamedSharding(mesh, P(None, 'x')) + + @jax.jit + def scan(xs): + def step(carry, x): + next_carry = jax.vmap(jax.random.fold_in)(carry, x) + next_carry = jnp.where(x % 2 == 0, carry, next_carry) + return next_carry, None + rng = jnp.broadcast_to(jax.random.key(0), xs.shape[1:]) + rng, _ = jax.lax.scan(step, rng, xs) + return rng + + xs = jnp.arange(8).reshape(2, 4) + scan(xs) + + xs = jax.device_put(xs, sharding) + scan(xs) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): From 6242ffb1ca207e783150176cbca6d97db6fc3325 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Wed, 2 Apr 2025 07:40:02 -0700 Subject: [PATCH 329/483] Remove unused Attrs from `lu_pivots_to_permutation` FFI kernel. It has been more than 6 months since the release of 0.4.32 which was the first release to stop including `permutation_size` as an attribute when lowering, so it is now safe (via our compatibility policy) to remove this argument. PiperOrigin-RevId: 743132169 --- .../cuda_lu_pivots_to_permutation.py | 25 ++++++++----------- jaxlib/gpu/linalg_kernels.cc | 7 +----- tests/export_back_compat_test.py | 4 +-- 3 files changed, 14 insertions(+), 22 deletions(-) diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_lu_pivots_to_permutation.py b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_lu_pivots_to_permutation.py index 12285a45b77a..8063d9f44722 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_lu_pivots_to_permutation.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_lu_pivots_to_permutation.py @@ -16,11 +16,11 @@ from numpy import array, int32 # Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_08_08 = dict( +data_2025_04_01 = dict( testdata_version=1, platform='cuda', custom_call_targets=['cu_lu_pivots_to_permutation'], - serialized_date=datetime.date(2024, 8, 8), + serialized_date=datetime.date(2025, 4, 1), inputs=(), expected_outputs=(array([[[0, 1, 2, 3, 4, 5, 6, 7], [4, 5, 6, 7, 0, 1, 2, 3], @@ -31,25 +31,22 @@ [0, 1, 2, 3, 4, 5, 6, 7]]], dtype=int32),), mlir_module_text=r""" module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<2x3x8xi32> {jax.result_info = "", mhlo.layout_mode = "default"}) { + func.func public @main() -> (tensor<2x3x8xi32> {jax.result_info = "result"}) { %0 = stablehlo.iota dim = 0 : tensor<24xi32> loc(#loc4) %1 = stablehlo.reshape %0 : (tensor<24xi32>) -> tensor<2x3x4xi32> loc(#loc5) - %c = stablehlo.constant dense<2> : tensor loc(#loc6) - %c_0 = stablehlo.constant dense<3> : tensor loc(#loc6) - %c_1 = stablehlo.constant dense<4> : tensor loc(#loc6) - %2 = stablehlo.custom_call @cu_lu_pivots_to_permutation(%1) {mhlo.backend_config = {permutation_size = 8 : i32}, operand_layouts = [dense<[2, 1, 0]> : tensor<3xindex>], result_layouts = [dense<[2, 1, 0]> : tensor<3xindex>]} : (tensor<2x3x4xi32>) -> tensor<2x3x8xi32> loc(#loc6) + %2 = stablehlo.custom_call @cu_lu_pivots_to_permutation(%1) {mhlo.backend_config = {}, mhlo.frontend_attributes = {num_batch_dims = "2"}, operand_layouts = [dense<[2, 1, 0]> : tensor<3xindex>], result_layouts = [dense<[2, 1, 0]> : tensor<3xindex>]} : (tensor<2x3x4xi32>) -> tensor<2x3x8xi32> loc(#loc6) return %2 : tensor<2x3x8xi32> loc(#loc) } loc(#loc) } loc(#loc) #loc = loc(unknown) -#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":347:26) -#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":347:14) -#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":348:11) -#loc4 = loc("jit()/jit(main)/iota[dtype=int32 shape=(24,) dimension=0]"(#loc1)) -#loc5 = loc("jit()/jit(main)/reshape[new_sizes=(2, 3, 4) dimensions=None]"(#loc2)) -#loc6 = loc("jit()/jit(main)/lu_pivots_to_permutation[permutation_size=8]"(#loc3)) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":408:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":408:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":409:11) +#loc4 = loc("jit()/jit(main)/iota"(#loc1)) +#loc5 = loc("jit()/jit(main)/reshape"(#loc2)) +#loc6 = loc("jit()/jit(main)/lu_pivots_to_permutation"(#loc3)) """, - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1d\x05\x01\x03\x01\x03\x05\x03\r\x07\t\x0b\r\x0f\x11\x03\xa7}\x17\x01Q\x0f\x07\x0b\x0b\x0f\x0b+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x17\x0f\x0b\x17\x13\x0b\x17\x13\x13S\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03-\x0b\x0b\x0f\x0b\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x0f///\x0b\x0b\x0b\x13\x0b\x0fo\x01\x05\x0b\x0f\x03\x13\x0f\x07\x1b\x07\x13\x13\x1b\x13\x07\x02Z\x04\x1d57\x1f\x05\x13\x05\x15\x11\x03\x05\x05\x17\x03\t\x0f\x11\x13\t\x15\t\x0b\x17\x05\x19\x11\x01\x00\x05\x1b\x05\x1d\x05\x1f\x03\x0b\x1bQ\x1dW\x1fY\x0bc!e\x05!\x05#\x05%\x05'\x03\x03%g\x05)\x1d)+\x05+\x17\x05n\x055\x1d/1\x05-\x17\x05n\x05\x1d\x03\x03\x07i\x05/\x17\x05r\x05\x17\x03\x03\x07k\x03\x03\x07m\x03\x13?oASCqEQGsIuKUMQOU\x051\x053\x055\x057\x059\x05;\x05=\x05?\x05A\x03\x01\x1dC\x03\x03{#\r\x03\x03[\r\x05]S_a\x1dE\x1dG\x1dI\x1dK\x1dM\x13\x0b\x01\x1f\x05\x11\x02\x00\x00\x00\x00\x00\x00\x00\x1f\x05\x11\x03\x00\x00\x00\x00\x00\x00\x00\x1f\x05\x11\x04\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1dO\x05\x01\r\x03wy\x1dQ\x13\x07!\x1f\x131\x02\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x01\x0b\x1b)\x07\t\r!\x07\x1d\x11\x01\x03\t)\x03a\x07)\x07\t\r\x11\x07)\x03\r\x15\x13\x04{\x05\x01\x11\x03\r\x07\x03\x01\x05\x05\x11\x03\x19\x07\x03\r\x1d\x07\x03'#\x03\x0f\t\x06-\x03\x11\x03\x01\x03\x03\x013\x03\x05\x03\x03\x019\x03\x05\x03\x03\x01;\x03\x05\x0b\x07\x01=\x03\t\x03\x03\r\x04\x03\x03\x0b\x06\x03\x01\x05\x01\x00f\x0cS#9\x0f\x0b\x11#!\x03\x1f/!)!)#\x1f\x19\x8b\x8b\x85\x1f\x1f\x15\x1d\x15\x1b%)9\x13\ri\x15\x1f\x17\x11\x11\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00func_v1\x00iota_v1\x00reshape_v1\x00custom_call_v1\x00return_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00value\x00sym_name\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit()/jit(main)/iota[dtype=int32 shape=(24,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(2, 3, 4) dimensions=None]\x00jit()/jit(main)/lu_pivots_to_permutation[permutation_size=8]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00\x00jax.result_info\x00mhlo.layout_mode\x00default\x00main\x00public\x00cu_lu_pivots_to_permutation\x00permutation_size\x00", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.9.3\x00\x01\x1d\x05\x01\x05\r\x01\x03\x0b\x03\x0b\x0f\x13\x17\x1b\x1f\x03yQ\x15\x01+\x07\x0b\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x17\x0f\x0b\x17\x1b\x0b\x0b\x0f\x0b\x17\x03'\x0b\x0f\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0f\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0bo\x01\x05\x0b\x0f\x03\x11\x07\x1b\x13\x13\x07\x1b\x13\x07\x02\x9e\x02\x1f\x05\x11\x11\x03\x05\x03\x07\t\x0b\r\x05\x0f\x05\x05\x13\x11\x01\x00\x05\x15\x05\x17\x05\x19\x1d\x15\x17\x05\x1b\x17\x03b\x065\x1d\x1b\x1d\x05\x1d\x17\x03b\x06\x1d\x03\x05!?#A\x05\x1f\x05!\x1d')\x05#\x17\x03f\x06\x17\x03\x01\x03\x03O#\t\x03\x033\r\x0357\x1d%\x1d'\x1d)\x1d+\x13\r\x01\r\x01\r\x03CE\x1d-\x1d/\x0b\x03\x1d1\x1d3\x05\x01\x1f\x111\x02\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02\x1b)\x07\t\r!\x05\x11\x01\x03\x07)\x03a\x05\x1d)\x07\t\r\x11\x05)\x03\r\x13\x13\x04c\x05\x01Q\x01\x07\x01\x07\x04Q\x03\x01\x05\x03P\x01\x03\x07\x04=\x03\x07\x11\x05B\x13\x05\x03\x0b\x07\x06\x19\x03\x0f\x03\x01\tG%\x1f\x07\x03\x07\x03\x03\x0b\x04\x01\x03\x05\x06\x03\x01\x05\x01\x00J\x0759\x03\x05\x1f\x0f\x0b\x0f!c3)A;\x1b%)9i\x15\x1f\x17\x11\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00iota_v1\x00reshape_v1\x00custom_call_v1\x00return_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/iota\x00jit()/jit(main)/reshape\x00mhlo.backend_config\x00mhlo.frontend_attributes\x00jit()/jit(main)/lu_pivots_to_permutation\x00jax.result_info\x00result\x00main\x00public\x00num_batch_dims\x002\x00\x00cu_lu_pivots_to_permutation\x00\x08+\t\x05#\x01\x0b+/19;\x03=\x11GIK+M-+-", xla_call_module_version=9, nr_devices=1, ) # End paste diff --git a/jaxlib/gpu/linalg_kernels.cc b/jaxlib/gpu/linalg_kernels.cc index 2293bef89b7d..b48e64f2181d 100644 --- a/jaxlib/gpu/linalg_kernels.cc +++ b/jaxlib/gpu/linalg_kernels.cc @@ -90,8 +90,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(CholeskyUpdateFfi, CholeskyUpdateFfiImpl, namespace { ffi::Error LuPivotsToPermutationImpl( - gpuStream_t stream, ffi::Dictionary /* unused */, - ffi::Buffer pivots, + gpuStream_t stream, ffi::Buffer pivots, ffi::Result> permutation) { FFI_ASSIGN_OR_RETURN((auto [batch_size, pivot_size]), SplitBatch1D(pivots.dimensions())); @@ -119,10 +118,6 @@ ffi::Error LuPivotsToPermutationImpl( XLA_FFI_DEFINE_HANDLER_SYMBOL(LuPivotsToPermutation, LuPivotsToPermutationImpl, ffi::Ffi::Bind() .Ctx>() - // TODO(b/358275922): remove Attrs (and the - // unused Dictionary above) 12 weeks after - // release of jaxlib v0.4.32. - .Attrs() .Arg>() .Ret>()); diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py index 9b457b8f27a5..6a6c8c213a64 100644 --- a/tests/export_back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -140,7 +140,7 @@ def test_custom_call_coverage(self): cpu_qr_lapack_geqrf.data_2023_03_17, cuda_threefry2x32.data_2024_07_30, cpu_lu_lapack_getrf.data_2023_06_14, - cuda_lu_pivots_to_permutation.data_2024_08_08, + cuda_lu_pivots_to_permutation.data_2025_04_01, cuda_lu_cusolver_getrf.data_2024_08_19, cuda_qr_cusolver_geqrf.data_2024_09_26, cuda_eigh_cusolver_syev.data_2024_09_30, @@ -411,7 +411,7 @@ def lu_pivots_to_permutation_harness(shape): def test_cuda_lu_pivots_to_permutation(self): shape = (2, 3, 4) func = lambda: CompatTest.lu_pivots_to_permutation_harness(shape) - data = self.load_testdata(cuda_lu_pivots_to_permutation.data_2024_08_08) + data = self.load_testdata(cuda_lu_pivots_to_permutation.data_2025_04_01) self.run_one_test(func, data) @parameterized.named_parameters( From 297a4f42dec42e0db08270cbdf436c993586445c Mon Sep 17 00:00:00 2001 From: Olli Lupton Date: Tue, 1 Apr 2025 09:59:38 +0000 Subject: [PATCH 330/483] docs: compilation_cache_expect_pgle option --- docs/gpu_performance_tips.md | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/docs/gpu_performance_tips.md b/docs/gpu_performance_tips.md index 737486485736..b3643cb8e292 100644 --- a/docs/gpu_performance_tips.md +++ b/docs/gpu_performance_tips.md @@ -71,6 +71,10 @@ JAX will collect profile information and recompile a module in a single run. Whi in manual mode you need to run a task twice, the first time to collect and save profiles and the second to compile and run with provided data. +**Important**: the JAX profiler, which is used by both of the PGLE workflows documented +below, cannot co-exist with the NVIDIA Nsight Systems profiler. This limitation can be +avoided by using the JAX compilation cache, as described below. + ### Auto PGLE The auto PGLE can be turned on by setting the following environment variables: @@ -129,6 +133,28 @@ with config.enable_pgle(True), config.pgle_profiling_runs(1): train_step_compiled() ``` +#### Collecting NVIDIA Nsight Systems profiles when using AutoPGLE +[jax#24910](https://github.com/jax-ml/jax/pull/24910) (JAX v0.5.1 and newer) added a +new JAX configuration option, `JAX_COMPILATION_CACHE_EXPECT_PGLE`, which tells JAX to +attempt to load PGLE-optimized compiled functions from the persistent compilation +cache. + +This allows a two-step process, where the first step writes a PGLE-optimized function +to the cache: +```bash +export JAX_ENABLE_COMPILATION_CACHE=yes # not strictly needed, on by default +export JAX_COMPILATION_CACHE_DIR=/root/jax_cache +JAX_ENABLE_PGLE=yes python my-model.py +``` +And the second step uses Nsight Systems and loads the PGLE-optimized function from the +cache: +```bash +JAX_COMPILATION_CACHE_EXPECT_PGLE=yes nsys profile python my-model.py +``` +See also [this page]( +https://docs.jax.dev/en/latest/persistent_compilation_cache.html#pitfalls) for more +information about the persistent compilation cache and possible pitfalls. + ### Manual PGLE If you still want to use a manual Profile Guided Latency Estimator the workflow in XLA/GPU is: From c18139ba7b511aff24c83c44aae5d9e1e0a5e014 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Wed, 2 Apr 2025 08:09:13 -0700 Subject: [PATCH 331/483] Remove legacy GPU kernels for QR decomposition. Following the compatibility timeline described here: https://docs.jax.dev/en/latest/export/export.html#ensuring-forward-and-backward-compatibility On Apr 2, it will have been 6 months since the release of 0.4.34 which is the relevant release for this kernels. PiperOrigin-RevId: 743142261 --- .../cuda_qr_cusolver_geqrf.py | 141 +------------- .../rocm_qr_hipsolver_geqrf.py | 176 ------------------ jaxlib/cuda/BUILD | 52 ------ jaxlib/gpu/BUILD | 3 - jaxlib/gpu/blas.cc | 75 -------- jaxlib/gpu/blas_kernels.cc | 138 -------------- jaxlib/gpu/blas_kernels.h | 48 ----- jaxlib/gpu/gpu_kernels.cc | 5 - jaxlib/gpu/solver.cc | 86 --------- jaxlib/gpu/solver_kernels.cc | 172 ----------------- jaxlib/gpu/solver_kernels.h | 20 -- jaxlib/gpu_solver.py | 6 - jaxlib/rocm/BUILD | 49 ----- jaxlib/tools/build_gpu_kernels_wheel.py | 2 - tests/export_back_compat_test.py | 25 --- 15 files changed, 1 insertion(+), 997 deletions(-) delete mode 100644 jax/_src/internal_test_util/export_back_compat_test_data/rocm_qr_hipsolver_geqrf.py delete mode 100644 jaxlib/gpu/blas.cc delete mode 100644 jaxlib/gpu/blas_kernels.cc delete mode 100644 jaxlib/gpu/blas_kernels.h diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_qr_cusolver_geqrf.py b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_qr_cusolver_geqrf.py index be5c6e01f8d8..00ced41a0492 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_qr_cusolver_geqrf.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_qr_cusolver_geqrf.py @@ -15,149 +15,10 @@ # ruff: noqa import datetime -from numpy import array, float32, float64, complex64, complex128 +from numpy import array, float32, complex64 -data_2023_03_18 = {} - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_03_18["unbatched"] = dict( - testdata_version=1, - platform='cuda', - custom_call_targets=['cusolver_geqrf', 'cusolver_orgqr'], - serialized_date=datetime.date(2023, 3, 18), - inputs=(), - expected_outputs=(array([[ 0. , 0.9128705 , 0.40824863], - [-0.44721356, 0.36514878, -0.8164964 ], - [-0.8944271 , -0.18257457, 0.40824813]], dtype=float32), array([[-6.7082043e+00, -8.0498438e+00, -9.3914843e+00], - [ 0.0000000e+00, 1.0954436e+00, 2.1908882e+00], - [ 0.0000000e+00, 0.0000000e+00, 5.6703755e-08]], dtype=float32)), - mlir_module_text=r""" -module @jit__lambda_ { - func.func public @main() -> (tensor<3x3xf32> {jax.result_info = "[0]"}, tensor<3x3xf32> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<9xf32> - %1 = stablehlo.reshape %0 : (tensor<9xf32>) -> tensor<3x3xf32> - %2 = stablehlo.custom_call @cusolver_geqrf(%1) {api_version = 2 : i32, backend_config = "\00\00\00\00\01\00\00\00\03\00\00\00\03\00\00\00\00\00\03\00", operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xf32>) -> tuple, tensor<3xf32>, tensor, tensor<196608xf32>> - %3 = stablehlo.get_tuple_element %2[0] : (tuple, tensor<3xf32>, tensor, tensor<196608xf32>>) -> tensor<3x3xf32> - %4 = stablehlo.get_tuple_element %2[1] : (tuple, tensor<3xf32>, tensor, tensor<196608xf32>>) -> tensor<3xf32> - %5 = stablehlo.get_tuple_element %2[2] : (tuple, tensor<3xf32>, tensor, tensor<196608xf32>>) -> tensor - %6 = stablehlo.get_tuple_element %2[3] : (tuple, tensor<3xf32>, tensor, tensor<196608xf32>>) -> tensor<196608xf32> - %7 = stablehlo.constant dense<0> : tensor - %8 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor - %9 = stablehlo.compare EQ, %5, %8, SIGNED : (tensor, tensor) -> tensor - %10 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1x1xi1> - %11 = stablehlo.constant dense<0x7FC00000> : tensor - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor) -> tensor<3x3xf32> - %13 = stablehlo.broadcast_in_dim %10, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> - %14 = stablehlo.select %13, %3, %12 : tensor<3x3xi1>, tensor<3x3xf32> - %15 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1xi1> - %16 = stablehlo.constant dense<0x7FC00000> : tensor - %17 = stablehlo.broadcast_in_dim %16, dims = [] : (tensor) -> tensor<3xf32> - %18 = stablehlo.broadcast_in_dim %15, dims = [0] : (tensor<1xi1>) -> tensor<3xi1> - %19 = stablehlo.select %18, %4, %17 : tensor<3xi1>, tensor<3xf32> - %20 = stablehlo.constant dense<0.000000e+00> : tensor - %21 = stablehlo.pad %14, %20, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xf32>, tensor) -> tensor<3x3xf32> - %22 = stablehlo.custom_call @cusolver_orgqr(%21, %19) {api_version = 2 : i32, backend_config = "\00\00\00\00\01\00\00\00\03\00\00\00\03\00\00\00\03\00\00\00 \81\00\00", operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xf32>, tensor<3xf32>) -> tuple, tensor, tensor<33056xf32>> - %23 = stablehlo.get_tuple_element %22[0] : (tuple, tensor, tensor<33056xf32>>) -> tensor<3x3xf32> - %24 = stablehlo.get_tuple_element %22[1] : (tuple, tensor, tensor<33056xf32>>) -> tensor - %25 = stablehlo.get_tuple_element %22[2] : (tuple, tensor, tensor<33056xf32>>) -> tensor<33056xf32> - %26 = stablehlo.constant dense<0> : tensor - %27 = stablehlo.broadcast_in_dim %26, dims = [] : (tensor) -> tensor - %28 = stablehlo.compare EQ, %24, %27, SIGNED : (tensor, tensor) -> tensor - %29 = stablehlo.broadcast_in_dim %28, dims = [] : (tensor) -> tensor<1x1xi1> - %30 = stablehlo.constant dense<0x7FC00000> : tensor - %31 = stablehlo.broadcast_in_dim %30, dims = [] : (tensor) -> tensor<3x3xf32> - %32 = stablehlo.broadcast_in_dim %29, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> - %33 = stablehlo.select %32, %23, %31 : tensor<3x3xi1>, tensor<3x3xf32> - %34 = call @triu(%14) : (tensor<3x3xf32>) -> tensor<3x3xf32> - return %33, %34 : tensor<3x3xf32>, tensor<3x3xf32> - } - func.func private @triu(%arg0: tensor<3x3xf32>) -> tensor<3x3xf32> { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> - %1 = stablehlo.constant dense<-1> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<3x3xi32> - %3 = stablehlo.add %0, %2 : tensor<3x3xi32> - %4 = stablehlo.iota dim = 1 : tensor<3x3xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> - %6 = stablehlo.constant dense<0.000000e+00> : tensor - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<3x3xf32> - %8 = stablehlo.select %5, %7, %arg0 : tensor<3x3xi1>, tensor<3x3xf32> - return %8 : tensor<3x3xf32> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01+\x05\x01\x05\x01\x03\x05\x03\x1b\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f\x03~\x02\xf79\x01\x99\x0f\x0f\x17\x13\x0f\x07\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x13\x17\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x1b\x13\x13\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0bK\x0b\x13\x13\x0f\x0b#\x0b\x0b\x0b\x0f\x0bK\x0b\x13\x0b\x03_O/\x0b/\x0b\x0f\x0b\x0b\x0b\x0b\x0f\x0f\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b\x1f\x0b\x0b\x0f\x17\x1b\x0f\x0f\x0f\x0f\x1f\x0b\x1fO/\x0b\x0b\x13\x17\x039\x17\x0f\x0f\x07\x07\x07\x07\x17\x13\x17\x07\x1b\x0f\x17\x13\x1b\x17\x17\x13\x13\x1b\x13\x13\x13\x13\x13\x13\x17\x02\x06\t\x1d{\x05\x1d\x93\x05\x17\x1f\n\x06\x01\x03\x03\x13\xcb\x1dS\x05\x1f\x05!\x05#\x05%\x05'\x03\x03\r\xe9\x05)\x05+\x05-\x05/\x051\x03\x03#\xc7\x053\x1d[\x05\x055\x057\x03\x03\r\xd1\x17\x1f\x06\x06\x01\x059\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x03\x03\x0f\xdd\x03\x03\x0f\xdf\x03\x03\x0f\xe1\x03\x03\r\xe5\x03\x05'\xa7)\xe7\x03\x03\x13\xeb\x03\x03\x11M\x05I\x03\x0b\x17\x9d\x19\xb1\x1b\xb3\x11\xbd\x1d\xbf\x03\x0b\x17\xa3\x19\xc3\x1b\xa3\x11\xa5\x1d\xc5\x05K\x1dW\x05\x05M\x03\x03\r\xc9\x05O\x03\x03#\xcd\x1da\x05\x05Q\x03\x05'\xa7)\xcf\x1dg\x05\x05S\x1dk\x05\x05U\x1do\x05\x05W\x1ds-\x05Y\x1dw-\x05[\x03\x11/\xa91\xd33\xd55\x9d7\xab9\xd7;\xad=\xdb\x05]\x03\x03\x0f\xe3\x03\x03\x13\xed\x1d\x83\x05\x05_\x03\x07\x87\x9f\x89\x9f\x8b\x9f\x05a\x05c\x05e\x1d\x8f\x05\x05g\x03\x11/\xa91\xef3\xf15\x9d7\xab9\xf3;\xad=\xf5\x05i\x03\x03\x97\xa5\x05k\x1f+!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f-\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1dm\x03\x03\xc1\x1do\t\x07\x0b\x05\x05\x01\x03\x03\xd9\x1f/\x01#!\x03\x05\xb5\xb9\r\x03\xa1\xb7\x1dq\r\x03\xa1\xbb\x1ds\x1du\x1dw\r\x01##\x1dy\x13\x0b\x01\x1f\x03\t\xff\xff\xff\xff\x1f%\x01\x13\x0b\x05\x07\x05\x1f\x05\t\x00\x00\x00\x00\x1d{\x1d}\x03\x03\x99\x15\x03\x01\x01\x01\x03\t\x99\x9b\xaf\x9b\x13\t\x01\x13\t\x05\x13\t\t\x13\t\r\x1f\x03\t\x00\x00\x00\x00\x07\x01\x1f\x05\t\x00\x00\xc0\x7f\x1f\x1d!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d\x7f\x1d\x81\x03\x05\x99\x9b\x03\x07\x99\xaf\x9b)\x05\r\r\x07)\x01\t)\x01\x07\t\x1b\x1d\x01)\x05\r\r\t)\x03\r\x07)\x05\r\r\r\x13)\x03\x04\x000\x07)\x01\r)\x05\x05\x05\r)\x03\t\x0b)\x03\x04\x12\x08\x07\x11\x01\x05\x01\x01\x11\x03\x01\x03\x01)\x03\x01\x0b)\x03%\x07/\t\x01\x11\x03\x17)\x03\t\x15)\x03\x05\x15)\x03\x01\x15)\x03\x05\r)\x03\r\r)\x03\x05\x0b/\x07\x01\x03\x1f\x04\xe6\x05\x05\x01\x11\x0bK\x07\x03\x01\t\x0f\x11\x0bO\x05\x03G\x91\x0b\x03q!\x03'\x17\x06u\x03\x01\x03\x01\x13\x07\x01y\x03)\x03\x03\x07\x07\x01?\x03\x01\x03\x05\x07\x07\x01A\x03\x11\x03\x05\x07\x07\x01C\x03\x03\x03\x05\x07\x07\x01}\x03\x17\x03\x05\x05\x03\x01E\x03\x03\x03\x07\x01\x07\x03\x03\x03\x0f\r\x07\x01G\x03\x19\x05\x0b\x11\x03\x07\x01\x07\x03\x1b\x03\x13\x05\x03\x01\x15\x03\x05\x03\x07\x01\x07\x03\x01\x03\x17\x03\x07\x01I\x03\x13\x03\x15\t\x06\x01\x03\x01\x07\x1b\x07\x19\x03\x07\x01\x07\x031\x03\x13\x05\x03\x01\x15\x03\x05\x03\x07\x01\x07\x03\x11\x03!\x03\x07\x01\x7f\x033\x03\x1f\t\x06\x01\x03\x11\x07%\t#\x05\x03\x81+\x03\x05\x19\x07\x8d\x85\x03\x01\x05\x1d)\x13\x07\x03\x91\x037\x05+'\x07\x07\x03?\x03\x01\x03-\x07\x07\x03A\x03\x03\x03-\x07\x07\x03C\x03\x1f\x03-\x05\x03\x03E\x03\x03\x03\x07\x03\x07\x03\x03\x035\r\x07\x03G\x03\x19\x0517\x03\x07\x03\x07\x03\x1b\x039\x05\x03\x03\x15\x03\x05\x03\x07\x03\x07\x03\x01\x03=\x03\x07\x03I\x03\x13\x03;\t\x06\x03\x03\x01\x07A/?\x1b\x07\t\x95\x03\x01\x03\x1d\x11\x04\x0b\x05CE\x0f\x11\tQ\x05\x03\x15+\x03\x01\x0b\x0b\x03U!\x03\x0f\x05\x03\tY\x03\x03\x03\x07%\x07\x03\x0f\x03\x05\x15\x06%\x03\x0f\x05\x03\x07\x0b\x03_]\x03\x0f\r\x07ec\x03\x13\x05\t\x0b\x05\x03\t+\x03\x05\x03\x07i\x07\x03\x01\x03\x0f\t\x06m\x03\x01\x07\r\x11\x01\x11\x04\t\x03\x13\x06\x03\x01\x05\x01\x00\x86\x19\x83\x1f3\x1f+\x11\x0f\x0b\t\t\x0b!\x0fY\x87##%_=\x85\x87W\xb3K\x9bM\x9b\xd2\x02\x1b\x1f/!!)#\x1f\x19+\x1b\x1f\x83\x1f\x15\x1d\x15+\x13\r\r\x11\x0f\x17\x0f\x1f\x15\x11\x17\x11\x15+\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00get_tuple_element_v1\x00select_v1\x00iota_v1\x00compare_v1\x00func_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00index\x00sym_name\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00iota_dimension\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=float32 shape=(9,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]\x00jit()/jit(main)/householder_product\x00callee\x00jax.result_info\x00triu\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00\x00\x00\x00\x01\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x00\x00\x03\x00\x00cusolver_geqrf\x00\x00\x00\x00\x00\x01\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00 \x81\x00\x00\x00cusolver_orgqr\x00", - xla_call_module_version=4, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_03_18["batched"] = dict( - testdata_version=1, - platform='cuda', - custom_call_targets=['cublas_geqrf_batched', 'cusolver_orgqr'], - serialized_date=datetime.date(2023, 3, 18), - inputs=(), - expected_outputs=(array([[[ 0. , 0.91287094, 0.40824836], - [-0.4472136 , 0.36514843, -0.81649655], - [-0.8944272 , -0.18257417, 0.4082483 ]], - - [[-0.42426407, 0.80828977, 0.40824953], - [-0.5656854 , 0.11547142, -0.8164964 ], - [-0.7071068 , -0.5773508 , 0.4082474 ]]], dtype=float32), array([[[-6.7082038e+00, -8.0498447e+00, -9.3914852e+00], - [ 0.0000000e+00, 1.0954450e+00, 2.1908898e+00], - [ 0.0000000e+00, 0.0000000e+00, 4.8374091e-08]], - - [[-2.1213203e+01, -2.2910259e+01, -2.4607319e+01], - [ 0.0000000e+00, 3.4641042e-01, 6.9282258e-01], - [ 0.0000000e+00, 0.0000000e+00, 1.4548683e-06]]], dtype=float32)), - mlir_module_text=r""" -module @jit__lambda_ { - func.func public @main() -> (tensor<2x3x3xf32> {jax.result_info = "[0]"}, tensor<2x3x3xf32> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<18xf32> - %1 = stablehlo.reshape %0 : (tensor<18xf32>) -> tensor<2x3x3xf32> - %2 = stablehlo.custom_call @cublas_geqrf_batched(%1) {api_version = 2 : i32, backend_config = "\00\00\00\00\02\00\00\00\03\00\00\00\03\00\00\00", operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x3x3xf32>) -> tuple, tensor<2x3xf32>, tensor<16xi8>, tensor<16xi8>> - %3 = stablehlo.get_tuple_element %2[0] : (tuple, tensor<2x3xf32>, tensor<16xi8>, tensor<16xi8>>) -> tensor<2x3x3xf32> - %4 = stablehlo.get_tuple_element %2[1] : (tuple, tensor<2x3xf32>, tensor<16xi8>, tensor<16xi8>>) -> tensor<2x3xf32> - %5 = stablehlo.get_tuple_element %2[2] : (tuple, tensor<2x3xf32>, tensor<16xi8>, tensor<16xi8>>) -> tensor<16xi8> - %6 = stablehlo.get_tuple_element %2[3] : (tuple, tensor<2x3xf32>, tensor<16xi8>, tensor<16xi8>>) -> tensor<16xi8> - %7 = stablehlo.constant dense<0.000000e+00> : tensor - %8 = stablehlo.pad %3, %7, low = [0, 0, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<2x3x3xf32>, tensor) -> tensor<2x3x3xf32> - %9 = stablehlo.custom_call @cusolver_orgqr(%8, %4) {api_version = 2 : i32, backend_config = "\00\00\00\00\02\00\00\00\03\00\00\00\03\00\00\00\03\00\00\00 \81\00\00", operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x3x3xf32>, tensor<2x3xf32>) -> tuple, tensor<2xi32>, tensor<33056xf32>> - %10 = stablehlo.get_tuple_element %9[0] : (tuple, tensor<2xi32>, tensor<33056xf32>>) -> tensor<2x3x3xf32> - %11 = stablehlo.get_tuple_element %9[1] : (tuple, tensor<2xi32>, tensor<33056xf32>>) -> tensor<2xi32> - %12 = stablehlo.get_tuple_element %9[2] : (tuple, tensor<2xi32>, tensor<33056xf32>>) -> tensor<33056xf32> - %13 = stablehlo.constant dense<0> : tensor - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<2xi32> - %15 = stablehlo.compare EQ, %11, %14, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - %16 = stablehlo.broadcast_in_dim %15, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> - %17 = stablehlo.constant dense<0x7FC00000> : tensor - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor) -> tensor<2x3x3xf32> - %19 = stablehlo.broadcast_in_dim %16, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x3x3xi1> - %20 = stablehlo.select %19, %10, %18 : tensor<2x3x3xi1>, tensor<2x3x3xf32> - %21 = call @triu(%3) : (tensor<2x3x3xf32>) -> tensor<2x3x3xf32> - return %20, %21 : tensor<2x3x3xf32>, tensor<2x3x3xf32> - } - func.func private @triu(%arg0: tensor<2x3x3xf32>) -> tensor<2x3x3xf32> { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> - %1 = stablehlo.constant dense<-1> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<3x3xi32> - %3 = stablehlo.add %0, %2 : tensor<3x3xi32> - %4 = stablehlo.iota dim = 1 : tensor<3x3xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> - %6 = stablehlo.broadcast_in_dim %5, dims = [1, 2] : (tensor<3x3xi1>) -> tensor<2x3x3xi1> - %7 = stablehlo.constant dense<0.000000e+00> : tensor - %8 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor<2x3x3xf32> - %9 = stablehlo.select %6, %8, %arg0 : tensor<2x3x3xi1>, tensor<2x3x3xf32> - return %9 : tensor<2x3x3xf32> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01+\x05\x01\x05\x01\x03\x05\x03\x1b\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f\x03\x96\x02\xff=\x01\x9f\x17\x0f\x0f\x0f\x07\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x13\x17\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0bK\x0b\x13\x0f\x0b#\x0b\x0b\x0b\x0f\x0bK\x0b\x13\x1b\x13\x13\x13\x13\x0b\x03ao/\x0b/\x0b\x0f\x0b\x0b\x0b\x0b\x0fO\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0bO\x1f\x0b\x0b\x0f\x17\x1b\x0f\x0f\x0f\x0f\x0b\x0b\x13\x17\x1f\x0b/\x1fo\x03=\x1b\x07\x07\x07\x0f\x17\x0f\x07\x13\x07\x13\x1b\x17\x13\x1b\x17\x17\x13\x17\x13\x13\x1b\x07\x13\x13\x13\x17\x13\x1b\x13\x02\x1a\n\x17\x1d\n\x06\x01\x1d\x8f\x01\x1dK\x01\x1dy\x01\x1f\x05!\x03\x03\x0f\xd1\x05#\x05%\x05'\x05)\x05+\x05-\x05/\x051\x03\x03!\xcd\x053\x1dS\x01\x055\x057\x03\x03\x0b\xd9\x17\x1d\x06\x06\x01\x059\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x03\x03\x11\xe5\x03\x03\x11\xe7\x03\x03\x11\xe9\x03\x03\x13E\x05I\x03\x0b\x15\xa3\x17\xb7\x19\xb9\x13\xc3\x1b\xc5\x03\x0b\x15\xa9\x17\xc9\x19\xa9\x13\xab\x1b\xcb\x05K\x1dO\x01\x05M\x03\x03\x0b\xcf\x05O\x03\x03!\xd3\x1dY\x01\x05Q\x03\x05%\xad'\xd5\x1d_\x01\x05S\x03\x03\x0f\xd7\x1de\x01\x05U\x1di\x01\x05W\x1dm\x01\x05Y\x1dq+\x05[\x1du+\x05]\x03\x11-\xaf/\xdb1\xdd3\xa35\xb17\xdf9\xb3;\xe3\x05_\x03\x03\x11\xeb\x1d\x7f\x01\x05a\x03\x07\x83\xa5\x85\xa5\x87\xa5\x05c\x05e\x05g\x1d\x8b\x01\x05i\x03\x11-\xaf/\xed1\xef3\xa35\xb17\xf19\xb3;\xf3\x05k\x03\x03\x0b\xf5\x03\x05%\xad'\xf7\x03\x03\x0f\xf9\x03\x03\x0b\xfb\x03\x03\x0f\xfd\x03\x03\x9d\xab\x05m\x1f/1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1f\x1b\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1do\x03\x03\xc7\x1dq\t\x07\x0b\x05\x05\x01\x03\x03\xe1\x1f1!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00#\x1f\x03\x05\xbb\xbf\r\x03\xa7\xbd\x1ds\r\x03\xa7\xc1\x1du\x1dw\x1dy\r\x01#!\x1d{\x13\x05\x01\x1f\r\t\xff\xff\xff\xff\x1f#\x01\x13\x05\x05\x07\x05\x1f'!\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f\t\t\x00\x00\x00\x00\x1d}\x1d\x7f\x03\x03\x9f\x15\x03\x01\x01\x01\x03\t\x9f\xb5\xa1\xa1\x13\x03\x01\x13\x03\x05\x13\x03\t\x13\x03\r\x1d\x81\x1d\x83\x03\x05\x9f\xb5\x03\x07\x9f\xa1\xa1\x1f\r\t\x00\x00\x00\x00\x07\x01\x1f;\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\t\t\x00\x00\xc0\x7f\x1f\x1b1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00)\x07\t\r\r\x07\x1b\x1d\t)\x01\x07)\x05\r\r\x03)\x01\x03\x01)\x03A-\x13)\x03\t\x03)\x07\t\r\r\x0f)\x05\t\r\x07)\x03\r\x05)\x03\x04\x12\x08\x07\x11\x01\x05\x01\x01\x11\x03\x01\x03\x01)\x03\x01\x05)\x05\r\r\x0f)\x03\t\x05)\x03I\x07/\t\x01\x19\x11\x11\x17)\x03\r\x13)\x03\t\x13)\x03\x05\x13/\x07\x01\x15\x1d)\x03\t\x0f)\x07\t\x05\x05\x0f)\x03\x05\x05\x04r\x04\x05\x01\x11\tC\x07\x03\x01\t\x0b\x11\tG\x05\x03-]\t\x03o\x1f\x03)\x17\x06s\x03\x01\x03\x01\x13\x07\x07w\x03+\x03\x03\x05\x07\x07=\x03\x01\x03\x05\x05\x07\x07?\x03\x19\x03\x05\x05\x07\x07A\x03\x11\x03\x05\x05\x07\x07{\x03\x11\x03\x05\x07\x03})\x03\t\x19\x07\x89\x81\x03\x01\x05\x07\x0f\x13\x07\x03\x8d\x035\x05\x11\t\x05\x07\x03=\x03\x01\x03\x13\x05\x07\x03?\x03\x15\x03\x13\x05\x07\x03A\x03\x1d\x03\x13\x07\x03\x03\x91\x03\r\x03\x07\x03\r\x03\x15\x03\x1b\r\x07\x03\x93\x037\x05\x17\x1d\x03\x07\x03\x95\x039\x03\x1f\x07\x03\x03\x97\x03\t\x03\x07\x03\r\x03\x01\x03#\x03\x07\x03\x99\x03\x17\x03!\x0f\x06\x03\x03\x01\x07'\x15%\x1b\x07\x05\x9b\x03\x01\x03\x07\x11\x04\t\x05)+\x0b\x11\x05I\x05\x03\x17/\x03\x01\t\t\x03M\x1f\x03\x0b\x07\x03\x05Q\x03\r\x03\x07#\r\x03\x0b\x03\x05\x15\x06#\x03\x0b\x05\x03\x07\t\x03WU\x03\x0b\r\x07][\x03%\x05\t\x0b\x03\x07ca\x03\x17\x03\r\x07\x03\x05)\x03\t\x03\x07g\r\x03\x01\x03\x11\x0f\x06k\x03\x01\x07\x0f\x13\x01\x11\x04\x05\x03\x15\x06\x03\x01\x05\x01\x00Z\x1b\x85\x1f3+#\x11\x0f\x0b\t\t\x0b!\x0fY\x9d##%_=\x8b\x89W\xb9\xc1K\x9bM\x9b\xd2\x02\x1b\x1f/!!)#\x1f\x19+\x1b\x1f\x83\x1f\x15\x1d\x15\x13\r+\r\x11\x0f\x17\x0f\x1f\x15\x15\x17\x11\x11\x19+)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00get_tuple_element_v1\x00constant_v1\x00iota_v1\x00func_v1\x00compare_v1\x00select_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00broadcast_dimensions\x00index\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00iota_dimension\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(2, 3, 3) broadcast_dimensions=(1, 2)]\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(2, 3, 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=float32 shape=(18,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(2, 3, 3) dimensions=None]\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0), (0, 0, 0))]\x00jit()/jit(main)/householder_product\x00callee\x00jax.result_info\x00triu\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x00cublas_geqrf_batched\x00\x00\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00 \x81\x00\x00\x00cusolver_orgqr\x00", - xla_call_module_version=4, -) # End paste data_2024_09_26 = {} - data_2024_09_26["f32"] = dict( testdata_version=1, platform='cuda', diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/rocm_qr_hipsolver_geqrf.py b/jax/_src/internal_test_util/export_back_compat_test_data/rocm_qr_hipsolver_geqrf.py deleted file mode 100644 index bd5fa628741e..000000000000 --- a/jax/_src/internal_test_util/export_back_compat_test_data/rocm_qr_hipsolver_geqrf.py +++ /dev/null @@ -1,176 +0,0 @@ -# Copyright 2023 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import datetime -from numpy import array, float32 - -data_2024_08_05 = {} - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_08_05["unbatched"] = dict( - testdata_version=1, - platform='rocm', - custom_call_targets=['hipsolver_geqrf', 'hipsolver_orgqr'], - serialized_date=datetime.date(2024, 8, 5), - inputs=(), - expected_outputs=(array([[ 0. , 0.9128709 , 0.40824834], - [-0.4472136 , 0.3651484 , -0.81649655], - [-0.8944272 , -0.18257423, 0.40824828]], dtype=float32), array([[-6.7082038e+00, -8.0498447e+00, -9.3914852e+00], - [ 0.0000000e+00, 1.0954450e+00, 2.1908898e+00], - [ 0.0000000e+00, 0.0000000e+00, 1.6371473e-09]], dtype=float32)), - mlir_module_text=r""" -#loc2 = loc("/release/jax/tests/export_back_compat_test.py":346:0) -#loc9 = loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc2)) -module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<3x3xf32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<3x3xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { - %0 = stablehlo.iota dim = 0 : tensor<9xf32> loc(#loc3) - %1 = stablehlo.reshape %0 : (tensor<9xf32>) -> tensor<3x3xf32> loc(#loc4) - %2:4 = stablehlo.custom_call @hipsolver_geqrf(%1) {api_version = 2 : i32, backend_config = "\00\00\00\00\01\00\00\00\03\00\00\00\03\00\00\00\00\01\00\00", operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xf32>) -> (tensor<3x3xf32>, tensor<3xf32>, tensor, tensor<256xf32>) loc(#loc5) - %c = stablehlo.constant dense<0> : tensor loc(#loc5) - %3 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor loc(#loc5) - %4 = stablehlo.compare EQ, %2#2, %3, SIGNED : (tensor, tensor) -> tensor loc(#loc5) - %5 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %cst = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc5) - %6 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<3x3xf32> loc(#loc5) - %7 = stablehlo.broadcast_in_dim %5, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc5) - %8 = stablehlo.select %7, %2#0, %6 : tensor<3x3xi1>, tensor<3x3xf32> loc(#loc5) - %9 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<1xi1> loc(#loc5) - %cst_0 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc5) - %10 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor) -> tensor<3xf32> loc(#loc5) - %11 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<1xi1>) -> tensor<3xi1> loc(#loc5) - %12 = stablehlo.select %11, %2#1, %10 : tensor<3xi1>, tensor<3xf32> loc(#loc5) - %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc6) - %13 = stablehlo.pad %8, %cst_1, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xf32>, tensor) -> tensor<3x3xf32> loc(#loc7) - %14:3 = stablehlo.custom_call @hipsolver_orgqr(%13, %12) {api_version = 2 : i32, backend_config = "\00\00\00\00\01\00\00\00\03\00\00\00\03\00\00\00\03\00\00\00\80\00\00\00", operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xf32>, tensor<3xf32>) -> (tensor<3x3xf32>, tensor, tensor<128xf32>) loc(#loc8) - %c_2 = stablehlo.constant dense<0> : tensor loc(#loc8) - %15 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor) -> tensor loc(#loc8) - %16 = stablehlo.compare EQ, %14#1, %15, SIGNED : (tensor, tensor) -> tensor loc(#loc8) - %17 = stablehlo.broadcast_in_dim %16, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc8) - %cst_3 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc8) - %18 = stablehlo.broadcast_in_dim %cst_3, dims = [] : (tensor) -> tensor<3x3xf32> loc(#loc8) - %19 = stablehlo.broadcast_in_dim %17, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc8) - %20 = stablehlo.select %19, %14#0, %18 : tensor<3x3xi1>, tensor<3x3xf32> loc(#loc8) - %21 = call @triu(%8) : (tensor<3x3xf32>) -> tensor<3x3xf32> loc(#loc9) - return %20, %21 : tensor<3x3xf32>, tensor<3x3xf32> loc(#loc) - } loc(#loc) - func.func private @triu(%arg0: tensor<3x3xf32> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc2))) -> (tensor<3x3xf32> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc10) - %c = stablehlo.constant dense<-1> : tensor loc(#loc9) - %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc11) - %2 = stablehlo.add %0, %1 : tensor<3x3xi32> loc(#loc11) - %3 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc12) - %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc13) - %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc9) - %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<3x3xf32> loc(#loc14) - %6 = stablehlo.select %4, %5, %arg0 : tensor<3x3xi1>, tensor<3x3xf32> loc(#loc15) - return %6 : tensor<3x3xf32> loc(#loc9) - } loc(#loc9) -} loc(#loc) -#loc = loc(unknown) -#loc1 = loc("/release/jax/tests/export_back_compat_test.py":345:0) -#loc3 = loc("jit()/jit(main)/iota[dtype=float32 shape=(9,) dimension=0]"(#loc1)) -#loc4 = loc("jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]"(#loc1)) -#loc5 = loc("jit()/jit(main)/geqrf"(#loc2)) -#loc6 = loc("jit()/jit(main)/qr[full_matrices=True]"(#loc2)) -#loc7 = loc("jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]"(#loc2)) -#loc8 = loc("jit()/jit(main)/householder_product"(#loc2)) -#loc10 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]"(#loc2)) -#loc11 = loc("jit()/jit(main)/jit(triu)/add"(#loc2)) -#loc12 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]"(#loc2)) -#loc13 = loc("jit()/jit(main)/jit(triu)/ge"(#loc2)) -#loc14 = loc("jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]"(#loc2)) -#loc15 = loc("jit()/jit(main)/jit(triu)/select_n"(#loc2)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01)\x05\x01\x03\x01\x03\x05\x03\x19\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x03~\x02\xf39\x01\x99\x0f\x17\x13\x0f\x0f\x0b\x0b\x07\x0b\x13\x0f\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x13\x17\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x13+\x0b\x0f\x0b\x0b\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0bK\x0b\x13\x0f\x0b#\x0b\x0b\x0b\x0f\x0bK\x0b\x13\x0b\x03[O/\x0b\x0b\x0b/\x0b\x0f\x0b\x0b\x0b\x0b\x0f\x0f\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x1f\x0f\x0f\x0b\x1f\x0b\x0b\x0f\x17\x1b\x1f\x0b\x1fO/\x0b\x0b\x13\x17\x01\x05\x0b\x0f\x035\x17\x0f\x0f\x07\x07\x07\x17\x17\x13\x07\x07\x0f\x17\x13\x17\x17\x13\x13\x17\x13\x13\x13\x13\x13\x13\x17\x02\xde\x08\x1d}\x03\x17\x1fj\x05\x01\x03\x03\x11\xcf\x1d\x93\x03\x1dU\x03\x05\x1f\x05!\x1f\x05#\x03\x03\x0b\xe5\x11\x03\x05\x05%\x05'\x05)\x05+\x05-\x03\x03#\xcb\x05/\x1d]\x03\x051\x053\x03\x03\x0b\xd5\x17\x1ff\x05\x01\x055\x057\x059\x05;\x05=\x05?\x05A\x05C\x03\x03\x0b\xe1\x03\x05'\xab)\xe3\x03\x03\x11\xe7\x03\tGIK\x15M\x15\rO\x05E\x11\x01\x00\x05G\x05I\x05K\x03\x0b\x17\x9d\x19\xb5\x1b\xb7\r\xc1\x1d\xc3\x03\x0b\x17\xa7\x19\xc7\x1b\xa7\r\xa9\x1d\xc9\x05M\x1dY\x03\x05O\x03\x03\x0b\xcd\x05Q\x03\x03#\xd1\x1dc\x03\x05S\x03\x05'\xab)\xd3\x1di\x03\x05U\x1dm\x03\x05W\x1dq\x03\x05Y\x1du-\x05[\x1dy-\x05]\x03\x11/\xad1\xd73\xd95\x9d7\xaf9\xdb;\xb1=\xdf\x05_\x03\x03\x11\xe9\x1d\x83\x03\x05a\x03\x07\x87\xa3\x89\xa3\x8b\xa3\x05c\x05e\x05g\x1d\x8f\x03\x05i\x03\x11/\xad1\xeb3\xed5\x9d7\xaf9\xef;\xb1=\xf1\x05k\x03\x03\x97\xa9\x05m\x1f+!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f-\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1do\x1dq\x1f\x1f\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1ds\x03\x03\xc5\x1du\t\x07\x0b\x05\x05\x01\x03\x03\xdd\x1f/\x01#!\x03\x05\xb9\xbd\r\x05\xa5\xbb\x9f\xa1\x1dw\r\x05\xa5\xbf\x9f\xa1\x1dy\x1d{\x1d}\r\x03\x9f\xa1##\x1d\x7f\x13\r\x01\x1f\x07\t\xff\xff\xff\xff\x1f%\x01\x13\r\x05\x07\x05\x1f\t\t\x00\x00\x00\x00\x1d\x81\x1d\x83\x03\x03\x99\x15\x03\x01\x01\x01\x03\t\x99\x9b\xb3\x9b\x1f\x07\t\x00\x00\x00\x00\x07\x01\x1f\t\t\x00\x00\xc0\x7f\x1f\x1f!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d\x85\x1d\x87\x03\x05\x99\x9b\x03\x07\x99\xb3\x9b\x01\t\x01\x02\x02)\x05\r\r\x0b)\x01\x19)\x01\x0b\t\x1d\x01)\x05\r\r\x19)\x05\r\r\x0f)\x03\r\x0b\x13\x1b)\x01\x0f)\x05\x05\x05\x0f)\x03\t\r\x11\x01\x05\x05\x05\x11\x03\x05\x03\x05)\x03\x01\r)\x03%\x0b)\x03\x02\x08\x0b)\x03\t\x17)\x03\x05\x17)\x03\x01\x17)\x03\x05\x0f)\x03\r\x0f)\x03\x05\r)\x03\x02\x04\x0b\x04\x1a\x05\x05\x01\x11\x0fE\x07\x03\x01\t\r\x11\x0fQ\x07\x03Cu\t\x03s!\x03'\x15\x06w\x03\x05\x03\x01\x11\x07\x01{\t\x05\x15\x07)\x03\x03\x05\x03\x01?\x03\x07\x03\x07\x01\x05\x03\x07\x03\r\x0b\x07\x01A\x03\x1b\x05\t\x0f\x03\x07\x01\x05\x03\x1d\x03\x11\x05\x03\x01\x13\x03\t\x03\x07\x01\x05\x03\x05\x03\x15\x03\x07\x01C\x03\x13\x03\x13\x07\x06\x01\x03\x05\x07\x19\x05\x17\x03\x07\x01\x05\x031\x03\x11\x05\x03\x01\x13\x03\t\x03\x07\x01\x05\x03\x15\x03\x1f\x03\x07\x01\x7f\x033\x03\x1d\x07\x06\x01\x03\x15\x07#\x07!\x05\x03\x81+\x03\t\x17\x07\x8d\x85\x03\x05\x05\x1b'\x11\x07\x07\x91\x07\x05\x077\x05)%\x05\x03\x07?\x03\x07\x03\x07\x07\x05\x03\x07\x031\x0b\x07\x07A\x03\x1b\x05-3\x03\x07\x07\x05\x03\x1d\x035\x05\x03\x07\x13\x03\t\x03\x07\x07\x05\x03\x05\x039\x03\x07\x07C\x03\x13\x037\x07\x06\x07\x03\x05\x07=+;\x19\x07\t\x95\x03\x05\x03\x1b\x0f\x04\x0f\x05?A\r\x11\tS\x07\x03\x15+\x03\x05\t\t\x03W!\x03\x11\x05\x03\t[\x03\x07\x03\x07%\x05\x03\x11\x03\x05\x13\x06%\x03\x11\x05\x03\x07\t\x03a_\x03\x11\x0b\x07ge\x03\x13\x05\t\x0b\x05\x03\t+\x03\t\x03\x07k\x05\x03\x05\x03\x0f\x07\x06o\x03\x05\x07\r\x11\x01\x0f\x04\t\x03\x13\x06\x03\x01\x05\x01\x00\xea\x1a\x89!3!+\x11\x0f\x0b\t\t\x0b!\x11#\x0fY\x87##%_=\x85\x87W\xb3K\x9bM\x9bn\x03\x1b%)9\x1f/!!)#\x1f\x19+\x1b\x1f]\x1f\x15\x1d\x15+\x13\r\x11\x0f\x17\x0f\x1f\x15\x11\x17\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00iota_v1\x00compare_v1\x00func_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00sym_name\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00/release/jax/tests/export_back_compat_test.py\x00iota_dimension\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=float32 shape=(9,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]\x00jit()/jit(main)/householder_product\x00callee\x00mhlo.layout_mode\x00default\x00jax.result_info\x00triu\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00\x00\x00\x00\x01\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x00\x01\x00\x00\x00hipsolver_geqrf\x00\x00\x00\x00\x00\x01\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x80\x00\x00\x00\x00hipsolver_orgqr\x00", - xla_call_module_version=9, - nr_devices=1, -) # End paste - - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_08_05["batched"] = dict( - testdata_version=1, - platform='rocm', - custom_call_targets=['hipblas_geqrf_batched', 'hipsolver_orgqr'], - serialized_date=datetime.date(2024, 8, 5), - inputs=(), - expected_outputs=(array([[[ 0. , 0.9128709 , 0.40824834], - [-0.4472136 , 0.3651484 , -0.81649655], - [-0.8944272 , -0.18257423, 0.40824828]], - - [[-0.42426407, 0.8082888 , 0.4082513 ], - [-0.5656854 , 0.11547317, -0.81649613], - [-0.7071068 , -0.5773518 , 0.40824607]]], dtype=float32), array([[[-6.7082038e+00, -8.0498447e+00, -9.3914852e+00], - [ 0.0000000e+00, 1.0954450e+00, 2.1908898e+00], - [ 0.0000000e+00, 0.0000000e+00, 1.6371473e-09]], - - [[-2.1213203e+01, -2.2910259e+01, -2.4607313e+01], - [ 0.0000000e+00, 3.4641036e-01, 6.9281983e-01], - [ 0.0000000e+00, 0.0000000e+00, 8.3555670e-07]]], dtype=float32)), - mlir_module_text=r""" -#loc2 = loc("/release/jax/tests/export_back_compat_test.py":346:0) -#loc9 = loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc2)) -module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<2x3x3xf32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x3x3xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { - %0 = stablehlo.iota dim = 0 : tensor<18xf32> loc(#loc3) - %1 = stablehlo.reshape %0 : (tensor<18xf32>) -> tensor<2x3x3xf32> loc(#loc4) - %2:4 = stablehlo.custom_call @hipblas_geqrf_batched(%1) {api_version = 2 : i32, backend_config = "\00\00\00\00\02\00\00\00\03\00\00\00\03\00\00\00", operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x3x3xf32>) -> (tensor<2x3x3xf32>, tensor<2x3xf32>, tensor<16xi8>, tensor<16xi8>) loc(#loc5) - %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc6) - %3 = stablehlo.pad %2#0, %cst, low = [0, 0, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<2x3x3xf32>, tensor) -> tensor<2x3x3xf32> loc(#loc7) - %4:3 = stablehlo.custom_call @hipsolver_orgqr(%3, %2#1) {api_version = 2 : i32, backend_config = "\00\00\00\00\02\00\00\00\03\00\00\00\03\00\00\00\03\00\00\00\80\00\00\00", operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x3x3xf32>, tensor<2x3xf32>) -> (tensor<2x3x3xf32>, tensor<2xi32>, tensor<128xf32>) loc(#loc8) - %c = stablehlo.constant dense<0> : tensor loc(#loc8) - %5 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<2xi32> loc(#loc8) - %6 = stablehlo.compare EQ, %4#1, %5, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc8) - %7 = stablehlo.broadcast_in_dim %6, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc8) - %cst_0 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc8) - %8 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor) -> tensor<2x3x3xf32> loc(#loc8) - %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x3x3xi1> loc(#loc8) - %10 = stablehlo.select %9, %4#0, %8 : tensor<2x3x3xi1>, tensor<2x3x3xf32> loc(#loc8) - %11 = call @triu(%2#0) : (tensor<2x3x3xf32>) -> tensor<2x3x3xf32> loc(#loc9) - return %10, %11 : tensor<2x3x3xf32>, tensor<2x3x3xf32> loc(#loc) - } loc(#loc) - func.func private @triu(%arg0: tensor<2x3x3xf32> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc2))) -> (tensor<2x3x3xf32> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc10) - %c = stablehlo.constant dense<-1> : tensor loc(#loc9) - %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc11) - %2 = stablehlo.add %0, %1 : tensor<3x3xi32> loc(#loc11) - %3 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc12) - %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc13) - %5 = stablehlo.broadcast_in_dim %4, dims = [1, 2] : (tensor<3x3xi1>) -> tensor<2x3x3xi1> loc(#loc14) - %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc9) - %6 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<2x3x3xf32> loc(#loc15) - %7 = stablehlo.select %5, %6, %arg0 : tensor<2x3x3xi1>, tensor<2x3x3xf32> loc(#loc16) - return %7 : tensor<2x3x3xf32> loc(#loc9) - } loc(#loc9) -} loc(#loc) -#loc = loc(unknown) -#loc1 = loc("/release/jax/tests/export_back_compat_test.py":345:0) -#loc3 = loc("jit()/jit(main)/iota[dtype=float32 shape=(18,) dimension=0]"(#loc1)) -#loc4 = loc("jit()/jit(main)/reshape[new_sizes=(2, 3, 3) dimensions=None]"(#loc1)) -#loc5 = loc("jit()/jit(main)/geqrf"(#loc2)) -#loc6 = loc("jit()/jit(main)/qr[full_matrices=True]"(#loc2)) -#loc7 = loc("jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0), (0, 0, 0))]"(#loc2)) -#loc8 = loc("jit()/jit(main)/householder_product"(#loc2)) -#loc10 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]"(#loc2)) -#loc11 = loc("jit()/jit(main)/jit(triu)/add"(#loc2)) -#loc12 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]"(#loc2)) -#loc13 = loc("jit()/jit(main)/jit(triu)/ge"(#loc2)) -#loc14 = loc("jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(2, 3, 3) broadcast_dimensions=(1, 2)]"(#loc2)) -#loc15 = loc("jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(2, 3, 3) broadcast_dimensions=()]"(#loc2)) -#loc16 = loc("jit()/jit(main)/jit(triu)/select_n"(#loc2)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01)\x05\x01\x03\x01\x03\x05\x03\x19\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x03\x96\x02\xfb=\x01\x9f\x17\x0f\x0f\x0b\x13\x0b\x0b\x07\x0f\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x13\x17\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b+\x0b\x0f\x0b\x0b\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0bK\x0f\x0b\x0f\x0b#\x0b\x0b\x0b\x0f\x0bK\x0b\x13\x1b\x13\x13\x13\x13\x0b\x03]o/\x0b\x0b\x0b/\x0b\x0f\x0b\x0b\x0b\x0b\x0fO\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x1f\x0f\x0f\x0bO\x1f\x0b\x0b\x0f\x17\x1b\x0b\x0b\x13\x17\x1f\x0b/\x1fo\x01\x05\x0b\x0f\x039\x1b\x07\x07\x0f\x17\x0f\x07\x07\x07\x1b\x13\x13\x13\x17\x17\x13\x17\x13\x13\x17\x07\x13\x13\x13\x17\x13\x1b\x13\x02\xf6\t\x17\x1bj\x05\x01\x1d\x8f\x01\x1dK\x01\x05\x1f\x03\x03\x0b\xd5\x05!\x05#\x1f\x11\x03\x05\x05%\x05'\x05)\x05+\x05-\x03\x03\x1f\xd1\x05/\x1dS\x01\x051\x053\x03\x03\x07\xdd\x17\x1bf\x05\x01\x055\x057\x059\x05;\x05=\x05?\x05A\x05C\x03\t=?A\x11C\x11\rE\x05E\x11\x01\x00\x05G\x05I\x05K\x03\x0b\x13\xa3\x15\xbb\x17\xbd\r\xc7\x19\xc9\x03\x0b\x13\xad\x15\xcd\x17\xad\r\xaf\x19\xcf\x05M\x1dO\x01\x05O\x03\x03\x07\xd3\x05Q\x03\x03\x1f\xd7\x1dY\x01\x05S\x03\x05#\xb1%\xd9\x1d_\x01\x05U\x03\x03\x0b\xdb\x1de\x01\x05W\x1di\x01\x05Y\x1dm\x01\x05[\x1dq)\x05]\x1du)\x05_\x03\x11+\xb3-\xdf/\xe11\xa33\xb55\xe37\xb79\xe7\x1d{\x01\x05a\x1d\x7f\x01\x05c\x03\x07\x83\xa9\x85\xa9\x87\xa9\x05e\x05g\x05i\x1d\x8b\x01\x05k\x03\x11+\xb3-\xe9/\xeb1\xa33\xb55\xed7\xb79\xef\x05m\x03\x03\x07\xf1\x03\x05#\xb1%\xf3\x03\x03\x0b\xf5\x03\x03\x07\xf7\x03\x03\x0b\xf9\x03\x03\x9d\xaf\x05o\x1f/1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1dq\x1ds\x1f\x1b\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1du\x03\x03\xcb\x1dw\t\x07\x0b\x05\x05\x01\x03\x03\xe5\x1f1!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00#\x1f\x03\x05\xbf\xc3\r\x05\xab\xc1\xa5\xa7\x1dy\r\x05\xab\xc5\xa5\xa7\x1d{\x1d}\x1d\x7f\r\x03\xa5\xa7#!\x1d\x81\x13\x07\x01\x1f\x0f\t\xff\xff\xff\xff\x1f#\x01\x13\x07\x05\x07\x05\x1f'!\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f\x0b\t\x00\x00\x00\x00\x1d\x83\x1d\x85\x03\x03\x9f\x15\x03\x01\x01\x01\x03\t\x9f\xb9\xa1\xa1\x1d\x87\x1d\x89\x03\x05\x9f\xb9\x03\x07\x9f\xa1\xa1\x1f\x0f\t\x00\x00\x00\x00\x07\x01\x1f;\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x0b\t\x00\x00\xc0\x7f\x1f\x1b1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x07\t\r\r\t\x1d\t)\x01\t)\x05\r\r\x13)\x01\x13\x01\x1b\x13)\x07\t\r\r\x11)\x03A-)\x03\r\x07)\x03\t\x13\x11\x01\x05\x05\x05\x11\x03\x05\x03\x05)\x03\x01\x07)\x05\r\r\x11)\x03\t\x07)\x03I\t)\x05\t\r\t\x17)\x03\r\x15)\x03\t\x15)\x03\x05\x15)\x03\x02\x04\t)\x03\t\x11)\x07\t\x05\x05\x11)\x03\x05\x07\x04\xa6\x03\x05\x01\x11\x0f;\x07\x03\x01\t\t\x11\x0fG\x07\x03)A\x07\x03o\x1d\x03)\x15\x06s\x03\x05\x03\x01\x11\x07yw\t\x05+\x19\x19\x03\x03\x05\x03}'\x03\x0b\x17\x07\x89\x81\x03\x05\x05\x05\r\x11\x07\x03\x8d\x07\x05\x1d5\x05\x0f\x07\x05\x03\x03\x91\x03\x0f\x03\x07\x03\t\x03\x1d\x03\x17\x0b\x07\x03\x93\x037\x05\x13\x19\x03\x07\x03\x95\x039\x03\x1b\x05\x03\x03\x97\x03\x0b\x03\x07\x03\t\x03\x05\x03\x1f\x03\x07\x03\x99\x03\x17\x03\x1d\r\x06\x03\x03\x05\x07#\x11!\x19\x07\x05\x9b\x03\x05\x03\x05\x0f\x04\x0f\x05%'\t\x11\x05I\x07\x03\x17/\x03\x05\x05\x07\x03M\x1d\x03\r\x05\x03\x05Q\x03\x0f\x03\x07!\t\x03\r\x03\x05\x13\x06!\x03\r\x05\x03\x07\x07\x03WU\x03\r\x0b\x07][\x03%\x05\t\x0b\x03\x07ca\x03\x17\x03\r\x05\x03\x05'\x03\x0b\x03\x07g\t\x03\x05\x03\x11\r\x06k\x03\x05\x07\x0f\x13\x01\x0f\x04\x05\x03\x15\x06\x03\x01\x05\x01\x00\xbe\x1c\x8b!3-#\x11\x0f\x0b\t\t\x0b!\x11#\x0fY\x9d##%_=\x8b\x89W\xb9\xc1K\x9bM\x9bn\x03\x1b%)9\x1f/!!)#\x1f\x19+\x1b\x1f]\x1f\x15\x1d\x15\x13+\r\x11\x0f\x17\x0f\x1f\x15\x15\x17\x11\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00iota_v1\x00func_v1\x00compare_v1\x00select_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00broadcast_dimensions\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00/release/jax/tests/export_back_compat_test.py\x00iota_dimension\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(2, 3, 3) broadcast_dimensions=(1, 2)]\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(2, 3, 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=float32 shape=(18,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(2, 3, 3) dimensions=None]\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0), (0, 0, 0))]\x00jit()/jit(main)/householder_product\x00callee\x00mhlo.layout_mode\x00default\x00jax.result_info\x00triu\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x00hipblas_geqrf_batched\x00\x00\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x80\x00\x00\x00\x00hipsolver_orgqr\x00", - xla_call_module_version=9, - nr_devices=1, -) # End paste diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index d35e421ef904..d7035c92b24a 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -64,7 +64,6 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "@local_config_cuda//cuda:cublas_headers", "@local_config_cuda//cuda:cuda_headers", "@xla//xla/tsl/cuda:cupti", "@xla//xla/tsl/cuda:cusolver", @@ -98,55 +97,6 @@ cc_library( ], ) -cc_library( - name = "cublas_kernels", - srcs = ["//jaxlib/gpu:blas_kernels.cc"], - hdrs = ["//jaxlib/gpu:blas_kernels.h"], - deps = [ - ":cuda_blas_handle_pool", - ":cuda_gpu_kernel_helpers", - ":cuda_make_batch_pointers", - ":cuda_vendor", - "//jaxlib:kernel_helpers", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/base", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@local_config_cuda//cuda:cublas_headers", - "@local_config_cuda//cuda:cuda_headers", - "@xla//xla/service:custom_call_status", - "@xla//xla/tsl/cuda:cublas", - "@xla//xla/tsl/cuda:cudart", - ], -) - -nanobind_extension( - name = "_blas", - srcs = ["//jaxlib/gpu:blas.cc"], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], - features = ["-use_header_modules"], - module_name = "_blas", - deps = [ - ":cublas_kernels", - ":cuda_vendor", - "//jaxlib:kernel_nanobind_helpers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings:str_format", - "@nanobind", - "@xla//xla/tsl/cuda:cublas", - "@xla//xla/tsl/cuda:cudart", - "@xla//xla/tsl/python/lib/core:numpy", - ], -) - cc_library( name = "cudnn_rnn_kernels", srcs = ["//jaxlib/gpu:rnn_kernels.cc"], @@ -516,7 +466,6 @@ cc_library( srcs = ["//jaxlib/gpu:gpu_kernels.cc"], visibility = ["//visibility:public"], deps = [ - ":cublas_kernels", ":cuda_linalg_kernels", ":cuda_prng_kernels", ":cuda_vendor", @@ -651,7 +600,6 @@ nanobind_extension( py_library( name = "cuda_gpu_support", deps = [ - ":_blas", ":_hybrid", ":_linalg", ":_prng", diff --git a/jaxlib/gpu/BUILD b/jaxlib/gpu/BUILD index 1fd2775ecf9a..e153e0588cf6 100644 --- a/jaxlib/gpu/BUILD +++ b/jaxlib/gpu/BUILD @@ -30,11 +30,8 @@ package( ) exports_files(srcs = [ - "blas.cc", "blas_handle_pool.cc", "blas_handle_pool.h", - "blas_kernels.cc", - "blas_kernels.h", "ffi_wrapper.h", "gpu_kernel_helpers.cc", "gpu_kernel_helpers.h", diff --git a/jaxlib/gpu/blas.cc b/jaxlib/gpu/blas.cc deleted file mode 100644 index 59bf2c4603f6..000000000000 --- a/jaxlib/gpu/blas.cc +++ /dev/null @@ -1,75 +0,0 @@ -/* Copyright 2019 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/strings/str_format.h" -#include "nanobind/nanobind.h" -#include "nanobind/stl/pair.h" // IWYU pragma: keep -#include "jaxlib/gpu/blas_kernels.h" -#include "jaxlib/gpu/vendor.h" -#include "jaxlib/kernel_nanobind_helpers.h" -#include "xla/tsl/python/lib/core/numpy.h" - -namespace jax { -namespace JAX_GPU_NAMESPACE { -namespace { - -namespace nb = nanobind; - -// Converts a NumPy dtype to a Type. -BlasType DtypeToBlasType(const dtype& np_type) { - static auto* types = new absl::flat_hash_map, BlasType>({ - {{'f', 4}, BlasType::F32}, - {{'f', 8}, BlasType::F64}, - {{'c', 8}, BlasType::C64}, - {{'c', 16}, BlasType::C128}, - }); - auto it = types->find({np_type.kind(), np_type.itemsize()}); - if (it == types->end()) { - nb::str repr = nb::repr(np_type); - throw std::invalid_argument( - absl::StrFormat("Unsupported dtype %s", repr.c_str())); - } - return it->second; -} - -// Returns the descriptor for a GetrfBatched operation. -std::pair BuildGeqrfBatchedDescriptor(const dtype& dtype, - int b, int m, int n) { - BlasType type = DtypeToBlasType(dtype); - size_t size = b * sizeof(void*); - return {size, PackDescriptor(GeqrfBatchedDescriptor{type, b, m, n})}; -} - -nb::dict Registrations() { - nb::dict dict; - dict[JAX_GPU_PREFIX "blas_geqrf_batched"] = EncapsulateFunction(GeqrfBatched); - return dict; -} - -NB_MODULE(_blas, m) { - tsl::ImportNumpy(); - - m.def("registrations", &Registrations); - m.def("build_geqrf_batched_descriptor", &BuildGeqrfBatchedDescriptor); -} - -} // namespace -} // namespace JAX_GPU_NAMESPACE -} // namespace jax diff --git a/jaxlib/gpu/blas_kernels.cc b/jaxlib/gpu/blas_kernels.cc deleted file mode 100644 index cdcc154d026d..000000000000 --- a/jaxlib/gpu/blas_kernels.cc +++ /dev/null @@ -1,138 +0,0 @@ -/* Copyright 2019 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "jaxlib/gpu/blas_kernels.h" - -#include -#include -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_format.h" -#include "jaxlib/gpu/blas_handle_pool.h" -#include "jaxlib/gpu/gpu_kernel_helpers.h" -#include "jaxlib/gpu/make_batch_pointers.h" -#include "jaxlib/gpu/vendor.h" -#include "jaxlib/kernel_helpers.h" -#include "xla/service/custom_call_status.h" - -namespace jax { - -namespace JAX_GPU_NAMESPACE { - -namespace { - -int SizeOfBlasType(BlasType type) { - switch (type) { - case BlasType::F32: - return sizeof(float); - case BlasType::F64: - return sizeof(double); - case BlasType::C64: - return sizeof(gpublasComplex); - case BlasType::C128: - return sizeof(gpublasDoubleComplex); - } -} - -} // namespace - -// Batched QR decomposition: geqrfbatched - -static absl::Status GeqrfBatched_(gpuStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const GeqrfBatchedDescriptor& d = **s; - auto h = BlasHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - if (buffers[0] != buffers[1]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( - buffers[1], buffers[0], SizeOfBlasType(d.type) * d.batch * d.m * d.n, - gpuMemcpyDeviceToDevice, stream))); - } - - std::vector info(d.batch); - MakeBatchPointersAsync(stream, buffers[1], buffers[3], d.batch, - SizeOfBlasType(d.type) * d.m * d.n); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuGetLastError())); - MakeBatchPointersAsync(stream, buffers[2], buffers[4], d.batch, - SizeOfBlasType(d.type) * std::min(d.m, d.n)); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuGetLastError())); - switch (d.type) { - case BlasType::F32: { - float** a_batch_ptrs = static_cast(buffers[3]); - float** tau_batch_ptrs = static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpublasSgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m, - tau_batch_ptrs, info.data(), d.batch))); - break; - } - case BlasType::F64: { - double** a_batch_ptrs = static_cast(buffers[3]); - double** tau_batch_ptrs = static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpublasDgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m, - tau_batch_ptrs, info.data(), d.batch))); - break; - } - case BlasType::C64: { - gpublasComplex** a_batch_ptrs = static_cast(buffers[3]); - gpublasComplex** tau_batch_ptrs = - static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpublasCgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m, - tau_batch_ptrs, info.data(), d.batch))); - break; - } - case BlasType::C128: { - gpublasDoubleComplex** a_batch_ptrs = - static_cast(buffers[3]); - gpublasDoubleComplex** tau_batch_ptrs = - static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpublasZgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m, - tau_batch_ptrs, info.data(), d.batch))); - break; - } - } - auto it = - std::find_if(info.begin(), info.end(), [](int i) { return i != 0; }); - - if (it != info.end()) { - return absl::InvalidArgumentError( - absl::StrFormat("QR decomposition failed with status %d for batch " - "element %d", - *it, std::distance(info.begin(), it))); - } - - return absl::OkStatus(); -} - -void GeqrfBatched(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = GeqrfBatched_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -} // namespace JAX_GPU_NAMESPACE -} // namespace jax diff --git a/jaxlib/gpu/blas_kernels.h b/jaxlib/gpu/blas_kernels.h deleted file mode 100644 index 8ca7b4db4668..000000000000 --- a/jaxlib/gpu/blas_kernels.h +++ /dev/null @@ -1,48 +0,0 @@ -/* Copyright 2019 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef JAXLIB_GPU_BLAS_KERNELS_H_ -#define JAXLIB_GPU_BLAS_KERNELS_H_ - -#include - -#include "jaxlib/gpu/vendor.h" -#include "xla/service/custom_call_status.h" - -namespace jax { -namespace JAX_GPU_NAMESPACE { - -// Set of types known to Cusolver. -enum class BlasType { - F32, - F64, - C64, - C128, -}; - -// Batched QR decomposition: geqrfbatched - -struct GeqrfBatchedDescriptor { - BlasType type; - int batch, m, n; -}; - -void GeqrfBatched(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -} // namespace JAX_GPU_NAMESPACE -} // namespace jax - -#endif // JAXLIB_GPU_BLAS_KERNELS_H_ diff --git a/jaxlib/gpu/gpu_kernels.cc b/jaxlib/gpu/gpu_kernels.cc index 840c313f2fa3..620f9cf45199 100644 --- a/jaxlib/gpu/gpu_kernels.cc +++ b/jaxlib/gpu/gpu_kernels.cc @@ -16,7 +16,6 @@ limitations under the License. // This file is not used by JAX itself, but exists to assist with running // JAX-generated HLO code from outside of JAX. -#include "jaxlib/gpu/blas_kernels.h" #include "jaxlib/gpu/linalg_kernels.h" #include "jaxlib/gpu/prng_kernels.h" #include "jaxlib/gpu/rnn_kernels.h" @@ -33,21 +32,17 @@ namespace jax { namespace JAX_GPU_NAMESPACE { namespace { -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cublas_geqrf_batched", GeqrfBatched, - "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cudnn_rnn", RNNForward, "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cudnn_rnn_bwd", RNNBackward, "CUDA"); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_getrf_ffi", "CUDA", GetrfFfi); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_syrk_ffi", "CUDA", SyrkFfi); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_geqrf", Geqrf, "CUDA"); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_geqrf_ffi", "CUDA", GeqrfFfi); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_csrlsvqr", Csrlsvqr, "CUDA"); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_csrlsvqr_ffi", "CUDA", CsrlsvqrFfi); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_orgqr", Orgqr, "CUDA"); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_orgqr_ffi", "CUDA", OrgqrFfi); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_syevd", Syevd, "CUDA"); diff --git a/jaxlib/gpu/solver.cc b/jaxlib/gpu/solver.cc index 8013d9877ed5..3c76598e5285 100644 --- a/jaxlib/gpu/solver.cc +++ b/jaxlib/gpu/solver.cc @@ -54,45 +54,6 @@ SolverType DtypeToSolverType(const dtype& np_type) { return it->second; } -// geqrf: QR decomposition - -// Returns the workspace size and a descriptor for a geqrf operation. -std::pair BuildGeqrfDescriptor(const dtype& dtype, int b, int m, - int n) { - SolverType type = DtypeToSolverType(dtype); - auto h = SolverHandlePool::Borrow(/*stream=*/nullptr); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - int lwork; - switch (type) { - case SolverType::F32: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnSgeqrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); - break; - case SolverType::F64: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnDgeqrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); - break; - case SolverType::C64: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnCgeqrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); - break; - case SolverType::C128: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnZgeqrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); - break; - } - return {lwork, PackDescriptor(GeqrfDescriptor{type, b, m, n, lwork})}; -} - #ifdef JAX_GPU_CUDA // csrlsvqr: Linear system solve via Sparse QR @@ -106,49 +67,6 @@ nb::bytes BuildCsrlsvqrDescriptor(const dtype& dtype, int n, int nnzA, #endif // JAX_GPU_CUDA -// orgqr/ungqr: apply elementary Householder transformations - -// Returns the workspace size and a descriptor for a geqrf operation. -std::pair BuildOrgqrDescriptor(const dtype& dtype, int b, int m, - int n, int k) { - SolverType type = DtypeToSolverType(dtype); - auto h = SolverHandlePool::Borrow(/*stream=*/nullptr); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - int lwork; - switch (type) { - case SolverType::F32: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnSorgqr_bufferSize(handle.get(), m, n, k, - /*A=*/nullptr, - /*lda=*/m, - /*tau=*/nullptr, &lwork))); - break; - case SolverType::F64: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnDorgqr_bufferSize(handle.get(), m, n, k, - /*A=*/nullptr, - /*lda=*/m, - /*tau=*/nullptr, &lwork))); - break; - case SolverType::C64: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnCungqr_bufferSize(handle.get(), m, n, k, - /*A=*/nullptr, - /*lda=*/m, - /*tau=*/nullptr, &lwork))); - break; - case SolverType::C128: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnZungqr_bufferSize(handle.get(), m, n, k, - /*A=*/nullptr, - /*lda=*/m, - /*tau=*/nullptr, &lwork))); - break; - } - return {lwork, PackDescriptor(OrgqrDescriptor{type, b, m, n, k, lwork})}; -} - // Symmetric (Hermitian) eigendecomposition, QR algorithm: syevd/heevd // Returns the workspace size and a descriptor for a syevd operation. @@ -423,8 +341,6 @@ std::pair BuildSytrdDescriptor(const dtype& dtype, bool lower, nb::dict Registrations() { nb::dict dict; - dict[JAX_GPU_PREFIX "solver_geqrf"] = EncapsulateFunction(Geqrf); - dict[JAX_GPU_PREFIX "solver_orgqr"] = EncapsulateFunction(Orgqr); dict[JAX_GPU_PREFIX "solver_syevd"] = EncapsulateFunction(Syevd); dict[JAX_GPU_PREFIX "solver_syevj"] = EncapsulateFunction(Syevj); dict[JAX_GPU_PREFIX "solver_gesvd"] = EncapsulateFunction(Gesvd); @@ -456,8 +372,6 @@ nb::dict Registrations() { NB_MODULE(_solver, m) { tsl::ImportNumpy(); m.def("registrations", &Registrations); - m.def("build_geqrf_descriptor", &BuildGeqrfDescriptor); - m.def("build_orgqr_descriptor", &BuildOrgqrDescriptor); m.def("build_syevd_descriptor", &BuildSyevdDescriptor); m.def("build_syevj_descriptor", &BuildSyevjDescriptor); m.def("build_gesvd_descriptor", &BuildGesvdDescriptor); diff --git a/jaxlib/gpu/solver_kernels.cc b/jaxlib/gpu/solver_kernels.cc index 8971619d7f34..040b5a137bc6 100644 --- a/jaxlib/gpu/solver_kernels.cc +++ b/jaxlib/gpu/solver_kernels.cc @@ -50,92 +50,6 @@ static int SizeOfSolverType(SolverType type) { } } -// geqrf: QR decomposition - -static absl::Status Geqrf_(gpuStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const GeqrfDescriptor& d = **s; - auto h = SolverHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - if (buffers[1] != buffers[0]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( - buffers[1], buffers[0], - SizeOfSolverType(d.type) * static_cast(d.batch) * - static_cast(d.m) * static_cast(d.n), - gpuMemcpyDeviceToDevice, stream))); - } - - int* info = static_cast(buffers[3]); - void* workspace = buffers[4]; - switch (d.type) { - case SolverType::F32: { - float* a = static_cast(buffers[1]); - float* tau = static_cast(buffers[2]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpusolverDnSgeqrf(handle.get(), d.m, d.n, a, d.m, tau, - static_cast(workspace), d.lwork, info))); - a += d.m * d.n; - tau += std::min(d.m, d.n); - ++info; - } - break; - } - case SolverType::F64: { - double* a = static_cast(buffers[1]); - double* tau = static_cast(buffers[2]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpusolverDnDgeqrf(handle.get(), d.m, d.n, a, d.m, tau, - static_cast(workspace), d.lwork, info))); - a += d.m * d.n; - tau += std::min(d.m, d.n); - ++info; - } - break; - } - case SolverType::C64: { - gpuComplex* a = static_cast(buffers[1]); - gpuComplex* tau = static_cast(buffers[2]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnCgeqrf( - handle.get(), d.m, d.n, a, d.m, tau, - static_cast(workspace), d.lwork, info))); - a += d.m * d.n; - tau += std::min(d.m, d.n); - ++info; - } - break; - } - case SolverType::C128: { - gpuDoubleComplex* a = static_cast(buffers[1]); - gpuDoubleComplex* tau = static_cast(buffers[2]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnZgeqrf( - handle.get(), d.m, d.n, a, d.m, tau, - static_cast(workspace), d.lwork, info))); - a += d.m * d.n; - tau += std::min(d.m, d.n); - ++info; - } - break; - } - } - return absl::OkStatus(); -} - -void Geqrf(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = Geqrf_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - #ifdef JAX_GPU_CUDA // csrlsvqr: Linear system solve via Sparse QR @@ -237,92 +151,6 @@ void Csrlsvqr(gpuStream_t stream, void** buffers, const char* opaque, #endif // JAX_GPU_CUDA -// orgqr/ungqr: apply elementary Householder transformations - -static absl::Status Orgqr_(gpuStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const OrgqrDescriptor& d = **s; - auto h = SolverHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - if (buffers[2] != buffers[0]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( - buffers[2], buffers[0], - SizeOfSolverType(d.type) * static_cast(d.batch) * - static_cast(d.m) * static_cast(d.n), - gpuMemcpyDeviceToDevice, stream))); - } - - int* info = static_cast(buffers[3]); - void* workspace = buffers[4]; - switch (d.type) { - case SolverType::F32: { - float* a = static_cast(buffers[2]); - float* tau = static_cast(buffers[1]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpusolverDnSorgqr(handle.get(), d.m, d.n, d.k, a, d.m, tau, - static_cast(workspace), d.lwork, info))); - a += d.m * d.n; - tau += d.k; - ++info; - } - break; - } - case SolverType::F64: { - double* a = static_cast(buffers[2]); - double* tau = static_cast(buffers[1]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpusolverDnDorgqr(handle.get(), d.m, d.n, d.k, a, d.m, tau, - static_cast(workspace), d.lwork, info))); - a += d.m * d.n; - tau += d.k; - ++info; - } - break; - } - case SolverType::C64: { - gpuComplex* a = static_cast(buffers[2]); - gpuComplex* tau = static_cast(buffers[1]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnCungqr( - handle.get(), d.m, d.n, d.k, a, d.m, tau, - static_cast(workspace), d.lwork, info))); - a += d.m * d.n; - tau += d.k; - ++info; - } - break; - } - case SolverType::C128: { - gpuDoubleComplex* a = static_cast(buffers[2]); - gpuDoubleComplex* tau = static_cast(buffers[1]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnZungqr( - handle.get(), d.m, d.n, d.k, a, d.m, tau, - static_cast(workspace), d.lwork, info))); - a += d.m * d.n; - tau += d.k; - ++info; - } - break; - } - } - return absl::OkStatus(); -} - -void Orgqr(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = Orgqr_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - // Symmetric (Hermitian) eigendecomposition, QR algorithm: syevd/heevd static absl::Status Syevd_(gpuStream_t stream, void** buffers, diff --git a/jaxlib/gpu/solver_kernels.h b/jaxlib/gpu/solver_kernels.h index 6372e55b930d..a68aaf1ca233 100644 --- a/jaxlib/gpu/solver_kernels.h +++ b/jaxlib/gpu/solver_kernels.h @@ -33,16 +33,6 @@ enum class SolverType { C128, }; -// geqrf: QR decomposition - -struct GeqrfDescriptor { - SolverType type; - int batch, m, n, lwork; -}; - -void Geqrf(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - #ifdef JAX_GPU_CUDA // csrlsvpr: Linear system solve via Sparse QR @@ -58,16 +48,6 @@ void Csrlsvqr(gpuStream_t stream, void** buffers, const char* opaque, #endif // JAX_GPU_CUDA -// orgqr/ungqr: apply elementary Householder transformations - -struct OrgqrDescriptor { - SolverType type; - int batch, m, n, k, lwork; -}; - -void Orgqr(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - // Symmetric (Hermitian) eigendecomposition, QR algorithm: syevd/heevd struct SyevdDescriptor { diff --git a/jaxlib/gpu_solver.py b/jaxlib/gpu_solver.py index efb58f9a4164..c846c63e2ff8 100644 --- a/jaxlib/gpu_solver.py +++ b/jaxlib/gpu_solver.py @@ -16,21 +16,15 @@ from .plugin_support import import_from_plugin -_cublas = import_from_plugin("cuda", "_blas") _cusolver = import_from_plugin("cuda", "_solver") _cuhybrid = import_from_plugin("cuda", "_hybrid") -_hipblas = import_from_plugin("rocm", "_blas") _hipsolver = import_from_plugin("rocm", "_solver") _hiphybrid = import_from_plugin("rocm", "_hybrid") def registrations() -> dict[str, list[tuple[str, Any, int]]]: registrations = {"CUDA": [], "ROCM": []} - for platform, module in [("CUDA", _cublas), ("ROCM", _hipblas)]: - if module: - registrations[platform].extend( - (*i, 0) for i in module.registrations().items()) for platform, module in [("CUDA", _cusolver), ("ROCM", _hipsolver)]: if module: registrations[platform].extend( diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index 358a6d1cc9aa..5893af26de85 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -87,54 +87,6 @@ cc_library( ], ) -cc_library( - name = "hipblas_kernels", - srcs = ["//jaxlib/gpu:blas_kernels.cc"], - hdrs = ["//jaxlib/gpu:blas_kernels.h"], - deps = [ - ":hip_blas_handle_pool", - ":hip_gpu_kernel_helpers", - ":hip_make_batch_pointers", - ":hip_vendor", - "//jaxlib:kernel_helpers", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/base", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/synchronization", - "@local_config_rocm//rocm:hipblas", - "@local_config_rocm//rocm:rocm_headers", - "@xla//xla/service:custom_call_status", - ], -) - -nanobind_extension( - name = "_blas", - srcs = ["//jaxlib/gpu:blas.cc"], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], - features = ["-use_header_modules"], - module_name = "_blas", - deps = [ - ":hip_vendor", - ":hipblas_kernels", - "//jaxlib:kernel_nanobind_helpers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings:str_format", - "@local_config_rocm//rocm:hipblas", - "@local_config_rocm//rocm:rocm_headers", - "@nanobind", - "@xla//xla/tsl/python/lib/core:numpy", - ], -) - cc_library( name = "miopen_rnn_kernels", srcs = ["//jaxlib/gpu:rnn_kernels.cc"], @@ -552,7 +504,6 @@ nanobind_extension( py_library( name = "rocm_gpu_support", deps = [ - ":_blas", ":_hybrid", ":_linalg", ":_prng", diff --git a/jaxlib/tools/build_gpu_kernels_wheel.py b/jaxlib/tools/build_gpu_kernels_wheel.py index 2f81eacbdde4..e9684108caf0 100644 --- a/jaxlib/tools/build_gpu_kernels_wheel.py +++ b/jaxlib/tools/build_gpu_kernels_wheel.py @@ -102,7 +102,6 @@ def prepare_wheel_cuda( dst_dir=plugin_dir, src_files=[ f"__main__/jaxlib/cuda/_solver.{pyext}", - f"__main__/jaxlib/cuda/_blas.{pyext}", f"__main__/jaxlib/cuda/_linalg.{pyext}", f"__main__/jaxlib/cuda/_prng.{pyext}", f"__main__/jaxlib/cuda/_rnn.{pyext}", @@ -140,7 +139,6 @@ def prepare_wheel_rocm( copy_runfiles( dst_dir=plugin_dir, src_files=[ - f"__main__/jaxlib/rocm/_blas.{pyext}", f"__main__/jaxlib/rocm/_linalg.{pyext}", f"__main__/jaxlib/rocm/_prng.{pyext}", f"__main__/jaxlib/rocm/_solver.{pyext}", diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py index 6a6c8c213a64..789838f99d14 100644 --- a/tests/export_back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -38,7 +38,6 @@ from jax._src.internal_test_util.export_back_compat_test_data import cpu_eigh_lapack_syev from jax._src.internal_test_util.export_back_compat_test_data import cpu_lu_lapack_getrf from jax._src.internal_test_util.export_back_compat_test_data import cuda_qr_cusolver_geqrf -from jax._src.internal_test_util.export_back_compat_test_data import rocm_qr_hipsolver_geqrf from jax._src.internal_test_util.export_back_compat_test_data import cpu_qr_lapack_geqrf from jax._src.internal_test_util.export_back_compat_test_data import cpu_schur_lapack_gees from jax._src.internal_test_util.export_back_compat_test_data import cpu_svd_lapack_gesdd @@ -147,7 +146,6 @@ def test_custom_call_coverage(self): cuda_svd_cusolver_gesvd.data_2024_10_08, cpu_tridiagonal_solve_lapack_gtsv.data_2025_01_09, cuda_tridiagonal_cusolver_sytrd.data_2025_01_09, - rocm_qr_hipsolver_geqrf.data_2024_08_05, rocm_eigh_hipsolver_syev.data_2024_08_05, cpu_schur_lapack_gees.data_2023_07_16, cpu_svd_lapack_gesdd.data_2023_06_19, @@ -454,29 +452,6 @@ def test_cpu_qr_lapack_geqrf(self, dtype_name="f32"): self.run_one_test(func, data, rtol=rtol, expect_current_custom_calls=info["custom_call_targets"]) - # TODO(b/369826500): Remove legacy custom call test after mid March 2025. - @parameterized.named_parameters( - dict(testcase_name=f"_dtype={dtype_name}_{batched}", - dtype_name=dtype_name, batched=batched) - for dtype_name in ("f32",) - # For batched qr we use cublas_geqrf_batched/hipblas_geqrf_batched. - for batched in ("batched", "unbatched")) - def test_gpu_qr_solver_geqrf_legacy(self, dtype_name, batched): - if jtu.test_device_matches(["rocm"]): - data = self.load_testdata(rocm_qr_hipsolver_geqrf.data_2024_08_05[batched]) - prefix = "hip" - elif jtu.test_device_matches(["cuda"]): - data = self.load_testdata(cuda_qr_cusolver_geqrf.data_2023_03_18[batched]) - prefix = "cu" - else: - self.skipTest("Unsupported platform") - dtype = dict(f32=np.float32)[dtype_name] - rtol = dict(f32=1e-3)[dtype_name] - shape = dict(batched=(2, 3, 3), unbatched=(3, 3))[batched] - func = lambda: CompatTest.qr_harness(shape, dtype) - self.run_one_test(func, data, rtol=rtol, expect_current_custom_calls=[ - f"{prefix}solver_geqrf_ffi", f"{prefix}solver_orgqr_ffi"]) - @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) for dtype_name in ("f32", "f64", "c64", "c128")) From 576843283bf0e1a1be5fce7fe25dd1f8db94cde7 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 2 Apr 2025 14:57:30 +0000 Subject: [PATCH 332/483] Disabled default env var LOBPCG_EMIT_DEBUG_PLOTS=1 Description: - Disabled default env var LOBPCG_EMIT_DEBUG_PLOTS=1 - When run inside TSAN CI job with 3.14t cpython and under multi-threading the test code from main leads to `RecursionError: maximum recursion depth exceeded` error: ``` ERROR: testLobpcgMonotonicityF32cluster_k_1__n100 (__main__.F32LobpcgTest) F32LobpcgTest.testLobpcgMonotonicityF32cluster_k_1__n100 testLobpcgMonotonicityF32cluster_k_1__n100(matrix_name='cluster(k-1)', n=100, k=10, m=20, tol=2e-06) ---------------------------------------------------------------------- Traceback (most recent call last): File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/pypi_absl_py/site-packages/absl/testing/parameterized.py", line 319, in bound_param_test return test_method(self, **testcase_params) File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/__main__/tests/lobpcg_test.py", line 408, in testLobpcgMonotonicityF32 self.checkLobpcgMonotonicity(matrix_name, n, k, m, tol, jnp.float32) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/__main__/tests/lobpcg_test.py", line 272, in checkLobpcgMonotonicity self._possibly_plot(A, eigs, X, m, matrix_name) ~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/__main__/tests/lobpcg_test.py", line 290, in _possibly_plot self._debug_plots(X, eigs, info, matrix_name, plot_dir) ~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/__main__/tests/lobpcg_test.py", line 318, in _debug_plots ax0.legend() ~~~~~~~~~~^^ File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/pypi_matplotlib/site-packages/matplotlib/axes/_axes.py", line 337, in legend self.legend_ = mlegend.Legend(self, handles, labels, **kwargs) ~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/pypi_matplotlib/site-packages/matplotlib/legend.py", line 549, in __init__ self._init_legend_box(handles, labels, markerfirst) ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/pypi_matplotlib/site-packages/matplotlib/legend.py", line 896, in _init_legend_box handle_list.append(handler.legend_artist(self, orig_handle, ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^ fontsize, handlebox)) ^^^^^^^^^^^^^^^^^^^^ File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/pypi_matplotlib/site-packages/matplotlib/legend_handler.py", line 129, in legend_artist artists = self.create_artists(legend, orig_handle, xdescent, ydescent, width, height, fontsize, handlebox.get_transform()) File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/pypi_matplotlib/site-packages/matplotlib/legend_handler.py", line 303, in create_artists self.update_prop(legline, orig_handle, legend) ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/pypi_matplotlib/site-packages/matplotlib/legend_handler.py", line 88, in update_prop self._update_prop(legend_handle, orig_handle) ~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/pypi_matplotlib/site-packages/matplotlib/legend_handler.py", line 79, in _update_prop self._default_update_prop(legend_handle, orig_handle) ~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/pypi_matplotlib/site-packages/matplotlib/legend_handler.py", line 84, in _default_update_prop legend_handle.update_from(orig_handle) ~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^ File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/pypi_matplotlib/site-packages/matplotlib/lines.py", line 1358, in update_from self._marker = MarkerStyle(marker=other._marker) ~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^ File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/pypi_matplotlib/site-packages/matplotlib/markers.py", line 248, in __init__ self._set_marker(marker) ~~~~~~~~~~~~~~~~^^^^^^^^ File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/pypi_matplotlib/site-packages/matplotlib/markers.py", line 323, in _set_marker self.__dict__ = copy.deepcopy(marker.__dict__) ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^ File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/external/python_x86_64-unknown-linux-gnu-freethreaded/lib/python3.14t/copy.py", line 131, in deepcopy y = copier(x, memo) File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/external/python_x86_64-unknown-linux-gnu-freethreaded/lib/python3.14t/copy.py", line 202, in _deepcopy_dict y[deepcopy(key, memo)] = deepcopy(value, memo) ~~~~~~~~^^^^^^^^^^^^^ File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/external/python_x86_64-unknown-linux-gnu-freethreaded/lib/python3.14t/copy.py", line 138, in deepcopy y = copier(memo) File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/pypi_matplotlib/site-packages/matplotlib/path.py", line 285, in __deepcopy__ p = copy.deepcopy(super(), memo) File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/external/python_x86_64-unknown-linux-gnu-freethreaded/lib/python3.14t/copy.py", line 157, in deepcopy y = _reconstruct(x, memo, *rv) File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/external/python_x86_64-unknown-linux-gnu-freethreaded/lib/python3.14t/copy.py", line 234, in _reconstruct y = func(*args) File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/external/python_x86_64-unknown-linux-gnu-freethreaded/lib/python3.14t/copy.py", line 233, in args = (deepcopy(arg, memo) for arg in args) ~~~~~~~~^^^^^^^^^^^ File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/external/python_x86_64-unknown-linux-gnu-freethreaded/lib/python3.14t/copy.py", line 138, in deepcopy y = copier(memo) File "/root/.cache/bazel/_bazel_root/840503f2165a538d6d79458755b06642/execroot/__main__/bazel-out/k8-opt/bin/tests/lobpcg_test_cpu.runfiles/pypi_matplotlib/site-packages/matplotlib/path.py", line 285, in __deepcopy__ p = copy.deepcopy(super(), memo) ``` --- tests/BUILD | 5 ++++- tests/lobpcg_test.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/BUILD b/tests/BUILD index b501a614da39..23d59e8d549a 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -239,7 +239,10 @@ jax_multiplatform_test( jax_multiplatform_test( name = "lobpcg_test", srcs = ["lobpcg_test.py"], - env = {"LOBPCG_EMIT_DEBUG_PLOTS": "1"}, + # Set LOBPCG_EMIT_DEBUG_PLOTS=1 to debug + # checkLobpcgMonotonicity and checkApproxEigs tests + # using matplotlib plots + # env = {"LOBPCG_EMIT_DEBUG_PLOTS": "1"}, shard_count = { "cpu": 48, "gpu": 48, diff --git a/tests/lobpcg_test.py b/tests/lobpcg_test.py index fc2b0df849d1..76d6006432f4 100644 --- a/tests/lobpcg_test.py +++ b/tests/lobpcg_test.py @@ -272,7 +272,7 @@ def checkLobpcgMonotonicity(self, matrix_name, n, k, m, tol, dtype): self._possibly_plot(A, eigs, X, m, matrix_name) def _possibly_plot(self, A, eigs, X, m, matrix_name): - if not os.getenv('LOBPCG_EMIT_DEBUG_PLOTS'): + if os.getenv('LOBPCG_EMIT_DEBUG_PLOTS', '0') != '1': return if isinstance(A, (np.ndarray, jax.Array)): From 2e16367991bd72f98381d40a77835c4b03c2c3e1 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 2 Apr 2025 08:29:12 -0700 Subject: [PATCH 333/483] Remove the extra stack frame that was introduce in uniform due to dropping the entire function in auto axes. PiperOrigin-RevId: 743148311 --- jax/_src/random.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/jax/_src/random.py b/jax/_src/random.py index 7277ed5aa966..0dcbda7bb717 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -406,15 +406,13 @@ def uniform(key: ArrayLike, raise ValueError(f"dtype argument to `uniform` must be a float dtype, " f"got {dtype}") dtype = dtypes.canonicalize_dtype(dtype) - return _uniform_auto(key, shape, dtype, minval, maxval, out_sharding) - -@partial(jit, static_argnums=(1, 2, 5)) -def _uniform_auto(key, shape, dtype, minval, maxval, out_sharding) -> Array: if out_sharding is None: return _uniform(key, shape, dtype, minval, maxval) - def f(key, minval, maxval): return _uniform(key, shape, dtype, minval, maxval) + def f(k, minv, maxv): + return _uniform(k, shape, dtype, minv, maxv) return auto_axes(f, out_shardings=out_sharding)(key, minval, maxval) +@partial(jit, static_argnums=(1, 2)) def _uniform(key, shape, dtype, minval, maxval) -> Array: _check_shape("uniform", shape) if not jnp.issubdtype(dtype, np.floating): From 3aeabaedea957293ba6a2f777d8cb30a9bf0aed4 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 1 Apr 2025 13:13:25 -0700 Subject: [PATCH 334/483] jnp.isinf & friends: support __jax_array__ --- jax/_src/numpy/ufuncs.py | 10 ++++++---- tests/array_extensibility_test.py | 4 ++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 60e10b3be048..3fe63545e6df 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -36,7 +36,7 @@ from jax._src.numpy import reductions from jax._src.numpy.ufunc_api import ufunc from jax._src.numpy.util import ( - check_arraylike, promote_args, promote_args_inexact, + check_arraylike, ensure_arraylike, promote_args, promote_args_inexact, promote_args_numeric, promote_dtypes_inexact, promote_dtypes_numeric, promote_shapes, _where, check_no_float0s) from jax._src.util import set_module @@ -3500,7 +3500,7 @@ def isinf(x: ArrayLike, /) -> Array: >>> jnp.isinf(x) Array([False, True, False, True, False], dtype=bool) """ - check_arraylike("isinf", x) + x = ensure_arraylike("isinf", x) dtype = dtypes.dtype(x) if dtypes.issubdtype(dtype, np.floating): return lax.eq(lax.abs(x), _constant_like(x, np.inf)) @@ -3513,7 +3513,7 @@ def isinf(x: ArrayLike, /) -> Array: return lax.full_like(x, False, dtype=np.bool_) -def _isposneginf(infinity: float, x: ArrayLike, out) -> Array: +def _isposneginf(infinity: float, x: Array, out) -> Array: if out is not None: raise NotImplementedError("The 'out' argument to isneginf/isposinf is not supported.") dtype = dtypes.dtype(x) @@ -3556,6 +3556,7 @@ def isposinf(x, /, out=None): >>> jnp.isposinf(x) Array([False, False, True, False, False], dtype=bool) """ + x = ensure_arraylike("isposinf", x) return _isposneginf(np.inf, x, out) @@ -3590,6 +3591,7 @@ def isneginf(x, /, out=None): >>> jnp.isneginf(x) Array([ True, False, False, False, False], dtype=bool) """ + x = ensure_arraylike("isneginf", x) return _isposneginf(-np.inf, x, out) @@ -3624,7 +3626,7 @@ def isnan(x: ArrayLike, /) -> Array: >>> jnp.isnan(x) Array([False, False, False, True], dtype=bool) """ - check_arraylike("isnan", x) + x = ensure_arraylike("isnan", x) return lax.ne(x, x) diff --git a/tests/array_extensibility_test.py b/tests/array_extensibility_test.py index 8f5ea33b5894..55089720f520 100644 --- a/tests/array_extensibility_test.py +++ b/tests/array_extensibility_test.py @@ -353,8 +353,8 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: NumPyAPI.sig(jnp.isin, Int[5], Int[10]), NumPyAPI.sig(jnp.isinf, Float[5]), NumPyAPI.sig(jnp.isnan, Float[5]), - # NumPyAPI.sig(jnp.isneginf, Float[5]), - # NumPyAPI.sig(jnp.isposinf, Float[5]), + NumPyAPI.sig(jnp.isneginf, Float[5]), + NumPyAPI.sig(jnp.isposinf, Float[5]), NumPyAPI.sig(jnp.isreal, Float[5]), NumPyAPI.sig(jnp.isrealobj, Float[5]), NumPyAPI.sig(jnp.isscalar, Float[()]), From 2a24b407368b392b384885c727eb0323be01c802 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 2 Apr 2025 16:42:25 +0000 Subject: [PATCH 335/483] Bump actions/cache from 4.2.0 to 4.2.3 Bumps [actions/cache](https://github.com/actions/cache) from 4.2.0 to 4.2.3. - [Release notes](https://github.com/actions/cache/releases) - [Changelog](https://github.com/actions/cache/blob/main/RELEASES.md) - [Commits](https://github.com/actions/cache/compare/1bd1e32a3bdc45362d1e726936510720a7c30a57...5a3ec84eff668545956fd18022155c47e93e2684) --- updated-dependencies: - dependency-name: actions/cache dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- .github/workflows/ci-build.yaml | 2 +- .github/workflows/tsan.yaml | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index c575c84cd422..5576ccd6e745 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -35,7 +35,7 @@ jobs: with: python-version: 3.11 - run: python -m pip install pre-commit - - uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 + - uses: actions/cache@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3 with: path: ~/.cache/pre-commit key: pre-commit-${{ env.pythonLocation }}-${{ hashFiles('.pre-commit-config.yaml', 'setup.py') }} diff --git a/.github/workflows/tsan.yaml b/.github/workflows/tsan.yaml index 4c28608a8257..1bdb36b2cd03 100644 --- a/.github/workflows/tsan.yaml +++ b/.github/workflows/tsan.yaml @@ -74,7 +74,7 @@ jobs: - name: Restore cached TSAN CPython ${{ matrix.python-version }} id: cache-cpython-tsan-restore - uses: actions/cache/restore@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 + uses: actions/cache/restore@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3 with: path: | ./python-tsan.tgz @@ -97,7 +97,7 @@ jobs: - name: Save TSAN CPython ${{ matrix.python-version }} id: cache-cpython-tsan-save if: steps.cache-cpython-tsan-restore.outputs.cache-hit != 'true' - uses: actions/cache/save@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 + uses: actions/cache/save@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3 with: path: | ./python-tsan.tgz @@ -105,7 +105,7 @@ jobs: - name: Restore cached TSAN Numpy id: cache-numpy-tsan-restore - uses: actions/cache/restore@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 + uses: actions/cache/restore@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3 with: path: | ./wheelhouse @@ -160,7 +160,7 @@ jobs: - name: Save TSAN Numpy wheel id: cache-numpy-tsan-save if: steps.cache-numpy-tsan-restore.outputs.cache-hit != 'true' - uses: actions/cache/save@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 + uses: actions/cache/save@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3 with: path: | ./wheelhouse @@ -169,7 +169,7 @@ jobs: - name: Restore cached Scipy if: ${{ matrix.python-version == '3.14' }} id: cache-scipy-restore - uses: actions/cache/restore@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 + uses: actions/cache/restore@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3 with: path: | ./wheelhouse @@ -236,7 +236,7 @@ jobs: - name: Save Scipy wheel id: cache-scipy-save if: ${{ steps.cache-scipy-restore.outputs.cache-hit != 'true' && matrix.python-version == '3.14' }} - uses: actions/cache/save@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 + uses: actions/cache/save@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3 with: path: | ./wheelhouse From 3d70fc819748b2ac78025653e9625660b2664886 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 2 Apr 2025 10:20:32 -0700 Subject: [PATCH 336/483] Add pbroadcast insertion for `psum_p` in the traceable. This effectively replaces `psum_p` with `psum2_p` if `varying_axes_in_types` is on. psum_p can be replaced with psum2_p in follow up CLs Also populate the aval of `ShardMapTracer` with `vma` PiperOrigin-RevId: 743188081 --- jax/_src/lax/parallel.py | 22 ++++++++++++++++++++-- jax/experimental/shard_map.py | 9 +++++++-- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 39b6c68679ca..5e02318da441 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -141,10 +141,28 @@ def pos_reduce(x): size = math.prod([core.get_axis_env().axis_size(name) for name in named_axes]) out_flat = tuple(lax._const(leaf, size) * pos_reduce(leaf) for leaf in leaves) else: - out_flat = psum_p.bind( - *leaves, axes=tuple(axis_name), axis_index_groups=axis_index_groups) + if config.varying_axes_in_types.value: + out_flat = bind_psum2_p(leaves, axes=tuple(axis_name), + axis_index_groups=axis_index_groups) + else: + out_flat = psum_p.bind( + *leaves, axes=tuple(axis_name), axis_index_groups=axis_index_groups) return tree_util.tree_unflatten(treedef, out_flat) +def bind_psum2_p(leaves, *, axes, axis_index_groups): + if axis_index_groups is not None: + raise NotImplementedError + + from jax.experimental.shard_map import psum2_p, pbroadcast + axes_ = frozenset(axes) + args_ = [] + for x in leaves: + in_vma = core.get_aval(x).vma + args_.append(pbroadcast(x, tuple(pbroadcast_names)) + if (pbroadcast_names := axes_ - in_vma) else x) + return psum2_p.bind(*args_, axes=axes, axis_index_groups=axis_index_groups) + + def pmean(x, axis_name, *, axis_index_groups=None): """Compute an all-reduce mean on ``x`` over the pmapped axis ``axis_name``. diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 4b9daf170dce..ef3751c96901 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -1018,7 +1018,10 @@ def aval(self): new_sharding = NamedSharding( _as_manual_mesh(self._trace.mesh, self._trace.auto), out.sharding.spec) # pytype: disable=attribute-error - return out.update(sharding=new_sharding) + manual_axes = set(self._trace.mesh.axis_names) - self._trace.auto + vma = (frozenset(manual_axes - self.rep) + if config.varying_axes_in_types.value else frozenset()) + return out.update(sharding=new_sharding, vma=vma) def to_concrete_value(self): if self.rep == set(self._trace.mesh.axis_names): @@ -1111,7 +1114,9 @@ def _pbroadcast_abstract_eval(*args, axes, axis_index_groups): f"over axis name {axes}. Please open an issue at " "https://github.com/jax-ml/jax/issues, and as a temporary " "workaround pass the check_rep=False argument to shard_map") - return [a.update(vma=a.vma.union(frozenset(axes))) for a in args] + sharding = NamedSharding(get_abstract_mesh(), P()) + return [a.update(sharding=sharding, vma=a.vma.union(frozenset(axes))) + for a in args] pbroadcast_p.def_abstract_eval(_pbroadcast_abstract_eval) mlir.register_lowering(pbroadcast_p, lambda ctx, *x, axes, axis_index_groups: x) From 92f7aeab48f144ba059cac29406b267e4030fe31 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Wed, 2 Apr 2025 12:09:48 -0700 Subject: [PATCH 337/483] Add simple vmap support for lax.ragged_all_to_all. PiperOrigin-RevId: 743230485 --- jax/_src/lax/parallel.py | 27 +++++ tests/ragged_collective_test.py | 194 ++++++++++++++++++++++++++++++++ 2 files changed, 221 insertions(+) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 5e02318da441..e533672a1d9b 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -1374,11 +1374,38 @@ def _ragged_all_to_all_transpose( output_t = jax.numpy.where(mask, 0, t) return [operand_t, output_t] + [None] * 4 +def _ragged_all_to_all_batched_collective(axis_data, vals_in, dims_in, + axis_name, axis_index_groups): + del axis_data + if axis_index_groups: + raise NotImplementedError("Please open a feature request!") + + operand, output, input_offsets, send_sizes, output_offsets, recv_sizes = vals_in + operand_dim, output_dim, input_offsets_dim, send_sizes_dim, output_offsets_dim, recv_sizes_dim = dims_in + if not (operand.shape[operand_dim] == output.shape[output_dim] == input_offsets.shape[input_offsets_dim] == send_sizes.shape[send_sizes_dim] == output_offsets.shape[output_offsets_dim] == recv_sizes.shape[recv_sizes_dim]): + raise ValueError("all operands must have the same batch sizes") + + sliced_results = [] + for i in range(operand.shape[operand_dim]): + sliced_operand = slicing.slice_in_dim(operand, start_index=i, limit_index=i+1, axis=operand_dim).flatten() + sliced_output = slicing.slice_in_dim(output, start_index=i, limit_index=i+1, axis=output_dim).flatten() + sliced_input_offsets = slicing.slice_in_dim(input_offsets, start_index=i, limit_index=i+1, axis=input_offsets_dim).flatten() + sliced_send_sizes = slicing.slice_in_dim(send_sizes, start_index=i, limit_index=i+1, axis=send_sizes_dim).flatten() + sliced_output_offsets = slicing.slice_in_dim(output_offsets, start_index=i, limit_index=i+1, axis=output_offsets_dim).flatten() + sliced_recv_sizes = slicing.slice_in_dim(recv_sizes, start_index=i, limit_index=i+1, axis=recv_sizes_dim).flatten() + sliced_result = ragged_all_to_all(sliced_operand, sliced_output, sliced_input_offsets, sliced_send_sizes, sliced_output_offsets, sliced_recv_sizes, axis_name=axis_name, axis_index_groups=axis_index_groups) + sliced_result = lax.expand_dims(sliced_result, dimensions=(output_dim,)) + sliced_results.append(sliced_result) + + concat_result = lax.concatenate(sliced_results, dimension=output_dim) + return concat_result, operand_dim + ragged_all_to_all_p = core.Primitive('ragged_all_to_all') ragged_all_to_all_p.def_effectful_abstract_eval(_ragged_all_to_all_effectful_abstract_eval) ad.primitive_jvps[ragged_all_to_all_p] = _ragged_all_to_all_jvp ad.primitive_transposes[ragged_all_to_all_p] = _ragged_all_to_all_transpose mlir.register_lowering(ragged_all_to_all_p, _ragged_all_to_all_lowering) +batching.fancy_primitive_batchers[ragged_all_to_all_p] = _ragged_all_to_all_batched_collective batching.skippable_batchers[ragged_all_to_all_p] = partial(_names_in_param, 'axis_name') def insert_collective_pbroadcast(axis_name, x): diff --git a/tests/ragged_collective_test.py b/tests/ragged_collective_test.py index 844892adc052..1dd6ef657561 100644 --- a/tests/ragged_collective_test.py +++ b/tests/ragged_collective_test.py @@ -21,6 +21,7 @@ import jax import jax.ad_checkpoint from jax import lax +from jax import vmap from jax.sharding import PartitionSpec as P from jax._src import config from jax._src import test_util as jtu @@ -381,6 +382,199 @@ def fwd( c, jnp.array([[0, 0, 1, 0], [0, 2, 3, 4]], dtype=jnp.int32) ) + @parameterized.named_parameters( + dict( + testcase_name='_batch_0_data_shard_axis_0_input_0', + axis_name='x', + vmap_axis_name='y', + mesh_axes=dict(x=2, y=2), + vmap_batch_axis=0, + data_shard_axis=0, + input_config=0, + ), + dict( + testcase_name='_batch_0_data_shard_axis_1_input_0', + axis_name='x', + vmap_axis_name='y', + mesh_axes=dict(x=2, y=2), + vmap_batch_axis=0, + data_shard_axis=1, + input_config=0, + ), + dict( + testcase_name='_batch_1_data_shard_axis_0_input_1', + axis_name='x', + vmap_axis_name='y', + mesh_axes=dict(x=2, y=2), + vmap_batch_axis=1, + data_shard_axis=0, + input_config=1, + ), + dict( + testcase_name='_batch_1_data_shard_axis_1_input_1', + axis_name='x', + vmap_axis_name='y', + mesh_axes=dict(x=2, y=2), + vmap_batch_axis=1, + data_shard_axis=1, + input_config=1, + ), + ) + def test_ragged_all_to_all_vmap( + self, + axis_name, + vmap_axis_name, + mesh_axes, + vmap_batch_axis, + data_shard_axis, + input_config, + ): + device_type = jax.devices()[0].platform + if device_type == 'tpu' and jtu.get_tpu_version() < 4: + raise unittest.SkipTest( + 'UNSUPPORTED: HLO opcode `ragged-all-to-all` is not supported by TPU' + f' v{jtu.get_tpu_version()}' + ) + mesh = jtu.create_mesh(tuple(mesh_axes.values()), tuple(mesh_axes.keys())) + + def get_data_sharding(axis): + if axis == 0: + return P(axis_name, None, None) + elif axis == 1: + return P(None, axis_name, None) + else: + raise ValueError("Invalid data_shard_axis") + + data_sharding = get_data_sharding(data_shard_axis) + + if input_config == 0: + operand_data = jnp.array([[[1, 2, 3], [4, 5, 6]], + [[1, 2, 3], [4, 5, 6]]], dtype=jnp.int32) + send_sizes_data = jnp.array([[[1, 2], [1, 1]], + [[1, 2], [1, 1]]], dtype=jnp.int32) + output_offsets_data = jnp.array([[[0, 0], [1, 2]], + [[0, 0], [1, 2]]], dtype=jnp.int32) + recv_sizes_data = jnp.array([[[1, 1], [2, 1]], + [[1, 1], [2, 1]]], dtype=jnp.int32) + elif input_config == 1: + operand_data = jnp.array([[[1, 2, 3], [1, 2, 3]], + [[4, 5, 6], [4, 5, 6]]], dtype=jnp.int32) + send_sizes_data = jnp.array([[[1, 2], [1, 2]], + [[1, 1], [1, 1]]], dtype=jnp.int32) + output_offsets_data = jnp.array([[[0, 0], [0, 0]], + [[1, 2], [1, 2]]], dtype=jnp.int32) + recv_sizes_data = jnp.array([[[1, 1], [1, 1]], + [[2, 1], [2, 1]]], dtype=jnp.int32) + else: + raise ValueError("Invalid input config") + + output_data = jnp.zeros((2, 2, 4), dtype=jnp.int32) + input_offsets_data = jnp.array([[[0, 1], [0, 1]], + [[0, 1], [0, 1]]], dtype=jnp.int32) + + operand = jax.device_put(operand_data, jax.sharding.NamedSharding(mesh, data_sharding)) + output = jax.device_put(output_data, jax.sharding.NamedSharding(mesh, data_sharding)) + input_offsets = jax.device_put(input_offsets_data, jax.sharding.NamedSharding(mesh, data_sharding)) + send_sizes = jax.device_put(send_sizes_data, jax.sharding.NamedSharding(mesh, data_sharding)) + output_offsets = jax.device_put(output_offsets_data, jax.sharding.NamedSharding(mesh, data_sharding)) + recv_sizes = jax.device_put(recv_sizes_data, jax.sharding.NamedSharding(mesh, data_sharding)) + + @partial( + shard_map, + mesh=mesh, + in_specs=( + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + ), + out_specs=P(axis_name), + check_rep=False, + ) + def fwd( + operand, output, input_offsets, send_sizes, output_offsets, recv_sizes + ): + return lax.ragged_all_to_all( + operand=operand.reshape(operand.shape[1:]), + output=output.reshape(output.shape[1:]), + input_offsets=input_offsets.reshape(input_offsets.shape[1:]), + send_sizes=send_sizes.reshape(send_sizes.shape[1:]), + output_offsets=output_offsets.reshape(output_offsets.shape[1:]), + recv_sizes=recv_sizes.reshape(recv_sizes.shape[1:]), + axis_name=axis_name, + ) + + res = vmap( + fwd, in_axes=vmap_batch_axis, out_axes=0, axis_name=vmap_axis_name + )( + operand, output, input_offsets, send_sizes, output_offsets, recv_sizes + ).reshape( + (2, 2, 4) + ) + expected_res = jnp.array([[[1, 4, 0, 0], [2, 3, 5, 0]], + [[1, 4, 0, 0], [2, 3, 5, 0]]], dtype=jnp.int32) + self.assertAllClose(res, expected_res) + + def test_ragged_all_to_all_vmap_unsupported_axis_index_groups(self): + device_type = jax.devices()[0].platform + if device_type == 'tpu' and jtu.get_tpu_version() < 4: + raise unittest.SkipTest( + 'UNSUPPORTED: HLO opcode `ragged-all-to-all` is not supported by TPU' + f' v{jtu.get_tpu_version()}' + ) + + axis_name = 'x' + mesh_axes = dict(x=2) + mesh = jtu.create_mesh(tuple(mesh_axes.values()), tuple(mesh_axes.keys())) + data_sharding = P(axis_name, None, None) + operand_data = jnp.zeros((2, 2, 3), dtype=jnp.int32) + output_data = jnp.zeros((2, 2, 4), dtype=jnp.int32) + input_offsets_data = jnp.zeros((2, 2, 2), dtype=jnp.int32) + send_sizes_data = jnp.zeros((2, 2, 2), dtype=jnp.int32) + output_offsets_data = jnp.zeros((2, 2, 2), dtype=jnp.int32) + recv_sizes_data = jnp.zeros((2, 2, 2), dtype=jnp.int32) + + operand = jax.device_put(operand_data, jax.sharding.NamedSharding(mesh, data_sharding)) + output = jax.device_put(output_data, jax.sharding.NamedSharding(mesh, data_sharding)) + input_offsets = jax.device_put(input_offsets_data, jax.sharding.NamedSharding(mesh, data_sharding)) + send_sizes = jax.device_put(send_sizes_data, jax.sharding.NamedSharding(mesh, data_sharding)) + output_offsets = jax.device_put(output_offsets_data, jax.sharding.NamedSharding(mesh, data_sharding)) + recv_sizes = jax.device_put(recv_sizes_data, jax.sharding.NamedSharding(mesh, data_sharding)) + + @partial( + shard_map, + mesh=mesh, + in_specs=( + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + ), + out_specs=P(axis_name), + check_rep=False, + ) + def fwd( + operand, output, input_offsets, send_sizes, output_offsets, recv_sizes + ): + return lax.ragged_all_to_all( + operand=operand.reshape(operand.shape[1:]), + output=output.reshape(output.shape[1:]), + input_offsets=input_offsets.reshape(input_offsets.shape[1:]), + send_sizes=send_sizes.reshape(send_sizes.shape[1:]), + output_offsets=output_offsets.reshape(output_offsets.shape[1:]), + recv_sizes=recv_sizes.reshape(recv_sizes.shape[1:]), + axis_name=axis_name, + axis_index_groups=[[0, 1]], + ) + + with self.assertRaisesWithLiteralMatch( + NotImplementedError, 'Please open a feature request!'): + vmap(fwd, in_axes=0, out_axes=0, axis_name='b')(operand, output, input_offsets, send_sizes, output_offsets, recv_sizes) + def test_ragged_all_to_all_errors(self): operand = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=jnp.float32) output = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=jnp.float32) From a442fecca8b75f0803a27601046ca66d5cba134c Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Wed, 2 Apr 2025 15:32:15 -0400 Subject: [PATCH 338/483] Fix custom_transpose when composed with custom_jvp and use_direct_linearize=True. --- jax/_src/custom_transpose.py | 22 +++++++++++++--------- tests/api_test.py | 22 ++++++++++++++++++++++ 2 files changed, 35 insertions(+), 9 deletions(-) diff --git a/jax/_src/custom_transpose.py b/jax/_src/custom_transpose.py index 5e87fdb203c9..21e607b5bff2 100644 --- a/jax/_src/custom_transpose.py +++ b/jax/_src/custom_transpose.py @@ -177,15 +177,19 @@ def bind_with_trace(self, trace, call_args, params): # TODO(frostig,mattjj): consider keeping `call` as a named parameter # instead of following this "call primitive" convention. def get_bind_params(self, params): - assert 'call_jaxpr' in params - assert 'transpose_jaxpr_thunk' in params - new_params: dict[str, Any] = dict(params) - new_params['transpose'] = make_transpose_from_thunk( - new_params.pop('transpose_jaxpr_thunk'), - new_params['lin_tree']) - call_jaxpr: core.ClosedJaxpr = new_params.pop('call_jaxpr') - call = lu.wrap_init(core.jaxpr_as_fun(call_jaxpr), - debug_info=call_jaxpr.jaxpr.debug_info) + if 'call_jaxpr' in params: + assert 'transpose_jaxpr_thunk' in params + new_params: dict[str, Any] = dict(params) + new_params['transpose'] = make_transpose_from_thunk( + new_params.pop('transpose_jaxpr_thunk'), + new_params['lin_tree']) + call_jaxpr: core.ClosedJaxpr = new_params.pop('call_jaxpr') + call = lu.wrap_init(core.jaxpr_as_fun(call_jaxpr), + debug_info=call_jaxpr.jaxpr.debug_info) + else: + assert 'transpose' in params + new_params: dict[str, Any] = dict(params) + call = new_params.pop("call") return [call], new_params diff --git a/tests/api_test.py b/tests/api_test.py index 032c09910fd9..83264f10e033 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -10387,6 +10387,28 @@ def cond_wrap(f): self.assertAllClose(f_(x), g_(x)) self.assertAllClose(f_t(x), g_t(x)) + def test_compose_custom_jvp(self): + @jax.custom_jvp + def f(x): + return jnp.sin(x) + + @f.defjvp + def f_jvp(primals, tangents): + x, = primals + dx, = tangents + return f(x), g(x, dx) + + @custom_transpose + def g(x, dx): + return jnp.cos(x) * dx + + @g.def_transpose + def gt(x, t): + return jnp.cos(x) * t + + with config.use_direct_linearize(True): + self.assertAllClose(jax.grad(f)(0.5), jnp.cos(0.5)) + class CustomDceTest(jtu.JaxTestCase): From 7f4e8c56fe0b47778ad3795545df2e946a4b4a57 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 2 Apr 2025 13:17:59 -0700 Subject: [PATCH 339/483] jnp.concat and friends: support __jax_array__ --- jax/_src/numpy/lax_numpy.py | 7 ++++--- jax/_src/numpy/util.py | 2 +- tests/array_extensibility_test.py | 12 ++++++------ 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index ae32703e7113..43b8923246d4 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -4466,7 +4466,7 @@ def stack(arrays: np.ndarray | Array | Sequence[ArrayLike], axis = _canonicalize_axis(axis, arrays.ndim) return concatenate(expand_dims(arrays, axis + 1), axis=axis, dtype=dtype) else: - util.check_arraylike("stack", *arrays) + arrays = util.ensure_arraylike_tuple("stack", arrays) shape0 = np.shape(arrays[0]) axis = _canonicalize_axis(axis, len(shape0) + 1) new_arrays = [] @@ -4555,7 +4555,7 @@ def tile(A: ArrayLike, reps: DimSize | Sequence[DimSize]) -> Array: [1, 2], [3, 4]], dtype=int32) """ - util.check_arraylike("tile", A) + A = util.ensure_arraylike("tile", A) try: iter(reps) # type: ignore[arg-type] except TypeError: @@ -4628,7 +4628,7 @@ def concatenate(arrays: np.ndarray | Array | Sequence[ArrayLike], """ if isinstance(arrays, (np.ndarray, Array)): return _concatenate_array(arrays, axis, dtype=dtype) - util.check_arraylike("concatenate", *arrays) + arrays = util.ensure_arraylike_tuple("concatenate", arrays) if not len(arrays): raise ValueError("Need at least one array to concatenate.") if axis is None: @@ -4870,6 +4870,7 @@ def dstack(tup: np.ndarray | Array | Sequence[ArrayLike], else: # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. util.check_arraylike("dstack", *tup, emit_warning=True) + tup = util.ensure_arraylike_tuple("dstack", tup) arrs = [atleast_3d(m) for m in tup] return concatenate(arrs, axis=2, dtype=dtype) diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index e0e20d443e02..9d56267c4b61 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -159,7 +159,7 @@ def ensure_arraylike(fun_name: str, /, *args: Any) -> Array | tuple[Array, ...]: return tuple(_arraylike_asarray(arg) for arg in args) # pytype: disable=bad-return-type -def ensure_arraylike_tuple(fun_name: str, tup: tuple[Any, ...]) -> tuple[Array, ...]: +def ensure_arraylike_tuple(fun_name: str, tup: Sequence[Any]) -> tuple[Array, ...]: """Check that argument elements are arraylike and convert to a tuple of arrays. This is useful because ensure_arraylike with a single argument returns a single array. diff --git a/tests/array_extensibility_test.py b/tests/array_extensibility_test.py index 55089720f520..730001abef76 100644 --- a/tests/array_extensibility_test.py +++ b/tests/array_extensibility_test.py @@ -267,10 +267,10 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: NumPyAPI.sig(jnp.ceil, Float[5]), # NumPyAPI.sig(jnp.choose, [int, float], [(3,), (10,)]), NumPyAPI.sig(jnp.clip, Float[5]), - # NumPyAPI.sig(jnp.column_stack, [float], [(3, 10)]), + NumPyAPI.sig(jnp.column_stack, [Float[5], Float[5], Float[5]]), NumPyAPI.sig(jnp.compress, Float[10], Bool[10]), - # NumPyAPI.sig(jnp.concat, [Float[5], Float[5]]), - # NumPyAPI.sig(jnp.concatenate, [Float[5], Float[5]]), + NumPyAPI.sig(jnp.concat, [Float[5], Float[5]]), + NumPyAPI.sig(jnp.concatenate, [Float[5], Float[5]]), NumPyAPI.sig(jnp.conj, Float[5]), NumPyAPI.sig(jnp.conjugate, Float[5]), NumPyAPI.sig(jnp.convolve, Float[7], Float[3]), @@ -300,7 +300,7 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: NumPyAPI.sig(jnp.divmod, Float[5], Float[5]), NumPyAPI.sig(jnp.dot, Float[5], Float[5]), NumPyAPI.sig(jnp.dsplit, Float[3, 5, 6], indices_or_sections=2), - # NumPyAPI.sig(jnp.dstack, Float[3, 5]), + NumPyAPI.sig(jnp.dstack, [Float[3, 5, 1], Float[3, 5, 3]]), NumPyAPI.sig(jnp.ediff1d, Float[5]), NumPyAPI.sig(jnp.empty_like, Float[5]), NumPyAPI.sig(jnp.equal, Float[5], Float[5]), @@ -469,7 +469,7 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: NumPyAPI.sig(jnp.sqrt, Float[5]), NumPyAPI.sig(jnp.square, Float[5]), NumPyAPI.sig(jnp.squeeze, Float[5]), - # NumPyAPI.sig(jnp.stack, [Float[2, 3], Float[2, 3]], axis=1), + NumPyAPI.sig(jnp.stack, [Float[2, 3], Float[2, 3]], axis=1), NumPyAPI.sig(jnp.std, Float[5]), NumPyAPI.sig(jnp.subtract, Float[5], Float[5]), NumPyAPI.sig(jnp.sum, Float[5]), @@ -479,7 +479,7 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: NumPyAPI.sig(jnp.tan, Float[5]), NumPyAPI.sig(jnp.tanh, Float[5]), NumPyAPI.sig(jnp.tensordot, Float[2, 3, 4], Float[3, 4, 5]), - # NumPyAPI.sig(jnp.tile, Float[5], reps=(2,)), + NumPyAPI.sig(jnp.tile, Float[5], reps=(2,)), NumPyAPI.sig(jnp.trace, Float[5, 5]), NumPyAPI.sig(jnp.transpose, Float[5, 6]), NumPyAPI.sig(jnp.trapezoid, Float[5]), From a2d62e2d3a332b1d67e0f4ef7a23375182f1646e Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 2 Apr 2025 13:46:07 -0700 Subject: [PATCH 340/483] [array_api] update array_api_version to 2024.12 --- .github/workflows/jax-array-api.yml | 2 +- jax/_src/numpy/array_api_metadata.py | 2 +- tests/array_api_skips.txt | 55 ++++++++++++++++++++++++++++ 3 files changed, 57 insertions(+), 2 deletions(-) diff --git a/.github/workflows/jax-array-api.yml b/.github/workflows/jax-array-api.yml index c91ab6b8b7da..7df4228dd2a3 100644 --- a/.github/workflows/jax-array-api.yml +++ b/.github/workflows/jax-array-api.yml @@ -28,7 +28,7 @@ jobs: with: repository: data-apis/array-api-tests # TODO(jakevdp) update this to a stable release/tag when available. - ref: '0b89c5268e4e4a352223a487b8f63dbd1023872d' # Latest commit as of 2025-03-04 + ref: 'c48410f96fc58e02eea844e6b7f6cc01680f77ce' # Latest commit as of 2025-04-02 submodules: 'true' path: 'array-api-tests' - name: Set up Python ${{ matrix.python-version }} diff --git a/jax/_src/numpy/array_api_metadata.py b/jax/_src/numpy/array_api_metadata.py index 5267e51215ee..d634a2856a1b 100644 --- a/jax/_src/numpy/array_api_metadata.py +++ b/jax/_src/numpy/array_api_metadata.py @@ -29,7 +29,7 @@ from jax._src import xla_bridge as xb -__array_api_version__ = '2023.12' +__array_api_version__ = '2024.12' def __array_namespace__(self, *, api_version: None | str = None) -> ModuleType: diff --git a/tests/array_api_skips.txt b/tests/array_api_skips.txt index 2f8d4d1c666f..7534cf6f8acd 100644 --- a/tests/array_api_skips.txt +++ b/tests/array_api_skips.txt @@ -10,6 +10,24 @@ array_api_tests/test_creation_functions.py::test_asarray_arrays # Returns wrong zero sign array_api_tests/test_special_cases.py::test_unary[sign((x_i is -0 or x_i == +0)) -> 0] +array_api_tests/test_special_cases.py::test_iop[__imod__(x1_i is +0 and x2_i < 0) -> -0] +array_api_tests/test_special_cases.py::test_iop[__imod__(x1_i is -0 and x2_i > 0) -> +0] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -0 and x2_i > 0) -> -0] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -0 and x2_i < 0) -> +0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -0 and x2_i > 0) -> -0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -0 and x2_i < 0) -> +0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] +array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is -0 and x2_i > 0) -> +0] +array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is -0 and x2_i > 0) -> +0] +array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i < 0 and x2_i is -infinity) -> +0] +array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is +0 and x2_i < 0) -> -0] +array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is +0 and x2_i < 0) -> -0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is -infinity) -> +0] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -0 and x2_i > 0) -> -0] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -0 and x2_i < 0) -> +0] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is -infinity) -> +0] + # Returns int32 when int64 is expected array_api_tests/test_searching_functions.py::test_searchsorted @@ -19,3 +37,40 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_clip # JAX raises a ValueError rather than the expected IndexError for out-of-bound axis array_api_tests/test_manipulation_functions.py::test_expand_dims + +# Doesn't promote to uint64 +array_api_tests/test_statistical_functions.py::test_cumulative_prod + +# TODO(jakevdp): fix the following failures: + +# Returns NaN rather than inf +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i < 0 and x2_i is +0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i > 0 and x2_i is +0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i < 0 and x2_i is -0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i > 0 and x2_i is -0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i < 0 and x2_i is +0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i < 0 and x2_i is -0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i > 0 and x2_i is +0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i > 0 and x2_i is -0) -> -infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i > 0 and x2_i is +0) -> +infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i > 0 and x2_i is -0) -> -infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i < 0 and x2_i is -0) -> +infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i < 0 and x2_i is +0) -> -infinity] + +# Returns -1.0 rather than 0.0 +array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] \ No newline at end of file From bff0fa18adaf0544d6465b62f4e57e8b83b4e614 Mon Sep 17 00:00:00 2001 From: Zac Mustin Date: Wed, 2 Apr 2025 14:07:29 -0700 Subject: [PATCH 341/483] Support `conv` `unfused_flops` in roofline. Since calculating flops is non-trivial, we don't test all the cases currently tested by `test_conv_general_dilated_unfused_hbm_bytes`. Instead, we test behaviors more directly. PiperOrigin-RevId: 743272840 --- jax/experimental/roofline/rooflines.py | 203 ++++++++++++++++++++++++- tests/roofline_test.py | 106 ++++++++++++- 2 files changed, 298 insertions(+), 11 deletions(-) diff --git a/jax/experimental/roofline/rooflines.py b/jax/experimental/roofline/rooflines.py index bc8d65e966dd..63a2d7cc4698 100644 --- a/jax/experimental/roofline/rooflines.py +++ b/jax/experimental/roofline/rooflines.py @@ -14,6 +14,7 @@ from collections import defaultdict from dataclasses import replace import itertools as it +from typing import Sequence import numpy as np from jax._src import ad_util @@ -35,7 +36,7 @@ from jax.experimental import roofline from jax.experimental import shard_map - +# One FMA (Fused Multiply Add) takes 2 flops to compute. _FMA_FLOPS_FACTOR = 2 for prim in it.chain( @@ -179,16 +180,208 @@ def _dot_general_roofline( unfused_hbm_bytes=hbm_bytes, ) + +def _get_spatial_valid_position_count_for_one_dim( + window_dim_stride: int, + base_dilation: int, + window_dilation: int, + kernel_limit: int, + input_limit: int, + output_limit: int, + padding: tuple[int, int], +) -> int: + """Gets the valid position count for conv for a single spatial dimension. + + Args: + window_dim_stride: The stride of the window along this dimension. + base_dilation: The base dilation factor along this dimension. + window_dilation: The window dilation factor along this dimension. + kernel_limit: The size of the kernel along this dimension. + input_limit: The size of the input along this dimension. + output_limit: The size of the output along this dimension. + padding: The padding applied to the input along this dimension. + """ + padding_low = padding[0] + padding_high = padding[1] + + # These two conditions will create an N^2 iteration pattern with only N + # valid elements. This is a performance optimization and produces the same + # result as the whole loop. + if ( + input_limit == output_limit + and kernel_limit == output_limit + and input_limit == base_dilation + and window_dilation == 1 + and max(1, input_limit - 1) == window_dim_stride + and padding_low == 0 + and padding_high == 0 + ): + return input_limit + + if ( + input_limit == 1 + and kernel_limit == output_limit + and window_dilation == 1 + and base_dilation == 1 + and window_dim_stride == 1 + and padding_low == output_limit - 1 + and padding_high == output_limit - 1 + ): + return output_limit + + valid_position_count = 0 + # Loop over each point in the kernel + for kernel_idx in range(kernel_limit): + + # Skip loop for trivial stride and base_dilation + if window_dim_stride == 1 and base_dilation == 1: + undilated_index_base = padding_low - kernel_idx * window_dilation + upper_limit = min( + input_limit + undilated_index_base, + output_limit, + ) + lower_limit = max(0, undilated_index_base) + + valid_position_count += max(upper_limit - lower_limit, 0) + continue + + # Loop over each point in the output + for output_idx in range(output_limit): + # Calculate lhs (input) index without taking base dilation into account + undilated_index = ( + output_idx * window_dim_stride + - padding_low + + kernel_idx * window_dilation + ) + # Calculate the actual lhs (input) index after dilation + lhs_spatial_index = int(undilated_index / base_dilation) + + # Skip if the lhs (input) index is to be dilated. + if undilated_index != lhs_spatial_index * base_dilation: + continue + # Skip if input index is not in bound. + if lhs_spatial_index < 0 or lhs_spatial_index >= input_limit: + continue + + valid_position_count += 1 + return valid_position_count + + +def _get_spatial_valid_position_count( + dnums: convolution.ConvDimensionNumbers, + lhs: roofline.RooflineShape, + rhs: roofline.RooflineShape, + out: roofline.RooflineShape, + window_strides: Sequence[int], + padding: Sequence[tuple[int, int]], + lhs_dilation: Sequence[int], + rhs_dilation: Sequence[int], +) -> int: + """Gets the number of valid spatial positions for conv_general_dilated. + + Args: + dnums: The dimension numbers for the convolution. + lhs: The shape of the left-hand side of the convolution. + rhs: The shape of the right-hand side of the convolution. + out: The shape of the output of the convolution. + window_strides: The stride of the window along each spatial dimension. + padding: The padding applied to the input along each spatial dimension. + lhs_dilation: The dilation factor for the left-hand side along each spatial + dimension. + rhs_dilation: The dilation factor for the right-hand side along each spatial + dimension. + """ + input_spatial_dims, kernel_spatial_dims, out_spatial_dims = ( + dnums.lhs_spec[2:], + dnums.rhs_spec[2:], + dnums.out_spec[2:], + ) + + valid_position_counts = 1 + # Loop over each spatial dimension and determine how many valid positions + # there are for each dimension. + for d in range(len(input_spatial_dims)): + valid_position_counts *= _get_spatial_valid_position_count_for_one_dim( + window_dim_stride=window_strides[d], + base_dilation=lhs_dilation[d], + window_dilation=rhs_dilation[d], + kernel_limit=rhs.shape[kernel_spatial_dims[d]], + input_limit=lhs.shape[input_spatial_dims[d]], + output_limit=out.shape[out_spatial_dims[d]], + padding=padding[d], + ) + + return valid_position_counts + + +def _calculate_conv_flops( + lhs: roofline.RooflineShape, + rhs: roofline.RooflineShape, + out: roofline.RooflineShape, + window_strides: Sequence[int], + padding: Sequence[tuple[int, int]], + lhs_dilation: Sequence[int], + rhs_dilation: Sequence[int], + dimension_numbers: convolution.ConvGeneralDilatedDimensionNumbers, + batch_group_count: int, +) -> int: + """Calculates roofline unfused flops for Jax's conv_general_dilated primitive. + + See `jax.lax.conv_general_dilated` for details on the arguments. + """ + dnums = convolution.conv_dimension_numbers( + lhs.shape, rhs.shape, dimension_numbers + ) + + spatial_valid_position_counts = _get_spatial_valid_position_count( + dnums, lhs, rhs, out, window_strides, padding, lhs_dilation, rhs_dilation + ) + + batch = lhs.shape[dnums.lhs_spec[0]] + num_output_features = out.shape[dnums.out_spec[1]] + num_input_features = rhs.shape[dnums.rhs_spec[1]] + num_output_batch = batch / batch_group_count + + non_spatial_dims_factor = ( + num_input_features * num_output_features * num_output_batch + ) + + fma_count = non_spatial_dims_factor * spatial_valid_position_counts + flops = fma_count * _FMA_FLOPS_FACTOR + return int(flops) + + @roofline.register_roofline(convolution.conv_general_dilated_p) def _conv_general_dilated_roofline( - ctx: roofline.RooflineRuleContext, - *args, - **kw, + ctx: roofline.RooflineRuleContext, + *args, + window_strides: Sequence[int], + padding: Sequence[tuple[int, int]], + lhs_dilation: Sequence[int], + rhs_dilation: Sequence[int], + dimension_numbers: convolution.ConvGeneralDilatedDimensionNumbers, + batch_group_count: int, + **kw, ) -> roofline.RooflineResult: + """Roofline for Jax's conv_general_dilated primitive. + + See `jax.lax.conv_general_dilated` for details on the arguments. + """ lhs, rhs = (roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in) out = roofline.RooflineShape.from_aval(ctx.avals_out[0]) - # TODO(b/394648206): support computing unfused_flops for conv. + return roofline.RooflineResult( + unfused_flops=_calculate_conv_flops( + lhs, + rhs, + out, + window_strides, + padding, + lhs_dilation, + rhs_dilation, + dimension_numbers, + batch_group_count, + ), unfused_hbm_bytes=( lhs.dtype.itemsize * lhs.size + rhs.dtype.itemsize * rhs.size diff --git a/tests/roofline_test.py b/tests/roofline_test.py index 98f6176c22a0..140beb3c6e71 100644 --- a/tests/roofline_test.py +++ b/tests/roofline_test.py @@ -28,6 +28,8 @@ jax.config.parse_flags_with_absl() jtu.request_cpu_devices(8) +_VERY_LARGE_NUMBER = 512 * 1024 + def create_inputs( *shardings: P, @@ -628,7 +630,6 @@ def test_conv_general_dilated_unfused_hbm_bytes( expected_unfused_hbm_bytes = self._bytes_per_word * ( expected_input_size + expected_kernel_size + expected_output_size ) - # TODO(b/394648206): add subtest for unfused_flops once they are supported. self.assertEqual(result.unfused_hbm_bytes, expected_unfused_hbm_bytes) @jtu.parameterized.named_parameters( @@ -641,10 +642,10 @@ def test_conv_general_dilated_unfused_hbm_bytes( padding="SAME_LOWER", ), ) - def test_conv_general_dilated_padding_string_unfused_hbm_bytes( + def test_conv_general_dilated_padding_string( self, padding: str ): - input_data = jnp.zeros((1, 1, 10, 20), dtype=int) + input_data = jnp.zeros((1, 1, 3, 3), dtype=int) kernel_data = jnp.ones((1, 1, 3, 3), dtype=int) conv = lambda a, b: lax.conv_general_dilated( lhs=a, rhs=b, window_strides=(1, 1), padding=padding @@ -652,10 +653,11 @@ def test_conv_general_dilated_padding_string_unfused_hbm_bytes( _, result = roofline.roofline(conv)(input_data, kernel_data) - expected_input_size = 1 * 1 * 10 * 20 + # Test hbm bytes. + expected_input_size = 1 * 1 * 3 * 3 expected_kernel_size = 1 * 1 * 3 * 3 # Because of same{_lower} padding, output shape should equal to input shape. - # This may not be true for other `{feature, batch}`_group_count`s.c + # This may not be true for other `{feature, batch}`_group_count`s. expected_output_size = expected_input_size # Bytes accessed is sum of inputs and output. expected_unfused_hbm_bytes = self._bytes_per_word * ( @@ -663,7 +665,21 @@ def test_conv_general_dilated_padding_string_unfused_hbm_bytes( ) self.assertEqual(result.unfused_hbm_bytes, expected_unfused_hbm_bytes) - def test_conv_general_dilated_padding_string_valid_unfused_hbm_bytes(self): + # Test flops. + # For spatial_valid_position_counts, we have 3x3 output with the following + # flops for each element: + # 4 6 4 + # 6 9 6 + # 4 6 4 + # Non_spatial_dims_factor = 1 because `{batch, feature}_group_count` are + # both equal to 1. + # Each FMA is 2 flops. + self.assertEqual( + result.unfused_flops, + 2 * (4 + 6 + 4 + 6 + 9 + 6 + 4 + 6 + 4), + ) + + def test_conv_general_dilated_padding_string_valid(self): input_data = jnp.zeros((1, 1, 10, 20), dtype=int) kernel_data = jnp.ones((1, 1, 3, 3), dtype=int) conv = lambda a, b: lax.conv_general_dilated( @@ -681,12 +697,90 @@ def test_conv_general_dilated_padding_string_valid_unfused_hbm_bytes(self): * self.get_conv_output_dim(10, 3, 0, 0, 1) * self.get_conv_output_dim(20, 3, 0, 0, 1) ) + # Bytes accessed is sum of inputs and output. expected_unfused_hbm_bytes = self._bytes_per_word * ( expected_input_size + expected_kernel_size + expected_output_size ) self.assertEqual(result.unfused_hbm_bytes, expected_unfused_hbm_bytes) + # Output shape is [1x1x8x18] and each output element requires (3x3) FMAs, + # and each FMA is 2 flops. + self.assertEqual( + result.unfused_flops, 2 * expected_output_size * 3 * 3 + ) + + + @jtu.parameterized.named_parameters( + dict( + testcase_name="padding", + input_spatial_dim=1, + window_strides=[1], + padding=[(_VERY_LARGE_NUMBER - 1, _VERY_LARGE_NUMBER - 1)], + lhs_dilation=[1], + ), + dict( + testcase_name="input", + input_spatial_dim=_VERY_LARGE_NUMBER, + window_strides=[_VERY_LARGE_NUMBER - 1], + padding=[(0, 0)], + lhs_dilation=[_VERY_LARGE_NUMBER], + ), + ) + def test_conv_general_dilated_flops_very_large( + self, input_spatial_dim, window_strides, padding, lhs_dilation + ): + input_data = jnp.zeros((1, 1, input_spatial_dim), dtype=int) + kernel_data = jnp.ones((1, 1, _VERY_LARGE_NUMBER), dtype=int) + conv = lambda a, b: lax.conv_general_dilated( + lhs=a, + rhs=b, + window_strides=window_strides, + padding=padding, + lhs_dilation=lhs_dilation, + ) + _, result = roofline.roofline(conv)(input_data, kernel_data) + + self.assertEqual(result.unfused_flops, 2 * _VERY_LARGE_NUMBER) + + def test_conv_general_dilated_flops_feature_group_count(self): + feature_group_count = 120 + input_data = jnp.zeros((1, feature_group_count, 10, 20), dtype=int) + kernel_data = jnp.ones((feature_group_count, 1, 3, 3), dtype=int) + conv = lambda a, b: lax.conv_general_dilated( + lhs=a, + rhs=b, + window_strides=(1, 1), + padding=((0, 0), (0, 0)), + feature_group_count=feature_group_count, + ) + _, result = roofline.roofline(conv)(input_data, kernel_data) + + # Output shape is [1x120x8x18] and each output element requires (3x3) + # FMAs and one FMA is 2 flops. + self.assertEqual( + result.unfused_flops, 2 * 120 * 8 * 18 * 3 * 3 + ) + + def test_conv_general_dilated_flops_batch_group_count(self): + batch_group_count = 120 + input_data = jnp.zeros((batch_group_count, 1, 10, 20), dtype=int) + kernel_data = jnp.ones((batch_group_count, 1, 3, 3), dtype=int) + conv = lambda a, b: lax.conv_general_dilated( + lhs=a, + rhs=b, + window_strides=(1, 1), + padding=((0, 0), (0, 0)), + batch_group_count=batch_group_count, + ) + _, result = roofline.roofline(conv)(input_data, kernel_data) + + # Output shape is [120x1x8x18] and each output element requires (3x3) + # FMAs and one FMA is 2 flops. + self.assertEqual( + result.unfused_flops, 2 * 120 * 8 * 18 * 3 * 3 + ) + def test_reduce_sum_no_axis(self): _, result = roofline.roofline(lambda x: jnp.sum(x))(jnp.zeros((11, 4))) self.assertEqual(result.unfused_flops, 11 * 4 - 1) From 96780f19b0b8775f02dc5d57dda11597a2f9c97e Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 2 Apr 2025 14:28:54 -0700 Subject: [PATCH 342/483] jax.numpy: support __jax_array__ in several more functions --- jax/_src/numpy/lax_numpy.py | 5 +++-- jax/_src/numpy/util.py | 2 +- tests/array_extensibility_test.py | 14 +++++++------- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 43b8923246d4..dba208327adc 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -2984,7 +2984,7 @@ def bincount(x: ArrayLike, weights: ArrayLike | None = None, >>> jnp.bincount(x, length=5) Array([2, 1, 0, 1, 0], dtype=int32) """ - util.check_arraylike("bincount", x) + x = util.ensure_arraylike("bincount", x) if _dtype(x) == bool: x = lax.convert_element_type(x, 'int32') if not issubdtype(_dtype(x), np.integer): @@ -5018,7 +5018,7 @@ def choose(a, choices): """ if out is not None: raise NotImplementedError("The 'out' argument to jnp.choose is not supported.") - util.check_arraylike('choose', a, *choices) + a, *choices = util.ensure_arraylike_tuple('choose', (a, *choices)) if not issubdtype(_dtype(a), np.integer): raise ValueError("`a` array must be integer typed") N = len(choices) @@ -8781,6 +8781,7 @@ def argwhere( >>> jnp.argwhere(0) Array([], shape=(0, 0), dtype=int32) """ + a = util.ensure_arraylike("argwhere", a) result = transpose(vstack(nonzero(atleast_1d(a), size=size, fill_value=fill_value))) if np.ndim(a) == 0: return result[:0].reshape(result.shape[0], 0) diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index 9d56267c4b61..49605ffc3b0c 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -259,7 +259,7 @@ def _broadcast_arrays(*args: ArrayLike) -> list[Array]: def _broadcast_to(arr: ArrayLike, shape: DimSize | Shape, sharding=None ) -> Array: - check_arraylike("broadcast_to", arr) + arr = ensure_arraylike("broadcast_to", arr) arr = arr if isinstance(arr, Array) else lax.asarray(arr) if not isinstance(shape, tuple) and np.ndim(shape) == 0: shape = (shape,) diff --git a/tests/array_extensibility_test.py b/tests/array_extensibility_test.py index 730001abef76..14fcc18ca7a5 100644 --- a/tests/array_extensibility_test.py +++ b/tests/array_extensibility_test.py @@ -92,6 +92,8 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: 'apply_along_axis', 'apply_over_axes', 'arange', + 'array_str', + 'array_repr', 'astype', 'bartlett', 'bfloat16', @@ -101,6 +103,7 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: 'bool_', 'broadcast_shapes', 'c_', + 'can_cast', 'cdouble', 'character', 'complex128', @@ -233,14 +236,12 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: NumPyAPI.sig(jnp.argmin, Float[10]), NumPyAPI.sig(jnp.argpartition, Float[10], kth=5), NumPyAPI.sig(jnp.argsort, Float[10]), - # NumPyAPI.sig(jnp.argwhere, [float], [(10,)]), + NumPyAPI.sig(jnp.argwhere, Float[10]), NumPyAPI.sig(jnp.around, Float[5]), NumPyAPI.sig(jnp.array, Float[5]), NumPyAPI.sig(jnp.array_equal, Float[5], Float[5]), NumPyAPI.sig(jnp.array_equiv, Float[5], Float[5]), - # NumPyAPI.sig(jnp.array_repr, Float[5]), NumPyAPI.sig(jnp.array_split, Float[9], indices_or_sections=3), - # NumPyAPI.sig(jnp.array_str, Float[5]), NumPyAPI.sig(jnp.asarray, Float[5]), NumPyAPI.sig(jnp.asin, Float[5]), NumPyAPI.sig(jnp.asinh, Float[5]), @@ -251,7 +252,7 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: NumPyAPI.sig(jnp.atleast_2d, Float[5]), NumPyAPI.sig(jnp.atleast_3d, Float[5]), NumPyAPI.sig(jnp.average, Float[10]), - # NumPyAPI.sig(jnp.bincount, int[10]), + NumPyAPI.sig(jnp.bincount, Int[10]), NumPyAPI.sig(jnp.bitwise_and, Int[5], Int[5]), NumPyAPI.sig(jnp.bitwise_count, Int[5]), NumPyAPI.sig(jnp.bitwise_invert, Int[5]), @@ -261,11 +262,10 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: NumPyAPI.sig(jnp.bitwise_right_shift, Int[5], Int[5]), NumPyAPI.sig(jnp.bitwise_xor, Int[5], Int[5]), NumPyAPI.sig(jnp.broadcast_arrays, Float[5]), - # NumPyAPI.sig(jnp.broadcast_to, Float[()], shape=(10,)), - # NumPyAPI.sig(jnp.can_cast, Float[()], to='int32'), + NumPyAPI.sig(jnp.broadcast_to, Float[()], shape=(10,)), NumPyAPI.sig(jnp.cbrt, Float[5]), NumPyAPI.sig(jnp.ceil, Float[5]), - # NumPyAPI.sig(jnp.choose, [int, float], [(3,), (10,)]), + NumPyAPI.sig(jnp.choose, Int[3], [Float[3], Float[3], Float[3]], mode='clip'), NumPyAPI.sig(jnp.clip, Float[5]), NumPyAPI.sig(jnp.column_stack, [Float[5], Float[5], Float[5]]), NumPyAPI.sig(jnp.compress, Float[10], Bool[10]), From 9c58a112b3e3ccf5a4eb8bdbddfb2760a9b2161a Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 2 Apr 2025 14:58:06 -0700 Subject: [PATCH 343/483] `jnp.array` no longer accepts None PiperOrigin-RevId: 743291099 --- CHANGELOG.md | 5 +++++ jax/_src/numpy/lax_numpy.py | 9 +-------- tests/lax_numpy_test.py | 15 ++++----------- 3 files changed, 10 insertions(+), 19 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 68450dca4057..ffd197b390d0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. ## Unreleased +* Breaking changes + + * {func}`jax.numpy.array` no longer accepts `None`. This behavior was + deprecated since November 2023 and is now removed. + * Changes * The minimum CuDNN version is v9.8. * JAX is now built using CUDA 12.8. All versions of CUDA 12.1 or newer remain diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 43b8923246d4..1b59363d14c7 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -5502,14 +5502,7 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, leaves, treedef = tree_flatten(object, is_leaf=lambda x: x is None) if any(leaf is None for leaf in leaves): - # Added Nov 16 2023 - if deprecations.is_accelerated("jax-numpy-array-none"): - raise ValueError("None is not a valid value for jnp.array") - warnings.warn( - "None encountered in jnp.array(); this is currently treated as NaN. " - "In the future this will result in an error.", - FutureWarning, stacklevel=2) - leaves, treedef = tree_flatten(object) + raise ValueError("None is not a valid value for jnp.array") leaves = [ leaf if (leaf_jax_array := getattr(leaf, "__jax_array__", None)) is None diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index f94f42f027ce..2c305af6e8f5 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -47,7 +47,6 @@ from jax._src import array from jax._src import config from jax._src import core -from jax._src import deprecations from jax._src import dtypes from jax._src import test_util as jtu from jax._src.lax import lax as lax_internal @@ -3796,16 +3795,10 @@ def testArrayFromList(self): with self.assertRaisesRegex(OverflowError, "Python int too large.*"): jnp.array([0, val]) - def testArrayNoneWarning(self): - if deprecations.is_accelerated('jax-numpy-array-none'): - ctx = self.assertRaisesRegex( - ValueError, 'None is not a valid value for jnp.array' - ) - else: - ctx = self.assertWarnsRegex( - FutureWarning, r'None encountered in jnp.array\(\)' - ) - with ctx: + def testArrayNone(self): + with self.assertRaisesRegex( + ValueError, 'None is not a valid value for jnp.array' + ): jnp.array([0.0, None]) def testIssue121(self): From 9fa5de7b0584f74fce9d0eea89817e8fa9b96b8f Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 2 Apr 2025 15:45:15 -0700 Subject: [PATCH 344/483] [pallas] Removed `pl.device_id`. Use `lax.axis_index` instead. PiperOrigin-RevId: 743307670 --- jax/_src/pallas/mosaic/lowering.py | 4 ---- jax/_src/pallas/primitives.py | 8 -------- jax/experimental/pallas/__init__.py | 1 - jax/experimental/pallas/tpu.py | 1 - tests/pallas/tpu_pallas_distributed_test.py | 3 +-- 5 files changed, 1 insertion(+), 16 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index f8f49f3d7aea..6c8b3c646a0d 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -3602,10 +3602,6 @@ def _dma_wait_lowering_rule(ctx: LoweringRuleContext, *args, tree, lowering_rules[tpu_primitives.dma_wait_p] = _dma_wait_lowering_rule -def _device_id_lowering_rule(ctx: LoweringRuleContext): - return tpu.device_id() -lowering_rules[primitives.device_id_p] = _device_id_lowering_rule - def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable): grid_names = ctx.lowering_context.grid_names if grid_names and axis_name in grid_names: diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index 4971b83a9ba2..986a62571010 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -1237,11 +1237,3 @@ def _semaphore_wait_discharge_rule(in_avals, state_discharge.register_discharge_rule(semaphore_wait_p)( _semaphore_wait_discharge_rule ) - -device_id_p = jax_core.Primitive('device_id') - -@device_id_p.def_abstract_eval -def _device_id_abstract_eval(): - return jax_core.ShapedArray((), jnp.dtype("int32")) - -device_id = device_id_p.bind diff --git a/jax/experimental/pallas/__init__.py b/jax/experimental/pallas/__init__.py index fd523712fa9c..b6d2ac69d2c6 100644 --- a/jax/experimental/pallas/__init__.py +++ b/jax/experimental/pallas/__init__.py @@ -48,7 +48,6 @@ from jax._src.pallas.primitives import atomic_xchg as atomic_xchg from jax._src.pallas.primitives import atomic_xor as atomic_xor from jax._src.pallas.primitives import debug_print as debug_print -from jax._src.pallas.primitives import device_id as device_id from jax._src.pallas.primitives import dot as dot from jax._src.pallas.primitives import load as load from jax._src.pallas.primitives import max_contiguous as max_contiguous diff --git a/jax/experimental/pallas/tpu.py b/jax/experimental/pallas/tpu.py index da054bf18309..21976c47166b 100644 --- a/jax/experimental/pallas/tpu.py +++ b/jax/experimental/pallas/tpu.py @@ -52,7 +52,6 @@ # Those primitives got moved to Pallas core. Keeping the updated imports # here for backward compatibility. from jax._src.pallas.core import semaphore as semaphore -from jax._src.pallas.primitives import device_id as device_id from jax._src.pallas.primitives import DeviceIdType as DeviceIdType from jax._src.pallas.primitives import semaphore_read as semaphore_read from jax._src.pallas.primitives import semaphore_signal as semaphore_signal diff --git a/tests/pallas/tpu_pallas_distributed_test.py b/tests/pallas/tpu_pallas_distributed_test.py index f7d7daf1874f..737ab5137e99 100644 --- a/tests/pallas/tpu_pallas_distributed_test.py +++ b/tests/pallas/tpu_pallas_distributed_test.py @@ -51,8 +51,7 @@ def test_basic_remote_vmem_dma(self, mem): # Implements very simple collective permute def kernel(x_ref, y_ref): def body(ready_sem, send_sem, recv_sem): - dev_id = pltpu.device_id() - other_dev_id = 1 - dev_id + other_dev_id = 1 - lax.axis_index('x') pltpu.semaphore_signal(ready_sem, device_id=other_dev_id, device_id_type=pltpu.DeviceIdType.LOGICAL) pltpu.semaphore_wait(ready_sem) From 5ddec650868df2bee004e062c5664f9b69c762ee Mon Sep 17 00:00:00 2001 From: kaixih Date: Thu, 3 Apr 2025 00:00:25 +0000 Subject: [PATCH 345/483] Remove asserts --- jax/_src/nn/functions.py | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index cc4a345641dd..d0f5f770e196 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -1273,14 +1273,35 @@ def scaled_matmul( >>> b_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu) >>> scaled_matmul(a, b, a_scales, b_scales) """ - assert all(x.ndim == 3 for x in (a, b, a_scales, b_scales)) + if not all(x.ndim == 3 for x in (a, b, a_scales, b_scales)): + raise ValueError( + "scaled_matmul requires all inputs to be 3-dimensional arrays" + ) + B_a, M_a, K_a = a.shape B_b, N_b, K_b = b.shape - assert K_a == K_b and B_a == B_b + if K_a != K_b or B_a != B_b: + raise ValueError( + "scaled_matmul requires inputs a and b to have matching batch (B) " + f"and contract (K) dimensions, but got shapes {a.shape} and " + f"{b.shape}" + ) + B_as, M_as, K_as = a_scales.shape B_bs, N_bs, K_bs = b_scales.shape - assert K_as == K_bs and B_as == B_bs - assert M_as == M_a and N_bs == N_b + if K_as != K_bs or B_as != B_bs: + raise ValueError( + "scaled_matmul requires scales to have matching batch (B) and " + f"contract (K) dimensions, but got shapes {a_scales.shape} and " + f"{b_scales.shape}" + ) + + if M_as != M_a or N_bs != N_b: + raise ValueError( + "scaled_matmul requires scales to match non-contract dimensions of " + f"inputs, but got shapes a: {a.shape}, b: {b.shape}, a_scales: " + f"{a_scales.shape}, b_scales: {b_scales.shape}" + ) preferred_element_type = dtypes.canonicalize_dtype( np.dtype(preferred_element_type) From 2540fcde11e3531267b96e9ad40a80749984bace Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Wed, 2 Apr 2025 13:13:13 -0700 Subject: [PATCH 346/483] add an `out_sharding` option to `jax.random.bits` Drop into `Auto` mode in the implementation. --- jax/_src/random.py | 26 +++++++++++++++++--------- tests/pjit_test.py | 19 +++++++++++++++++++ 2 files changed, 36 insertions(+), 9 deletions(-) diff --git a/jax/_src/random.py b/jax/_src/random.py index 0dcbda7bb717..fc571be9493a 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -38,11 +38,11 @@ from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir -from jax._src.sharding_impls import canonicalize_sharding -from jax._src.pjit import auto_axes from jax._src.lax import lax as lax_internal from jax._src.numpy.lax_numpy import _convert_and_clip_integer from jax._src.numpy.util import _arraylike, check_arraylike, promote_dtypes_inexact +from jax._src.pjit import auto_axes +from jax._src.sharding_impls import canonicalize_sharding from jax._src.typing import Array, ArrayLike, DTypeLike from jax._src.util import canonicalize_axis @@ -348,9 +348,18 @@ def _check_shape(name: str, shape: Shape, *param_shapes) -> None: raise ValueError(msg.format(name, shape_, shape)) +def maybe_auto_axes(f, out_shardings, **hoist_kwargs): + f_ = partial(f, **hoist_kwargs) + if out_shardings is None: + return f_ + else: + return auto_axes(f_, out_shardings=out_shardings) + + def bits(key: ArrayLike, shape: Shape = (), - dtype: DTypeLikeUInt | None = None) -> Array: + dtype: DTypeLikeUInt | None = None, + out_sharding=None) -> Array: """Sample uniform bits in the form of unsigned integers. Args: @@ -373,8 +382,10 @@ def bits(key: ArrayLike, f"got {dtype}") dtype = dtypes.canonicalize_dtype(dtype) shape = core.canonicalize_shape(shape) + out_sharding = canonicalize_sharding(out_sharding, "bits") bit_width = dtype.itemsize * 8 - return _random_bits(key, bit_width, shape) + return maybe_auto_axes(_random_bits, out_sharding, + bit_width=bit_width, shape=shape)(key) def uniform(key: ArrayLike, @@ -711,16 +722,13 @@ def normal(key: ArrayLike, """ key, _ = _check_prng_key("normal", key) shape = core.canonicalize_shape(shape) - out_sharding = canonicalize_sharding(out_sharding, 'normal') + out_sharding = canonicalize_sharding(out_sharding, "normal") dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.inexact): raise ValueError(f"dtype argument to `normal` must be a float or complex dtype, " f"got {dtype}") dtype = dtypes.canonicalize_dtype(dtype) - if out_sharding is None: - return _normal(key, shape, dtype) - return auto_axes(partial(_normal, shape=shape, dtype=dtype), - out_shardings=out_sharding)(key) + return maybe_auto_axes(_normal, out_sharding, shape=shape, dtype=dtype)(key) @partial(jit, static_argnums=(1, 2)) def _normal(key, shape, dtype) -> Array: diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 38f191302ea1..ebfdd7fa0b20 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7274,6 +7274,25 @@ def f(key): out = f(key) self.assertEqual(out.sharding, NamedSharding(mesh, P())) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_random_bits(self, mesh): + @jax.jit + def f(key): + out = jax.random.bits(key, shape=(8, 12), out_sharding=P('x', 'y')) + self.assertEqual(out.aval.sharding.spec, P('x', 'y')) + return out + + key = jax.random.key(1) + out = f(key) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + + lowered_text = f.lower(key).as_text() + if config.use_shardy_partitioner.value: + self.assertIn('sdy.sharding_constraint', lowered_text) + self.assertIn('<@mesh, [{"x"}, {"y"}]>', lowered_text) + else: + self.assertIn('mhlo.sharding = "{devices=[2,2]<=[4]}"}', lowered_text) + @jtu.with_user_mesh((2, 2), ('x', 'y')) def test_random_uniform(self, mesh): @jax.jit From 2f617631fbceb56b33fb0312b228a49bc3bee608 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Wed, 2 Apr 2025 17:31:23 -0700 Subject: [PATCH 347/483] use common `maybe_auto_axes` helper in `random.uniform` --- jax/_src/random.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/jax/_src/random.py b/jax/_src/random.py index fc571be9493a..e519c284a567 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -417,11 +417,10 @@ def uniform(key: ArrayLike, raise ValueError(f"dtype argument to `uniform` must be a float dtype, " f"got {dtype}") dtype = dtypes.canonicalize_dtype(dtype) - if out_sharding is None: + def f(key, minval, maxval, shape, dtype): # reorder args return _uniform(key, shape, dtype, minval, maxval) - def f(k, minv, maxv): - return _uniform(k, shape, dtype, minv, maxv) - return auto_axes(f, out_shardings=out_sharding)(key, minval, maxval) + return maybe_auto_axes(f, out_sharding, shape=shape, dtype=dtype)( + key, minval, maxval) @partial(jit, static_argnums=(1, 2)) def _uniform(key, shape, dtype, minval, maxval) -> Array: From ab816ed8c4d787d5a6760e32b6b34db1fe55e1d1 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Wed, 2 Apr 2025 17:38:05 -0700 Subject: [PATCH 348/483] add an `out_sharding` option to `jax.random.randint` Drop into `Auto` mode in the implementation. --- jax/_src/random.py | 27 +++++++++++++++------------ tests/pjit_test.py | 20 ++++++++++++++++++++ 2 files changed, 35 insertions(+), 12 deletions(-) diff --git a/jax/_src/random.py b/jax/_src/random.py index e519c284a567..1d044ec111ff 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -417,13 +417,11 @@ def uniform(key: ArrayLike, raise ValueError(f"dtype argument to `uniform` must be a float dtype, " f"got {dtype}") dtype = dtypes.canonicalize_dtype(dtype) - def f(key, minval, maxval, shape, dtype): # reorder args - return _uniform(key, shape, dtype, minval, maxval) - return maybe_auto_axes(f, out_sharding, shape=shape, dtype=dtype)( - key, minval, maxval) + return maybe_auto_axes(_uniform, out_sharding, + shape=shape,dtype=dtype)(key, minval, maxval) -@partial(jit, static_argnums=(1, 2)) -def _uniform(key, shape, dtype, minval, maxval) -> Array: +@partial(jit, static_argnums=(3, 4)) +def _uniform(key, minval, maxval, shape, dtype) -> Array: _check_shape("uniform", shape) if not jnp.issubdtype(dtype, np.floating): raise TypeError("uniform only accepts floating point dtypes.") @@ -467,7 +465,8 @@ def randint(key: ArrayLike, shape: Shape, minval: IntegerArray, maxval: IntegerArray, - dtype: DTypeLikeInt = int) -> Array: + dtype: DTypeLikeInt = int, + out_sharding=None) -> Array: """Sample uniform random values in [minval, maxval) with given shape/dtype. Args: @@ -487,10 +486,12 @@ def randint(key: ArrayLike, dtypes.check_user_dtype_supported(dtype) dtype = dtypes.canonicalize_dtype(dtype) shape = core.canonicalize_shape(shape) - return _randint(key, shape, minval, maxval, dtype) + out_sharding = canonicalize_sharding(out_sharding, "randint") + return maybe_auto_axes(_randint, out_sharding, shape=shape, dtype=dtype)( + key, minval, maxval) -@partial(jit, static_argnums=(1, 4)) -def _randint(key, shape, minval, maxval, dtype) -> Array: +@partial(jit, static_argnums=(3, 4)) +def _randint(key, minval, maxval, shape, dtype) -> Array: _check_shape("randint", shape, np.shape(minval), np.shape(maxval)) if not jnp.issubdtype(dtype, np.integer): raise TypeError(f"randint only accepts integer dtypes, got {dtype}") @@ -1557,7 +1558,8 @@ def gumbel(key: ArrayLike, def _gumbel(key, shape, dtype, mode) -> Array: _check_shape("gumbel", shape) if mode == "high": - high, low = _uniform(key, (2,) + shape, dtype, minval=0., maxval=1.) + high, low = _uniform(key, minval=0., maxval=1., + shape=(2,) + shape, dtype=dtype) # TODO(parkers): The condition is to protect against rounding up but # we should be able to add safely with the right addition operation. x = jnp.where(high >= 0.5, high, @@ -1565,7 +1567,8 @@ def _gumbel(key, shape, dtype, mode) -> Array: return -jnp.log(-jnp.log1p(-x)) else: return -jnp.log(-jnp.log( - _uniform(key, shape, dtype, minval=jnp.finfo(dtype).tiny, maxval=1.))) + _uniform(key, minval=jnp.finfo(dtype).tiny, maxval=1., + shape=shape, dtype=dtype))) def categorical( diff --git a/tests/pjit_test.py b/tests/pjit_test.py index ebfdd7fa0b20..d3d9cab7a5ba 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7312,6 +7312,26 @@ def f(key): else: self.assertIn('mhlo.sharding = "{devices=[2,2]<=[4]}"}', lowered_text) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_random_randint(self, mesh): + @jax.jit + def f(key): + out = jax.random.randint(key, shape=(8, 12), minval=0, maxval=10, + out_sharding=P('x', 'y')) + self.assertEqual(out.aval.sharding.spec, P('x', 'y')) + return out + + key = jax.random.key(1) + out = f(key) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + + lowered_text = f.lower(key).as_text() + if config.use_shardy_partitioner.value: + self.assertIn('sdy.sharding_constraint', lowered_text) + self.assertIn('<@mesh, [{"x"}, {"y"}]>', lowered_text) + else: + self.assertIn('mhlo.sharding = "{devices=[2,2]<=[4]}"}', lowered_text) + @jtu.with_user_mesh((2, 2), ('x', 'y')) def test_random_normal(self, mesh): @jax.jit From f1adec35641553fd38aceff4266fbd5986c11ded Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Thu, 3 Apr 2025 00:26:46 -0700 Subject: [PATCH 349/483] [Mosaic GPU] Define the `mosaic_gpu.custom_primitive` dialect op. PiperOrigin-RevId: 743441718 --- jaxlib/mosaic/dialect/gpu/BUILD | 1 + jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc | 37 +++++++++++++++++++++ jaxlib/mosaic/dialect/gpu/mosaic_gpu.td | 26 +++++++++++++++ tests/mosaic/gpu_dialect_test.py | 44 +++++++++++++++++++++++++ 4 files changed, 108 insertions(+) diff --git a/jaxlib/mosaic/dialect/gpu/BUILD b/jaxlib/mosaic/dialect/gpu/BUILD index f0e399da0575..592d22b699a3 100644 --- a/jaxlib/mosaic/dialect/gpu/BUILD +++ b/jaxlib/mosaic/dialect/gpu/BUILD @@ -119,6 +119,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", "@llvm-project//mlir:LLVMCommonConversion", diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc index 1b3d08f91fb0..0a36aab6fbcd 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc @@ -31,10 +31,12 @@ limitations under the License. #include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypeInterfaces.h" @@ -370,6 +372,41 @@ llvm::LogicalResult WGMMAOp::verify() { return llvm::success(); } +llvm::LogicalResult CustomPrimitiveOp::verify() { + int num_vector_operands = 0; + int num_smem_ref_operands = 0; + mlir::Attribute smem = mlir::gpu::AddressSpaceAttr::get( + getContext(), mlir::gpu::AddressSpace::Workgroup); + for (auto operand : getOperands()) { + if (mlir::isa(operand.getType())) { + ++num_vector_operands; + } + + if (auto ref_ty = mlir::dyn_cast(operand.getType())) { + if (ref_ty.getMemorySpace() == smem) { + ++num_smem_ref_operands; + } + } + } + + if (num_vector_operands != getInLayouts().size()) { + return emitOpError( + "Custom primitive must have a layout for each vector operand."); + } + + if (num_smem_ref_operands != getInTransforms().size()) { + return emitOpError( + "Custom primitive must have transforms for each memref operand in " + "smem."); + } + + if (getResults().size() != getOutLayouts().size()) { + return emitOpError("Custom primitive must have a layout for each result."); + } + + return llvm::success(); +} + mlir::AffineMap LayoutAttr::getAffineMap() const { // This always returns an identity map. It's technically not correct, but we // don't actually use it anywhere. It's only called during verification of the diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td index 85929080faec..0d954716b179 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td @@ -484,4 +484,30 @@ def MosaicGPU_OptimizationBarrierOp : Op { + let summary = "Allows defining a custom Mosaic GPU primitive."; + let description = [{ + Allows defining a custom Mosaic GPU primitive. + + Custom primitives should carry input and output layouts for each of their + vector operands and outputs, and input transforms for each of their memref + operands that live in SMEM. + + Custom primitives can only return vectors. + }]; + + let arguments = ( + ins Variadic:$operands, + // Attributes + ArrayAttr:$in_layouts, + ArrayAttr:$in_transforms, + ArrayAttr:$out_layouts + ); + + let results = (outs Variadic>); + let regions = (region AnyRegion:$body); + + let hasVerifier = 1; +} + #endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_GPU_MOSAIC_GPU_TD_ diff --git a/tests/mosaic/gpu_dialect_test.py b/tests/mosaic/gpu_dialect_test.py index ba9d23fa5b4f..bc94d72dc0d8 100644 --- a/tests/mosaic/gpu_dialect_test.py +++ b/tests/mosaic/gpu_dialect_test.py @@ -862,6 +862,50 @@ def test_lower_conversion_op_lowers_to_same_op(self, op, in_dtype, out_dtype): self.assertLen(conversion_ops, 1) self.assertEqual(conversion_ops[0].result.type, scalar_out_ty) + @parameterized.parameters( + (True, False, False), + (False, True, False), + (False, False, True), + ) + def test_custom_primitive_op_must_have_number_of_annotations_matching_operands_and_results( + self, omit_in_layouts, omit_in_transforms, omit_out_layouts + ): + vec_ty = ir.VectorType.get((4, 32), ir.BF16Type.get()) + out_layouts = [ + layouts.to_layout_attr( + mgpu.WGStridedFragLayout.from_shaped_type(vec_ty) + ) + ] + in_layouts = out_layouts * 2 + in_transforms = [ + ir.ArrayAttr.get([mgpu.dialect.SwizzleTransformAttr.get(128)]) + ] + + in_layouts = [] if omit_in_layouts else in_layouts + in_transforms = [] if omit_in_transforms else in_transforms + out_layouts = [] if omit_out_layouts else out_layouts + + def body(vec1, vec2, ref): + mgpu.dialect.custom_primitive( + [vec_ty], [vec1, vec2, ref], in_layouts, in_transforms, out_layouts + ) + + with ir.InsertionPoint(self.module.body): + smem = ir.Attribute.parse("#gpu.address_space") + ref_ty = ir.MemRefType.get((4, 32), ir.BF16Type.get(), memory_space=smem) + func.FuncOp.from_py_func(vec_ty, vec_ty, ref_ty)(body) + + if omit_in_layouts: + error = "layout for each vector operand" + elif omit_in_transforms: + error = "transforms for each memref operand in smem" + else: + assert omit_out_layouts + error = "layout for each result" + + with self.assertRaisesRegex(ir.MLIRError, error): + self.module.operation.verify() + if __name__ == "__main__": parameterized.absltest.main(testLoader=jtu.JaxTestLoader()) From 6243ac80fca6ba718b01facc52c4cde7277838bc Mon Sep 17 00:00:00 2001 From: Michael Hudgins Date: Thu, 3 Apr 2025 02:44:18 -0700 Subject: [PATCH 350/483] [CI] Enable nightly TPU CI tests for v6e. PiperOrigin-RevId: 743478967 --- .github/workflows/cloud-tpu-ci-nightly.yml | 12 ++++++++++-- .github/workflows/wheel_tests_nightly_release.yml | 11 +++++++++-- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/.github/workflows/cloud-tpu-ci-nightly.yml b/.github/workflows/cloud-tpu-ci-nightly.yml index 099f4ad5c520..fd799a3f70b5 100644 --- a/.github/workflows/cloud-tpu-ci-nightly.yml +++ b/.github/workflows/cloud-tpu-ci-nightly.yml @@ -26,11 +26,19 @@ jobs: matrix: jaxlib-version: ["head", "pypi_latest", "nightly", "nightly+oldest_supported_libtpu"] tpu: [ - # {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"}, - {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"} + {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}, + {type: "v6e-8", cores: "8", runner: "linux-x86-ct6e-180-8tpu"} ] python-version: ["3.10"] + # Exclude v6e-8 tests for nightly+oldest_supported_libtpu and pypi_latest for resource constraints. + exclude: + - tpu: + type: "v6e-8" + jaxlib-version: "nightly+oldest_supported_libtpu" + - tpu: + type: "v6e-8" + jaxlib-version: "pypi_latest" name: "TPU test (jaxlib=${{ matrix.jaxlib-version }}, ${{ matrix.tpu.type }})" env: LIBTPU_OLDEST_VERSION_DATE: 20241205 diff --git a/.github/workflows/wheel_tests_nightly_release.yml b/.github/workflows/wheel_tests_nightly_release.yml index fd4a52d296e0..6fd48d016bd0 100644 --- a/.github/workflows/wheel_tests_nightly_release.yml +++ b/.github/workflows/wheel_tests_nightly_release.yml @@ -80,19 +80,26 @@ jobs: tpu-specs: [ # {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"}, - {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"} + {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}, + {type: "v6e-8", cores: "8", runner: "linux-x86-ct6e-180-8tpu"} ] libtpu-version-type: ["pypi_latest", "nightly", "oldest_supported_libtpu"] exclude: - libtpu-version-type: ${{ startsWith(github.ref_name, 'release/') && 'nightly' }} - libtpu-version-type: ${{ !startsWith(github.ref_name, 'release/') && 'pypi_latest' }} - # Run a single Python version for v4-8. + # Run a single Python version for v4-8 and v6e-8 - tpu-specs: type: "v4-8" python: "3.10" - tpu-specs: type: "v4-8" python: "3.11" + - tpu-specs: + type: "v6e-8" + python: "3.10" + - tpu-specs: + type: "v6e-8" + python: "3.11" # Run min and max Python versions for v5e-8 - tpu-specs: type: "v5e-8" From ea196dac12d53d011e01db724156bd3c7f9952f5 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 3 Apr 2025 03:34:50 -0700 Subject: [PATCH 351/483] [pallas:mosaic_gpu] Slightly reworded the docstrings for a few recently added primitives PiperOrigin-RevId: 743492343 --- jax/_src/pallas/mosaic_gpu/primitives.py | 29 ++++++++++--------- jax/experimental/mosaic/gpu/__init__.py | 1 + jax/experimental/mosaic/gpu/launch_context.py | 9 +++--- 3 files changed, 21 insertions(+), 18 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index ff2678454b42..46ec8a87082e 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -21,7 +21,7 @@ import enum import itertools import math -from typing import Any, Literal, Optional +from typing import Any, Literal import jax from jax._src import core as jax_core @@ -135,17 +135,17 @@ def _load_p_lowering_rule( def load( - src: _Ref, idx, *, layout: Optional[Layout | ParameterizedLayout] = None -) -> mgpu.FragmentedArray: - """ Loads a ref (SMEM or GMEM) into a FragmentedArray with the specified layout. + src: _Ref, idx, *, layout: Layout | ParameterizedLayout | None = None +) -> jax.Array: + """Loads from a reference into an array with the specified layout. Args: - src: The reference to copy from. + src: The reference to load from. Can be either in SMEM or GMEM. idx: The index to load from. - layout: The optional layout to use for the returned FragmentedArray. + layout: The optional layout to use for the resulting array. Returns: - A FragmentedArray containing the loaded data in the specified layout. + The loaded array. """ src, src_transforms = state_primitives.get_ref_and_transforms( src, idx, "load", force_trailing_indexer=True, @@ -160,6 +160,7 @@ def load( layout=layout ) + copy_smem_to_gmem_p = jax_core.Primitive("copy_smem_to_gmem") copy_smem_to_gmem_p.multiple_results = True @@ -185,7 +186,7 @@ def _copy_smem_to_gmem_lowering( dst_transforms_treedef, has_user_predicate, commit_group, - reduction_op: Literal["add", "min", "max", "inc", "dec", "and", "or", "xor"] | None, + reduction_op, ): if has_user_predicate: flat_args, user_predicate = flat_args[:-1], flat_args[-1] @@ -295,9 +296,7 @@ def copy_smem_to_gmem( predicate: jax.Array | None = None, *, commit_group: bool = True, - reduction_op: Literal[ - "add","min","max","inc","dec","and","or","xor" - ] | None = None, + reduction_op: mgpu.ReductionOp | None = None, ) -> None: """Asynchronously copies a SMEM reference to a GMEM reference. @@ -306,10 +305,12 @@ def copy_smem_to_gmem( dst: The GMEM reference to copy to. predicate: A boolean indicating whether the copy should be performed. If ``None``, the copy is always performed. - commit_group: If ``True``, this and any previously uncommitted copies - are committed to a group and can be awaited jointly via + commit_group: If ``True``, this and any previously uncommitted copies are + committed to a group and can be awaited jointly via :func:`jax.experimental.mosaic.gpu.wait_smem_to_gmem`. - reduction_op: if set, perform the specified reduction op when copy to gmem + reduction_op: If set, perform the specified reduction operation when storing + to GMEM. For example, using ``"add"`` is conceptually equivalent to + doing ``src += dst``. See also: :func:`jax.experimental.mosaic.gpu.wait_smem_to_gmem` diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index afc87b5d96fa..e645115940e4 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -32,6 +32,7 @@ from .launch_context import ( LaunchContext as LaunchContext, MemRefTransform as MemRefTransform, + ReductionOp as ReductionOp, Rounding as Rounding, TileTransform as TileTransform, TransposeTransform as TransposeTransform, diff --git a/jax/experimental/mosaic/gpu/launch_context.py b/jax/experimental/mosaic/gpu/launch_context.py index 41c15bc5492e..aca3fc723882 100644 --- a/jax/experimental/mosaic/gpu/launch_context.py +++ b/jax/experimental/mosaic/gpu/launch_context.py @@ -229,6 +229,7 @@ def batch(self, leading_rank: int) -> MemRefTransform: OnDeviceProfiler = profiler.OnDeviceProfiler +ReductionOp = Literal["add", "min", "max", "inc", "dec", "and", "or", "xor"] @dataclasses.dataclass() class LaunchContext: @@ -406,10 +407,10 @@ def async_copy( uniform: bool = True, collective: Sequence[gpu.Dimension] | gpu.Dimension | None = None, partitioned: int | None = None, - predicate: ir.Value | None = None, # Should select 0 or 1 threads from the WG. - reduction_op: Literal[ - "add","min","max","inc","dec","and","or","xor" - ] | None = None, + predicate: ( + ir.Value | None + ) = None, # Should select 0 or 1 threads from the WG. + reduction_op: ReductionOp | None = None, ): """Initiates an async copy between GMEM and SMEM. From 552eea8ebddccd7f9605f0f62e7ca685621bb0db Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 3 Apr 2025 03:57:38 -0700 Subject: [PATCH 352/483] [pallas:mosaic_gpu] `emit_pipeline*` now passes the loop indices into the body This replaces the old behavior where `emit_pipeline*` would replace the current parallel grid with the sequential grid, changing the output of `pl.program_id`. With this change, `pl.program_id` always works wrt the parallel grid. PiperOrigin-RevId: 743498194 --- jax/_src/pallas/mosaic_gpu/lowering.py | 8 +-- jax/_src/pallas/mosaic_gpu/pipeline.py | 49 ++++++++++--------- .../pallas/ops/gpu/attention_mgpu.py | 2 +- tests/pallas/mosaic_gpu_test.py | 26 +++++----- 4 files changed, 43 insertions(+), 42 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index f027d5bcb76d..aafab927d4c2 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -629,13 +629,9 @@ def scoped_pipeline_fn(*refs, sem_refs, scratch_refs): scratch_refs = [ next(sem_refs_it) if r is sem_placeholder else r for r in scratch_refs ] - def body_fn(*refs): - grid_env = pallas_core.current_grid_env() - assert grid_env is not None # Set by ``emit_pipeline``. + def body_fn(indices, *refs): program_ids_template = util.merge_lists( - which_parallel, - [grid_axis.index for grid_axis in grid_env], - [None] * sum(which_parallel), + which_parallel, indices, [None] * sum(which_parallel) ) assert len(refs) + len(scratch_refs) == len(jaxpr.invars) return gpu_primitives.jaxpr_call( diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index d85ba4ae2a03..ec088f43c4b2 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -33,7 +33,6 @@ from jax._src.pallas import core as pallas_core from jax._src.pallas.mosaic_gpu import core as gpu_core from jax._src.pallas.mosaic_gpu import primitives as gpu_primitives -from jax._src.util import foreach from jax.experimental import pallas as pl import jax.numpy as jnp @@ -171,7 +170,8 @@ def emit_pipeline( """Creates a function to emit a manual pipeline within a Pallas kernel. Args: - body: The pipeline body. + body: The pipeline body, called with the indices for the current step, the + input refs, followed by the output refs. grid: The grid to use for the pipeline. in_specs: The block specs for the inputs. out_specs: The block specs for the outputs. @@ -248,7 +248,8 @@ def scoped_pipeline( it.islice(it.product(*map(range, grid)), max_concurrent_steps) ): indices = tuple(map(lambda i: jnp.asarray(i, dtype=jnp.int32), indices)) - foreach(lambda bref: bref.copy_in(step, indices, barrier_ref), in_brefs) + for bref in in_brefs: + bref.copy_in(step, indices, barrier_ref) # This is true if any of the outputs need to be transferred inside the loop. copies_out_in_loop = not all(bref.is_index_invariant for bref in out_brefs) @@ -266,11 +267,13 @@ def loop_body(step, carry): max_concurrent_steps - (1 + delay_release), wait_read_only=True ) - with pallas_core.grid_env(map(pallas_core.GridAxis, indices, grid)): - body(*( - bref.get_ref_for_slot(slot) - for bref in it.chain(in_brefs, out_brefs) - )) + body( + indices, + *( + bref.get_ref_for_slot(slot) + for bref in it.chain(in_brefs, out_brefs) + ), + ) if copies_out_in_loop: gpu_primitives.commit_smem() @@ -355,6 +358,7 @@ def do_fetch(): return pipeline + def emit_pipeline_warp_specialized( body: Callable[..., None], *, @@ -376,14 +380,16 @@ def emit_pipeline_warp_specialized( ``manual_consumed_barriers`` argument is True. ``` - def body(*input_refs, *output_refs, [consumed_barriers]) -> None: + def body(indices, *input_refs, *output_refs, [consumed_barriers]) -> None: ``` or with a carries enabled (enabled via the ``carry_coroutine`` argument), where the body returns the next carry: ``` - def body(*input_refs, *output_refs, [consumed_barriers], carry) -> Carry: + def body( + indices, *input_refs, *output_refs, [consumed_barriers], carry + ) -> Carry: ``` Args: @@ -545,18 +551,17 @@ def compute_loop_body(step, carry): if copies_out_in_loop: gpu_primitives.wait_smem_to_gmem(max_concurrent_steps - 1) - with pallas_core.grid_env(map(pallas_core.GridAxis, indices, grid)): - body_refs = [] - for bref in it.chain(in_brefs, out_brefs): - buf_slot = _get_slot(slot, ~bref.is_index_invariant) - body_refs.append(bref.get_ref_for_slot(buf_slot)) - - body_args = body_refs - if manual_consumed_barriers: - body_args += [consumed_barrier_ref.at[slot] for consumed_barrier_ref in consumed_barrier_refs] - if has_carry: - body_args += [prev_body_carry] - next_body_carry = body(*body_args) + body_refs = [] + for bref in it.chain(in_brefs, out_brefs): + buf_slot = _get_slot(slot, ~bref.is_index_invariant) + body_refs.append(bref.get_ref_for_slot(buf_slot)) + + body_args = body_refs + if manual_consumed_barriers: + body_args += [consumed_barrier_ref.at[slot] for consumed_barrier_ref in consumed_barrier_refs] + if has_carry: + body_args += [prev_body_carry] + next_body_carry = body(indices, *body_args) if not manual_consumed_barriers: [consumed_barrier_ref] = consumed_barrier_refs diff --git a/jax/experimental/pallas/ops/gpu/attention_mgpu.py b/jax/experimental/pallas/ops/gpu/attention_mgpu.py index b19e371a1eb8..d06d3b39cb7a 100644 --- a/jax/experimental/pallas/ops/gpu/attention_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/attention_mgpu.py @@ -310,7 +310,7 @@ def _compute_thread(): ) plgpu.wait_smem_to_gmem(0) - def kv_pipeline(k_smem, v_smem, + def kv_pipeline(_, k_smem, v_smem, k_consumed_barrier, v_consumed_barrier, carry): acc, m_i, l_i = carry diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index acf82ce23eba..06dfd453fb19 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1923,7 +1923,7 @@ def kernel(x_gmem, o_gmem): max_concurrent_steps=2, )(x_gmem, o_gmem) - def kernel_body(x_smem, o_smem): + def kernel_body(_, x_smem, o_smem): # +1 for the indexing done by ``emit_pipeline`. self.assertLen(x_smem.transforms, len(transforms) + 1) o_smem[...] = x_smem[...] + 1.0 @@ -1949,7 +1949,7 @@ def kernel(x_gmem, o_gmem): grid=(), )(x_gmem, o_gmem) - def nested_kernel(x_gmem, o_gmem): + def nested_kernel(_, x_gmem, o_gmem): plgpu.emit_pipeline( nested_kernel_body, in_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))], @@ -1958,7 +1958,7 @@ def nested_kernel(x_gmem, o_gmem): max_concurrent_steps=2, )(x_gmem, o_gmem) - def nested_kernel_body(x_smem, o_smem): + def nested_kernel_body(_, x_smem, o_smem): o_smem[...] = x_smem[...] + 1.0 x = jnp.arange(32 * num_steps * 16) @@ -1983,7 +1983,7 @@ def kernel(x_gmem, o_gmem): max_concurrent_steps=2, )(x_gmem, o_gmem) - def kernel_body(x_smem, o_smem): + def kernel_body(_, x_smem, o_smem): o_smem[...] = x_smem[...] + 1.0 x = jnp.arange(32 * num_steps * 16) @@ -2016,7 +2016,7 @@ def kernel(x_gmem, o_gmem): max_concurrent_steps=2, )(x_gmem, o_gmem) - def kernel_body(x_smem, o_smem): + def kernel_body(_, x_smem, o_smem): o_smem[...] = x_smem[...] + 1.0 x = jnp.arange(num_steps1 * 32 * num_steps2 * 16) @@ -2044,7 +2044,7 @@ def kernel(x_gmem, o_gmem): max_concurrent_steps=2, )(x_gmem, o_gmem) - def kernel_body(x_smem, o_smem): + def kernel_body(_, x_smem, o_smem): o_smem[...] = x_smem[...] + 1.0 x = jnp.arange(32 * num_steps1 * 16 * num_steps2 * 8) @@ -2086,7 +2086,7 @@ def test_realistic_matmul(self): ) def kernel(a_gmem, b_gmem, o_smem, acc): - def kernel_body(a_smem, b_smem): + def kernel_body(_, a_smem, b_smem): assert a_smem.shape == (tile_m, tile_k) assert b_smem.shape == (tile_k, tile_n) plgpu.wgmma(acc, a_smem, b_smem) @@ -2147,7 +2147,7 @@ def test_pipelined_copy(self, m, n, manual_consumed_barriers): x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float16) blk_m = blk_n = 64 - def copy_kernel(x_smem, o_smem, o_last_block_smem, *consumed_barriers): + def copy_kernel(_, x_smem, o_smem, o_last_block_smem, *consumed_barriers): # TODO(justinfu): Have each wg compute a separate slice # after multiple-indexers are supported. # This is currently a race, but the values written are the same. @@ -2201,7 +2201,7 @@ def test_elementwise_add(self, m=256, n=256, num_compute_wgs=2): block_shape=(blk_m, blk_n), index_map=lambda i, j: (i, j) ) - def tiled_add_kernel(x_smem, y_smem, o_smem): + def tiled_add_kernel(_, x_smem, y_smem, o_smem): # TODO(justinfu): Have each wg compute a separate slice # after multiple-indexers are supported. # This is currently a race, but the values written are the same. @@ -2265,7 +2265,7 @@ def _compute_thread(): plgpu.copy_smem_to_gmem(acc_smem, acc_gmem) plgpu.wait_smem_to_gmem(0) - def tiled_acc_kernel(x_smem, carry): + def tiled_acc_kernel(_, x_smem, carry): o_carry, = carry new_carry = x_smem[...] + o_carry return (new_carry,) @@ -2620,7 +2620,7 @@ def test_stage4(self): self.kernel, out_shape=x, grid=(2,), grid_names=("rows",) ) def kernel(l_ref, r_ref, o_ref): - def compute(l_smem, r_smem, o_smem): + def compute(_, l_smem, r_smem, o_smem): o_smem[...] = l_smem[...] + r_smem[...] r = lax.axis_index("rows") block = pl.BlockSpec((row_block, col_block), lambda c: (r, c)) @@ -2644,7 +2644,7 @@ def test_stage5(self): self.kernel, out_shape=x, grid=(2,), grid_names=("rows",) ) def kernel(l_ref, r_ref, o_ref): - def compute(l_smem, r_smem, o_smem): + def compute(_, l_smem, r_smem, o_smem): o_smem[...] = l_smem[...] + r_smem[...] r = lax.axis_index("rows") block = plgpu.GPUBlockSpec( @@ -2707,7 +2707,7 @@ def test_stage6(self): self.kernel, out_shape=x, grid=(2, 2), grid_names=("m", "n") ) def kernel(l_ref, r_ref, o_ref): - def compute(l_smem, r_smem, o_smem): + def compute(_, l_smem, r_smem, o_smem): def do_wgmma(acc_ref): plgpu.wgmma(acc_ref, l_smem, r_smem) return acc_ref[...] From 0ec1251d9ec6f8995ff50e35cecebc5a11afc71c Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Thu, 3 Apr 2025 05:06:06 -0700 Subject: [PATCH 353/483] [Mosaic GPU] Get rid of `LayoutAttr` and related comments. This is no longer used, since we elected to refine the IR by annotating it with `{in,out}_transforms` in the lowering pipeline instead. PiperOrigin-RevId: 743516621 --- jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc | 21 --------- .../dialect/gpu/integrations/c/attributes.cc | 34 -------------- .../dialect/gpu/integrations/c/attributes.h | 16 ------- jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc | 8 ---- jaxlib/mosaic/dialect/gpu/mosaic_gpu.td | 45 +++---------------- 5 files changed, 6 insertions(+), 118 deletions(-) diff --git a/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc b/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc index c73084abc99d..2751719fc61d 100644 --- a/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc +++ b/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc @@ -138,25 +138,4 @@ NB_MODULE(_mosaic_gpu_ext, m) { .def_property_readonly("swizzle", [](MlirAttribute self) { return mlirMosaicGpuSwizzleTransformAttrGetSwizzle(self); }); - - mlir::python::nanobind_adaptors::mlir_attribute_subclass( - m, "LayoutAttr", mlirMosaicGpuIsALayoutAttr) - .def_classmethod( - "get", - [](nb::object cls, int32_t num_dimensions, - std::vector& transforms, MlirContext ctx) { - return cls(mlirMosaicGpuLayoutAttrGet( - ctx, num_dimensions, transforms.data(), transforms.size())); - }, - nb::arg("cls"), nb::arg("num_dimensions"), nb::arg("transforms"), - nb::arg("context").none() = nb::none(), - "Creates a LayoutAttr with the given transforms.") - .def_property_readonly("transforms", [](MlirAttribute self) { - std::vector result; - for (int i = 0; i < mlirMosaicGpuLayoutAttrGetTransformsSize(self); - ++i) { - result.push_back(mlirMosaicGpuLayoutAttrGetTransform(self, i)); - } - return result; - }); } diff --git a/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.cc b/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.cc index 259c37fe5d07..523b14e425c9 100644 --- a/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.cc +++ b/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.cc @@ -16,7 +16,6 @@ limitations under the License. #include "jaxlib/mosaic/dialect/gpu/integrations/c/attributes.h" #include -#include #include "mlir-c/IR.h" #include "mlir/CAPI/IR.h" @@ -97,36 +96,3 @@ int32_t mlirMosaicGpuSwizzleTransformAttrGetSwizzle(MlirAttribute attr) { .getSwizzle() .getValue()); } - -//===----------------------------------------------------------------------===// -// LayoutAttr -//===----------------------------------------------------------------------===// - -bool mlirMosaicGpuIsALayoutAttr(MlirAttribute attr) { - return mlir::isa(unwrap(attr)); -} - -MlirAttribute mlirMosaicGpuLayoutAttrGet(MlirContext ctx, - int32_t num_dimensions, - MlirAttribute* transforms, - int32_t transforms_size) { - std::vector unwrapped_transforms; - unwrapped_transforms.reserve(transforms_size); - for (int i = 0; i < transforms_size; ++i) { - unwrapped_transforms.push_back(unwrap(transforms[i])); - } - return wrap(mosaic_gpu::LayoutAttr::get(unwrap(ctx), num_dimensions, - unwrapped_transforms)); -} - -int32_t mlirMosaicGpuLayoutAttrGetTransformsSize(MlirAttribute attr) { - return mlir::cast(unwrap(attr)) - .getTransforms() - .size(); -} - -MlirAttribute mlirMosaicGpuLayoutAttrGetTransform(MlirAttribute attr, - int32_t index) { - return wrap( - mlir::cast(unwrap(attr)).getTransforms()[index]); -} \ No newline at end of file diff --git a/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.h b/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.h index 3b8425b6b142..3221b9220e5d 100644 --- a/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.h +++ b/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.h @@ -69,22 +69,6 @@ mlirMosaicGpuSwizzleTransformAttrGet(MlirContext ctx, int32_t swizzle); MLIR_CAPI_EXPORTED int32_t mlirMosaicGpuSwizzleTransformAttrGetSwizzle(MlirAttribute attr); -//===----------------------------------------------------------------------===// -// LayoutAttr -//===----------------------------------------------------------------------===// - -MLIR_CAPI_EXPORTED bool mlirMosaicGpuIsALayoutAttr(MlirAttribute attr); - -MLIR_CAPI_EXPORTED MlirAttribute -mlirMosaicGpuLayoutAttrGet(MlirContext ctx, int32_t num_dimensions, - MlirAttribute* transforms, int32_t transforms_size); - -MLIR_CAPI_EXPORTED int32_t -mlirMosaicGpuLayoutAttrGetTransformsSize(MlirAttribute attr); - -MLIR_CAPI_EXPORTED MlirAttribute -mlirMosaicGpuLayoutAttrGetTransform(MlirAttribute attr, int32_t index); - #ifdef __cplusplus } #endif diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc index 0a36aab6fbcd..073697df58ef 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc @@ -407,14 +407,6 @@ llvm::LogicalResult CustomPrimitiveOp::verify() { return llvm::success(); } -mlir::AffineMap LayoutAttr::getAffineMap() const { - // This always returns an identity map. It's technically not correct, but we - // don't actually use it anywhere. It's only called during verification of the - // layout attribute and needs to be semi-valid. - return mlir::AffineMap::getMultiDimIdentityMap(getNumDimensions(), - getContext()); -} - void MosaicGPUDialect::initialize() { addTypes< #define GET_TYPEDEF_LIST diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td index 0d954716b179..cda521855250 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td @@ -225,27 +225,6 @@ def SwizzleTransformAttr : MosaicGPU_Attr<"SwizzleTransform", "swizzle"> { let assemblyFormat = "`<` $swizzle `>`"; } -def LayoutAttr : MosaicGPU_Attr<"Layout", "layout", - [DeclareAttrInterfaceMethods]> { - let parameters = (ins - TypeParameter<"int32_t", "number of dimensions">:$num_dimensions, - ArrayRefParameter<"mlir::Attribute", "transforms">:$transforms - ); - - let summary = "Specifies a layout of a memref in SMEM."; - let description = [{ - This layout attribute is used to specify the layout of a memref in SMEM. - It is composed of a number of transforms, which are applied in the order - they are provided. The transforms can be any combination of: - - TileTransformAttr - - TransposeTransformAttr - - SwizzleTransformAttr - - The num_dimensions parameter must match the rank of the memref shape. - }]; - let assemblyFormat = "`<` $num_dimensions `,` $transforms `>`"; -} - def MosaicGPU_AsyncLoadOp : Op { let summary = "Schedules an async load of a MemRef from GMEM to SMEM"; @@ -265,16 +244,9 @@ def MosaicGPU_AsyncLoadOp : Op Date: Thu, 3 Apr 2025 06:18:14 -0700 Subject: [PATCH 354/483] Update XLA dependency to use revision http://github.com/openxla/xla/commit/921c164a67e8ac4cf052aab26e849f29b719f802. PiperOrigin-RevId: 743535272 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 90a19ac95e51..c30648a2b3a1 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "c3087e022f3c07f7ed1dd4e47024c437a504341b" -XLA_SHA256 = "66457303ddec4dbbe43accf38a8b6b635d55808938cf2495443b09ee9c95a147" +XLA_COMMIT = "921c164a67e8ac4cf052aab26e849f29b719f802" +XLA_SHA256 = "9e734da4a0211ac09a00cc07969645e31f107cfee19bbc5d2d1e21ddbb19090d" def repo(): tf_http_archive( From 8d59902e735dbf17dcc7c70bb4c76f858eb93dde Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 3 Apr 2025 15:00:29 +0100 Subject: [PATCH 355/483] Fix problem finding clang++ when building JAX via build.py on windows. It's important we use the un-stemmed name because on Windows there is an .exe suffix. --- build/tools/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build/tools/utils.py b/build/tools/utils.py index ccce8aff09cc..c52b89a1e6d2 100644 --- a/build/tools/utils.py +++ b/build/tools/utils.py @@ -204,7 +204,7 @@ def get_clang_major_version(clang_path): def get_clangpp_path(clang_path): clang_path = pathlib.Path(clang_path) - clang_exec_name = clang_path.stem + clang_exec_name = clang_path.name clangpp_exec_name = clang_exec_name if "clang++" not in clang_exec_name: clangpp_exec_name = clang_exec_name.replace("clang", "clang++") From 91b0884ad131ebddd69951927533b3ab12ec4113 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 3 Apr 2025 08:05:22 -0700 Subject: [PATCH 356/483] Restrict the regex for copying the wheels. The change is made to address the case when bazel dir has multiple wheels with different version suffixes. We need to copy only those wheels that were created by the current execution of build.py script. PiperOrigin-RevId: 743566122 --- build/build.py | 39 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/build/build.py b/build/build.py index 226d984b3d89..87aa36aeba8b 100755 --- a/build/build.py +++ b/build/build.py @@ -389,6 +389,11 @@ async def main(): arch = platform.machine() os_name = platform.system().lower() + custom_wheel_version_suffix = "" + wheel_build_date = "" + wheel_git_hash = "" + wheel_type = "snapshot" + args = parser.parse_args() logger.info("%s", BANNER) @@ -621,6 +626,17 @@ async def main(): ) for option in args.bazel_options: wheel_build_command_base.append(option) + + # Parse the build options for the wheel version suffix. + if "ML_WHEEL_TYPE" in option: + wheel_type = option.split("=")[-1] + if "ML_WHEEL_VERSION_SUFFIX" in option: + custom_wheel_version_suffix = option.split("=")[-1].replace("-", "") + if "ML_WHEEL_BUILD_DATE" in option: + wheel_build_date = option.split("=")[-1].replace("-", "") + if "ML_WHEEL_GIT_HASH" in option: + wheel_git_hash = option.split("=")[-1][:9] + if "cuda" in args.wheels: wheel_build_command_base.append("--config=cuda_libraries_from_stubs") @@ -729,10 +745,29 @@ async def main(): dst_dir = os.path.join(output_path, wheel_dir) utils.copy_dir_recursively(src_dir, dst_dir) else: - utils.copy_individual_files(bazel_dir, output_path, f"{wheel_dir}*.whl") + wheel_version_suffix = "dev0+selfbuilt" + if wheel_type == "release": + wheel_version_suffix = custom_wheel_version_suffix + elif wheel_type in ["nightly", "custom"]: + wheel_version_suffix = f".dev{wheel_build_date}" + if wheel_type == "custom": + wheel_version_suffix += ( + f"+{wheel_git_hash}{custom_wheel_version_suffix}" + ) + if wheel in ["jax", "jax-cuda-pjrt"]: + python_tag = "py" + else: + python_tag = "cp" + utils.copy_individual_files( + bazel_dir, + output_path, + f"{wheel_dir}*{wheel_version_suffix}-{python_tag}*.whl", + ) if wheel == "jax": utils.copy_individual_files( - bazel_dir, output_path, f"{wheel_dir}*.tar.gz" + bazel_dir, + output_path, + f"{wheel_dir}*{wheel_version_suffix}.tar.gz", ) # Exit with success if all wheels in the list were built successfully. From 1941714d261daffd3f164d87a3bf8dd89d996211 Mon Sep 17 00:00:00 2001 From: George Necula Date: Thu, 3 Apr 2025 10:25:02 +0100 Subject: [PATCH 357/483] [export] Add support for override_lowering_rules to jax.export. This parameter is already part of the internal API for the AOT lowering function, here we just expose it to `jax.export`. --- jax/_src/config.py | 2 +- jax/_src/export/_export.py | 15 +++++++++++++-- jax/_src/interpreters/mlir.py | 2 +- tests/export_test.py | 13 ++++++++++++- 4 files changed, 27 insertions(+), 5 deletions(-) diff --git a/jax/_src/config.py b/jax/_src/config.py index 5b8b87be2095..b4a12dcc1762 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -998,7 +998,7 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: name='jax_explain_cache_misses', default=False, help=('Each time there is a miss on one of the main caches (e.g. the ' - 'tracing cache), log an explanation.. Logging is performed with ' + 'tracing cache), log an explanation. Logging is performed with ' '`logging`. When this option is set, the log level is WARNING; ' 'otherwise the level is DEBUG.')) diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index 9b6a0f80930f..90cc0c186ad1 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -529,6 +529,7 @@ def export( *, platforms: Sequence[str] | None = None, disabled_checks: Sequence[DisabledSafetyCheck] = (), + _override_lowering_rules: Sequence[tuple[Any, Any]] | None = None ) -> Callable[..., Exported]: """Exports a JAX function for persistent serialization. @@ -541,6 +542,13 @@ def export( If None, then use the default JAX backend. The calling convention for multiple platforms is explained at https://jax.readthedocs.io/en/latest/export/export.html#module-calling-convention. + _override_lowering_rules: an optional sequence of custom lowering rules + for some JAX primitives. Each element of the sequence is a pair + of a JAX primitive and a lowering function. Defining lowering rules + is an advanced feature using JAX internal APIs, which are subject + to change. Furthermore, the responsibility for the stability of the + MLIR emitted through these custom lowering rules, rests with the user + of these rules. disabled_checks: the safety checks to disable. See documentation for of `jax.export.DisabledSafetyCheck`. @@ -568,7 +576,8 @@ def export( Array([0.09983342, 0.19866933, 0.29552022, 0.38941833], dtype=float32) """ return _export_internal(fun_jit, platforms=platforms, - disabled_checks=disabled_checks) + disabled_checks=disabled_checks, + override_lowering_rules=_override_lowering_rules) # TODO(necula): remove this once we improve the integration with jax2tf. @@ -577,7 +586,8 @@ def _export_internal( *, platforms: Sequence[str] | None = None, disabled_checks: Sequence[DisabledSafetyCheck] = (), - _device_assignment_for_internal_jax2tf_use_only = None, + _device_assignment_for_internal_jax2tf_use_only=None, + override_lowering_rules=None, ) -> Callable[..., Exported]: """Exports native serialization for a JAX function. @@ -604,6 +614,7 @@ def do_export(*args_specs, **kwargs_specs) -> Exported: lowered = traced.lower( lowering_platforms=actual_lowering_platforms, _private_parameters=mlir.LoweringParameters( + override_lowering_rules=override_lowering_rules, for_export=True, export_ignore_forward_compatibility=config.export_ignore_forward_compatibility.value)) return _export_lowered( diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index a1b37876f87e..a112063ce3ae 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -663,7 +663,7 @@ def __init__(self, @dataclasses.dataclass(frozen=True) class LoweringParameters: # A mapping between primitives and user-defined LoweringRules. - # When lowering a primitive, give priorioty to the rule in this map over + # When lowering a primitive, give priority to the rule in this map over # existing Jax rules. override_lowering_rules: tuple[tuple[core.Primitive, LoweringRule]] | None = None diff --git a/tests/export_test.py b/tests/export_test.py index 0b78a29a8e6a..2264fbdd997b 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -19,7 +19,6 @@ import dataclasses import functools import logging -import json import math import re import unittest @@ -281,6 +280,18 @@ def test_unused_args(self): self.assertAllClose(f(x, y), exp_f.call(x, y)) + def test_override_lowering_rules(self): + @jax.jit + def f(x): + return jnp.sin(x) + + def my_lowering_rule(ctx, arg, **_): + return mlir.hlo.CosineOp(arg).results + + exp = get_exported(f, _override_lowering_rules=( + (lax.sin_p, my_lowering_rule),))(42.) + self.assertIn("stablehlo.cosine", exp.mlir_module()) + def test_pytree(self): a = np.arange(4, dtype=np.float32) b = np.arange(6, dtype=np.float32) From f2f9152d573fb6f09ce2a500d5602b4aea14075b Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 3 Apr 2025 08:27:10 -0700 Subject: [PATCH 358/483] Moved the `jax.Array` baseclass to C++ This allows `ArrayImpl` to directly subclass `jax.Array` without relying on the expensive virtual subclasses machinery from `abc`. PiperOrigin-RevId: 743573028 --- jax/BUILD | 1 + jax/_src/array.py | 6 +-- jax/_src/basearray.py | 49 ++++++++++++------- jax/_src/basearray.pyi | 8 +-- jaxlib/xla/py_array.cc | 70 ++++++++++++++++++++++++++- jaxlib/xla/xla_client.py | 3 +- jaxlib/xla/xla_extension/__init__.pyi | 1 + 7 files changed, 111 insertions(+), 27 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index 5d37a8987445..f5745df0e5bf 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -400,6 +400,7 @@ pytype_strict_library( deps = [ ":partition_spec", ":sharding", + ":util", "//jax/_src/lib", ] + py_deps("numpy"), ) diff --git a/jax/_src/array.py b/jax/_src/array.py index ee196026887d..760593da9fa9 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -39,6 +39,7 @@ from jax._src.layout import AutoLayout, DeviceLocalLayout, Layout from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension as xe +from jax._src.lib import jaxlib_extension_version from jax._src.sharding import Sharding from jax._src.sharding_impls import ( PmapSharding, SingleDeviceSharding, @@ -1093,9 +1094,8 @@ def _get_aval_array(self): return core.update_aval_with_sharding(self.aval, self.sharding) core.pytype_aval_mappings[ArrayImpl] = _get_aval_array -# TODO(jakevdp) replace this with true inheritance at the C++ level. -basearray.Array.register(ArrayImpl) - +if jaxlib_extension_version < 325: + basearray.Array.register(ArrayImpl) def _array_mlir_constant_handler(val): try: diff --git a/jax/_src/basearray.py b/jax/_src/basearray.py index fbd14d157e78..3aabba4440ec 100644 --- a/jax/_src/basearray.py +++ b/jax/_src/basearray.py @@ -17,10 +17,15 @@ from __future__ import annotations import abc +from collections.abc import Sequence import sys -import numpy as np from typing import Any, Union -from collections.abc import Sequence + +from jax._src.lib import jaxlib_extension_version +from jax._src.lib import xla_client as xc +from jax._src.util import use_cpp_class +import numpy as np + # TODO(jakevdp): fix import cycles and define these. Device = Any @@ -30,7 +35,9 @@ # Array is a type annotation for standard JAX arrays and tracers produced by # core functions in jax.lax and jax.numpy; it is not meant to include # future non-standard array types like KeyArray and BInt. -class Array(abc.ABC): + + +class Array: """Array base class for JAX ``jax.Array`` is the public interface for instance checks and type annotation @@ -48,8 +55,6 @@ def f(x: Array) -> Array: # type annotations are valid for traced and non-trace :func:`jax.numpy.array`, :func:`jax.numpy.zeros`, :func:`jax.numpy.ones`, :func:`jax.numpy.full`, :func:`jax.numpy.arange`, etc. """ - # Note: abstract methods for this class are defined dynamically in - # lax_numpy.py # For the sake of static type analysis, these definitions are mirrored in the # associated basearray.pyi file. @@ -57,42 +62,41 @@ def f(x: Array) -> Array: # type annotations are valid for traced and non-trace __hash__ = None @property - @abc.abstractmethod def dtype(self) -> np.dtype: """The data type (:class:`numpy.dtype`) of the array.""" + raise NotImplementedError @property - @abc.abstractmethod def ndim(self) -> int: """The number of dimensions in the array.""" + raise NotImplementedError @property - @abc.abstractmethod def size(self) -> int: """The total number of elements in the array.""" + raise NotImplementedError @property - @abc.abstractmethod def shape(self) -> tuple[int, ...]: """The shape of the array.""" + raise NotImplementedError # Documentation for sharding-related methods and properties defined on ArrayImpl: - @abc.abstractmethod def addressable_data(self, index: int) -> Array: """Return an array of the addressable data at a particular index.""" + raise NotImplementedError @property - @abc.abstractmethod def addressable_shards(self) -> Sequence[Shard]: """List of addressable shards.""" + raise NotImplementedError @property - @abc.abstractmethod def global_shards(self) -> Sequence[Shard]: """List of global shards.""" + raise NotImplementedError @property - @abc.abstractmethod def is_fully_addressable(self) -> bool: """Is this Array fully addressable? @@ -104,19 +108,19 @@ def is_fully_addressable(self) -> bool: a jax.Array which is fully replicated can span across multiple hosts and is not fully addressable. """ + raise NotImplementedError @property - @abc.abstractmethod def is_fully_replicated(self) -> bool: """Is this Array fully replicated?""" + raise NotImplementedError @property - @abc.abstractmethod def sharding(self) -> Sharding: """The sharding for the array.""" + raise NotImplementedError @property - @abc.abstractmethod def committed(self) -> bool: """Whether the array is committed or not. @@ -141,17 +145,17 @@ def committed(self) -> bool: See https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices for more information. """ + raise NotImplementedError @property - @abc.abstractmethod def device(self) -> Device | Sharding: """Array API-compatible device attribute. For single-device arrays, this returns a Device. For sharded arrays, this returns a Sharding. """ + raise NotImplementedError - @abc.abstractmethod def copy_to_host_async(self): """Copies an ``Array`` to the host asynchronously. @@ -166,10 +170,19 @@ def copy_to_host_async(self): array, but does not wait for the copy to complete. This may speed up a future on-host access to the array's contents. """ + raise NotImplementedError + + +if jaxlib_extension_version >= 325: + Array = use_cpp_class(xc.Array)(Array) +else: + class Array(Array, metaclass=abc.ABCMeta): + ... Array.__module__ = "jax" + # StaticScalar is the Union of all scalar types that can be converted to # JAX arrays, and are possible to mark as static arguments. StaticScalar = Union[ diff --git a/jax/_src/basearray.pyi b/jax/_src/basearray.pyi index a368b593332d..8bf68f622051 100644 --- a/jax/_src/basearray.pyi +++ b/jax/_src/basearray.pyi @@ -14,11 +14,12 @@ import abc from collections.abc import Callable, Sequence from types import ModuleType -from typing import Any, Protocol, Union, runtime_checkable +from typing import Any, Protocol, runtime_checkable, Union import numpy as np -from jax._src.sharding import Sharding from jax._src.partition_spec import PartitionSpec +from jax._src.sharding import Sharding + # TODO(jakevdp) de-duplicate this with the DTypeLike definition in typing.py. # We redefine these here to prevent circular imports. @@ -39,7 +40,8 @@ Traceback = Any PrecisionLike = Any -class Array(abc.ABC): +# TODO(slebedev): Remove the metaclass once ``jax_extension_version >= 325``. +class Array(metaclass=abc.ABCMeta): aval: Any @property diff --git a/jaxlib/xla/py_array.cc b/jaxlib/xla/py_array.cc index a1937bc80327..ce5ceacbad99 100644 --- a/jaxlib/xla/py_array.cc +++ b/jaxlib/xla/py_array.cc @@ -237,12 +237,33 @@ tsl::RCReference CreateIfRtArrayFromSingleDeviceShardedPyArrays( return *std::move(ifrt_array); } +struct PyBaseArrayObject { + PyObject_HEAD; +#if PY_VERSION_HEX < 0x030C0000 + PyObject* weakrefs; +#endif // PY_VERSION_HEX < 0x030C0000 +}; + +extern "C" void PyBaseArray_tp_dealloc(PyBaseArrayObject* self) { + PyObject_GC_UnTrack(self); + PyObject_ClearWeakRefs((PyObject*)self); + PyTypeObject* tp = Py_TYPE(self); + tp->tp_free((PyObject*)self); + Py_DECREF(tp); +} + +extern "C" int PyBaseArray_tp_traverse(PyObject* self, visitproc visit, + void* arg) { + Py_VISIT(Py_TYPE(self)); + return 0; +} + struct PyArrayObject { PyObject_HEAD; #if PY_VERSION_HEX < 0x030C0000 PyObject* weakrefs; PyObject* dict; -#endif // PY_VERSION_HEX < 0x030B0000 +#endif // PY_VERSION_HEX < 0x030C0000 bool initialized; alignas(PyArray::Storage) char array_storage[sizeof(PyArray::Storage)]; }; @@ -1879,6 +1900,23 @@ absl::Status PyHostValue::CopyToHostAsync( } namespace { +PyMemberDef PyBaseArray_members[] = { +#if PY_VERSION_HEX < 0x030C0000 + {"__weaklistoffset__", T_PYSSIZET, + static_cast(offsetof(PyBaseArrayObject, weakrefs)), READONLY, + nullptr}, +#endif // PY_VERSION_HEX < 0x030C0000 + {nullptr, 0, 0, 0, nullptr}, +}; + +PyType_Slot PyBaseArray_slots[] = { + {Py_tp_dealloc, reinterpret_cast(PyBaseArray_tp_dealloc)}, + {Py_tp_members, reinterpret_cast(PyBaseArray_members)}, + {Py_tp_traverse, reinterpret_cast(PyBaseArray_tp_traverse)}, + {Py_tp_hash, reinterpret_cast(PyObject_HashNotImplemented)}, + {0, nullptr}, +}; + PyGetSetDef PyArray_tp_getset[] = { {"__dict__", PyObject_GenericGetDict, PyObject_GenericSetDict, nullptr, nullptr}, @@ -1911,6 +1949,34 @@ PyType_Slot PyArray_slots[] = { } // namespace absl::Status PyArray::RegisterTypes(nb::module_& m) { + // We are not using nanobind to avoid having a non-standard metaclass, which + // would make Array incompatible with abc.ABCMeta. + std::string base_name = + absl::StrCat(nb::cast(m.attr("__name__")), ".Array"); + PyType_Spec PyBaseArray_spec = { +#if PY_VERSION_HEX < 0x030B0000 + // Work around for https://github.com/python/cpython/issues/89478 + // CPython 3.10 and earlier assume that the .name value remains alive + // forever. + /*.name=*/strdup(base_name.c_str()), +#else + /*.name=*/base_name.c_str(), +#endif // PY_VERSION_HEX < 0x030B0000 + /*.basicsize=*/static_cast(sizeof(PyBaseArrayObject)), + /*.itemsize=*/0, +#if PY_VERSION_HEX < 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, +#else // PY_VERSION_HEX >= 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC | + Py_TPFLAGS_MANAGED_WEAKREF, +#endif // PY_VERSION_HEX >= 0x030C0000 + /*.slots=*/PyBaseArray_slots}; + auto* base_type = PyType_FromSpec(&PyBaseArray_spec); + if (!base_type) { + throw nb::python_error(); + } + m.attr("Array") = nb::borrow(base_type); + std::string name = absl::StrCat(nb::cast(m.attr("__name__")), ".ArrayImpl"); @@ -1934,7 +2000,7 @@ absl::Status PyArray::RegisterTypes(nb::module_& m) { /*.slots=*/PyArray_slots, }; - type_ = PyType_FromSpec(&PyArray_spec); + type_ = PyType_FromSpecWithBases(&PyArray_spec, base_type); if (!type_) { throw nb::python_error(); } diff --git a/jaxlib/xla/xla_client.py b/jaxlib/xla/xla_client.py index af751a00ab25..523f8bb57b90 100644 --- a/jaxlib/xla/xla_client.py +++ b/jaxlib/xla/xla_client.py @@ -50,7 +50,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version. -_version = 324 +_version = 325 # An internal increasing version number for protecting jaxlib code against # ifrt changes. @@ -486,6 +486,7 @@ def window_padding_type_to_pad_values( FftType = _xla.FftType Client = _xla.Client Memory = _xla.Memory +Array = _xla.Array ArrayImpl = _xla.ArrayImpl LoadedExecutable = _xla.LoadedExecutable DeviceList = _xla.DeviceList diff --git a/jaxlib/xla/xla_extension/__init__.pyi b/jaxlib/xla/xla_extension/__init__.pyi index d002080b17bc..2d759236b8c5 100644 --- a/jaxlib/xla/xla_extension/__init__.pyi +++ b/jaxlib/xla/xla_extension/__init__.pyi @@ -656,6 +656,7 @@ def pjrt_plugin_loaded(plugin_name: str) -> bool: ... def pjrt_plugin_initialized(plugin_name: str) -> bool: ... def initialize_pjrt_plugin(platform_name: str) -> _Status: ... +Array = Any ArrayImpl = Any # TODO(phawkins): this type is problematic because it is not a subtype of From e7a5147638ba7ef2ede25ee7e2b7b29ea355a495 Mon Sep 17 00:00:00 2001 From: Vladimir Belitskiy Date: Thu, 3 Apr 2025 08:43:29 -0700 Subject: [PATCH 359/483] Bump up tolerance in ShardMapSystematicTest.test_vmap_closure for GPUs. There's a mismatch between the resulting and the desired matrixes on H100, but not the older GPUs. PiperOrigin-RevId: 743578025 --- tests/shard_map_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 36966fde2a90..520cc02638df 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -3091,7 +3091,7 @@ def g(*args): else: slices = map(jnp.stack, zip(*expected_slices)) expected = jax.tree.unflatten(treedef, slices) - tol = 1e-2 if jtu.test_device_matches(['tpu']) else None + tol = 1e-2 if jtu.test_device_matches(['gpu', 'tpu']) else None self.assertAllClose(ans, expected, check_dtypes=False, atol=tol, rtol=tol) @jtu.pytest_mark_if_available('multiaccelerator') From d1009a3bcda3aeb4667f9822a2379c1bf7718b56 Mon Sep 17 00:00:00 2001 From: Yu-Hang Tang Date: Thu, 3 Apr 2025 16:35:11 +0000 Subject: [PATCH 360/483] Only trigger K8s CI on changes to cluster config and distributed initialize --- .github/workflows/k8s.yaml | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/.github/workflows/k8s.yaml b/.github/workflows/k8s.yaml index 31ee05a03482..a96ce1ead26c 100644 --- a/.github/workflows/k8s.yaml +++ b/.github/workflows/k8s.yaml @@ -4,9 +4,17 @@ on: push: branches: - main + paths: + - 'jax/distributed.py' + - 'jax/_src/distributed.py' + - 'jax/_src/clusters/**' pull_request: branches: - main + paths: + - 'jax/distributed.py' + - 'jax/_src/distributed.py' + - 'jax/_src/clusters/**' permissions: contents: read From 42735d04f1249026ce2fd223e20a02d78c18ff7b Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 3 Apr 2025 09:44:12 -0700 Subject: [PATCH 361/483] Not to use dynamic grid in the ragged paged attention Pallas kernel. We found a hanging issue when we use dynamic grid. We'll disable it for now. PiperOrigin-RevId: 743597352 --- jax/experimental/pallas/ops/tpu/ragged_paged_attention.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py b/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py index e1eacee550a7..d775e1331bcb 100644 --- a/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py +++ b/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py @@ -315,7 +315,9 @@ def prefetch_first_kv_blk(): def is_cur_q_blk_needed(q_states): done, cur_seq_idx, _ = q_states - return jnp.logical_and(done == 0, cur_seq_idx < num_seqs) + should_run = jnp.logical_and(q_len_start < cu_q_lens_ref[num_seqs], + cur_seq_idx < num_seqs) + return jnp.logical_and(done == 0, should_run) def compute_with_cur_q_blk(q_states): done, cur_seq_idx, cur_buf_idx = q_states @@ -680,14 +682,14 @@ def ragged_paged_attention( validate_static_inputs(q, kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs, sliding_window, soft_cap) if mask_value is None: mask_value = DEFAULT_MASK_VALUE - _, num_q_heads, head_dim = q.shape + num_q_tokens, num_q_heads, head_dim = q.shape _, page_size, num_combined_kv_heads, _ = kv_pages.shape assert num_combined_kv_heads % 2 == 0 num_kv_heads = num_combined_kv_heads // 2 num_q_per_blk = num_queries_per_block num_kv_pages_per_blk = num_kv_pages_per_block num_q_heads_per_kv_head = num_q_heads // num_kv_heads - num_q_blks = cdiv(cu_q_lens[num_seqs[0]], num_q_per_blk) + num_q_blks = cdiv(num_q_tokens, num_q_per_blk) num_q_heads_per_blk, num_combined_kv_heads_per_blk = get_min_heads_per_blk( num_q_heads, num_combined_kv_heads, q.dtype, kv_pages.dtype ) From 780c8827f296fa692cb08bb9a1abd198e8cf8efe Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Thu, 3 Apr 2025 11:19:11 -0700 Subject: [PATCH 362/483] [Mosaic GPU] Fix index_invariant slot in warp-specialized pipeline. PiperOrigin-RevId: 743633331 --- jax/_src/pallas/mosaic_gpu/pipeline.py | 3 ++- tests/pallas/mosaic_gpu_test.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index ec088f43c4b2..257ecbf5da4a 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -629,7 +629,8 @@ def compute_loop_body(step, carry): last_slot = lax.rem(num_pipeline_steps - 1, max_concurrent_steps) for bref in out_brefs: if bref.is_index_invariant: - bref.copy_out(last_slot, last_indices, predicate=None) + bref.copy_out(_get_slot(last_slot, has_seq_dim=False), + last_indices, predicate=None) gpu_primitives.commit_smem_to_gmem_group() diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 06dfd453fb19..8fd98f62eab1 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -2142,6 +2142,7 @@ class WarpSpecializedPipelineTest(PallasTest): @parameterized.product(m=[512], n=[512], manual_consumed_barriers=[False, True]) def test_pipelined_copy(self, m, n, manual_consumed_barriers): + self.skipTest("TODO(justinfu): Temporary skip for 3.12 update.") self.skip_if_wg_semantics() # Times out! x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float16) From c2eb9c1d9eff42eb05cac697759bcf8a5aeaf805 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 3 Apr 2025 13:01:35 -0700 Subject: [PATCH 363/483] Eliminate DeprecationWarning in python3.12+ in jax pallas for ~. The code was using ~ with a boolean, which leads to a new DeprecationWarning. That should only be used with ints. PiperOrigin-RevId: 743668386 --- jax/_src/pallas/mosaic_gpu/pipeline.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index 257ecbf5da4a..21efbbec6630 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -429,9 +429,9 @@ def body( # Trace the index maps to determine if they depend on the grid. # Grid-independent values will not be multiple-buffered. in_spec_has_seq_axis = [ - ~_is_index_invariant(spec, grid) for spec in in_specs] + not _is_index_invariant(spec, grid) for spec in in_specs] out_spec_has_seq_axis = [ - ~_is_index_invariant(spec, grid) for spec in out_specs] + not _is_index_invariant(spec, grid) for spec in out_specs] spec_has_seq_axis = [*in_spec_has_seq_axis, *out_spec_has_seq_axis] num_pipeline_steps = math.prod(grid) @@ -516,13 +516,13 @@ def scoped_pipeline( consumed_barrier_refs, ): in_brefs: Sequence[BufferedRef] = [ - BufferedRef(spec, ~has_seq_axis, gmem_ref, smem_ref) + BufferedRef(spec, not has_seq_axis, gmem_ref, smem_ref) for spec, has_seq_axis, gmem_ref, smem_ref in zip( in_specs, in_spec_has_seq_axis, in_gmem_refs, in_smem_refs ) ] out_brefs: Sequence[BufferedRef] = [ - BufferedRef(spec, ~has_seq_axis, gmem_ref, smem_ref) + BufferedRef(spec, not has_seq_axis, gmem_ref, smem_ref) for spec, has_seq_axis, gmem_ref, smem_ref in zip( out_specs, out_spec_has_seq_axis, out_gmem_refs, out_smem_refs ) @@ -553,7 +553,7 @@ def compute_loop_body(step, carry): body_refs = [] for bref in it.chain(in_brefs, out_brefs): - buf_slot = _get_slot(slot, ~bref.is_index_invariant) + buf_slot = _get_slot(slot, not bref.is_index_invariant) body_refs.append(bref.get_ref_for_slot(buf_slot)) body_args = body_refs @@ -586,7 +586,7 @@ def compute_loop_body(step, carry): new_store_slices[idx], ) slices_changed = ~functools.reduce(lax.bitwise_and, are_same_slices) - bref.copy_out(_get_slot(slot, ~bref.is_index_invariant), + bref.copy_out(_get_slot(slot, not bref.is_index_invariant), indices, predicate=slices_changed) gpu_primitives.commit_smem_to_gmem_group() @@ -645,7 +645,7 @@ def memory_block(): # Begin initial copies. for step in range(max_concurrent_steps): for bref, barrier in zip(in_brefs, in_smem_barrier_refs): - buf_slot = _get_slot(step, ~bref.is_index_invariant) + buf_slot = _get_slot(step, not bref.is_index_invariant) bref.copy_in(buf_slot, indices, barrier) indices = _inc_grid_by_1(indices, grid) @@ -668,7 +668,7 @@ def memory_loop_body(step, carry): if manual_consumed_barriers: gpu_primitives.barrier_wait(consumed_barrier.at[slot]) # pytype: disable=attribute-error bref.copy_in( - _get_slot(fetch_slot, ~bref.is_index_invariant), indices, barrier) + _get_slot(fetch_slot, not bref.is_index_invariant), indices, barrier) next_indices = _inc_grid_by_1(indices, grid) return (next_indices,) lax.fori_loop(0, num_pipeline_steps - max_concurrent_steps, From 41868ef06dd7e7da88f800071da040c6819b5707 Mon Sep 17 00:00:00 2001 From: kaixih Date: Thu, 3 Apr 2025 21:46:10 +0000 Subject: [PATCH 364/483] format --- jax/_src/nn/functions.py | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index d0f5f770e196..27436b01216a 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -1218,11 +1218,11 @@ def scaled_matmul( ) -> Array: r"""Scaled matrix multiplication function. - Performs block-scaled matmul of `a` and `b` using `a_scales` and `b_scales`. + Performs block-scaled matmul of `a` and `b` using `a_scales` and `b_scales`. The last dim is the contracting dim, and block size is inferred. Mathematically, this operation is equivalent to:: - + a_block_size = a.shape[-1] // a_scales.shape[-1] b_block_size = b.shape[-1] // b_scales.shape[-1] a_scaled = a * jnp.repeat(a_scales, a_block_size, axis=-1) @@ -1258,26 +1258,26 @@ def scaled_matmul( Basic case: - >>> a = jnp.array([1, 2, 3]).reshape((1, 1, 3)) - >>> b = jnp.array([4, 5, 6]).reshape((1, 1, 3)) - >>> a_scales = jnp.array([0.5]).reshape((1, 1, 1)) - >>> b_scales = jnp.array([0.5]).reshape((1, 1, 1)) - >>> scaled_matmul(a, b, a_scales, b_scales) - Array([[[8.]]], dtype=float32) - + >>> a = jnp.array([1, 2, 3]).reshape((1, 1, 3)) + >>> b = jnp.array([4, 5, 6]).reshape((1, 1, 3)) + >>> a_scales = jnp.array([0.5]).reshape((1, 1, 1)) + >>> b_scales = jnp.array([0.5]).reshape((1, 1, 1)) + >>> scaled_matmul(a, b, a_scales, b_scales) + Array([[[8.]]], dtype=float32) + Using fused cuDNN call on Blackwell GPUs: - >>> a = random.normal(keys[0], (3, 128, 64), dtype=jnp.float8_e4m3fn) - >>> b = random.normal(keys[1], (3, 128, 64), dtype=jnp.float8_e4m3fn) - >>> a_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu) - >>> b_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu) - >>> scaled_matmul(a, b, a_scales, b_scales) + >>> a = random.normal(keys[0], (3, 128, 64), dtype=jnp.float8_e4m3fn) + >>> b = random.normal(keys[1], (3, 128, 64), dtype=jnp.float8_e4m3fn) + >>> a_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu) + >>> b_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu) + >>> scaled_matmul(a, b, a_scales, b_scales) """ if not all(x.ndim == 3 for x in (a, b, a_scales, b_scales)): raise ValueError( "scaled_matmul requires all inputs to be 3-dimensional arrays" ) - + B_a, M_a, K_a = a.shape B_b, N_b, K_b = b.shape if K_a != K_b or B_a != B_b: @@ -1286,7 +1286,7 @@ def scaled_matmul( f"and contract (K) dimensions, but got shapes {a.shape} and " f"{b.shape}" ) - + B_as, M_as, K_as = a_scales.shape B_bs, N_bs, K_bs = b_scales.shape if K_as != K_bs or B_as != B_bs: @@ -1295,7 +1295,7 @@ def scaled_matmul( f"contract (K) dimensions, but got shapes {a_scales.shape} and " f"{b_scales.shape}" ) - + if M_as != M_a or N_bs != N_b: raise ValueError( "scaled_matmul requires scales to match non-contract dimensions of " @@ -1378,7 +1378,7 @@ def scaled_dot_general( lhs, rhs, and gradients. Users can obtain valid configurations via `jax.nn.get_scaled_dot_general_config`. Currently, `nvfp4` and `mxfp8` are supported. If `None`, falls back to `lax.dot_general`. - + Returns: Array: The resulting tensor, with batch dimensions first, followed by non-contracting/non-batch dimensions of lhs, and then those of rhs. @@ -1405,6 +1405,7 @@ def scaled_dot_general( Using scaled_dot_general with the configs: + >>> import functools >>> scaled_dot_general_fn = functools.partial(jax.nn.scaled_dot_general, configs=configs) >>> lhs = random.normal(keys[0], (3, 128, 64)) >>> rhs = random.normal(keys[1], (3, 128, 64)) From cb67d5646f94918b9b4dfb2bc742aec698cc62f7 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Thu, 3 Apr 2025 15:54:42 -0700 Subject: [PATCH 365/483] [Mosaic GPU] Re-enable WS pipelined copy test. PiperOrigin-RevId: 743727350 --- tests/pallas/mosaic_gpu_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 8fd98f62eab1..06dfd453fb19 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -2142,7 +2142,6 @@ class WarpSpecializedPipelineTest(PallasTest): @parameterized.product(m=[512], n=[512], manual_consumed_barriers=[False, True]) def test_pipelined_copy(self, m, n, manual_consumed_barriers): - self.skipTest("TODO(justinfu): Temporary skip for 3.12 update.") self.skip_if_wg_semantics() # Times out! x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float16) From 3901014f9ab8451ec50d04f56844b6d56c6a1fd8 Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Thu, 3 Apr 2025 16:06:54 -0700 Subject: [PATCH 366/483] [pallas:mgpu] General ref transform handling at lowering time. Replace `_handle_reshape()` and `_handle_index()` with a general `_handle_transform()` that applies all transforms except tiling and (optionally) transposes. The implementation is based on `_untransform_{transpose,reshape,index}()` transform methods on transforms that find the conjugate of the transpose/reshape/index wrt the transform. PiperOrigin-RevId: 743731515 --- jax/_src/pallas/mosaic_gpu/core.py | 58 ++++++++++++++ jax/_src/pallas/mosaic_gpu/lowering.py | 99 ++++++++++-------------- jax/_src/pallas/mosaic_gpu/primitives.py | 20 +++-- 3 files changed, 112 insertions(+), 65 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index b0d4f23c792e..0a949840ab62 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -36,6 +36,7 @@ from jax._src.state import indexing from jax._src.state import types as state_types import jax.experimental.mosaic.gpu as mgpu +from jax.experimental.mosaic.gpu import utils as mgpu_utils import jax.numpy as jnp from jaxlib.mlir import ir @@ -263,6 +264,29 @@ def transform_shape(self, shape): def transform_dtype(self, dtype): return dtype + def untransform_transpose( + self, perm: tuple[int, ...] + ) -> tuple[tuple[int, ...], state_types.Transform]: + # The transpose in question is applied to the utiled ref so we + # need to translate it by duplicating and offseting the last part. + off = len(perm) + new_suffix = [i + off for i in perm[-len(self.tiling) :]] + if set(new_suffix) != set(range(off, off + len(self.tiling))): + raise ValueError( + "Transpose cannot be moved before a tiling transform when it changes" + f" the set of tiled dimensions. (permutation: {perm}, tiling:" + f" {self.tiling})" + ) + + new_tiling = tuple(self.tiling[i - off] for i in new_suffix) + return (*perm, *new_suffix), dataclasses.replace(self, tiling=new_tiling) + + def untransform_reshape( + self, dtype: jnp.dtype, shape: tuple[int, ...] + ) -> tuple[tuple[int, ...], state_types.Transform]: + del dtype + raise NotImplementedError("Reshapes don't commute with transposes.") + def untransform_index( self, idxs: tuple[Index, ...] ) -> tuple[tuple[Index, ...], state_types.Transform]: @@ -352,6 +376,19 @@ def transform_shape(self, shape): def transform_dtype(self, dtype): return dtype + def untransform_transpose( + self, perm + ) -> tuple[tuple[int, ...], state_types.Transform]: + raise NotImplementedError( + "Commuting of transpose over transpose is not supported." + ) + + def untransform_reshape( + self, dtype: jnp.dtype | ir.Type, shape: tuple[int, ...] + ) -> tuple[tuple[int, ...], state_types.Transform]: + del shape, dtype + raise NotImplementedError("Can't reshape a transposed memref.") + def untransform_index( self, idxs: tuple[Index, ...] ) -> tuple[tuple[Index, ...], state_types.Transform]: @@ -436,6 +473,27 @@ def __call__(self, aval: jax_core.ShapedArray) -> jax_core.ShapedArray: class UnswizzleRef(state_types.Transform): swizzle: int = dataclasses.field(metadata=dict(static=True)) + def swizzle_elems(self, dtype: jnp.dtype | ir.Type) -> int: + if not isinstance(dtype, ir.Type): + dtype = mgpu_utils.dtype_to_ir_type(dtype) + return (self.swizzle * 8) // mgpu.bitwidth(dtype) + + def untransform_transpose(self, perm) -> tuple[tuple[int, ...], state_types.Transform]: + if perm[-1] != len(perm) - 1: + raise ValueError("Can't transpose the swizzled dimension.") + + return perm, self + + def untransform_reshape( + self, dtype: jnp.dtype | ir.Type, shape: tuple[int, ...] + ) -> tuple[tuple[int, ...], state_types.Transform]: + if shape[-1] == self.swizzle_elems(dtype): + raise ValueError( + f"Reshape shape {shape} is not divisible by swizzle elements" + f" {self.swizzle_elems(dtype)}" + ) + return shape, self + def untransform_index( self, idxs: tuple[Index, ...] ) -> tuple[tuple[Index, ...], state_types.Transform]: diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index aafab927d4c2..d8e2083a845c 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1024,61 +1024,49 @@ def _num_programs_lowering_rule(ctx: LoweringRuleContext, axis): gpu_dialect.block_dim(gpu_dialect.Dimension(axis)), ) - -def _handle_reshaping( - ref: ir.Value, transforms: Sequence[gpu_core.Transform] -) -> tuple[ir.Value, Sequence[gpu_core.Transform]]: - is_trivial_indexer = lambda t: isinstance( - t, indexing.NDIndexer - ) and gpu_core.is_trivial_index(t.indices, t.shape) - - last_reshaper_idx = next( - reversed([i for i, t in enumerate(transforms) if isinstance(t, RefReshaper)]), - None, - ) - if last_reshaper_idx is None: - return ref, transforms - # Check that before the reshape are only trivial indexes and or - # other reshapes. - # TODO(cperivol): Reshapes should bubble up rather than being - # expected to effectively be the first ref transform. - if not all(isinstance(t, RefReshaper) or is_trivial_indexer(t) for t in transforms[:last_reshaper_idx]): - raise NotImplementedError( - "Reshapes do not compose with other transforms and indexers must be" - f" trivial (transforms: {transforms})" - ) - reshaper = cast(RefReshaper, transforms[last_reshaper_idx]) - # Skip all the reshapes and trivial indexes. - return mgpu.memref_reshape(ref, reshaper.shape), transforms[last_reshaper_idx + 1:] - - -def _handle_indexing( - ref: ir.Value, transforms: Sequence[gpu_core.Transform] +def _handle_transforms( + ref: ir.Value, + transforms: Sequence[gpu_core.Transform], + *, + handle_transposes=True, + handle_reshapes=True, ) -> tuple[ir.Value, Sequence[gpu_core.Transform]]: - if not transforms: - pass - indexer_idxs = [ - i for i, t in enumerate(transforms) if isinstance(t, indexing.NDIndexer) - ] - if not indexer_idxs: - return ref, transforms - sliced_ref = ref + transformed_ref = ref new_transforms = [] - for t in transforms: - if not isinstance(t, indexing.NDIndexer): - new_transforms.append(t) - continue - indexer = cast(indexing.NDIndexer, t) - if indexer.int_indexer_shape: - raise NotImplementedError("int_indexer_shape non-empty") - indices = _ndindexer_indices(indexer) + def _bubble_up(untransform_fn, data): + nonlocal new_transforms new_transforms_rev = [] for t in reversed(new_transforms): - indices, new_t = t.untransform_index(indices) + data, new_t = untransform_fn(t, data) new_transforms_rev.append(new_t) - sliced_ref = mgpu.memref_slice(sliced_ref, indices) + new_transforms = list(reversed(new_transforms_rev)) - return sliced_ref, new_transforms + return data + + for t in transforms: + match t: + case indexing.NDIndexer(): + indexer = cast(indexing.NDIndexer, t) + if indexer.int_indexer_shape: + raise NotImplementedError("int_indexer_shape non-empty") + indices = _ndindexer_indices(indexer) + indices = _bubble_up( + lambda t, idxs: t.untransform_index(idxs), indices + ) + transformed_ref = mgpu.memref_slice(transformed_ref, indices) + case gpu_core.TransposeRef(perm) if handle_transposes: + perm = _bubble_up(lambda t, p: t.untransform_transpose(p), + perm) + transformed_ref = mgpu.memref_transpose(transformed_ref, perm) + case RefReshaper(dtype=dtype, shape=shape) if handle_reshapes: + shape = _bubble_up( + lambda t, p: t.untransform_reshape(dtype, p), # pylint: disable=cell-var-from-loop + shape) + transformed_ref = mgpu.memref_reshape(transformed_ref, shape) + case _: + new_transforms.append(t) + + return transformed_ref, new_transforms def _ndindexer_indices(indexer: indexing.NDIndexer) -> tuple[gpu_core.Index, ...]: @@ -1120,8 +1108,7 @@ def _get_lowering_rule(ctx: LoweringRuleContext, x_ref, *leaves, tree): x_aval = ctx.avals_in[0] transforms = jax.tree.unflatten(tree, leaves) - x_smem, transforms = _handle_reshaping(x_ref, transforms) - x_smem, transforms = _handle_indexing(x_smem, transforms) + x_smem, transforms = _handle_transforms(x_ref, transforms) match transforms: case (gpu_core.UnswizzleRef(swizzle), gpu_core.UntileRef(tiling)): @@ -1152,8 +1139,7 @@ def _get_lowering_rule_wg(ctx: LoweringRuleContext, x_smem, *leaves, tree): x_aval = ctx.avals_in[0] transforms = jax.tree.unflatten(tree, leaves) - x_smem, transforms = _handle_reshaping(x_smem, transforms) - x_smem, transforms = _handle_indexing(x_smem, transforms) + x_smem, transforms = _handle_transforms(x_smem, transforms) if transforms: raise NotImplementedError( @@ -1180,8 +1166,7 @@ def _swap_lowering_rule( raise TypeError(f"Can only store to references (got {x_smem}).") x_aval = ctx.avals_in[0] transforms = jax.tree.unflatten(tree, leaves) - x_smem, transforms = _handle_reshaping(x_smem, transforms) - x_smem, transforms = _handle_indexing(x_smem, transforms) + x_smem, transforms = _handle_transforms(x_smem, transforms) match transforms: case (gpu_core.UnswizzleRef(swizzle), gpu_core.UntileRef(tiling)): if tiling != (8, swizzle // x_aval.dtype.itemsize): @@ -1227,9 +1212,7 @@ def _swap_lowering_rule_wg( x_aval = ctx.avals_in[0] transforms = jax.tree.unflatten(tree, leaves) - x_smem, transforms = _handle_reshaping(x_smem, transforms) - x_smem, transforms = _handle_indexing(x_smem, transforms) - + x_smem, transforms = _handle_transforms(x_smem, transforms) if transforms: raise NotImplementedError( "Transforms are not yet implemented for warpgroup semantics" diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 46ec8a87082e..fe5319113a03 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -85,8 +85,7 @@ def _load_p_lowering_rule( x_aval = ctx.avals_in[0] transforms = jax.tree.unflatten(args_tree, leaves) - x_ref, transforms = lowering._handle_reshaping(x_ref, transforms) - x_ref, transforms = lowering._handle_indexing(x_ref, transforms) + x_ref, transforms = lowering._handle_transforms(x_ref, transforms) if layout is not None: layout = layout.to_mgpu() @@ -209,7 +208,7 @@ def _copy_smem_to_gmem_lowering( ) src_transforms = src_transforms_treedef.unflatten(flat_src_transforms) dst_transforms = dst_transforms_treedef.unflatten(flat_dst_transforms) - src, src_transforms = lowering._handle_indexing(src, src_transforms) + src, src_transforms = lowering._handle_transforms(src, src_transforms, handle_transposes=False) copy_params = _extract_gmem_copy_params(dst_transforms) | _extract_smem_copy_params(src_transforms) if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: ctx.launch_ctx.async_copy( @@ -382,7 +381,7 @@ def _copy_gmem_to_smem_lowering( ) src_transforms = src_transforms_treedef.unflatten(flat_src_transforms) dst_transforms = dst_transforms_treedef.unflatten(flat_dst_transforms) - dst, dst_transforms = lowering._handle_indexing(dst, dst_transforms) + dst, dst_transforms = lowering._handle_transforms(dst, dst_transforms, handle_transposes=False) copy_params = _extract_smem_copy_params(dst_transforms) | _extract_gmem_copy_params(src_transforms) barrier_indexer = _extract_barrier_indexer( barrier_transforms_treedef.unflatten(flat_barrier_transforms) @@ -743,7 +742,7 @@ def _wgmma_lowering( transforms_leaves, [a_transforms_tree.num_leaves] ) a_transforms = a_transforms_tree.unflatten(a_transforms_leaves) - a, a_transforms = lowering._handle_indexing(a, a_transforms) + a, a_transforms = lowering._handle_transforms(a, a_transforms) match a_transforms: case (gpu_core.UnswizzleRef(lhs_swizzle), gpu_core.UntileRef(tiling)): swizzle_elems = lhs_swizzle // a_aval.dtype.itemsize @@ -760,7 +759,9 @@ def _wgmma_lowering( ) b_transforms = b_transforms_tree.unflatten(b_transforms_leaves) - b, b_transforms = lowering._handle_indexing(b, b_transforms) + b, b_transforms = lowering._handle_transforms( + b, b_transforms, handle_transposes=False, handle_reshapes=False + ) match b_transforms: case (gpu_core.UnswizzleRef(rhs_swizzle), gpu_core.UntileRef(rhs_tiling)): @@ -787,6 +788,8 @@ def _wgmma_lowering( f" {rhs_tiling}." ) + # TODO(cperivol): Find a generic way to move this reshape into + # _handle_transforms. high_dims = [d // t for d, t in util.safe_zip(new_shape, rhs_tiling)] b = mgpu.memref_reshape(b, (*high_dims, *rhs_tiling)) rhs_transpose = False @@ -1107,9 +1110,12 @@ def _jaxpr_call_lowering_rule( for treedef, flat_ref in zip(ref_treedefs, flat_refs): ref = treedef.unflatten(flat_ref) if isinstance(ref, tuple): + ref, transforms = ref # We ignore other transforms here, because they are already embedded # in the jaxpr. - ref, _ = lowering._handle_indexing(*ref) + ref, _ = lowering._handle_transforms( + ref, transforms, handle_reshapes=False, handle_transposes=False + ) args.append(ref) program_ids = program_ids_treedef.unflatten(flat_program_ids) for axis, pid in enumerate(program_ids): From bbdea54ccb1b9338b8aa6932043393551474050e Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Wed, 2 Apr 2025 18:08:39 -0700 Subject: [PATCH 367/483] add an `out_sharding` option to `jax.random.permutation` Drop into `Auto` mode in the implementation. --- jax/_src/random.py | 13 ++++++++++--- tests/pjit_test.py | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/jax/_src/random.py b/jax/_src/random.py index 1d044ec111ff..a21cdf89a61f 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -556,7 +556,8 @@ def _randint(key, minval, maxval, shape, dtype) -> Array: def permutation(key: ArrayLike, x: int | ArrayLike, axis: int = 0, - independent: bool = False) -> Array: + independent: bool = False, + out_sharding=None) -> Array: """Returns a randomly permuted array or range. Args: @@ -573,11 +574,17 @@ def permutation(key: ArrayLike, key, _ = _check_prng_key("permutation", key) check_arraylike("permutation", x) axis = canonicalize_axis(axis, np.ndim(x) or 1) + out_sharding = canonicalize_sharding(out_sharding, "permutation") if not np.ndim(x): if not np.issubdtype(lax.dtype(x), np.integer): raise TypeError("x must be an integer or at least 1-dimensional") - r = core.concrete_or_error(int, x, 'argument x of jax.random.permutation()') - return _shuffle(key, jnp.arange(r), axis) + r = core.concrete_or_error(int, x, "argument x of jax.random.permutation()") + return maybe_auto_axes(lambda key: _shuffle(key, jnp.arange(r), axis), + out_sharding)(key) + return maybe_auto_axes( + _permutation, out_sharding, axis=axis, independent=independent)(key, x) + +def _permutation(key, x, axis, independent): if independent or np.ndim(x) == 1: return _shuffle(key, x, axis) ind = _shuffle(key, jnp.arange(x.shape[axis]), 0) # type: ignore[union-attr] diff --git a/tests/pjit_test.py b/tests/pjit_test.py index d3d9cab7a5ba..580cfcd7ad8d 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7332,6 +7332,45 @@ def f(key): else: self.assertIn('mhlo.sharding = "{devices=[2,2]<=[4]}"}', lowered_text) + @jtu.with_user_mesh((4,), ('x',)) + def test_random_permutation_1d(self, mesh): + @jax.jit + def f(key): + out = jax.random.permutation(key, 8, out_sharding=P('x')) + self.assertEqual(out.aval.sharding.spec, P('x')) + return out + + key = jax.random.key(1) + out = f(key) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + + lowered_text = f.lower(key).as_text() + if config.use_shardy_partitioner.value: + self.assertIn('sdy.sharding_constraint', lowered_text) + self.assertIn('<@mesh, [{"x"}]>', lowered_text) + else: + self.assertIn('mhlo.sharding = "{devices=[4]<=[4]}"}', lowered_text) + + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_random_permutation_2d(self, mesh): + @jax.jit + def f(key): + out = jax.random.permutation(key, jnp.arange(8 * 12).reshape(8, 12), + out_sharding=P('x', 'y')) + self.assertEqual(out.aval.sharding.spec, P('x', 'y')) + return out + + key = jax.random.key(1) + out = f(key) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + + lowered_text = f.lower(key).as_text() + if config.use_shardy_partitioner.value: + self.assertIn('sdy.sharding_constraint', lowered_text) + self.assertIn('<@mesh, [{"x"}, {"y"}]>', lowered_text) + else: + self.assertIn('mhlo.sharding = "{devices=[2,2]<=[4]}"}', lowered_text) + @jtu.with_user_mesh((2, 2), ('x', 'y')) def test_random_normal(self, mesh): @jax.jit From 7583814e35c85b9df55eb6ed65c4559207262f33 Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Thu, 3 Apr 2025 16:28:22 -0700 Subject: [PATCH 368/483] [mgpu:pallas] Changes to allow the use of WGMMA_TRANSPOSED_LAYOUT. It is up to _handle_transposes() to check that the swizzle dimension is not transposed rather than `UnswizzleRef.untransform_transpose()`. This allows us to disable the check in certain situations where mgpu can handle it like wgmma and swap_p when storing a WGMMA_TRANSPOSED_LAYOUT. If this check is completely skipped it can cause the kernel to crash at runtime. Furthermore this CL adds a test to check this behavior. PiperOrigin-RevId: 743738166 --- jax/_src/pallas/mosaic_gpu/lowering.py | 30 ++++++++++++++++++-- tests/pallas/mosaic_gpu_test.py | 38 ++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 3 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index d8e2083a845c..80757ef69e64 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1166,13 +1166,37 @@ def _swap_lowering_rule( raise TypeError(f"Can only store to references (got {x_smem}).") x_aval = ctx.avals_in[0] transforms = jax.tree.unflatten(tree, leaves) - x_smem, transforms = _handle_transforms(x_smem, transforms) + transposed_value = value.layout == mgpu.WGMMA_TRANSPOSED_LAYOUT + x_smem, transforms = _handle_transforms( + x_smem, transforms, handle_transposes=not transposed_value + ) match transforms: - case (gpu_core.UnswizzleRef(swizzle), gpu_core.UntileRef(tiling)): + case ( + gpu_core.UnswizzleRef(swizzle), + gpu_core.UntileRef(tiling), + *maybe_transpose, + ): if tiling != (8, swizzle // x_aval.dtype.itemsize): raise NotImplementedError("Tiling does not fit swizzle") + + if transposed_value != bool(maybe_transpose): + raise ValueError( + "Either both the ref and the value are transposed or neither is." + ) + + if maybe_transpose: + if maybe_transpose != [gpu_core.TransposeRef((1, 0))]: + raise NotImplementedError( + f"Unsupported transforms: {transforms} ({maybe_transpose})" + ) + + x_smem = mgpu.memref_transpose(x_smem, (1, 0, 3, 2)) + old_value = mgpu.FragmentedArray.load_tiled( - x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype), swizzle=swizzle + x_smem, + is_signed=mgpu_utils.is_signed(x_aval.dtype), + swizzle=swizzle, + layout=value.layout, ) value.store_tiled(x_smem, swizzle=swizzle) return old_value diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 06dfd453fb19..754e53255438 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1372,6 +1372,44 @@ def kernel(o_ref): x = jnp.full(shape, 42.0, jnp.float32) np.testing.assert_array_equal(kernel(), x) + def test_wgmma_transposed_layout(self): + """Tests that the result of wgmma can be store transposed using + the WGMMA_TRNASPOSED layout. + """ + + dtype = jnp.dtype(jnp.float16) + swizzle_elems = 128 // dtype.itemsize + shape = (128, 128) + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct(shape, dtype), + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + scratch_shapes=[ + plgpu.SMEM( + shape, dtype, + transforms=( + plgpu.TilingTransform((8, swizzle_elems)), + plgpu.SwizzleTransform(128), + ), + ) + ] + ) + def kernel(o_ref, smem): + iota = plgpu.broadcasted_iota( + dtype, o_ref.shape, 0, layout=plgpu.Layout.WGMMA + ) * o_ref.shape[0] + iota += plgpu.broadcasted_iota( + dtype, o_ref.shape, 1, layout=plgpu.Layout.WGMMA + ) + + smem_trns = plgpu.transpose_ref(smem, (1, 0)) + smem_trns[...] = plgpu.layout_cast(iota, plgpu.Layout.WGMMA_TRANSPOSED) + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(smem, o_ref) + + x = jnp.arange(128 * 128, dtype=dtype).reshape((128, 128)).T + np.testing.assert_array_equal(kernel(), x) + def test_profiler(self): self.skip_if_wg_semantics() # Transform inference fails. From 26fc1cde4cdb593239f796a37834645184ac10fb Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Thu, 3 Apr 2025 17:18:18 -0700 Subject: [PATCH 369/483] [pallas:mgpu] Initial version of inline_mgpu op PiperOrigin-RevId: 743751560 --- jax/_src/pallas/mosaic_gpu/primitives.py | 94 +++++++++++++++++++++++- jax/experimental/pallas/mosaic_gpu.py | 2 + tests/pallas/mosaic_gpu_test.py | 34 +++++++++ 3 files changed, 129 insertions(+), 1 deletion(-) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index fe5319113a03..37d71cd6d1c6 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -16,7 +16,7 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Sequence, Callable import dataclasses import enum import itertools @@ -1218,3 +1218,95 @@ def jaxpr_call( ref_treedefs=ref_treedefs, program_ids_treedef=program_ids_treedef, ) + +inline_mgpu_p = jax_core.Primitive("inline_mgpu_p") +inline_mgpu_p.multiple_results = True + + +@dataclasses.dataclass(frozen=True) +class RefType: + ... + + +def inline_mgpu(*args, arg_types): + flat_args, treedef = jax.tree.flatten(tuple(args)) + flat_types, treedef_ty = jax.tree.flatten(tuple(arg_types)) + if treedef != treedef_ty: + raise ValueError(f"Mismatched type shape: {treedef} != {treedef_ty}") + + # Strip the transforms from the refs since they will be recorded in + # the types. + raw_refs_flat_args = [] + for a, t in zip(flat_args, flat_types): + def traced_ty(ty): + return isinstance(a, jax_core.Tracer) and isinstance(a.aval, ty) + + if isinstance(t, (ParameterizedLayout, Layout)) and traced_ty(jax_core.ShapedArray): + raw_refs_flat_args.append(a) + elif isinstance(t, RefType) and traced_ty(_Ref): + ref, transforms = a, () + if isinstance(a, state_types.TransformedRef): + ref, transforms = ref.ref, ref.transforms + + raw_refs_flat_args.append(ref) + if transforms: + raise NotImplementedError("Transformed refs (or types) are not supported.") + else: + raise ValueError(f"Mismatched type: {a, t}") + + def inner(f): + return inline_mgpu_p.bind( + *raw_refs_flat_args, + args_treedef=treedef, + flat_types=flat_types, + mgpu_fn=f, + ) + return inner + + +@inline_mgpu_p.def_effectful_abstract_eval +def _inline_mgpu_abstract_eval( + *flat_args, + args_treedef, + flat_types, + mgpu_fn, +): + del args_treedef, flat_types, mgpu_fn # Unused. + # TODO(cperivol): Let the user set the effects. + return (), { + gpu_core._wgmma_pipeline_effect, + gpu_core._memory_effect, + *itertools.chain.from_iterable( + (state.ReadEffect(i), state.WriteEffect(i)) + for i, r in enumerate(flat_args) + if isinstance(r, pallas_core.AbstractMemoryRef) + ), + } + + +@discharge.register_partial_discharge_rule(inline_mgpu_p) +def _inline_mgpu_discharge(*args, **kwargs): + raise NotImplementedError("inline_mgpu_p does not support discharge.") + +@lowering.register_lowering_rule(inline_mgpu_p, mgpu.ThreadSemantics.Lane) +def _inline_mgpu_lowering_rule( + ctx: lowering.LoweringRuleContext, + *flat_args, + mgpu_fn: Callable[..., Any], + flat_types, + args_treedef, +): + for a, t in zip(flat_args, flat_types, strict=True): + match a: + case ir.Value() if ir.MemRefType.isinstance(a.type): + # We checked the memory spaces at tracing time. + pass + case mgpu.FragmentedArray(): + if a.layout != t.to_mgpu(): + raise ValueError(f"Unexpected layout for {a} (expected: {t})") + case _: + raise ValueError(f"Unexpected argument {a}") + + args = jax.tree.unflatten(args_treedef, flat_args) + mgpu_fn(ctx.launch_ctx, *args) + return () diff --git a/jax/experimental/pallas/mosaic_gpu.py b/jax/experimental/pallas/mosaic_gpu.py index b44c86ea7a4c..1d3bebbc3757 100644 --- a/jax/experimental/pallas/mosaic_gpu.py +++ b/jax/experimental/pallas/mosaic_gpu.py @@ -43,9 +43,11 @@ from jax._src.pallas.mosaic_gpu.primitives import commit_smem_to_gmem_group as commit_smem_to_gmem_group from jax._src.pallas.mosaic_gpu.primitives import copy_gmem_to_smem as copy_gmem_to_smem from jax._src.pallas.mosaic_gpu.primitives import copy_smem_to_gmem as copy_smem_to_gmem +from jax._src.pallas.mosaic_gpu.primitives import inline_mgpu as inline_mgpu from jax._src.pallas.mosaic_gpu.primitives import Layout as Layout from jax._src.pallas.mosaic_gpu.primitives import layout_cast as layout_cast from jax._src.pallas.mosaic_gpu.primitives import load as load +from jax._src.pallas.mosaic_gpu.primitives import RefType as RefType from jax._src.pallas.mosaic_gpu.primitives import set_max_registers as set_max_registers from jax._src.pallas.mosaic_gpu.primitives import wait_smem_to_gmem as wait_smem_to_gmem from jax._src.pallas.mosaic_gpu.primitives import wgmma as wgmma diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 754e53255438..d35446359756 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -33,6 +33,7 @@ from jax._src.pallas.mosaic_gpu import primitives as mgpu_primitives from jax._src.state import discharge from jax.experimental import pallas as pl +import jax.experimental.mosaic.gpu as mgpu from jax.experimental.pallas import mosaic_gpu as plgpu import jax.numpy as jnp import numpy as np @@ -369,6 +370,38 @@ def kernel(o_ref): kernel(), jax.lax.broadcasted_iota(dtype, (128, 128), dimension) ) + def test_inline_mgpu(self): + dtype = jnp.bfloat16 + self.skip_if_wg_semantics() + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((128, 128), dtype), + in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), + scratch_shapes=[ + plgpu.SMEM((128, 128), dtype), + plgpu.Barrier(num_arrivals=1), + ], + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + ) + def kernel(x_ref, o_ref, smem_ref, barrier): + plgpu.copy_gmem_to_smem(x_ref, smem_ref, barrier) + plgpu.barrier_wait(barrier) + arr = jnp.ones_like(x_ref) + @plgpu.inline_mgpu( + smem_ref, + o_ref, + arr, + arg_types=[plgpu.RefType(), plgpu.RefType(), plgpu.Layout.WG_SPLAT(x_ref.shape)], + ) + def _(ctx, smem_ref, o_ref, y): + del ctx + x = mgpu.FragmentedArray.load_strided(smem_ref) + (x + y).store_untiled(o_ref) + + key = jax.random.key(0) + x = (jax.random.uniform(key, (128, 128)) * 42).astype(dtype) + np.testing.assert_array_equal(kernel(x), x + 1) + @parameterized.product(indexer=[..., slice(128), slice(None, 128)]) def test_copy_smem_to_gmem(self, indexer): @functools.partial( @@ -1506,6 +1539,7 @@ def test_missing_primitive_lowerings_are_tracked(self): actual_missing_primitives = lane_lowered_primitives - wg_lowered_primitives expected_missing_primitives = { + mgpu_primitives.inline_mgpu_p, mgpu_primitives.broadcasted_iota_p, mgpu_primitives.layout_cast_p, mgpu_primitives.load_p, From d645172765886a0df06a3a7f58393b313681572f Mon Sep 17 00:00:00 2001 From: Zac Mustin Date: Thu, 3 Apr 2025 17:18:20 -0700 Subject: [PATCH 370/483] Delete `PjRtClient.Defragment`. The `Defragment` implementation for GPU is in `py_client.cc`, so this should be a no-op. PiperOrigin-RevId: 743751570 --- jaxlib/xla/py_client.cc | 7 ++++++- jaxlib/xla/xla_extension/__init__.pyi | 1 - tests/array_test.py | 13 ++++++++----- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/jaxlib/xla/py_client.cc b/jaxlib/xla/py_client.cc index 1e41d9cf8a0d..2ce11e7e76c7 100644 --- a/jaxlib/xla/py_client.cc +++ b/jaxlib/xla/py_client.cc @@ -204,9 +204,14 @@ absl::Status PyClient::Defragment() { platform_id == SyclId(); if (!is_gpu_client) { - return pjrt_client()->Defragment(); + return absl::UnimplementedError( + "Defragmentation is not supported on this runtime."); } + // TODO(b/399879011): This is a GPU-specific implementation of `Defragment`. + // Ideally, this would be replaced with some kind of auto-defrag-on-OOM, or at + // least would not live in this file. + struct TmpBuffer { // Non-empty for buffers found in a PyArray_Storage. Multiple Arrays // can reference the same PjRtBuffer. diff --git a/jaxlib/xla/xla_extension/__init__.pyi b/jaxlib/xla/xla_extension/__init__.pyi index 2d759236b8c5..3fe2de1e30a3 100644 --- a/jaxlib/xla/xla_extension/__init__.pyi +++ b/jaxlib/xla/xla_extension/__init__.pyi @@ -579,7 +579,6 @@ class Client: host_callbacks: Sequence[Any] = ..., ) -> LoadedExecutable: ... def heap_profile(self) -> bytes: ... - def defragment(self) -> _Status: ... def make_python_callback_from_host_send_and_recv( self, callable: Callable, diff --git a/tests/array_test.py b/tests/array_test.py index 5891db5a3e36..901ce9521da1 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -655,12 +655,15 @@ def f(x): output_shardings._to_xla_hlo_sharding(x_dummy.ndim), s._to_xla_hlo_sharding(x_dummy.ndim))) - # TODO(skyewm): remove this test when we can remove the workaround manual - # defragment API - @jtu.skip_on_devices('cpu') # defragment not implemented for TFRT CPU + # TODO(b/399879011): GPU is the only platform that has an implementation for + # this, which exists in py_client.cc. Ideally, this would be replaced with + # some kind of auto-defrag-on-OOM. + @jtu.run_on_devices('gpu') def test_defragment(self): + # Since the GPU implementation is in py_client.cc, it cannot be exposed via + # the PjRt C API. if xb.using_pjrt_c_api(): - self.skipTest("Manual defragment not exposed via PJRT C API") + self.skipTest('Manual defragment not exposed via PJRT C API') # Create a few arrays global_mesh = jtu.create_mesh((jax.local_device_count(),), ('x',)) @@ -673,7 +676,7 @@ def test_defragment(self): # Delete one of them arr2.delete() - # Defragment + # Defragment. xb.get_backend().defragment() # Sanity check remaining arrays From f8bbe98a860acd0d16ea0288f10839f7a0ed2d1d Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Thu, 3 Apr 2025 17:25:22 -0700 Subject: [PATCH 371/483] require `out_shardings` as a keyword-only argument on public functions PiperOrigin-RevId: 743753215 --- jax/_src/lax/lax.py | 12 +++++++----- jax/_src/random.py | 5 +++++ 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index ac6054328f73..13511641558c 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -2460,6 +2460,7 @@ def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None, def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionNumbers, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, + *, out_sharding=None) -> Array: """General dot product/contraction operator. @@ -2667,7 +2668,7 @@ def ragged_dot_general( ) -def broadcast(operand: ArrayLike, sizes: Sequence[int], out_sharding=None +def broadcast(operand: ArrayLike, sizes: Sequence[int], *, out_sharding=None ) -> Array: """Broadcasts an array, adding new leading dimensions @@ -2689,7 +2690,7 @@ def broadcast(operand: ArrayLike, sizes: Sequence[int], out_sharding=None out_sharding=out_sharding) def broadcast_in_dim(operand: ArrayLike, shape: Shape, - broadcast_dimensions: Sequence[int], out_sharding=None + broadcast_dimensions: Sequence[int], *, out_sharding=None ) -> Array: """Wraps XLA's `BroadcastInDim `_ @@ -2732,7 +2733,7 @@ def broadcast_to_rank(x: ArrayLike, rank: int) -> Array: def reshape(operand: ArrayLike, new_sizes: Shape, dimensions: Sequence[int] | None = None, - out_sharding: NamedSharding | P | None = None) -> Array: + *, out_sharding: NamedSharding | P | None = None) -> Array: """Wraps XLA's `Reshape `_ operator. @@ -3378,7 +3379,7 @@ def iota(dtype: DTypeLike, size: int) -> Array: return broadcasted_iota(dtype, (size,), 0) def broadcasted_iota(dtype: DTypeLike, shape: Shape, dimension: int, - out_sharding=None) -> Array: + *, out_sharding=None) -> Array: """Convenience wrapper around ``iota``.""" dtype = dtypes.canonicalize_dtype(dtype) shape = canonicalize_shape(shape) @@ -8430,7 +8431,8 @@ def _propagate_mem_kind_copy(in_mem_kind): pxla.memory_kind_propagate_rule[copy_p] = _propagate_mem_kind_copy def rng_bit_generator(key, shape, dtype=np.uint32, - algorithm=RandomAlgorithm.RNG_DEFAULT, out_sharding=None): + algorithm=RandomAlgorithm.RNG_DEFAULT, + *, out_sharding=None): """Stateless PRNG bit generator. Experimental and its use is discouraged. Returns uniformly distributed random bits with the specified shape and dtype diff --git a/jax/_src/random.py b/jax/_src/random.py index a21cdf89a61f..5cbd966e7a7b 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -359,6 +359,7 @@ def maybe_auto_axes(f, out_shardings, **hoist_kwargs): def bits(key: ArrayLike, shape: Shape = (), dtype: DTypeLikeUInt | None = None, + *, out_sharding=None) -> Array: """Sample uniform bits in the form of unsigned integers. @@ -393,6 +394,7 @@ def uniform(key: ArrayLike, dtype: DTypeLikeFloat = float, minval: RealArray = 0., maxval: RealArray = 1., + *, out_sharding=None) -> Array: """Sample uniform random values in [minval, maxval) with given shape/dtype. @@ -466,6 +468,7 @@ def randint(key: ArrayLike, minval: IntegerArray, maxval: IntegerArray, dtype: DTypeLikeInt = int, + *, out_sharding=None) -> Array: """Sample uniform random values in [minval, maxval) with given shape/dtype. @@ -557,6 +560,7 @@ def permutation(key: ArrayLike, x: int | ArrayLike, axis: int = 0, independent: bool = False, + *, out_sharding=None) -> Array: """Returns a randomly permuted array or range. @@ -707,6 +711,7 @@ def choice(key: ArrayLike, def normal(key: ArrayLike, shape: Shape = (), dtype: DTypeLikeFloat = float, + *, out_sharding=None) -> Array: r"""Sample standard normal random values with given shape and float dtype. From 5b3e419515404de8650c169cb27ff00b1fb53340 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 3 Apr 2025 18:34:35 -0700 Subject: [PATCH 372/483] Add `auto_axes`, `explicit_axes` and `manual_axes` properties to Mesh and AbstractMesh PiperOrigin-RevId: 743767895 --- jax/_src/mesh.py | 15 +++++++++++++++ tests/array_test.py | 3 +++ 2 files changed, 18 insertions(+) diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index a8003e693459..00859f9b3d74 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -174,6 +174,21 @@ def _any_axis_auto(self) -> bool: def _any_axis_explicit(self) -> bool: return any_axis_types_match(self._axis_types, AxisType.Explicit) + @functools.cached_property + def auto_axes(self): + return tuple(n for n, t in safe_zip(self.axis_names, self._axis_types) + if t == AxisType.Auto) + + @functools.cached_property + def explicit_axes(self): + return tuple(n for n, t in safe_zip(self.axis_names, self._axis_types) + if t == AxisType.Explicit) + + @functools.cached_property + def manual_axes(self): + return tuple(n for n, t in safe_zip(self.axis_names, self._axis_types) + if t == AxisType.Manual) + @functools.cached_property def _axis_types_dict(self): if not self.axis_names: diff --git a/tests/array_test.py b/tests/array_test.py index 901ce9521da1..2bdc54607473 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -1391,6 +1391,9 @@ def test_make_mesh_axis_types(self): self.assertDictEqual( mesh._axis_types_dict, {AxisType.Auto: ('y',), AxisType.Explicit: ('x',), AxisType.Manual: ('z',)}) + self.assertEqual(mesh.explicit_axes, ('x',)) + self.assertEqual(mesh.auto_axes, ('y',)) + self.assertEqual(mesh.manual_axes, ('z',)) mesh = jax.make_mesh((1, 1, 1), ('x', 'y', 'z'), axis_types=(Explicit, Explicit, Manual)) From c1bdd1a234ae6fa6c650426e1a4a3c04851e82da Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Thu, 3 Apr 2025 19:39:03 -0700 Subject: [PATCH 373/483] [Mosaic TPU] Allow specify priority in enqueueDMA. For now we only support priority 0 (on-demand thread) and priority 1 (background thread) on local DMA. PiperOrigin-RevId: 743780185 --- jax/BUILD | 1 + jax/_src/tpu_custom_call.py | 14 ++++++++++++-- jaxlib/mosaic/dialect/tpu/tpu.td | 4 +++- jaxlib/mosaic/dialect/tpu/tpu_ops.cc | 13 ++++++++++++- jaxlib/mosaic/dialect/tpu/transforms/serde.cc | 10 +++++++++- 5 files changed, 37 insertions(+), 5 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index f5745df0e5bf..fe2e6b8d7df1 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -1073,6 +1073,7 @@ pytype_strict_library( srcs = ["_src/tpu_custom_call.py"], visibility = [":internal"], deps = [ + ":cloud_tpu_init", ":config", ":core", ":jax", diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index e37d5e064a26..f84db206f4d1 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -32,6 +32,7 @@ from jax._src import config from jax._src import core from jax._src import sharding_impls +from jax._src.cloud_tpu_init import is_cloud_tpu_older_than from jax._src.interpreters import mlir from jax._src.lib import tpu from jax._src.lib import xla_client @@ -64,7 +65,14 @@ # This tracks the latest Mosaic IR version with a monthly delay. -FWD_COMPAT_IR_VERSION = 3 +FWD_COMPAT_IR_VERSION = 4 +DEFAULT_IR_VERSION = None +# TODO(jevinjiang): Remove this once both jaxlib and libtpu are up to date. +if is_cloud_tpu_older_than(2025, 4, 5) or jax.version._version_as_tuple( + jax.lib.__version__ +) < (0, 5, 4): + FWD_COMPAT_IR_VERSION = 3 + DEFAULT_IR_VERSION = 3 tpu_custom_call_p = core.Primitive("tpu_custom_call") @@ -671,7 +679,9 @@ def lower_module_to_custom_call( serialization_format=serialization_format, output_memory_spaces=output_memory_spaces, kernel_name=kernel_name, - ir_version=FWD_COMPAT_IR_VERSION if ctx.is_forward_compat() else None, + ir_version=FWD_COMPAT_IR_VERSION + if ctx.is_forward_compat() + else DEFAULT_IR_VERSION, ) return _tpu_custom_call_lowering( ctx, diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 4b5ed34934d7..0cd045621413 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -752,7 +752,9 @@ def TPU_EnqueueDMAOp : TPU_Op<"enqueue_dma", [AttrSizedOperandSegments]> { AnyMemRef:$target, MemRefOf<[TPU_DMASemaphoreType]>:$target_semaphore, Optional:$device_id, // For remote DMAs - Optional:$core_id // For megacore + Optional:$core_id, // For megacore + // Smaller number means higher priority. 0 is the highest and the default. + DefaultValuedAttr:$priority ); let hasVerifier = 1; } diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index 41342efeb1b4..5ed5e94b13c0 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -955,13 +955,24 @@ LogicalResult EnqueueDMAOp::verify() { "device_id or core_id is specified"); } } + bool is_remote = getDeviceId() || getCoreId(); if (getSourceSemaphore()) { - if (!getDeviceId() && !getCoreId()) { + if (!is_remote) { return emitOpError( "DMA destination device_id or core_id must be specified when source " "semaphore is specified"); } } + int priority = getPriority(); + if (priority < 0 || priority > 1) { + return emitOpError( + "Not implemented: only support priority 0 or 1, but got ") + << priority; + } + if (priority != 0 && is_remote) { + return emitOpError( + "Not implemented: non-zero priority is not supported for remote DMA"); + } return success(); } diff --git a/jaxlib/mosaic/dialect/tpu/transforms/serde.cc b/jaxlib/mosaic/dialect/tpu/transforms/serde.cc index 5f6c9bd712ff..e08149fe44fc 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/serde.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/serde.cc @@ -40,7 +40,7 @@ constexpr StringRef kMangledDialect = "stable_mosaic."; constexpr StringRef kVersionAttrName = "stable_mosaic.version"; // When this is bumped, we should file a TODO to update the forward-compatible // version in tpu_custom_call.py in a month! -constexpr int kVersion = 3; +constexpr int kVersion = 4; using SerdeRuleType = jaxlib::mosaic::SerdeRuleType; @@ -62,6 +62,11 @@ LogicalResult enqueue_dma_upgrade(Operation* op, int version) { << op->getNumOperands(); } } + if (version < 4) { + op->setAttr("priority", + mlir::IntegerAttr::get( + mlir::IntegerType::get(op->getContext(), 32), 0)); + } return success(); } @@ -69,6 +74,9 @@ LogicalResult enqueue_dma_downgrade(Operation* op, int version) { if (version < 2) { return op->emitError("Downgrade to version ") << version << " unsupported"; } + if (version < 4) { + op->removeAttr("priority"); + } return success(); } From a9bd1e3f9df474e769210d78f86ae829544c0e7b Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Thu, 3 Apr 2025 22:25:46 -0700 Subject: [PATCH 374/483] [Pallas TPU] Support DMA priority in async copy start For now, we can only specify priority 0 (on-demand) or priority 1 (background) in local DMA. Also added priority to pretty print by making `dma_start` to `dma_start(px)` which means priority x. Full example: ``` { lambda ; a:MemRef{int32[8,128]} b:MemRef{int32[8,128]} c:MemRef{int32[8,128]} d:MemRef{int32[8,128]} e:MemRef{int32[8,128]} f:MemRef{int32[8,128]} g:MemRef{dma_sem[]} h:MemRef{dma_sem[]}. let dma_start(p1) a[...] -> e[...] g[...] dma_start(p0) b[...] -> f[...] h[...] dma_wait e[...] g[...] dma_wait f[...] h[...] dma_start(p0) e[...] -> c[...] g[...] dma_start(p1) f[...] -> d[...] h[...] dma_wait c[...] g[...] dma_wait d[...] h[...] in () } ``` PiperOrigin-RevId: 743815050 --- jax/_src/pallas/mosaic/lowering.py | 23 ++++++++++++---- jax/_src/pallas/mosaic/primitives.py | 32 ++++++++++++++++------ tests/pallas/tpu_pallas_test.py | 41 +++++++++++++++++++++++++--- 3 files changed, 79 insertions(+), 17 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 6c8b3c646a0d..0ca298c88dc5 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -3531,8 +3531,14 @@ def _semaphore_wait_lowering_rule(ctx: LoweringRuleContext, *args, args_tree): return [] lowering_rules[primitives.semaphore_wait_p] = _semaphore_wait_lowering_rule -def _dma_start_lowering_rule(ctx: LoweringRuleContext, *args, tree, - device_id_type: primitives.DeviceIdType): + +def _dma_start_lowering_rule( + ctx: LoweringRuleContext, + *args, + tree, + device_id_type: primitives.DeviceIdType, + priority: int, +): ( src_ref, src_transforms, @@ -3564,10 +3570,17 @@ def _dma_start_lowering_rule(ctx: LoweringRuleContext, *args, tree, sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, sem_transforms) if device_id is not None: device_id = _device_id_to_logical(ctx, device_id, device_id_type) - tpu.enqueue_dma(src_ref, dst_ref, sem, source_semaphore=src_sem, - device_id=device_id) - + tpu.enqueue_dma( + src_ref, + dst_ref, + sem, + source_semaphore=src_sem, + device_id=device_id, + priority=priority, + ) return [] + + lowering_rules[tpu_primitives.dma_start_p] = _dma_start_lowering_rule diff --git a/jax/_src/pallas/mosaic/primitives.py b/jax/_src/pallas/mosaic/primitives.py index c50a21218117..59856c0ca7b2 100644 --- a/jax/_src/pallas/mosaic/primitives.py +++ b/jax/_src/pallas/mosaic/primitives.py @@ -208,9 +208,14 @@ def _get_args_and_tree(self, swap_src_and_dst: bool = False): self.device_id, )) - def start(self): + def start(self, priority: int = 0): flat_args, tree = self._get_args_and_tree() - dma_start_p.bind(*flat_args, tree=tree, device_id_type=self.device_id_type) + dma_start_p.bind( + *flat_args, + tree=tree, + device_id_type=self.device_id_type, + priority=priority, + ) def wait(self): if self.is_remote: @@ -239,7 +244,9 @@ def wait_send(self): dma_start_p.multiple_results = True @dma_start_p.def_effectful_abstract_eval -def _dma_start_abstract_eval(*args, tree, device_id_type): +def _dma_start_abstract_eval(*args, tree, device_id_type, priority): + if priority < 0: + raise ValueError(f"DMA start priority must be non-negative: {priority}") ( src_ref_aval, src_transforms_avals, @@ -274,6 +281,7 @@ def _dma_start_pp_eqn(eqn: jax_core.JaxprEqn, settings: jax_core.JaxprPpSettings): invars = eqn.invars tree = eqn.params["tree"] + priority = eqn.params["priority"] ( src_ref, src_transforms, @@ -290,7 +298,7 @@ def _dma_start_pp_eqn(eqn: jax_core.JaxprEqn, if src_sem or device_id: return jax_core._pp_eqn(eqn, context, settings) return pp.concat([ - pp.text("dma_start"), + pp.text(f"dma_start(p{priority})"), pp.text(" "), sp.pp_ref_transforms(context, src_ref, src_transforms), pp.text(" -> "), @@ -301,8 +309,12 @@ def _dma_start_pp_eqn(eqn: jax_core.JaxprEqn, jax_core.pp_eqn_rules[dma_start_p] = _dma_start_pp_eqn -def dma_start_partial_discharge_rule(should_discharge, in_avals, out_avals, - *args, tree, device_id_type): + +def dma_start_partial_discharge_rule( + should_discharge, in_avals, out_avals, *args, tree, device_id_type, priority +): + # Note: we ignore the DMA priority in discharge rules. + del priority ( src_ref, src_transforms, @@ -461,6 +473,7 @@ def do_discharge_src_sem(src_sem=src_sem): return new_vals, [] + state_discharge.register_partial_discharge_rule(dma_start_p)(dma_start_partial_discharge_rule) @@ -550,6 +563,7 @@ def _get_ref_and_transforms(ref): return ref.ref, ref.transforms return ref, () + def make_async_copy(src_ref, dst_ref, sem): """Issues a DMA copying from src_ref to dst_ref.""" src_ref, src_transforms = _get_ref_and_transforms(src_ref) @@ -568,12 +582,14 @@ def make_async_copy(src_ref, dst_ref, sem): primitives.DeviceIdType.MESH, ) -def async_copy(src_ref, dst_ref, sem): + +def async_copy(src_ref, dst_ref, sem, *, priority: int = 0): """Issues a DMA copying from src_ref to dst_ref.""" copy_descriptor = make_async_copy(src_ref, dst_ref, sem) - copy_descriptor.start() + copy_descriptor.start(priority=priority) return copy_descriptor + def make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id, device_id_type: primitives.DeviceIdType = primitives.DeviceIdType.MESH): """Creates a description of a remote copy operation. diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index 128fe50687a0..2e773b88fbad 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -1135,6 +1135,39 @@ def kernel(x_hbm_ref, y_hbm_ref, sem_val_ref, dma_sem): np.testing.assert_array_equal(y, x) np.testing.assert_array_equal(sem_val, 0) + def test_set_dma_priority(self): + if not jtu.if_cloud_tpu_at_least(2025, 4, 5): + self.skipTest('Needs a newer libTPU') + if jtu.get_tpu_version() < 5: + self.skipTest('Target does not support DMA prefetch between HBM and VMEM') + def kernel(x1, x2, y1, y2, scratch1, scratch2, sem1, sem2): + copy1 = pltpu.async_copy(x1, scratch1, sem1, priority=1) + copy2 = pltpu.async_copy(x2, scratch2, sem2, priority=0) + copy1.wait() + copy2.wait() + copy1 = pltpu.async_copy(scratch1, y1, sem1, priority=0) + copy2 = pltpu.async_copy(scratch2, y2, sem2, priority=1) + copy1.wait() + copy2.wait() + + shape = (8, 128) + dtype = jnp.int32 + x1 = jnp.arange(np.prod(shape), dtype=dtype).reshape(shape) + x2 = x1 + 1 + y1, y2 = self.pallas_call( + kernel, + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[pl.BlockSpec(memory_space=pl.ANY)] * 2, + scratch_shapes=[pltpu.VMEM(shape, dtype)] * 2 + + [pltpu.SemaphoreType.DMA] * 2, + out_specs=[pl.BlockSpec(memory_space=pl.ANY)] * 2, + ), + out_shape=[jax.ShapeDtypeStruct(shape, dtype)] * 2, + )(x1, x2) + np.testing.assert_array_equal(y1, x1) + np.testing.assert_array_equal(y2, x2) + def test_hbm_hbm_dma(self): def kernel(x_hbm_ref, y_hbm_ref): def body(sem): @@ -2665,19 +2698,19 @@ class PrettyPrintingTest(PallasBaseTest): @parameterized.parameters( ( lambda i: (i, pl.ds(0, 8), pl.ds(0, 128)), - 'dma_start c[d,:,:] -> e[...] f', + 'dma_start(p0) c[d,:,:] -> e[...] f', ), ( lambda i: (0, pl.ds(i, 8), pl.ds(0, 128)), - 'dma_start c[0,d:d+8,:] -> e[...] f', + 'dma_start(p0) c[0,d:d+8,:] -> e[...] f', ), ( lambda i: (i, pl.ds(2, 4), pl.ds(0, 100)), - 'dma_start c[d,2:6,:100] -> e[...] f', + 'dma_start(p0) c[d,2:6,:100] -> e[...] f', ), ( lambda i: (i, pl.ds(2, 6), pl.ds(4, 100)), - 'dma_start c[d,2:,4:104] -> e[...] f', + 'dma_start(p0) c[d,2:,4:104] -> e[...] f', ), ) def test_dma_custom_pretty_print(self, indexer, expected): From 4f00249aa8bff45b379f76304e79e293273f9ad6 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 3 Apr 2025 22:30:45 -0700 Subject: [PATCH 375/483] [pallas:mosaic_gpu] Do not specify the default `index_map` in tests PiperOrigin-RevId: 743816110 --- tests/pallas/mosaic_gpu_test.py | 108 ++++++++++---------------------- 1 file changed, 32 insertions(+), 76 deletions(-) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index d35446359756..0cfe9197db36 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -644,8 +644,6 @@ def kernel(x_ref, o_ref, barrier_ref): in_spec = pl.BlockSpec(memory_space=plgpu.GMEM) out_spec = plgpu.GPUBlockSpec( - (128, 128), - lambda: (0, 0), transforms=( plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(128), @@ -676,9 +674,7 @@ def body(tmp_ref): pl.run_scoped(body, plgpu.SMEM((128, 128), jnp.float32, transforms=ts)) in_spec = pl.BlockSpec(memory_space=plgpu.GMEM) - out_spec = plgpu.GPUBlockSpec( - (128, 128), lambda: (0, 0), transforms=ts, memory_space=plgpu.SMEM, - ) + out_spec = plgpu.GPUBlockSpec(transforms=ts, memory_space=plgpu.SMEM) f = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct([128, 128], jnp.float32), @@ -719,8 +715,6 @@ def kernel(x_ref, o_ref, barrier_ref): in_spec = pl.BlockSpec(memory_space=plgpu.GMEM) out_spec = plgpu.GPUBlockSpec( - (2, 128, 128), - lambda: (0, 0, 0), transforms=( plgpu.TilingTransform((8, 32)), plgpu.TransposeTransform((0, 2, 1, 3, 4)), @@ -750,11 +744,7 @@ def test_load_to_strided_layout_with_indexing(self, src_memory_space, layout): self.pallas_call, out_shape=jax.ShapeDtypeStruct([2, 128], jnp.float32), in_specs=[pl.BlockSpec(memory_space=src_memory_space)], - out_specs=plgpu.GPUBlockSpec( - (2, 128), - lambda: (0, 0), - memory_space=plgpu.SMEM, - ), + out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.SMEM), ) def kernel(x_ref, o_ref): for i in range(2): @@ -776,11 +766,7 @@ def test_load_to_wgmma_row_col_layout_with_indexing(self, src_memory_space, layo self.pallas_call, out_shape=jax.ShapeDtypeStruct([2, m], jnp.float32), in_specs=[pl.BlockSpec(memory_space=src_memory_space)], - out_specs=plgpu.GPUBlockSpec( - (2, m), - lambda: (0, 0), - memory_space=plgpu.SMEM, - ), + out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.SMEM), ) def kernel(x_ref, o_ref): for i in range(2): @@ -819,24 +805,19 @@ def compute(acc_ref): out = pl.run_scoped(compute, plgpu.ACC((m, n), jnp.float32)) o_ref[...] = out - - out_spec = plgpu.GPUBlockSpec( - (m, n), lambda: (0, 0), memory_space=plgpu.SMEM, - ) f = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct([m, n], jnp.float32), in_specs=( pl.BlockSpec(memory_space=src_memory_space), plgpu.GPUBlockSpec( - (k, n), - lambda: (0, 0), transforms=( - plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128), + plgpu.TilingTransform((8, 64)), + plgpu.SwizzleTransform(128), ), ), ), - out_specs=out_spec, + out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.SMEM), ) out_ref = ( @@ -855,9 +836,7 @@ def kernel(x_ref, o_ref, barrier_ref): plgpu.barrier_wait(barrier_ref) in_spec = pl.BlockSpec(memory_space=plgpu.GMEM) - out_spec = plgpu.GPUBlockSpec( - (2, 64, 2, 128), lambda: (0, 0, 0, 0), memory_space=plgpu.SMEM, - ) + out_spec = plgpu.GPUBlockSpec(memory_space=plgpu.SMEM) f = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct([2, 64, 2, 128], jnp.float32), @@ -960,11 +939,10 @@ def test_print_wgmma_tiled_layout(self): out_shape=jax.ShapeDtypeStruct(shape, jnp.float32), in_specs=[ plgpu.GPUBlockSpec( - shape, - lambda: (0, 0), transforms=( - plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(128), - ), + plgpu.TilingTransform((8, 32)), + plgpu.SwizzleTransform(128), + ) ) ], ) @@ -1143,7 +1121,8 @@ def test_swizzled_blockspec_shapes(self): (128, 64), lambda *i: i, transforms=( - plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128), + plgpu.TilingTransform((8, 64)), + plgpu.SwizzleTransform(128), ), ) @functools.partial( @@ -1327,9 +1306,7 @@ def test_tile_slicing(self): shape = (256, 128) block_spec = plgpu.GPUBlockSpec( - transforms=( - plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128), - ) + transforms=(plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) ) @functools.partial( self.pallas_call, @@ -1380,12 +1357,7 @@ def rotate(src, dst): x = jnp.arange(128 * 128).astype(jnp.float16).reshape(128, 128) spec = plgpu.GPUBlockSpec( - (128, 128), - lambda: (0, 0), - transforms=( - plgpu.TilingTransform((8, 64)), - plgpu.SwizzleTransform(128), - ), + transforms=(plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) ) f = self.pallas_call(rotate, out_shape=x, in_specs=[spec], out_specs=spec) expected = np.empty_like(x) @@ -1560,11 +1532,9 @@ def test_fori_loop_accumulator(self, force_while): transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) @functools.partial( self.pallas_call, - in_specs=[ - plgpu.GPUBlockSpec((64, 64), lambda: (0, 0), transforms=transforms) - ], + in_specs=[plgpu.GPUBlockSpec((64, 64), transforms=transforms)], out_shape=jax.ShapeDtypeStruct((64, 64), jnp.float16), - out_specs=plgpu.GPUBlockSpec((64, 64), lambda: (0, 0)), + out_specs=plgpu.GPUBlockSpec((64, 64)), ) def kernel(i_ref, o_ref): def scope(acc_ref): @@ -1613,25 +1583,28 @@ def _epilogue(): if self.THREAD_SEMANTICS == plgpu.ThreadSemantics.Lane: lhs_spec = plgpu.GPUBlockSpec( - lhs_spec.block_shape, lhs_spec.index_map, + lhs_spec.block_shape, + lhs_spec.index_map, transforms=( plgpu.TilingTransform((8, elems_128b)), plgpu.SwizzleTransform(128), - ) + ), ) rhs_spec = plgpu.GPUBlockSpec( - rhs_spec.block_shape, rhs_spec.index_map, + rhs_spec.block_shape, + rhs_spec.index_map, transforms=( plgpu.TilingTransform((8, elems_128b)), plgpu.SwizzleTransform(128), - ) + ), ) out_spec = plgpu.GPUBlockSpec( - out_spec.block_shape, out_spec.index_map, + out_spec.block_shape, + out_spec.index_map, transforms=( plgpu.TilingTransform((8, elems_128b)), plgpu.SwizzleTransform(128), - ) + ), ) res = self.pallas_call( @@ -1717,14 +1690,9 @@ def scope(acc_ref): res = self.pallas_call( kernel, in_specs=[ - plgpu.GPUBlockSpec( - (64, 128), lambda: (0, 0), transforms=transforms - ), - plgpu.GPUBlockSpec( - (128, 192), lambda: (0, 0), transforms=transforms - ), + plgpu.GPUBlockSpec(transforms=transforms), + plgpu.GPUBlockSpec(transforms=transforms), ], - out_specs=plgpu.GPUBlockSpec((64, 192), lambda: (0, 0)), out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float32), )(a, b) np.testing.assert_allclose(res, a @ b, rtol=1e-3) @@ -1747,17 +1715,10 @@ def scope(acc_ref): res = self.pallas_call( kernel, in_specs=[ - plgpu.GPUBlockSpec( - (64, 128), lambda: (0, 0), transforms=transforms - ), - plgpu.GPUBlockSpec( - (128, 192), lambda: (0, 0), transforms=transforms - ), - plgpu.GPUBlockSpec( - (64, 192), lambda: (0, 0), transforms=transforms - ), + plgpu.GPUBlockSpec(transforms=transforms), + plgpu.GPUBlockSpec(transforms=transforms), + plgpu.GPUBlockSpec(transforms=transforms), ], - out_specs=plgpu.GPUBlockSpec((64, 192), lambda: (0, 0)), out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float16), )(a, b, i) np.testing.assert_allclose(res, i + a @ b, rtol=2e-3) @@ -1783,14 +1744,9 @@ def scope(acc_ref): res = self.pallas_call( kernel, in_specs=[ - plgpu.GPUBlockSpec( - (2, 64, 128), lambda: (0, 0, 0), transforms=transforms - ), - plgpu.GPUBlockSpec( - (2, 128, 192), lambda: (0, 0, 0), transforms=transforms - ), + plgpu.GPUBlockSpec(transforms=transforms), + plgpu.GPUBlockSpec(transforms=transforms), ], - out_specs=plgpu.GPUBlockSpec((64, 192), lambda: (0, 0)), out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float32), )(a, b) np.testing.assert_allclose(res, a[0] @ b[0], rtol=1e-3) From 97cecdf862690e30da2296c90337176492f08e9e Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Thu, 3 Apr 2025 21:40:40 -0700 Subject: [PATCH 376/483] add an `out_sharding` option to `jax.random.truncated_normal` Drop into `Auto` mode in the implementation. --- jax/_src/random.py | 7 +++++-- tests/pjit_test.py | 20 ++++++++++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/jax/_src/random.py b/jax/_src/random.py index 5cbd966e7a7b..e632d4a9a2fa 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -851,7 +851,8 @@ def truncated_normal(key: ArrayLike, lower: RealArray, upper: RealArray, shape: Shape | None = None, - dtype: DTypeLikeFloat = float) -> Array: + dtype: DTypeLikeFloat = float, + *, out_sharding=None) -> Array: r"""Sample truncated standard normal random values with given shape and dtype. The values are returned according to the probability density function: @@ -882,12 +883,14 @@ def truncated_normal(key: ArrayLike, if shape is not None: shape = core.canonicalize_shape(shape) key, _ = _check_prng_key("truncated_normal", key) + out_sharding = canonicalize_sharding(out_sharding, "truncated_normal") dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `truncated_normal` must be a float " f"dtype, got {dtype}") dtype = dtypes.canonicalize_dtype(dtype) - return _truncated_normal(key, lower, upper, shape, dtype) + return maybe_auto_axes(_truncated_normal, out_sharding, + shape=shape, dtype=dtype)(key, lower, upper) @partial(jit, static_argnums=(3, 4)) def _truncated_normal(key, lower, upper, shape, dtype) -> Array: diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 580cfcd7ad8d..f2db913af736 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7423,6 +7423,26 @@ def f(arr, key): out = f(jax.device_put(np.arange(8 * 12.).reshape(8, 12), s), key) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_random_truncated_normal(self, mesh): + @jax.jit + def f(key, lower): + out = jax.random.truncated_normal(key, lower, 2., shape=(8, 12), + out_sharding=P('x', 'y')) + self.assertEqual(out.aval.sharding.spec, P('x', 'y')) + return out + + key = jax.random.key(1) + out = f(key, -1.) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + + lowered_text = f.lower(key, -1.).as_text() + if config.use_shardy_partitioner.value: + self.assertIn('sdy.sharding_constraint', lowered_text) + self.assertIn('<@mesh, [{"x"}, {"y"}]>', lowered_text) + else: + self.assertIn('mhlo.sharding = "{devices=[2,2]<=[4]}"}', lowered_text) + def test_auto_axes_no_context_mesh(self): mesh = jtu.create_mesh((2, 2), ('x', 'y'), axis_types=(AxisType.Explicit,) * 2) np_inp = np.arange(16.).reshape(8, 2) From 5eb4e7b2dc4d1ab9677cb7a22feac3f7933142ae Mon Sep 17 00:00:00 2001 From: Gleb Pobudzey Date: Thu, 3 Apr 2025 23:17:25 -0700 Subject: [PATCH 377/483] [Mosaic GPU] Return the combined softmax residuals. It's scaled so that it can be used directly as an input to exp2 in the backwards pass. PiperOrigin-RevId: 743825330 --- .../pallas/ops/gpu/attention_mgpu.py | 112 +++++++++++++----- tests/pallas/mgpu_attention_test.py | 11 +- 2 files changed, 94 insertions(+), 29 deletions(-) diff --git a/jax/experimental/pallas/ops/gpu/attention_mgpu.py b/jax/experimental/pallas/ops/gpu/attention_mgpu.py index d06d3b39cb7a..6a20b448ca54 100644 --- a/jax/experimental/pallas/ops/gpu/attention_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/attention_mgpu.py @@ -43,8 +43,8 @@ def __post_init__(self): raise ValueError(f"{self.max_concurrent_steps=} must be at least 2") -@functools.partial(jax.jit, static_argnames=["config"]) -def attention(q, k, v, config: TuningConfig): +@functools.partial(jax.jit, static_argnames=["config", "save_residuals"]) +def attention(q, k, v, config: TuningConfig, save_residuals: bool = False): if q.ndim != 4 or k.ndim != 4 or v.ndim != 4: raise ValueError(f"q, k, and v should all be 4D, got: {q.ndim=}, {k.ndim=}, {v.ndim=}") batch_size, q_seq_len, num_q_heads, head_dim = q.shape @@ -69,12 +69,12 @@ def attention(q, k, v, config: TuningConfig): ) block_q, block_kv = config.block_q, config.block_kv - def kernel(q_ref, k_ref, v_ref, out_ref, scoped): + def kernel(q_ref, k_ref, v_ref, out_ref, lse_ref, scoped): batch = lax.axis_index("batch") q_head = lax.axis_index("heads") smem_buffers, buffer_barriers, consumed_barriers, schedule_barrier = scoped wg_idx = lax.axis_index("wg") - qo_smem2, k_smem, v_smem = smem_buffers + qo_smem2, k_smem, v_smem, lse_smem2 = smem_buffers k_barriers, v_barriers, q_barriers = buffer_barriers k_consumed_barriers, v_consumed_barriers = consumed_barriers def perform_schedule_barrier(): @@ -85,6 +85,7 @@ def perform_schedule_barrier(): def _compute_wg(): plgpu.set_max_registers(232, action="increase") qo_smem = qo_smem2.at[wg_idx] + lse_smem = lse_smem2.at[wg_idx] if lse_smem2 is not None else None q_seq_base = lax.axis_index("q_seq") * (2 * block_q) + wg_idx * block_q plgpu.copy_gmem_to_smem( @@ -162,15 +163,23 @@ def _wait(): 0, kv_seq_len // block_kv, kv_loop, (acc, m_i, l_i) ) pl.when(wg_idx == 0)(perform_schedule_barrier) - del m_i # Not needed anymore # TODO(apaszke): Invert and multiply to avoid expensive divisions. acc /= lax.broadcast_in_dim(l_i, (block_q, head_dim), [0]) qo_smem[...] = acc.astype(dtype) + if lse_smem is not None: + RCP_LN2 = 1.4426950408889634 + log2 = lambda x: jnp.log(x) * RCP_LN2 + lse_smem[...] = m_i + log2(l_i) plgpu.commit_smem() plgpu.copy_smem_to_gmem( qo_smem, out_ref.at[batch, pl.ds(q_seq_base, block_q), q_head], ) + if lse_smem is not None: + plgpu.copy_smem_to_gmem( + lse_smem, + lse_ref.at[batch, q_head, pl.ds(q_seq_base, block_q)], + ) plgpu.wait_smem_to_gmem(0) @pl.when(wg_idx == 2) def _memory_wg(): @@ -191,7 +200,7 @@ def kv_loop(kv_step, _): plgpu.copy_gmem_to_smem(v_ref.at[s], v_smem.at[tma_slot], v_barriers.at[tma_slot]) lax.fori_loop(0, kv_seq_len // block_kv - max_concurrent_steps, kv_loop, None) - def entry(q_ref, k_ref, v_ref, out_ref): + def entry(q_ref, k_ref, v_ref, out_ref, lse_ref): compute_wgs = 2 tiling = plgpu.TilingTransform((8, 64)) swizzle = plgpu.SwizzleTransform(128) @@ -207,9 +216,12 @@ def entry(q_ref, k_ref, v_ref, out_ref): (max_concurrent_steps, block_kv, head_dim), jnp.float16, transforms=(tiling, swizzle), ) + scratch = [qo_scratch, k_scratch, v_scratch, None] + if save_residuals: + scratch[3] = plgpu.SMEM((compute_wgs, block_q), jnp.float32) pl.run_scoped( - lambda *args: kernel(q_ref, k_ref, v_ref, out_ref, args), - (qo_scratch, k_scratch, v_scratch), + lambda *args: kernel(q_ref, k_ref, v_ref, out_ref, lse_ref, args), + scratch, ( plgpu.Barrier(1, num_barriers=max_concurrent_steps), plgpu.Barrier(1, num_barriers=max_concurrent_steps), @@ -223,9 +235,17 @@ def entry(q_ref, k_ref, v_ref, out_ref): if rem: raise NotImplementedError(f"{q_seq_len=} must be a multiple of {block_q * 2=}") - return plgpu.kernel( + out_shape = [q, None] + if save_residuals: + # Note that we keep seq_len in the minor-most dimension so that we can do + # 1D TMAs on chunks of `block_q`. + out_shape[1] = jax.ShapeDtypeStruct( + (batch_size, num_q_heads, q_seq_len), jnp.float32 + ) + + out, lse = plgpu.kernel( entry, - out_shape=q, + out_shape=out_shape, grid=(batch_size, num_q_tiles, num_q_heads), grid_names=("batch", "q_seq", "heads"), num_threads=3, @@ -233,8 +253,14 @@ def entry(q_ref, k_ref, v_ref, out_ref): compiler_params=plgpu.GPUCompilerParams(approx_math=True), )(q, k, v) -@functools.partial(jax.jit, static_argnames=["config"]) -def attention_with_pipeline_emitter(q, k, v, config: TuningConfig): + if save_residuals: + assert lse is not None + return out, (lse,) + + return out + +@functools.partial(jax.jit, static_argnames=["config", "save_residuals"]) +def attention_with_pipeline_emitter(q, k, v, config: TuningConfig, save_residuals=False): if q.ndim != 4 or k.ndim != 4 or v.ndim != 4: raise ValueError(f"q, k, and v should all be 4D, got: {q.ndim=}, {k.ndim=}, {v.ndim=}") batch_size, q_seq_len, num_q_heads, head_dim = q.shape @@ -266,10 +292,11 @@ def attention_with_pipeline_emitter(q, k, v, config: TuningConfig): tiling = plgpu.TilingTransform((8, 64)) swizzle = plgpu.SwizzleTransform(128) - def fa3_kernel(q_ref, k_ref, v_ref, out_ref, scoped): + def fa3_kernel(q_ref, k_ref, v_ref, out_ref, lse_ref, scoped): batch = lax.axis_index("batch") wg_idx = lax.axis_index("wg") - qo_smem2, q_barriers, schedule_barrier = scoped + smem_buffers, q_barriers, schedule_barrier = scoped + qo_smem2, lse_smem2 = smem_buffers q_seq_base = lax.axis_index("q_seq") * (2 * block_q) + wg_idx * block_q q_head = lax.axis_index("heads") kv_head = lax.div(q_head, jnp.array(q_heads_per_kv_head, q_head.dtype)) @@ -281,6 +308,7 @@ def perform_schedule_barrier(): def _compute_thread(): qo_smem = qo_smem2.at[wg_idx] + lse_smem = lse_smem2.at[wg_idx] if lse_smem2 is not None else None m_i = plgpu.layout_cast( jnp.full((block_q,), -jnp.inf, dtype=jnp.float32), plgpu.Layout.WGMMA_ROW, ) @@ -299,15 +327,23 @@ def _compute_thread(): plgpu.barrier_wait(q_barriers.at[wg_idx]) pl.when(wg_idx == 1)(perform_schedule_barrier) final_carry = (yield (acc, m_i, l_i)) - del m_i # Unused pl.when(wg_idx == 0)(perform_schedule_barrier) - acc, _, l_i = final_carry + acc, m_i, l_i = final_carry acc /= lax.broadcast_in_dim(l_i, (block_q, head_dim), [0]) qo_smem[...] = acc.astype(dtype) + if lse_smem is not None: + RCP_LN2 = 1.4426950408889634 + log2 = lambda x: jnp.log(x) * RCP_LN2 + lse_smem[...] = m_i + log2(l_i) plgpu.commit_smem() plgpu.copy_smem_to_gmem( qo_smem, out_ref.at[batch, pl.ds(q_seq_base, block_q), q_head], ) + if lse_smem is not None: + plgpu.copy_smem_to_gmem( + lse_smem, + lse_ref.at[batch, q_head, pl.ds(q_seq_base, block_q)], + ) plgpu.wait_smem_to_gmem(0) def kv_pipeline(_, k_smem, v_smem, @@ -371,7 +407,7 @@ def compute_pv(acc_ref): thread_name="wg", ) def run(refs): - q_ref, k_ref, v_ref, out_ref = refs + q_ref, k_ref, v_ref, out_ref, lse_ref = refs @pl.core_map(mesh, compiler_params=plgpu.GPUCompilerParams(approx_math=True), ) @@ -380,22 +416,36 @@ def _kernel_entry(): (compute_wgs, block_q, head_dim), jnp.float16, transforms=(tiling, swizzle), ) + scratch = [qo_scratch, None] + if save_residuals: + scratch[1] = plgpu.SMEM((compute_wgs, block_q), jnp.float32) pl.run_scoped( - lambda *args: fa3_kernel(q_ref, k_ref, v_ref, out_ref, args), - qo_scratch, + lambda *args: fa3_kernel(q_ref, k_ref, v_ref, out_ref, lse_ref, args), + scratch, plgpu.Barrier(1, num_barriers=compute_wgs), plgpu.Barrier(num_arrivals=compute_wgs), ) @jax.jit - def run_function(q, k, v, o): - _, _, _, out = pl.run_state(run)((q, k, v, o)) - return out - out = run_function(q, k, v, jnp.full_like(q, jnp.inf)) + def run_function(q, k, v, o, lse): + _, _, _, out, lse = pl.run_state(run)((q, k, v, o, lse)) + return out, lse + + lse = ( + jnp.full((batch_size, num_q_heads, q_seq_len), -jnp.inf, dtype=jnp.float32) + if save_residuals + else None + ) + out, lse = run_function(q, k, v, jnp.full_like(q, jnp.inf), lse) + + if save_residuals: + assert lse is not None + return out, (lse,) + return out -@jax.jit -def attention_reference(q, k, v): +@functools.partial(jax.jit, static_argnames=["save_residuals"]) +def attention_reference(q, k, v, save_residuals=False): batch_size, q_seq_len, num_q_heads, head_dim = q.shape num_kv_heads = k.shape[2] q, k, v = map(lambda x: x.astype(jnp.float32), (q, k, v)) @@ -407,8 +457,16 @@ def attention_reference(q, k, v): unnormalized = jnp.exp(logits - m) l = unnormalized.sum(axis=-1, keepdims=True) weights = unnormalized / l - return jnp.einsum("bqHhk,bkHc->bqHhc", weights, v).reshape(*q.shape) - + out = jnp.einsum("bqHhk,bkHc->bqHhc", weights, v).reshape(*q.shape) + + if save_residuals: + log2e = math.log2(math.e) + l = l.reshape(*q.shape[:-1]) + m = m.reshape(*q.shape[:-1]) + lse = m * log2e + jnp.log2(l) + return out, (lse.swapaxes(-1, -2),) + else: + return out def main(unused_argv): num_q_heads = 16 diff --git a/tests/pallas/mgpu_attention_test.py b/tests/pallas/mgpu_attention_test.py index cf8ed30925bf..27588683d0e9 100644 --- a/tests/pallas/mgpu_attention_test.py +++ b/tests/pallas/mgpu_attention_test.py @@ -62,6 +62,7 @@ def setUp(self): attention_mgpu.attention, attention_mgpu.attention_with_pipeline_emitter, ), + save_residuals=(True,), ) def test_flash_attention( self, @@ -71,22 +72,28 @@ def test_flash_attention( num_q_and_kv_heads, head_dim, attention_impl, + save_residuals, ): num_q_heads, num_kv_heads = num_q_and_kv_heads k1, k2, k3 = jax.random.split(jax.random.key(42), 3) q = jax.random.normal(k1, (batch_size, q_seq_len, num_q_heads, head_dim), jnp.float16) k = jax.random.normal(k2, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16) v = jax.random.normal(k3, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16) - out = attention_impl( + out, *res = attention_impl( q, k, v, attention_mgpu.TuningConfig( block_q=64, block_kv=64, max_concurrent_steps=2 ), + save_residuals=save_residuals, ) - out_ref = attention_mgpu.attention_reference(q, k, v) + out_ref, *res_ref = attention_mgpu.attention_reference(q, k, v, save_residuals=save_residuals) np.testing.assert_allclose(out, out_ref, atol=2e-3, rtol=1e-3) + if save_residuals: + (lse,) = res[0] + (lse_ref,) = res_ref[0] + np.testing.assert_allclose(lse, lse_ref, atol=2e-3, rtol=1e-3) if __name__ == "__main__": From 12b1a99ad943de56f776d9e18bcdfb351908927c Mon Sep 17 00:00:00 2001 From: Meesum Qazalbash Date: Fri, 4 Apr 2025 11:35:46 +0500 Subject: [PATCH 378/483] fix(docs): corrected the name of the function call in the document --- docs/gradient-checkpointing.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/gradient-checkpointing.md b/docs/gradient-checkpointing.md index 0938a5da944f..e4e842df49f0 100644 --- a/docs/gradient-checkpointing.md +++ b/docs/gradient-checkpointing.md @@ -341,7 +341,7 @@ def predict(params, x): return x ``` -By itself, {func}`jax.ad_checkpoint import.checkpoint_name` is just an identity function. But because some policy functions know to look for them, you can use the names to control whether certain values output by {func}`jax.ad_checkpoint import.checkpoint_name` are considered saveable: +By itself, {func}`jax.ad_checkpoint.checkpoint_name` is just an identity function. But because some policy functions know to look for them, you can use the names to control whether certain values output by {func}`jax.ad_checkpoint.checkpoint_name` are considered saveable: ```{code-cell} print_saved_residuals(loss, params, x, y) From e619fc0b72570cc4b8fe305ccc70c75f27c1a52f Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 4 Apr 2025 00:02:47 -0700 Subject: [PATCH 379/483] Avoid double buffering when no windowing info is present. PiperOrigin-RevId: 743834475 --- jax/_src/pallas/mosaic/lowering.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 0ca298c88dc5..67cacbc8dcf9 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -743,6 +743,15 @@ def dynamic_shape_replacement_fn( block_shape = [ 1 if b is pallas_core.mapped else b for b in bm.block_shape ] + + # No sense in double-buffering without any windowing pattern. + buffer_count = 0 + if ( + tpu_memory_space == tpu_core.TPUMemorySpace.VMEM + and bm.has_trivial_window() + ): + buffer_count = 1 + # If we have an extended dtype, we need to add the block shape for the # remaining physical dtype. block_shape += list(_get_aval_physical_dtype_shape(bm.block_aval.inner_aval)) @@ -765,7 +774,8 @@ def dynamic_shape_replacement_fn( raise LoweringException( f"Unsupported pipeline mode: {bm.pipeline_mode}." ) - buffer_count = bm.pipeline_mode.buffer_count + if buffer_count == 0: + buffer_count = bm.pipeline_mode.buffer_count if buffer_count < 1 or buffer_count > 2: raise LoweringException( "Only single (1) and double (2) buffering are supported. Got" From 1b63d5e26f72cee57c72e8abee842b7a5ee35405 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 3 Apr 2025 16:26:53 +0000 Subject: [PATCH 380/483] Fixed deadlock in NamedSharding ctor Description: - Test timeout were seen in ColocatedPythonTest test case - GDB report: https://gist.github.com/vfdev-5/d64183f7b5dde3e666eea6cd61670128 --- jaxlib/xla/sharding.cc | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/jaxlib/xla/sharding.cc b/jaxlib/xla/sharding.cc index 5a80c03e01da..858c025745e4 100644 --- a/jaxlib/xla/sharding.cc +++ b/jaxlib/xla/sharding.cc @@ -223,9 +223,22 @@ NamedSharding::NamedSharding(nb::object mesh, nb::object spec, // TODO(phawkins): this leaks a reference to the check_pspec function. // A better way to fix this would be to move PartitionSpec and this check into // C++. - static nb::object* check_pspec = []() { + nb::object* check_pspec = [](){ + static absl::Mutex mu; + static nb::object* output = nullptr; + { + absl::MutexLock lock(&mu); + if (output) { + return output; + } + } nb::module_ si = nb::module_::import_("jax._src.named_sharding"); - return new nb::object(si.attr("check_pspec")); + nb::object attr = si.attr("check_pspec"); + absl::MutexLock lock(&mu); + if (!output) { + output = new nb::object(attr); + } + return output; }(); (*check_pspec)(mesh_, spec_, manual_axes_); } From 206dec859d30e10970e60664802295ec1737131c Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 4 Apr 2025 02:33:24 -0700 Subject: [PATCH 381/483] [pallas:mosaic_gpu] Added pretty printing to primitives consuming refs I also changed existing pretty printers for transforms to use {} instead of [], so that transforms are visually distinct from slicing. PiperOrigin-RevId: 743869470 --- jax/_src/pallas/mosaic_gpu/BUILD | 3 +- jax/_src/pallas/mosaic_gpu/core.py | 10 ++ jax/_src/pallas/mosaic_gpu/primitives.py | 181 +++++++++++++++++++++++ jax/_src/state/indexing.py | 34 +++++ jax/_src/state/primitives.py | 68 +-------- jax/_src/state/types.py | 11 ++ 6 files changed, 239 insertions(+), 68 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/BUILD b/jax/_src/pallas/mosaic_gpu/BUILD index 33883326e58c..554b9db878f6 100644 --- a/jax/_src/pallas/mosaic_gpu/BUILD +++ b/jax/_src/pallas/mosaic_gpu/BUILD @@ -78,6 +78,7 @@ pytype_strict_library( "//jax:dtypes", "//jax:effects", "//jax:mosaic_gpu", + "//jax:pretty_printer", "//jax:state_types", "//jax:tree_util", "//jax/_src/lib", @@ -94,8 +95,8 @@ pytype_strict_library( ":lowering", "//jax", "//jax:core", - "//jax:mlir", "//jax:mosaic_gpu", + "//jax:pretty_printer", "//jax:tree_util", "//jax:util", "//jax/_src/lib", diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 0a949840ab62..2150b48b5108 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -29,6 +29,7 @@ from jax._src import dtypes from jax._src import effects from jax._src import tree_util +from jax._src import pretty_printer as pp from jax._src.lib.mlir.dialects import arith as arith_dialect from jax._src.pallas import core as pallas_core from jax._src.pallas import primitives as pallas_primitives @@ -328,6 +329,9 @@ def untransform_index( def undo_to_gpu_transform(self) -> mgpu.MemRefTransform: return mgpu.TileTransform(self.tiling) + def pretty_print(self, context: jax_core.JaxprPpContext) -> pp.Doc: + return pp.text(f"{{untile({list(self.tiling)})}}") + def _perm_inverse(permutation: tuple[int, ...]) -> tuple[int, ...]: inverse = [-1] * len(permutation) @@ -406,6 +410,9 @@ def untransform_index( def undo_to_gpu_transform(self) -> mgpu.MemRefTransform: return mgpu.TransposeTransform(_perm_inverse(self.permutation)) + def pretty_print(self, context: jax_core.JaxprPpContext) -> pp.Doc: + return pp.text(f"{{transpose({list(self.permutation)})}}") + def transform_ref( ref: pallas_core.TransformedRef, @@ -517,6 +524,9 @@ def untransform_index( raise ValueError("Swizzled dims cannot be sliced") return idxs, self + def pretty_print(self, context: jax_core.JaxprPpContext) -> pp.Doc: + return pp.text(f"{{unswizzle({self.swizzle})}}") + @dataclasses.dataclass class GPUBlockSpec(pallas_core.BlockSpec): diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 37d71cd6d1c6..b909a31496bf 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -25,6 +25,7 @@ import jax from jax._src import core as jax_core +from jax._src import pretty_printer as pp from jax._src import state from jax._src import tree_util from jax._src import util @@ -172,6 +173,40 @@ def _copy_smem_to_gmem_abstract_eval(src, dst, *args, **params): return (), {state.ReadEffect(0), state.WriteEffect(1)} +def _copy_smem_to_gmem_pp_eqn( + eqn: jax_core.JaxprEqn, + context: jax_core.JaxprPpContext, + settings: jax_core.JaxprPpSettings, +): + src, dst, *flat_args = eqn.invars + src_transforms_treedef = eqn.params["src_transforms_treedef"] + dst_transforms_treedef = eqn.params["dst_transforms_treedef"] + pp_params = {} + if not (commit_group := eqn.params["commit_group"]): + pp_params["commit_group"] = commit_group + if has_user_predicate := eqn.params["has_user_predicate"]: + pp_params["has_user_predicate"] = has_user_predicate + if reduction_op := eqn.params["reduction_op"]: + pp_params["reduction_op"] = reduction_op + flat_src_transforms, flat_dst_transforms = util.split_list( + flat_args, + [src_transforms_treedef.num_leaves], + ) + src_transforms = src_transforms_treedef.unflatten(flat_src_transforms) + dst_transforms = dst_transforms_treedef.unflatten(flat_dst_transforms) + return pp.concat([ + pp.text("copy_smem_to_gmem"), + jax_core.pp_kv_pairs(pp_params.items(), context, settings), + pp.text(" "), + state_primitives.pp_ref_transforms(context, src, src_transforms), + pp.text(" -> "), + state_primitives.pp_ref_transforms(context, dst, dst_transforms), + ]) + + +jax_core.pp_eqn_rules[copy_smem_to_gmem_p] = _copy_smem_to_gmem_pp_eqn + + @lowering.register_lowering_rule(copy_smem_to_gmem_p, mgpu.ThreadSemantics.Lane) @lowering.register_lowering_rule( copy_smem_to_gmem_p, mgpu.ThreadSemantics.Warpgroup @@ -355,6 +390,47 @@ def _copy_gmem_to_smem_abstract_eval(src, dst, barrier, *args, **params): return (), {state.ReadEffect(0), state.WriteEffect(1)} +def _copy_gmem_to_smem_pp_eqn( + eqn: jax_core.JaxprEqn, + context: jax_core.JaxprPpContext, + settings: jax_core.JaxprPpSettings, +): + src, dst, barrier, *flat_args = eqn.invars + src_transforms_treedef = eqn.params["src_transforms_treedef"] + dst_transforms_treedef = eqn.params["dst_transforms_treedef"] + barrier_transforms_treedef = eqn.params["barrier_transforms_treedef"] + pp_params = {} + if collective_axes := eqn.params["collective_axes"]: + pp_params["collective_axes"] = collective_axes + flat_src_transforms, flat_dst_transforms, flat_barrier_transforms = ( + util.split_list( + flat_args, + [ + src_transforms_treedef.num_leaves, + dst_transforms_treedef.num_leaves, + ], + ) + ) + src_transforms = src_transforms_treedef.unflatten(flat_src_transforms) + dst_transforms = dst_transforms_treedef.unflatten(flat_dst_transforms) + barrier_transforms = barrier_transforms_treedef.unflatten( + flat_barrier_transforms + ) + return pp.concat([ + pp.text("copy_gmem_to_smem"), + jax_core.pp_kv_pairs(pp_params.items(), context, settings), + pp.text(" "), + state_primitives.pp_ref_transforms(context, src, src_transforms), + pp.text(" -> "), + state_primitives.pp_ref_transforms(context, dst, dst_transforms), + pp.text(" using "), + state_primitives.pp_ref_transforms(context, barrier, barrier_transforms), + ]) + + +jax_core.pp_eqn_rules[copy_gmem_to_smem_p] = _copy_gmem_to_smem_pp_eqn + + @lowering.register_lowering_rule(copy_gmem_to_smem_p, mgpu.ThreadSemantics.Lane) @lowering.register_lowering_rule( copy_gmem_to_smem_p, mgpu.ThreadSemantics.Warpgroup @@ -521,6 +597,25 @@ def _barrier_arrive_abstract_eval(barrier, *args, **params): return (), {gpu_core._memory_effect} +def _barrier_arrive_pp_eqn( + eqn: jax_core.JaxprEqn, + context: jax_core.JaxprPpContext, + settings: jax_core.JaxprPpSettings, +): + del settings + barrier, *flat_transforms = eqn.invars + transforms_treedef = eqn.params["transforms_tree"] + transforms = transforms_treedef.unflatten(flat_transforms) + return pp.concat([ + pp.text("barrier_arrive"), + pp.text(" "), + state_primitives.pp_ref_transforms(context, barrier, transforms), + ]) + + +jax_core.pp_eqn_rules[barrier_arrive_p] = _barrier_arrive_pp_eqn + + @lowering.register_lowering_rule(barrier_arrive_p, mgpu.ThreadSemantics.Lane) @lowering.register_lowering_rule(barrier_arrive_p, mgpu.ThreadSemantics.Warpgroup) def _barrier_arrive_lowering( @@ -560,6 +655,25 @@ def _barrier_wait_abstract_eval(barrier, *args, **params): return (), {gpu_core._memory_effect} +def _barrier_wait_pp_eqn( + eqn: jax_core.JaxprEqn, + context: jax_core.JaxprPpContext, + settings: jax_core.JaxprPpSettings, +): + del settings + barrier, *flat_transforms = eqn.invars + transforms_treedef = eqn.params["transforms_treedef"] + transforms = transforms_treedef.unflatten(flat_transforms) + return pp.concat([ + pp.text("barrier_wait"), + pp.text(" "), + state_primitives.pp_ref_transforms(context, barrier, transforms), + ]) + + +jax_core.pp_eqn_rules[barrier_wait_p] = _barrier_wait_pp_eqn + + @lowering.register_lowering_rule(barrier_wait_p, mgpu.ThreadSemantics.Lane) @lowering.register_lowering_rule(barrier_wait_p, mgpu.ThreadSemantics.Warpgroup) def _barrier_wait_lowering( @@ -715,6 +829,39 @@ def _wgmma_ref_effectful_abstract_eval(acc_aval, a_aval, b_aval, *_, **params): } +def _wgmma_ref_pp_eqn( + eqn: jax_core.JaxprEqn, + context: jax_core.JaxprPpContext, + settings: jax_core.JaxprPpSettings, +): + del settings + acc, a, b, *leaves = eqn.invars + a_transforms_treedef = eqn.params["a_transforms_tree"] + b_transforms_treedef = eqn.params["b_transforms_tree"] + a_transforms = ( + a_transforms_treedef.unflatten(leaves[: a_transforms_treedef.num_leaves]) + if a_transforms_treedef is not None + else [] + ) + b_transforms = ( + b_transforms_treedef.unflatten(leaves[a_transforms_treedef.num_leaves :]) + if b_transforms_treedef is not None + else [] + ) + return pp.concat([ + pp.text("wgmma_ref"), + pp.text(" "), + pp.text(jax_core.pp_var(acc, context)), + pp.text(" <- "), + state_primitives.pp_ref_transforms(context, a, a_transforms), + pp.text(" @ "), + state_primitives.pp_ref_transforms(context, b, b_transforms), + ]) + + +jax_core.pp_eqn_rules[wgmma_ref_p] = _wgmma_ref_pp_eqn + + @discharge.register_discharge_rule(wgmma_ref_p) def _wgmma_ref_discharge(in_avals, out_avals, *args, **kwargs): del in_avals, out_avals @@ -1090,6 +1237,40 @@ def _jaxpr_call_abstract_eval(*args, jaxpr: jax_core.Jaxpr, **params): return [v.aval for v in jaxpr.outvars] +def _jaxpr_call_pp_eqn( + eqn: jax_core.JaxprEqn, + context: jax_core.JaxprPpContext, + settings: jax_core.JaxprPpSettings, +): + flat_args = eqn.invars + ref_treedefs = eqn.params["ref_treedefs"] + flat_refs, _ = util.split_list( + flat_args, [sum(treedef.num_leaves for treedef in ref_treedefs)] + ) + flat_refs = util.split_list( + flat_refs, + [treedef.num_leaves for treedef in ref_treedefs[: len(ref_treedefs) - 1]], + ) + trailer = [] + for treedef, flat_ref in zip(ref_treedefs, flat_refs): + ref = treedef.unflatten(flat_ref) + transforms = [] + if isinstance(ref, tuple): + ref, transforms = ref + trailer.append(pp.text(" ")) + trailer.append(state_primitives.pp_ref_transforms(context, ref, transforms)) + return pp.concat([ + pp.text("jaxpr_call"), + pp.text("["), + jax_core.pp_kv_pair("jaxpr", eqn.params["jaxpr"], context, settings), + pp.text("]"), + pp.concat(trailer), + ]) + + +jax_core.pp_eqn_rules[jaxpr_call_p] = _jaxpr_call_pp_eqn + + @lowering.register_lowering_rule(jaxpr_call_p, mgpu.ThreadSemantics.Lane) @lowering.register_lowering_rule(jaxpr_call_p, mgpu.ThreadSemantics.Warpgroup) def _jaxpr_call_lowering_rule( diff --git a/jax/_src/state/indexing.py b/jax/_src/state/indexing.py index 4b627c1cd581..e7b581680efe 100644 --- a/jax/_src/state/indexing.py +++ b/jax/_src/state/indexing.py @@ -20,6 +20,7 @@ from typing import Any, Sequence, Union from jax._src import core +from jax._src import pretty_printer as pp from jax._src import tree_util from jax._src.typing import Array from jax._src.util import merge_lists @@ -78,6 +79,30 @@ def from_slice(cls, slc: slice, size: int) -> Slice: return cls(start, size, step) +def _pp_slice(context: core.JaxprPpContext, dim, slc: Slice) -> str: + start, size = slc.start, slc.size + if isinstance(start, core.Var): + start_str = core.pp_var(start, context) + size_str = ( + core.pp_var(size, context) if isinstance(size, core.Var) else str(size) + ) + return f"{start_str}:{start_str}+{size_str}" + else: + start_str = str(start) + if start == 0: + start_str = "" + if isinstance(size, core.Var): + size_str = core.pp_var(size, context) + if start_str: + return f"{start_str}:{start_str}+{size_str}" + else: + return f":{size_str}" + else: + end = start + size + end_str = "" if end == dim else str(end) + return f"{start_str}:{end_str}" + + def dslice( start: int | Array | None, size: int | Array | None = None, @@ -282,3 +307,12 @@ def transform_sharding(self, sharding): f"along unsharded axes, but ref of shape {self.shape} " f"was sliced on axis {i}, which is sharded like {s}") return sharding + + def pretty_print(self, context: core.JaxprPpContext) -> pp.Doc: + indices = [] + for idx, dim in zip(self.indices, self.shape): + if isinstance(idx, Slice): + indices.append(_pp_slice(context, dim, idx)) + else: + indices.append(core.pp_var(idx, context)) # type: ignore + return pp.concat([pp.text("["), pp.text(",".join(indices)), pp.text("]")]) diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py index 6f7570a5f3cd..f992f96992da 100644 --- a/jax/_src/state/primitives.py +++ b/jax/_src/state/primitives.py @@ -34,8 +34,6 @@ AbstractRef, AccumEffect, ReadEffect, - RefBitcaster, - RefReshaper, Transform, TransformedRef, WriteEffect, @@ -297,70 +295,6 @@ def _addupdate_abstract_eval(ref_aval: AbstractRef, pp_ref_var = partial(pp.color, intensity=pp.Intensity.NORMAL, foreground=pp.Color.GREEN) -def _pp_slice(context: core.JaxprPpContext, dim, slc: indexing.Slice - ) -> str: - start, size = slc.start, slc.size - if isinstance(start, core.Var): - start_str = core.pp_var(start, context) - size_str = ( - core.pp_var(size, context) - if isinstance(size, core.Var) - else str(size) - ) - return f'{start_str}:{start_str}+{size_str}' - else: - start_str = str(start) - if start == 0: - start_str = '' - if isinstance(size, core.Var): - size_str = core.pp_var(size, context) - if start_str: - return f'{start_str}:{start_str}+{size_str}' - else: - return f':{size_str}' - else: - end = start + size - end_str = '' if end == dim else str(end) - return f'{start_str}:{end_str}' - -def pp_indexer(context: core.JaxprPpContext,indexer: indexing.NDIndexer - ) -> pp.Doc: - indices = [] - for idx, dim in zip(indexer.indices, indexer.shape): - if isinstance(idx, indexing.Slice): - indices.append(_pp_slice(context, dim, idx)) - else: - indices.append(core.pp_var(idx, context)) # type: ignore - return pp.concat([pp.text("["), pp.text(','.join(indices)), pp.text("]")]) - - -def pp_bitcaster( - context: core.JaxprPpContext, bitcaster: RefBitcaster -) -> pp.Doc: - del context - return pp.text( - f"[bitcast({bitcaster.dtype}[{','.join(str(d) for d in bitcaster.shape)}])]" - ) - - -def pp_reshaper(context: core.JaxprPpContext, reshaper: RefReshaper) -> pp.Doc: - del context - return pp.text( - f"[reshape({reshaper.dtype}[{','.join(str(d) for d in reshaper.shape)}])]" - ) - - -def pp_transform(context: core.JaxprPpContext, transform: Transform) -> pp.Doc: - match transform: - case indexing.NDIndexer(): - return pp_indexer(context, transform) - case RefBitcaster(): - return pp_bitcaster(context, transform) - case RefReshaper(): - return pp_reshaper(context, transform) - case _: - return pp.text(f"[{transform}]") - def _pp_transforms( context: core.JaxprPpContext, @@ -369,7 +303,7 @@ def _pp_transforms( if not transforms: return pp.text("[...]") return pp.concat( - [pp_transform(context, transform) for transform in transforms] + [transform.pretty_print(context) for transform in transforms] ) diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py index b9dbaf35c5d2..1acb856fd1ba 100644 --- a/jax/_src/state/types.py +++ b/jax/_src/state/types.py @@ -125,6 +125,10 @@ def transform_sharding(self, sharding): return sharding raise NotImplementedError + def pretty_print(self, context: core.JaxprPpContext) -> pp.Doc: + del context # Unused. + return pp.text(f"{{bitcast({self.dtype}{list(self.shape)}])}}") + @tree_util.register_pytree_node_class @dataclasses.dataclass(frozen=True) @@ -178,6 +182,10 @@ def transform_sharding(self, sharding): return sharding raise NotImplementedError + def pretty_print(self, context: core.JaxprPpContext) -> pp.Doc: + del context # Unused. + return pp.text(f"{{reshape({self.dtype}{list(self.shape)})}}") + class Transform(Protocol): @@ -205,6 +213,9 @@ def transform_sharding(self, sharding): if all(p is None for p in sharding.spec): return sharding # no explicit axes raise NotImplementedError + def pretty_print(self, context: core.JaxprPpContext) -> pp.Doc: + return pp.text(f"{{{self}}}") + @dataclasses.dataclass class RefIndexer: From b0a920dd92480962ecfb1fa55232fa2c0e584038 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 4 Apr 2025 05:11:10 -0700 Subject: [PATCH 382/483] [Mosaic GPU] Don't force TiledLayout.lane_dims to partition data This allows us to replicate elements across a warp and replace the special WGMMAFragRowLayout with a TiledLayout. PiperOrigin-RevId: 743903003 --- jax/_src/pallas/mosaic_gpu/lowering.py | 9 +- jax/_src/pallas/mosaic_gpu/primitives.py | 49 ++-- jax/experimental/mosaic/gpu/__init__.py | 2 +- .../mosaic/gpu/fragmented_array.py | 231 +++++++++--------- jax/experimental/mosaic/gpu/layouts.py | 4 - jaxlib/mosaic/dialect/gpu/mosaic_gpu.td | 12 - tests/mosaic/gpu_test.py | 16 +- tests/pallas/mosaic_gpu_test.py | 4 +- 8 files changed, 163 insertions(+), 164 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 80757ef69e64..d26c71cecc31 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1202,9 +1202,12 @@ def _swap_lowering_rule( return old_value case (): match value.layout: - case mgpu.WGMMARowFragLayout(): - old_value = mgpu.FragmentedArray.load_wgmma_row( - x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype) + case mgpu.TiledLayout(): + old_value = mgpu.FragmentedArray.load_untiled( + x_smem, + layout=value.layout, + is_signed=mgpu_utils.is_signed(x_aval.dtype), + optimized=False, ) value.store_untiled(x_smem) return old_value diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index b909a31496bf..76759d6bcb83 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -67,9 +67,8 @@ def _check_ref( load_p = jax_core.Primitive("load") @load_p.def_effectful_abstract_eval -def _load_abstract_eval(src, *avals_flat, args_tree, layout): - del layout # Unused. - +def _load_abstract_eval(src, *avals_flat, args_tree, layout, optimized): + del layout, optimized # Unused. transforms = args_tree.unflatten(avals_flat) return ( jax_core.ShapedArray(transforms[-1].get_indexer_shape(), src.dtype), @@ -78,7 +77,7 @@ def _load_abstract_eval(src, *avals_flat, args_tree, layout): @lowering.register_lowering_rule(load_p, mgpu.ThreadSemantics.Lane) def _load_p_lowering_rule( - ctx: lowering.LoweringRuleContext, x_ref, *leaves, args_tree, layout + ctx: lowering.LoweringRuleContext, x_ref, *leaves, args_tree, layout, optimized ): if not isinstance(x_ref, ir.Value) or not ir.MemRefType.isinstance(x_ref.type): raise TypeError(f"Can only load from references (got {x_ref}).") @@ -91,29 +90,36 @@ def _load_p_lowering_rule( if layout is not None: layout = layout.to_mgpu() + is_signed = mgpu_utils.is_signed(x_aval.dtype) match transforms: case (gpu_core.UnswizzleRef(swizzle), gpu_core.UntileRef(tiling)): if tiling != (8, swizzle // x_aval.dtype.itemsize): raise NotImplementedError("Tiling does not fit swizzle") return mgpu.FragmentedArray.load_tiled( - x_ref, is_signed=mgpu_utils.is_signed(x_aval.dtype), swizzle=swizzle, - layout=layout + x_ref, + is_signed=is_signed, + swizzle=swizzle, + layout=layout, ) case (): # Handle scalar indexing. if not ctx.avals_out[0].shape: is_signed = mgpu_utils.is_signed(x_aval.dtype) val = memref_dialect.load(x_ref, []) - return mgpu.FragmentedArray.splat(val, shape=(), layout=layout, is_signed=is_signed) + return mgpu.FragmentedArray.splat( + val, shape=(), layout=layout, is_signed=is_signed + ) match layout: - case mgpu.WGMMARowFragLayout(): - return mgpu.FragmentedArray.load_wgmma_row( - x_ref, is_signed=mgpu_utils.is_signed(x_aval.dtype) + case mgpu.WGMMA_ROW_LAYOUT: + return mgpu.FragmentedArray.load_untiled( + x_ref, + is_signed=is_signed, + layout=layout, + swizzle=16, + optimized=optimized, ) case mgpu.WGMMAColFragLayout(): - return mgpu.FragmentedArray.load_wgmma_col( - x_ref, is_signed=mgpu_utils.is_signed(x_aval.dtype) - ) + return mgpu.FragmentedArray.load_wgmma_col(x_ref, is_signed=is_signed) case mgpu.WGStridedFragLayout(shape=shape, vec_size=vec_size): ref_ty = ir.MemRefType(x_ref.type) if shape != tuple(ref_ty.shape): @@ -122,12 +128,10 @@ def _load_p_lowering_rule( ) return mgpu.FragmentedArray.load_strided( - x_ref, is_signed=mgpu_utils.is_signed(x_aval.dtype), vec_size=vec_size, + x_ref, is_signed=is_signed, vec_size=vec_size, ) case None: - return mgpu.FragmentedArray.load_strided( - x_ref, is_signed=mgpu_utils.is_signed(x_aval.dtype) - ) + return mgpu.FragmentedArray.load_strided(x_ref, is_signed=is_signed) case _: raise NotImplementedError(f"Unsupported layout: {layout}") case _: @@ -135,7 +139,11 @@ def _load_p_lowering_rule( def load( - src: _Ref, idx, *, layout: Layout | ParameterizedLayout | None = None + src: _Ref, + idx, + *, + layout: Layout | ParameterizedLayout | None = None, + optimized: bool = True, ) -> jax.Array: """Loads from a reference into an array with the specified layout. @@ -143,6 +151,8 @@ def load( src: The reference to load from. Can be either in SMEM or GMEM. idx: The index to load from. layout: The optional layout to use for the resulting array. + optimized: If True, a compilation error will be raised if no optimized + implementation for the load is available. Returns: The loaded array. @@ -157,7 +167,8 @@ def load( src, *flat_src_transforms, args_tree=src_transforms_treedef, - layout=layout + layout=layout, + optimized=optimized, ) diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index e645115940e4..d5b3bea6d36b 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -53,11 +53,11 @@ from .fragmented_array import ( FragmentedArray as FragmentedArray, FragmentedLayout as FragmentedLayout, + TiledLayout as TiledLayout, WGMMA_LAYOUT as WGMMA_LAYOUT, WGMMA_ROW_LAYOUT as WGMMA_ROW_LAYOUT, WGMMA_COL_LAYOUT as WGMMA_COL_LAYOUT, WGMMA_TRANSPOSED_LAYOUT as WGMMA_TRANSPOSED_LAYOUT, - WGMMARowFragLayout as WGMMARowFragLayout, WGMMAColFragLayout as WGMMAColFragLayout, WGSplatFragLayout as WGSplatFragLayout, WGStridedFragLayout as WGStridedFragLayout, diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index b730e34e2ed0..f7ce36c62c9e 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -202,6 +202,11 @@ def enumerate_negative(elems: Sequence[T]) -> Iterable[tuple[int, T]]: yield i - offset, e +@dataclasses.dataclass(frozen=True) +class Replicated: + times: int + + @dataclasses.dataclass(frozen=True) class TiledLayout: """A FragmentedArray layout derived from a tiling expression. @@ -248,7 +253,7 @@ class TiledLayout: """ tiling: Tiling warp_dim: int - lane_dims: tuple[int, ...] # major-to-minor + lane_dims: tuple[int | Replicated, ...] # major-to-minor vector_dim: int def __post_init__(self): @@ -256,8 +261,8 @@ def __post_init__(self): raise ValueError("Tiling must have at least one tile") min_shape = self.tiling.tiles[0] min_tiled_shape = self.tiling.tile_shape(min_shape) - dims_set = {self.warp_dim, *self.lane_dims, self.vector_dim} - if len(dims_set) != len(self.lane_dims) + 2: + dims_set = {self.warp_dim, *self.partitioned_lane_dims, self.vector_dim} + if len(dims_set) != len(self.partitioned_lane_dims) + 2: raise ValueError for d in dims_set: if d >= 0: @@ -266,9 +271,19 @@ def __post_init__(self): raise ValueError("Dimension out of range") if min_tiled_shape[self.warp_dim] != WARPS_IN_WARPGROUP: raise ValueError - if math.prod(min_tiled_shape[d] for d in self.lane_dims) != WARP_SIZE: + lane_dims_prod = math.prod( + d.times if isinstance(d, Replicated) else min_tiled_shape[d] + for d in self.lane_dims + ) + if lane_dims_prod != WARP_SIZE: raise ValueError + @functools.cached_property + def partitioned_lane_dims(self) -> tuple[int, ...]: + return tuple( + d for d in self.lane_dims if not isinstance(d, Replicated) + ) + def thread_idxs(self, shape: tuple[int, ...]) -> Iterable[tuple[ir.Value, ...]]: # We first find the linear index and then divide by the shape to # get the index. @@ -326,7 +341,7 @@ def registers_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: """Returns the shape of the register array needed to represent an array of the given logical shape.""" tiled_shape = list(self.tiling.tile_shape(shape)) tiled_shape[self.warp_dim] = 1 - for d in self.lane_dims: + for d in self.partitioned_lane_dims: tiled_shape[d] = 1 tiled_shape[self.vector_dim] = 1 return tuple(tiled_shape) @@ -339,15 +354,18 @@ def shape_from_registers_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: tiled_tiling = self.tiled_tiling_shape shape = list(shape) shape[self.warp_dim] = WARPS_IN_WARPGROUP - for d in self.lane_dims: + for d in self.partitioned_lane_dims: shape[d] = tiled_tiling[d] shape[self.vector_dim] = tiled_tiling[self.vector_dim] return self.tiling.untile_shape(tuple(shape)) - def lane_indices(self) -> tuple[ir.Value, ...]: + def _full_lane_indices(self) -> tuple[ir.Value, ...]: i32 = ir.IntegerType.get_signless(32) tiled_shape = self.tiled_tiling_shape - lanes_shape = tuple(tiled_shape[d] for d in self.lane_dims) + lanes_shape = tuple( + d.times if isinstance(d, Replicated) else tiled_shape[d] + for d in self.lane_dims + ) assert math.prod(lanes_shape) == WARP_SIZE lane_strides = utils.get_contiguous_strides(lanes_shape) lane_idx = arith.remui(utils.thread_idx(), c(WARP_SIZE, i32)) @@ -355,8 +373,16 @@ def lane_indices(self) -> tuple[ir.Value, ...]: arith.remui(arith.divui(lane_idx, c(stride, i32)), c(size, i32)) for stride, size in zip(lane_strides, lanes_shape) ) + return lane_indices + + def lane_indices(self) -> tuple[ir.Value, ...]: + i32 = ir.IntegerType.get_signless(32) + tiled_shape = self.tiled_tiling_shape + lane_indices = self._full_lane_indices() full_indices = [arith.constant(i32, 0)] * len(tiled_shape) for d, i in zip(self.lane_dims, lane_indices): + if isinstance(d, Replicated): + continue full_indices[d] = i return tuple(full_indices) @@ -385,41 +411,6 @@ def _tiled_wgmma_layout(shape: tuple[int, ...]): return WGMMA_LAYOUT -@dataclasses.dataclass(frozen=True) -class WGMMARowFragLayout: - """[m] matrix, where m % 64 == 0.""" - - def registers_element_type(self, t: ir.Type) -> ir.Type: - return t - - def registers_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: - """Returns the shape of the register array needed to represent an array of the given logical shape.""" - if len(shape) != 1: - raise ValueError("WGMMARowFragLayout requires a 1D shape") - if shape[0] % 64: - raise ValueError( - "WGMMARowFragLayout requires shape[0] to be a multiple of 64" - ) - return (shape[0] // 64, 2) - - def thread_idxs(self, shape): - index = ir.IndexType.get() - assert len(shape) == 1 - assert shape[0] % 64 == 0 - tid = arith.index_cast(ir.IndexType.get(), mgpu.thread_idx()) - tid_wg = arith.remui(tid, c(WARPGROUP_SIZE, index)) - warp_idx = arith.divui(tid_wg, c(32, index)) - lane_id = arith.remui(tid_wg, c(32, index)) - row_base = arith.addi( - arith.divui(lane_id, c(4, index)), arith.muli(warp_idx, c(16, index)) - ) - - for row_group in range(0, shape[0], 64): - for row_subgroup in (0, 8): - row = arith.addi(row_base, c(row_group + row_subgroup, index)) - yield (row,) - - @dataclasses.dataclass(frozen=True) class WGMMAColFragLayout: """[n] matrix, where n % 8 == 0.""" @@ -547,11 +538,16 @@ def linear_thread_idxs(self): yield arith.addi(off, c(i * WARPGROUP_SIZE * self.vec_size, tidx.type)) -FragmentedLayout = WGSplatFragLayout | WGStridedFragLayout | WGMMARowFragLayout | WGMMAColFragLayout | TiledLayout +FragmentedLayout = WGSplatFragLayout | WGStridedFragLayout | WGMMAColFragLayout | TiledLayout -WGMMA_ROW_LAYOUT = WGMMARowFragLayout() WGMMA_COL_LAYOUT = WGMMAColFragLayout() +WGMMA_ROW_LAYOUT = TiledLayout( + Tiling(((64,), (16,), (8,), (1,))), + warp_dim=-4, + lane_dims=(-2, Replicated(4)), + vector_dim=-1, +) # The tiled layout is equivalent to one described here in PTX documentation: # https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-d @@ -663,12 +659,6 @@ def __init__( ) match self.layout: - # Registers are [m_tiles, 2 rows] in WGMMA_ROW layout - # Each element is a dtype scalar - case WGMMARowFragLayout(): - if _registers.ndim != 2 or _registers.shape[-1] != 2: - raise ValueError(f"Invalid register array shape: {_registers.shape}") - # Registers are [n_tiles] in WGMMA_COL layout # Each element is a vector of size 2. case WGMMAColFragLayout(): @@ -731,30 +721,6 @@ def load_strided( vecs = [vector.load(vec_ty, ref, vec_idx) for vec_idx in layout.thread_idxs(shape)] return cls(_registers=np.array(vecs), _layout=layout, _is_signed=is_signed) - @classmethod - def load_wgmma_row( - cls, - ref: ir.Value, - *, - is_signed: bool | None = None, - ): - if not ir.MemRefType.isinstance(ref.type): - raise TypeError(ref.type) - - ref_ty = ir.MemRefType(ref.type) - shape = tuple(ref_ty.shape) - if len(shape) != 1: - raise ValueError("WGMMARowFragLayout requires a 1D shape") - if shape[0] % 64: - raise ValueError( - "WGMMARowFragLayout requires shape[0] to be a multiple of 64" - ) - - layout = WGMMARowFragLayout() - registers = [memref.load(ref, [idx]) for (idx,) in layout.thread_idxs(shape)] - registers = np.array(registers).reshape(-1, 2) - return cls(_registers=registers, _layout=layout, _is_signed=is_signed) - @classmethod def load_wgmma_col( cls, @@ -790,7 +756,7 @@ def load_wgmma_col( def splat(cls, value, shape, layout=None, *, is_signed: bool | None = None): layout = layout or WGSplatFragLayout(shape) match layout: - case WGMMARowFragLayout() | WGSplatFragLayout(): + case WGSplatFragLayout(): pass case WGStridedFragLayout() | TiledLayout(): value = vector.splat(layout.registers_element_type(value.type), value) @@ -806,9 +772,6 @@ def splat(cls, value, shape, layout=None, *, is_signed: bool | None = None): @property def shape(self): match self.layout: - case WGMMARowFragLayout(): - row_tiles = self.registers.shape[0] - return (row_tiles * 64,) case WGMMAColFragLayout(): col_tiles = self.registers.shape[0] return (col_tiles * 8,) @@ -827,7 +790,7 @@ def mlir_dtype(self): match self.layout: case WGStridedFragLayout() | WGMMAColFragLayout() | TiledLayout(): return ir.VectorType(reg_ty).element_type - case WGMMARowFragLayout() | WGSplatFragLayout(): + case WGSplatFragLayout(): return reg_ty case _: raise NotImplementedError @@ -1589,7 +1552,7 @@ def upcast_to_bf16(reg, high): case WGStridedFragLayout() | TiledLayout(): shape = ir.VectorType(self.registers.flat[0].type).shape upcast_ty = ir.VectorType.get(shape, larger_ty) - case WGMMARowFragLayout() | WGSplatFragLayout(): + case WGSplatFragLayout(): upcast_ty = larger_ty case _: raise NotImplementedError(f"Unsupported layout {self.layout}") @@ -1614,7 +1577,7 @@ def upcast_to_bf16(reg, high): case WGStridedFragLayout() | TiledLayout(): shape = ir.VectorType(self.registers.flat[0].type).shape new_reg_ty = ir.VectorType.get(shape, new_dtype) - case WGMMARowFragLayout() | WGSplatFragLayout(): + case WGSplatFragLayout(): new_reg_ty = new_dtype case _: raise NotImplementedError(f"Unsupported layout {self.layout}") @@ -1713,9 +1676,9 @@ def reduce(self, op: str | Callable[[ir.Value, ir.Value], ir.Value], axis): i32 = ir.IntegerType.get_signless(32) row_tile_dim = self.registers.shape[0] row_subtile_dim = self.registers.shape[4] - new_regs = np.empty((row_tile_dim, row_subtile_dim), dtype=object) + new_regs = np.empty((row_tile_dim, 1, row_subtile_dim, 1, 1), dtype=object) assert self.registers.shape[-1] == 1 - for row_tile, row_subtile in np.ndindex(new_regs.shape): + for row_tile, row_subtile in np.ndindex(row_tile_dim, row_subtile_dim): # Reduce the registers owned by the current thread over n tiles reg_index = [0] * self.registers.ndim reg_index[0] = row_tile @@ -1746,7 +1709,9 @@ def reduce(self, op: str | Callable[[ir.Value, ir.Value], ir.Value], axis): nvvm.ShflKind.bfly, ) result = op(result, other_result) - new_regs[row_tile, row_subtile] = result + new_regs[row_tile, :, row_subtile] = vector.splat( + ir.VectorType.get((1,), self.mlir_dtype), result + ) return FragmentedArray( _registers=new_regs, _layout=WGMMA_ROW_LAYOUT, _is_signed=self.is_signed ) @@ -1791,12 +1756,14 @@ def broadcast_minor(self, n): reg_shape = WGMMA_LAYOUT.registers_shape((self.shape[0], n)) new_regs = np.empty(reg_shape, dtype=object) dtype = self.mlir_dtype - for (row_tile, row_subtile), reg in np.ndenumerate(self.registers): + i0 = arith.constant(ir.IndexType.get(), 0) + for (row_tile, _, row_subtile, *__), reg in np.ndenumerate(self.registers): tile = [slice(None)] * len(new_regs.shape) tile[0] = row_tile tile[4] = row_subtile new_regs[tuple(tile)] = vector.splat( - ir.VectorType.get((WGMMA_LAYOUT.vector_length,), dtype), reg + ir.VectorType.get((WGMMA_LAYOUT.vector_length,), dtype), + vector.extractelement(reg, position=i0), ) return FragmentedArray( _registers=new_regs, _layout=WGMMA_LAYOUT, _is_signed=self.is_signed @@ -1874,8 +1841,6 @@ def vs_unsupported(): ) match self.layout: - case WGMMARowFragLayout(): - self._store_untiled_wgmma_row(ref) case WGMMAColFragLayout(): self._store_untiled_wgmma_col(ref) case WGSplatFragLayout(): @@ -1889,6 +1854,22 @@ def vs_unsupported(): case _: raise NotImplementedError(self.layout) + @classmethod + def load_untiled( + cls, + ref: ir.Value, + *, + layout: TiledLayout, + swizzle: int = 16, + is_signed: bool | None = None, + optimized: bool = True, + ): + ref_shape = ir.MemRefType(ref.type).shape + ref = utils.memref_reshape(ref, (*(1 for _ in ref_shape), *ref_shape)) + return cls.load_tiled( + ref, swizzle=swizzle, is_signed=is_signed, layout=layout, optimized=optimized + ) + def _store_untiled_splat(self, ref: ir.Value): vec_size = 64 // mgpu.bitwidth(self.mlir_dtype) if np.prod(self.shape) < vec_size * WARPGROUP_SIZE: @@ -1924,23 +1905,6 @@ def _store_untiled_wg_strided(self, ref: ir.Value): for idx, reg in zip(idxs, self.registers.flat): vector.store(reg, ref_, idx) - def _store_untiled_wgmma_row(self, ref: ir.Value): - """Stores an array with a WGMMA row layout.""" - assert self.layout == WGMMA_ROW_LAYOUT - index = ir.IndexType.get() - tid = arith.index_cast(ir.IndexType.get(), mgpu.thread_idx()) - - is_first = arith.cmpi( - arith.CmpIPredicate.eq, arith.remui(tid, c(4, index)), c(0, index) - ) - # Consecutive groups of 4 threads hold the same value in this layout, - # therefore we only need to transfer data from one of them. - with utils.when(is_first): - for (idx,), value in zip( - self.layout.thread_idxs(self.shape), self.registers.flatten() - ): - memref.store(value, ref, [idx]) - def _store_untiled_wgmma_col(self, ref: ir.Value): """Stores an array with a WGMMA col layout.""" assert isinstance(self.layout, WGMMAColFragLayout) @@ -2007,6 +1971,9 @@ def store_tiled(self, ref, swizzle: int | None): if not isinstance(self.layout, TiledLayout): raise NotImplementedError(self.layout) layout, shape = self.layout, self.shape + # Note that the loop below will "race" for layouts that replicate data. + # However, in that case all of the racing writes store the same data, which + # is ok in the CUDA memory model. for get, _, ptr in self.transfer_tiled2(ref, swizzle, layout, shape): llvm.store(get(self.registers), ptr) @@ -2018,6 +1985,7 @@ def load_tiled( *, is_signed: bool | None = None, layout: FragmentedLayout = WGMMA_LAYOUT, + optimized: bool = True, ): ref_ty = ir.MemRefType(ref.type) dtype = ref_ty.element_type @@ -2036,7 +2004,8 @@ def load_tiled( ) registers = np.full(layout.registers_shape(shape), zero, dtype=object) reg_ty = ir.VectorType.get((layout.vector_length,), ref_ty.element_type) - for _, update, ptr in cls.transfer_tiled2(ref, swizzle, layout, shape): + loads = cls.transfer_tiled2(ref, swizzle, layout, shape, optimized) + for _, update, ptr in loads: update(registers, llvm.load(reg_ty, ptr)) case _: raise NotImplementedError(layout) @@ -2132,6 +2101,7 @@ def transfer_tiled2( swizzle: int | None, layout: TiledLayout, shape: tuple[int, ...], + optimized: bool = True, ): """Generate a transfer schedule for a tiled layout. @@ -2183,11 +2153,15 @@ def transfer_tiled2( raise NotImplementedError("Memory and register tiling incompatible") tiled_shape = list(itertools.chain.from_iterable(tiled_nested_shape)) elem_tiled_strides = list(itertools.chain.from_iterable(tiled_nested_strides)) - elem_lane_strides = [elem_tiled_strides[d] for d in layout.lane_dims] - lane_shape = [tiled_shape[d] for d in layout.lane_dims] + lane_shape = [ + d.times if isinstance(d, Replicated) else tiled_shape[d] for d in layout.lane_dims + ] + lane_strides = [ + 0 if isinstance(d, Replicated) else elem_tiled_strides[d] for d in layout.lane_dims + ] if elem_tiled_strides[layout.vector_dim] != 1: raise ValueError("Stride of the vectorized dimension should be 1") - for d in (layout.warp_dim, *layout.lane_dims, layout.vector_dim): + for d in (layout.warp_dim, *layout.partitioned_lane_dims, layout.vector_dim): tiled_shape[d] = 1 element_bits = mgpu.bitwidth(dtype) @@ -2223,10 +2197,22 @@ def transfer_tiled2( transfer_tiled_strides = [s // layout.vector_length for s in elem_tiled_strides] transfer_dtype = ir.VectorType.get((layout.vector_length,), dtype) - plan = plan_tiled_transfer( - tiled_shape, elem_tiled_strides, lane_shape, elem_lane_strides, layout, - element_bits, swizzle - ) + if ref_ty.memory_space is None: + llvm_memory_space = None + elif ref_ty.memory_space == ir.Attribute.parse("#gpu.address_space"): + llvm_memory_space = 3 + else: + raise ValueError(f"Unsupported memory space: {ref_ty.memory_space}") + + if optimized: + if llvm_memory_space != 3: + raise NotImplementedError("Only optimized transfers to SMEM supported") + plan = plan_tiled_transfer( + tiled_shape, elem_tiled_strides, lane_shape, lane_strides, + layout, element_bits, swizzle + ) + else: + plan = TrivialTransferPlan() # All offsets are in units of transfer_dtype. dyn_tiled_strides = [ @@ -2235,9 +2221,7 @@ def transfer_tiled2( lane_offset = utils.dyn_dot(layout.lane_indices(), dyn_tiled_strides) warp_offset = utils.dyn_dot(layout.warp_indices(), dyn_tiled_strides) dyn_offset = arith.addi(lane_offset, warp_offset) - if ref_ty.memory_space != ir.Attribute.parse("#gpu.address_space"): - raise ValueError("Tiled stores can be performed into SMEM") - ptr = utils.memref_ptr(ref, memory_space=3) + ptr = utils.memref_ptr(ref, memory_space=llvm_memory_space) _as_consts = lambda consts: [c(const) for const in consts.tolist()] # This has bits set only for the offset bits that influence swizzling. swizzle_mask = swizzle_block_transfers - swizzle_tile_transfers @@ -2416,9 +2400,18 @@ def plan_tiled_transfer( num_wavefronts = max(transfer_bytes // smem_bank_bytes, 1) wavefront_lanes = WARP_SIZE // num_wavefronts + lane_mask = np.full(lane_shape, False) + lane_mask[tuple(slice(0, 1) if s == 0 else slice(None) for s in lane_strides)] = True + wavefront_mask = lane_mask.reshape(num_wavefronts, wavefront_lanes) + wavefront_active_lanes = wavefront_mask.sum(-1) + # We make a simplifying assumption: wavefronts have the same number of lanes + if any(act != wavefront_active_lanes[0] for act in wavefront_active_lanes): + raise NotImplementedError + lane_offsets_in_tile = np.dot(list(np.ndindex(*lane_shape)), lane_strides) def has_bank_conflicts(tile_idx_transform): - tile_idxs = np.unravel_index(np.arange(math.prod(tiled_shape)), tiled_shape) + num_tiles = math.prod(tiled_shape) + tile_idxs = np.unravel_index(np.arange(num_tiles), tiled_shape) tile_idxs = np.expand_dims(np.stack(tile_idxs, 1), 1) # [#tiles, 1, #dims] lane_tile_idx = tile_idx_transform(tile_idxs) # [#tiles, #lanes/1, #dims] assert lane_tile_idx.shape[1] in {1, WARP_SIZE} @@ -2429,6 +2422,8 @@ def has_bank_conflicts(tile_idx_transform): swizzle_bits = swizzle_groups * swizzle_tile_elems lane_banks = ((offsets ^ swizzle_bits) // elems_per_bank) % num_banks wavefront_banks = lane_banks.reshape(-1, num_wavefronts, wavefront_lanes) + # Mask out the inactive lanes in each wavefront + wavefront_banks = wavefront_banks[:, wavefront_mask].reshape(num_tiles, num_wavefronts, -1) # Order of threads within the wavefront is unimportant. wavefront_banks = np.sort(wavefront_banks, axis=-1) # There are no conflicts if each wavefront only contains unique banks. diff --git a/jax/experimental/mosaic/gpu/layouts.py b/jax/experimental/mosaic/gpu/layouts.py index 5c3b23119779..cb94c3eaf749 100644 --- a/jax/experimental/mosaic/gpu/layouts.py +++ b/jax/experimental/mosaic/gpu/layouts.py @@ -155,7 +155,6 @@ def to_layout_attr( fa.WGSplatFragLayout | fa.WGStridedFragLayout | fa.TiledLayout - | fa.WGMMARowFragLayout ), ) -> ir.Attribute: """Constructs an MLIR attribute that corresponds to the given layout.""" @@ -166,8 +165,6 @@ def to_layout_attr( return to_strided_fragmented_layout_attr(layout) case fa.TiledLayout(): return to_tiled_layout_attr(layout) - case fa.WGMMARowFragLayout(): - return ir.Attribute.parse("#mosaic_gpu.WGMMARowFragLayout") case _: raise NotImplementedError( f"Unsupported layout for conversion to MLIR attribute: {layout}" @@ -189,7 +186,6 @@ def from_layout_attr( fa.WGSplatFragLayout | fa.WGStridedFragLayout | fa.TiledLayout - | fa.WGMMARowFragLayout ): """Constructs a layout from an MLIR attribute.""" if is_splat_fragmented_layout(attr): diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td index cda521855250..6b934b951d93 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td @@ -142,18 +142,6 @@ def MosaicGPU_WGSplatFragLayout : AttrDef { - let summary = "1D array that is a row that can be tiled by supported WGMMA shapes."; - let description = [{ - This layout is used to handle rows that are fragmented across all threads - in a warpgroup that is executing a WGMMA operation. The length of the array - must be divisible by 64. - }]; - - let mnemonic = "WGMMARowFragLayout"; - let assemblyFormat = ""; -} - def MosaicGPU_TiledLayout : AttrDef { let summary = "A layout derived from a tiling expression."; let description = [{ diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 6c63e3ce40e1..cc0e6a04bdc3 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -2004,12 +2004,16 @@ def kernel(ctx, *args): )(inp) np.testing.assert_array_equal(inp, result) - @parameterized.product(in_shape=((128,), (64,))) - def test_wgmma_row_load_store_with_layout(self, in_shape): - def kernel(ctx, *args): - gmem_input, gmem_output, (smem_input, smem_output) = args - copy(gmem_input, smem_input) - t = mgpu.FragmentedArray.load_wgmma_row(smem_input) + @parameterized.product( + in_shape=((1024,), (256,), (128,), (64,)), swizzle=(16, 32, 64, 128) + ) + def test_wgmma_row_load_store_with_layout(self, in_shape, swizzle): + def kernel(ctx, gmem_input, gmem_output, smem): + smem_input, smem_output = smem + copy(gmem_input, smem_input, swizzle=swizzle) + t = mgpu.FragmentedArray.load_untiled( + smem_input, layout=mgpu.WGMMA_ROW_LAYOUT, swizzle=swizzle + ) t.store_untiled(smem_output) copy(smem_output, gmem_output) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 0cfe9197db36..67472dbdd9e5 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -770,7 +770,9 @@ def test_load_to_wgmma_row_col_layout_with_indexing(self, src_memory_space, layo ) def kernel(x_ref, o_ref): for i in range(2): - x = plgpu.load(x_ref, (i,), layout=layout) + x = plgpu.load( + x_ref, (i,), layout=layout, optimized=src_memory_space != plgpu.GMEM + ) o_ref[i, ...] = x x = jnp.arange(2 * m, dtype=jnp.float32).reshape(2, m) From 635805e9b02a9b400b54efe0e766964278e178dd Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 4 Apr 2025 05:43:03 -0700 Subject: [PATCH 383/483] [Mosaic GPU] Allow replicating data over warps This extends the tiled layouts further and allows us to replace WGMMA_COL_LAYOUT implementation with a TiledLayout. PiperOrigin-RevId: 743909503 --- jax/_src/pallas/mosaic_gpu/lowering.py | 6 - jax/_src/pallas/mosaic_gpu/primitives.py | 5 +- jax/experimental/mosaic/gpu/__init__.py | 1 - .../mosaic/gpu/fragmented_array.py | 146 +++++------------- tests/mosaic/gpu_test.py | 31 ++-- 5 files changed, 61 insertions(+), 128 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index d26c71cecc31..827794d37e2b 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1211,12 +1211,6 @@ def _swap_lowering_rule( ) value.store_untiled(x_smem) return old_value - case mgpu.WGMMAColFragLayout(): - old_value = mgpu.FragmentedArray.load_wgmma_col( - x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype) - ) - value.store_untiled(x_smem) - return old_value case _: old_value = mgpu.FragmentedArray.load_strided( x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 76759d6bcb83..b2beec700fad 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -110,7 +110,7 @@ def _load_p_lowering_rule( val, shape=(), layout=layout, is_signed=is_signed ) match layout: - case mgpu.WGMMA_ROW_LAYOUT: + case mgpu.WGMMA_ROW_LAYOUT | mgpu.WGMMA_COL_LAYOUT: return mgpu.FragmentedArray.load_untiled( x_ref, is_signed=is_signed, @@ -118,15 +118,12 @@ def _load_p_lowering_rule( swizzle=16, optimized=optimized, ) - case mgpu.WGMMAColFragLayout(): - return mgpu.FragmentedArray.load_wgmma_col(x_ref, is_signed=is_signed) case mgpu.WGStridedFragLayout(shape=shape, vec_size=vec_size): ref_ty = ir.MemRefType(x_ref.type) if shape != tuple(ref_ty.shape): raise ValueError( f"Unsupported shape {shape}, (expected {tuple(ref_ty.shape)})" ) - return mgpu.FragmentedArray.load_strided( x_ref, is_signed=is_signed, vec_size=vec_size, ) diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index d5b3bea6d36b..a4f1e0a9cfe0 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -58,7 +58,6 @@ WGMMA_ROW_LAYOUT as WGMMA_ROW_LAYOUT, WGMMA_COL_LAYOUT as WGMMA_COL_LAYOUT, WGMMA_TRANSPOSED_LAYOUT as WGMMA_TRANSPOSED_LAYOUT, - WGMMAColFragLayout as WGMMAColFragLayout, WGSplatFragLayout as WGSplatFragLayout, WGStridedFragLayout as WGStridedFragLayout, optimization_barrier as optimization_barrier, diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index f7ce36c62c9e..f6c5e7d1ed19 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -252,7 +252,7 @@ class TiledLayout: by a single (logical) register. """ tiling: Tiling - warp_dim: int + warp_dim: int | Replicated lane_dims: tuple[int | Replicated, ...] # major-to-minor vector_dim: int @@ -261,15 +261,20 @@ def __post_init__(self): raise ValueError("Tiling must have at least one tile") min_shape = self.tiling.tiles[0] min_tiled_shape = self.tiling.tile_shape(min_shape) - dims_set = {self.warp_dim, *self.partitioned_lane_dims, self.vector_dim} - if len(dims_set) != len(self.partitioned_lane_dims) + 2: + dims_set = {*self.partitioned_lane_dims, self.vector_dim} + if partitions_warp_dim := not isinstance(self.warp_dim, Replicated): + dims_set.add(self.warp_dim) + if len(dims_set) != len(self.partitioned_lane_dims) + 1 + partitions_warp_dim: raise ValueError for d in dims_set: if d >= 0: raise ValueError("All dimensions must be negative") if d < -(len(min_tiled_shape) - len(min_shape)): raise ValueError("Dimension out of range") - if min_tiled_shape[self.warp_dim] != WARPS_IN_WARPGROUP: + if isinstance(self.warp_dim, Replicated): + if self.warp_dim.times != WARPS_IN_WARPGROUP: + raise ValueError + elif min_tiled_shape[self.warp_dim] != WARPS_IN_WARPGROUP: raise ValueError lane_dims_prod = math.prod( d.times if isinstance(d, Replicated) else min_tiled_shape[d] @@ -340,7 +345,8 @@ def registers_element_type(self, t: ir.Type) -> ir.Type: def registers_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: """Returns the shape of the register array needed to represent an array of the given logical shape.""" tiled_shape = list(self.tiling.tile_shape(shape)) - tiled_shape[self.warp_dim] = 1 + if not isinstance(self.warp_dim, Replicated): + tiled_shape[self.warp_dim] = 1 for d in self.partitioned_lane_dims: tiled_shape[d] = 1 tiled_shape[self.vector_dim] = 1 @@ -353,7 +359,8 @@ def shape_from_registers_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: """ tiled_tiling = self.tiled_tiling_shape shape = list(shape) - shape[self.warp_dim] = WARPS_IN_WARPGROUP + if not isinstance(self.warp_dim, Replicated): + shape[self.warp_dim] = WARPS_IN_WARPGROUP for d in self.partitioned_lane_dims: shape[d] = tiled_tiling[d] shape[self.vector_dim] = tiled_tiling[self.vector_dim] @@ -389,12 +396,13 @@ def lane_indices(self) -> tuple[ir.Value, ...]: def warp_indices(self) -> tuple[ir.Value, ...]: i32 = ir.IntegerType.get_signless(32) tiled_shape_rank = len(self.tiled_tiling_shape) - warp_idx = arith.remui( - arith.divui(utils.thread_idx(), c(WARP_SIZE, i32)), - c(WARPS_IN_WARPGROUP, i32), - ) indices = [arith.constant(i32, 0)] * tiled_shape_rank - indices[self.warp_dim] = warp_idx + if not isinstance(self.warp_dim, Replicated): + warp_idx = arith.remui( + arith.divui(utils.thread_idx(), c(WARP_SIZE, i32)), + c(WARPS_IN_WARPGROUP, i32), + ) + indices[self.warp_dim] = warp_idx return tuple(indices) @@ -411,23 +419,6 @@ def _tiled_wgmma_layout(shape: tuple[int, ...]): return WGMMA_LAYOUT -@dataclasses.dataclass(frozen=True) -class WGMMAColFragLayout: - """[n] matrix, where n % 8 == 0.""" - - def thread_idxs(self, shape): - index = ir.IndexType.get() - assert len(shape) == 1 - assert shape[0] % 8 == 0 - - tid = arith.index_cast(ir.IndexType.get(), mgpu.thread_idx()) - lane_id = arith.remui(tid, c(WARP_SIZE, index)) - col_base = arith.muli(arith.remui(lane_id, c(4, index)), c(2, index)) - - for col_group in range(0, shape[0], 8): - col = arith.addi(col_base, c(col_group, index)) - yield (col,) - @dataclasses.dataclass(frozen=True) class WGSplatFragLayout: """A fragmented array where all the values are equal represented as a register per thread. @@ -538,10 +529,15 @@ def linear_thread_idxs(self): yield arith.addi(off, c(i * WARPGROUP_SIZE * self.vec_size, tidx.type)) -FragmentedLayout = WGSplatFragLayout | WGStridedFragLayout | WGMMAColFragLayout | TiledLayout +FragmentedLayout = WGSplatFragLayout | WGStridedFragLayout | TiledLayout -WGMMA_COL_LAYOUT = WGMMAColFragLayout() +WGMMA_COL_LAYOUT = TiledLayout( + Tiling(((8,), (2,))), + warp_dim=Replicated(4), + lane_dims=(Replicated(8), -2), + vector_dim=-1, +) WGMMA_ROW_LAYOUT = TiledLayout( Tiling(((64,), (16,), (8,), (1,))), warp_dim=-4, @@ -659,12 +655,6 @@ def __init__( ) match self.layout: - # Registers are [n_tiles] in WGMMA_COL layout - # Each element is a vector of size 2. - case WGMMAColFragLayout(): - if _registers.ndim != 1: - raise ValueError(f"Invalid register array shape: {_registers.shape}") - # Registers are flat case WGStridedFragLayout(shape): [reg_size] = ir.VectorType(_registers.flat[0].type).shape @@ -721,37 +711,6 @@ def load_strided( vecs = [vector.load(vec_ty, ref, vec_idx) for vec_idx in layout.thread_idxs(shape)] return cls(_registers=np.array(vecs), _layout=layout, _is_signed=is_signed) - @classmethod - def load_wgmma_col( - cls, - ref: ir.Value, - *, - is_signed: bool | None = None, - ): - if not ir.MemRefType.isinstance(ref.type): - raise TypeError(ref.type) - - ref_ty = ir.MemRefType(ref.type) - shape = tuple(ref_ty.shape) - layout = WGMMAColFragLayout() - - if len(shape) != 1: - raise ValueError("WGMMAColFragLayout requires a 1D shape.") - - if shape[0] % 8: - raise ValueError( - f"WGMMAColFragLayout requires {shape[0]=} to be a multiple of 8." - ) - - vec_ty = ir.VectorType.get((2,), ref_ty.element_type) - new_regs = np.full((shape[0] // 8,), llvm.mlir_undef(vec_ty)) - - for col_tile, (idx,) in enumerate(layout.thread_idxs(shape)): - reg = vector.load(vec_ty, ref, [idx]) - new_regs[col_tile] = reg - - return cls(_registers=new_regs, _layout=layout, _is_signed=is_signed) - @classmethod def splat(cls, value, shape, layout=None, *, is_signed: bool | None = None): layout = layout or WGSplatFragLayout(shape) @@ -772,9 +731,6 @@ def splat(cls, value, shape, layout=None, *, is_signed: bool | None = None): @property def shape(self): match self.layout: - case WGMMAColFragLayout(): - col_tiles = self.registers.shape[0] - return (col_tiles * 8,) case WGStridedFragLayout(shape): return shape case WGSplatFragLayout(shape=shape): @@ -788,7 +744,7 @@ def shape(self): def mlir_dtype(self): reg_ty = self.registers.flat[0].type match self.layout: - case WGStridedFragLayout() | WGMMAColFragLayout() | TiledLayout(): + case WGStridedFragLayout() | TiledLayout(): return ir.VectorType(reg_ty).element_type case WGSplatFragLayout(): return reg_ty @@ -1770,15 +1726,11 @@ def broadcast_minor(self, n): ) def broadcast_major(self, m): - if not isinstance(self.layout, WGMMAColFragLayout): - raise NotImplementedError - if m % 64: raise ValueError("Number of rows must be divisible by 64") - reg_shape = WGMMA_LAYOUT.registers_shape((m, self.shape[0])) new_regs = np.empty(reg_shape, dtype=object) - for col_tile, reg in np.ndenumerate(self.registers): + for (col_tile, *_), reg in np.ndenumerate(self.registers): tile = [slice(None)] * len(new_regs.shape) tile[1] = col_tile new_regs[tuple(tile)] = reg @@ -1841,8 +1793,6 @@ def vs_unsupported(): ) match self.layout: - case WGMMAColFragLayout(): - self._store_untiled_wgmma_col(ref) case WGSplatFragLayout(): vs_unsupported() self._store_untiled_splat(ref) @@ -1905,21 +1855,6 @@ def _store_untiled_wg_strided(self, ref: ir.Value): for idx, reg in zip(idxs, self.registers.flat): vector.store(reg, ref_, idx) - def _store_untiled_wgmma_col(self, ref: ir.Value): - """Stores an array with a WGMMA col layout.""" - assert isinstance(self.layout, WGMMAColFragLayout) - index = ir.IndexType.get() - tid = arith.index_cast(ir.IndexType.get(), mgpu.thread_idx()) - tid_wg = arith.remui(tid, c(WARPGROUP_SIZE, index)) - - # Consecutive groups of 4 threads replicate the same data, so we only need to - # transfer data from one group. - is_first = arith.cmpi(arith.CmpIPredicate.ult, tid_wg, c(4, index)) - - with utils.when(is_first): - for (idx,), reg in zip(self.layout.thread_idxs(self.shape), self.registers): - vector.store(reg, ref, [idx]) - def _store_untiled_tiled(self, ref: ir.Value, *, vector_store: bool = True): """Stores an array with a tiled layout. Not optimized at the moment.""" if utils.bitwidth(self.mlir_dtype) < 8: @@ -2161,8 +2096,10 @@ def transfer_tiled2( ] if elem_tiled_strides[layout.vector_dim] != 1: raise ValueError("Stride of the vectorized dimension should be 1") - for d in (layout.warp_dim, *layout.partitioned_lane_dims, layout.vector_dim): + for d in (*layout.partitioned_lane_dims, layout.vector_dim): tiled_shape[d] = 1 + if not isinstance(layout.warp_dim, Replicated): + tiled_shape[layout.warp_dim] = 1 element_bits = mgpu.bitwidth(dtype) if (layout.vector_length * element_bits) % 8 != 0: @@ -2403,10 +2340,6 @@ def plan_tiled_transfer( lane_mask = np.full(lane_shape, False) lane_mask[tuple(slice(0, 1) if s == 0 else slice(None) for s in lane_strides)] = True wavefront_mask = lane_mask.reshape(num_wavefronts, wavefront_lanes) - wavefront_active_lanes = wavefront_mask.sum(-1) - # We make a simplifying assumption: wavefronts have the same number of lanes - if any(act != wavefront_active_lanes[0] for act in wavefront_active_lanes): - raise NotImplementedError lane_offsets_in_tile = np.dot(list(np.ndindex(*lane_shape)), lane_strides) def has_bank_conflicts(tile_idx_transform): @@ -2422,12 +2355,17 @@ def has_bank_conflicts(tile_idx_transform): swizzle_bits = swizzle_groups * swizzle_tile_elems lane_banks = ((offsets ^ swizzle_bits) // elems_per_bank) % num_banks wavefront_banks = lane_banks.reshape(-1, num_wavefronts, wavefront_lanes) - # Mask out the inactive lanes in each wavefront - wavefront_banks = wavefront_banks[:, wavefront_mask].reshape(num_tiles, num_wavefronts, -1) - # Order of threads within the wavefront is unimportant. - wavefront_banks = np.sort(wavefront_banks, axis=-1) - # There are no conflicts if each wavefront only contains unique banks. - return np.any(wavefront_banks[..., 1:] == wavefront_banks[..., :-1]) + # We step over wavefronts since they might have a different number of lanes. + wavefront_banks = wavefront_banks.swapaxes(0, 1) + for banks, mask in zip(wavefront_banks, wavefront_mask): + banks = banks[:, mask] + # Order of threads within the wavefront is unimportant. + banks = np.sort(banks, axis=-1) + # There are no conflicts if each wavefront only contains unique banks. + repeats = np.any(banks[..., 1:] == banks[..., :-1]) + if repeats: + return True + return False # We don't need any special treatment if there are no conflicts when each lane # transfers the same tile at a time. diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index cc0e6a04bdc3..9d9d3fa8979c 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -2005,9 +2005,11 @@ def kernel(ctx, *args): np.testing.assert_array_equal(inp, result) @parameterized.product( - in_shape=((1024,), (256,), (128,), (64,)), swizzle=(16, 32, 64, 128) + in_shape=((1024,), (256,), (128,), (64,)), + dtype=(jnp.float16, jnp.float32), + swizzle=(16, 32, 64, 128) ) - def test_wgmma_row_load_store_with_layout(self, in_shape, swizzle): + def test_wgmma_row_load_store_with_layout(self, in_shape, dtype, swizzle): def kernel(ctx, gmem_input, gmem_output, smem): smem_input, smem_output = smem copy(gmem_input, smem_input, swizzle=swizzle) @@ -2017,20 +2019,24 @@ def kernel(ctx, gmem_input, gmem_output, smem): t.store_untiled(smem_output) copy(smem_output, gmem_output) - inp = out = self.prng.uniform(-1, 1, in_shape).astype(jnp.float32) + inp = out = self.prng.uniform(-1, 1, in_shape).astype(dtype) result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (inp,), out, [inp, out], )(inp) np.testing.assert_array_equal(inp, result) @parameterized.product( - in_shape=((128,), (64,)), dtype=[jnp.float16, jnp.float32] + in_shape=((128,), (64,)), + dtype=(jnp.float16, jnp.float32), + swizzle=(16, 32, 64, 128), ) - def test_wgmma_col_load_store_with_layout(self, in_shape, dtype): + def test_wgmma_col_load_store_with_layout(self, in_shape, dtype, swizzle): def kernel(ctx, *args): gmem_input, gmem_output, (smem_input, smem_output) = args - copy(gmem_input, smem_input) - t = mgpu.FragmentedArray.load_wgmma_col(smem_input) + copy(gmem_input, smem_input, swizzle=swizzle) + t = mgpu.FragmentedArray.load_untiled( + smem_input, swizzle=swizzle, layout=mgpu.WGMMA_COL_LAYOUT + ) t.store_untiled(smem_output) copy(smem_output, gmem_output) @@ -2042,18 +2048,17 @@ def kernel(ctx, *args): @parameterized.parameters((128, 128), (128, 64), (64, 128)) def test_broadcast_major(self, m, n): - def kernel(ctx, *args): - gmem_input, gmem_output, () = args - t = mgpu.FragmentedArray.load_wgmma_col(gmem_input) + def kernel(ctx, gmem_input, gmem_output, _): + t = mgpu.FragmentedArray.load_untiled( + gmem_input, layout=mgpu.WGMMA_COL_LAYOUT, optimized=False + ) t.broadcast_major(m).store_untiled(gmem_output) inp = self.prng.uniform(-1, 1, (n,)).astype(jnp.float16) out_shape = jax.ShapeDtypeStruct((m, n), jnp.float16) - result = mgpu.as_gpu_kernel( - kernel, (1, 1, 1), (128, 1, 1), (inp,), out_shape, () + kernel, (1, 1, 1), (128, 1, 1), (inp,), out_shape, inp )(inp) - out_ref = jax.lax.broadcast_in_dim(inp, (m, n), (1,)) np.testing.assert_array_equal(result, out_ref) From e4a381c12e42c62d149f694122472d237dba333b Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Fri, 4 Apr 2025 05:47:54 -0700 Subject: [PATCH 384/483] [pallas:mgpu] Check that swizzle dim is not transposed in copy_smem_to_gmem() PiperOrigin-RevId: 743910324 --- jax/experimental/mosaic/gpu/launch_context.py | 8 ++++++++ tests/pallas/mosaic_gpu_test.py | 11 ++++++++--- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/jax/experimental/mosaic/gpu/launch_context.py b/jax/experimental/mosaic/gpu/launch_context.py index aca3fc723882..243c5e5df15c 100644 --- a/jax/experimental/mosaic/gpu/launch_context.py +++ b/jax/experimental/mosaic/gpu/launch_context.py @@ -325,6 +325,14 @@ def init_tma_desc(host_ptr): ref = t.apply(ref) ref_ty = ir.MemRefType(ref.type) # TODO(apaszke): Use utils.memref_ptr to compute base_ptr + strides, _ = ref_ty.get_strides_and_offset() + if strides[-1] != 1: + raise ValueError( + "TMA requires the stride of the last dimension after" + " transforming the GMEM reference to be 1, but it is" + f" {strides[-1]}." + ) + _, offset, *sizes_and_strides = memref.extract_strided_metadata(ref) aligned_ptr_idx = memref.extract_aligned_pointer_as_index(ref) as_i64 = lambda i: arith.index_cast(i64, i) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 67472dbdd9e5..0a3087f5902d 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1379,7 +1379,8 @@ def kernel(o_ref): x = jnp.full(shape, 42.0, jnp.float32) np.testing.assert_array_equal(kernel(), x) - def test_wgmma_transposed_layout(self): + @parameterized.parameters(False, True) + def test_wgmma_transposed_layout(self, store_transposed): """Tests that the result of wgmma can be store transposed using the WGMMA_TRNASPOSED layout. """ @@ -1412,10 +1413,14 @@ def kernel(o_ref, smem): smem_trns = plgpu.transpose_ref(smem, (1, 0)) smem_trns[...] = plgpu.layout_cast(iota, plgpu.Layout.WGMMA_TRANSPOSED) plgpu.commit_smem() - plgpu.copy_smem_to_gmem(smem, o_ref) + plgpu.copy_smem_to_gmem(smem_trns if store_transposed else smem, o_ref) x = jnp.arange(128 * 128, dtype=dtype).reshape((128, 128)).T - np.testing.assert_array_equal(kernel(), x) + if store_transposed: + with self.assertRaises(ValueError): + kernel() + else: + np.testing.assert_array_equal(kernel(), x) def test_profiler(self): self.skip_if_wg_semantics() # Transform inference fails. From cbae2539d4724e490aefd2aa3e8e661223c57e35 Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Fri, 4 Apr 2025 06:32:27 -0700 Subject: [PATCH 385/483] [mgpu:pallas] Typo in `UnswizzleRef.untransform_reshape()` check. PiperOrigin-RevId: 743920665 --- jax/_src/pallas/mosaic_gpu/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 2150b48b5108..c0bf602e0962 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -494,7 +494,7 @@ def untransform_transpose(self, perm) -> tuple[tuple[int, ...], state_types.Tran def untransform_reshape( self, dtype: jnp.dtype | ir.Type, shape: tuple[int, ...] ) -> tuple[tuple[int, ...], state_types.Transform]: - if shape[-1] == self.swizzle_elems(dtype): + if shape[-1] != self.swizzle_elems(dtype): raise ValueError( f"Reshape shape {shape} is not divisible by swizzle elements" f" {self.swizzle_elems(dtype)}" From 5a29311c8b97922b4fc6f1d942a87d5a784d86c7 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 4 Apr 2025 06:36:33 -0700 Subject: [PATCH 386/483] Update XLA dependency to use revision http://github.com/openxla/xla/commit/ffa8cf08e295cec70a27a6b27bfaa19c5d0daeec. PiperOrigin-RevId: 743921735 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index c30648a2b3a1..7abf3da775d2 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "921c164a67e8ac4cf052aab26e849f29b719f802" -XLA_SHA256 = "9e734da4a0211ac09a00cc07969645e31f107cfee19bbc5d2d1e21ddbb19090d" +XLA_COMMIT = "ffa8cf08e295cec70a27a6b27bfaa19c5d0daeec" +XLA_SHA256 = "d5de319756b6a32748d2821f5319f831b062d0f5f22b7f0bde1d9564dc6b6f5e" def repo(): tf_http_archive( From da7b1577e24784c6e1edcc8167407e85bb85195e Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Fri, 4 Apr 2025 07:20:58 -0700 Subject: [PATCH 387/483] [mgpu:pallas] Swizzle elements computed using bitwidth rather than bytewidth. PiperOrigin-RevId: 743933866 --- jax/_src/pallas/mosaic_gpu/core.py | 3 ++- jax/_src/pallas/mosaic_gpu/lowering.py | 2 +- jax/_src/pallas/mosaic_gpu/primitives.py | 8 +++++++- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index c0bf602e0962..d3d7f89812c3 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -33,6 +33,7 @@ from jax._src.lib.mlir.dialects import arith as arith_dialect from jax._src.pallas import core as pallas_core from jax._src.pallas import primitives as pallas_primitives +import jax._src.pallas.utils as pallas_utils from jax._src.state import discharge as state_discharge from jax._src.state import indexing from jax._src.state import types as state_types @@ -466,7 +467,7 @@ def undo_to_gpu_transform(self) -> mgpu.MemRefTransform: raise NotImplementedError def __call__(self, aval: jax_core.ShapedArray) -> jax_core.ShapedArray: - swizzle_elems = self.swizzle // aval.dtype.itemsize + swizzle_elems = (self.swizzle * 8) // pallas_utils.dtype_bitwidth(aval.dtype) if swizzle_elems != aval.shape[-1]: raise ValueError( f"Swizzle {self.swizzle} requires the trailing dimension to be of" diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 827794d37e2b..f8932b08b90a 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1112,7 +1112,7 @@ def _get_lowering_rule(ctx: LoweringRuleContext, x_ref, *leaves, tree): match transforms: case (gpu_core.UnswizzleRef(swizzle), gpu_core.UntileRef(tiling)): - if tiling != (8, swizzle // x_aval.dtype.itemsize): + if tiling != (8, (swizzle * 8) // pallas_utils.dtype_bitwidth(x_aval.dtype)): raise NotImplementedError("Tiling does not fit swizzle") return mgpu.FragmentedArray.load_tiled( x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype), swizzle=swizzle diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index b2beec700fad..bfe1bf3fe5bc 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -481,7 +481,13 @@ def _copy_gmem_to_smem_lowering( for axis in collective_axes ) dst_ty = ir.MemRefType(dst.type) - bytes = math.prod(dst_ty.shape) * mgpu.bytewidth(dst_ty.element_type) + bits = math.prod(dst_ty.shape) * mgpu.bitwidth(dst_ty.element_type) + if bits % 8: + raise ValueError( + f"Can only transfer integer bytes (shape={dst_ty.shape}," + f" dtype={dst_ty.element_type})" + ) + bytes = bits // 8 if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: if bytes % WARPGROUP_SIZE: raise NotImplementedError("Only aligned copies are supported") From 53abbd5606495633fbe2eb0ea720d9d1f4e4f937 Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Fri, 4 Apr 2025 08:30:54 -0700 Subject: [PATCH 388/483] [mgpu] Foreach to handle scalar registers in fragmented arrays. PiperOrigin-RevId: 743953606 --- .../mosaic/gpu/fragmented_array.py | 16 ++++++++---- tests/mosaic/gpu_test.py | 25 +++++++++++++++++++ 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index f6c5e7d1ed19..ecd51f79eab0 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -1765,12 +1765,18 @@ def foreach( for mlir_idx, reg_idx in zip(self.layout.thread_idxs(self.shape), np.ndindex(self.registers.shape), strict=True): reg = self.registers[reg_idx] assert len(mlir_idx) == len(self.shape), (mlir_idx, self.shape) - [elems] = ir.VectorType(reg.type).shape - for i in range(elems): - i = c(i, index) - val = fn(vector.extractelement(reg, position=i), (*mlir_idx[:-1], arith.addi(mlir_idx[-1], i))) + if ir.VectorType.isinstance(reg.type): + [elems] = ir.VectorType(reg.type).shape + for i in range(elems): + i = c(i, index) + val = fn(vector.extractelement(reg, position=i), (*mlir_idx[:-1], arith.addi(mlir_idx[-1], i))) + if create_array: + new_regs[reg_idx] = vector.insertelement(val, new_regs[reg_idx], position=i) + else: + val = fn(reg, mlir_idx) if create_array: - new_regs[reg_idx] = vector.insertelement(val, new_regs[reg_idx], position=i) + new_regs[reg_idx] = val + if create_array: return FragmentedArray(_registers=new_regs, _layout=self.layout, _is_signed=is_signed) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 9d9d3fa8979c..80e5380d165b 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1756,6 +1756,31 @@ def kernel(ctx, dst, _): rhs = rhs = 0 if rhs_is_literal else iota + 1 np.testing.assert_array_equal(result, op(iota, rhs).astype(jnp.int8)) + def test_foreach_wgmma_row_array(self): + def kernel(ctx, out, smem): + del ctx, smem + x = iota_tensor(128, 128, jnp.float32) + row = x.reduce("add", 1) + # Test returning an array + row = row.foreach( + lambda x, _: arith.addf(x, c(1, row.mlir_dtype)), create_array=True + ) + # Test no array return + @row.foreach + def _(v, idx): + memref.store(v, out, idx) + + result = mgpu.as_gpu_kernel( + kernel, + grid=(1, 1, 1), + block=(128, 1, 1), + in_shape=(), + out_shape=jax.ShapeDtypeStruct(shape=(128,), dtype=jnp.float32), + smem_scratch_shape=(), + )() + iota = np.arange(128 * 128, dtype=jnp.float32).reshape(128, 128) + np.testing.assert_array_equal(result, iota.sum(axis=1) + 1) + def test_foreach(self): dtype = jnp.int32 swizzle = 128 From b9007145d7c4f6f44c41c7111edc56b61be921d7 Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Fri, 4 Apr 2025 08:31:49 -0700 Subject: [PATCH 389/483] [mgpu:pallas] Fix swizzling check bug where it was comparing w/ #bytes rather than #elems. PiperOrigin-RevId: 743953910 --- jax/_src/pallas/mosaic_gpu/core.py | 13 ++++++++----- jax/_src/pallas/mosaic_gpu/lowering.py | 3 ++- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index d3d7f89812c3..444fe6e50f88 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -290,8 +290,9 @@ def untransform_reshape( raise NotImplementedError("Reshapes don't commute with transposes.") def untransform_index( - self, idxs: tuple[Index, ...] + self, dtype: jnp.dtype | ir.Type, idxs: tuple[Index, ...] ) -> tuple[tuple[Index, ...], state_types.Transform]: + del dtype untiled_idxs = idxs[: -len(self.tiling)] tiled_idxs = idxs[-len(self.tiling) :] idxs_after_tiling: list[Index] = [] @@ -395,8 +396,9 @@ def untransform_reshape( raise NotImplementedError("Can't reshape a transposed memref.") def untransform_index( - self, idxs: tuple[Index, ...] + self, dtype: jnp.dtype | ir.Type, idxs: tuple[Index, ...] ) -> tuple[tuple[Index, ...], state_types.Transform]: + del dtype removed_dims = [ i for i, idx in enumerate(idxs) if not isinstance(idx, (slice, mgpu.ds)) ] @@ -503,8 +505,9 @@ def untransform_reshape( return shape, self def untransform_index( - self, idxs: tuple[Index, ...] + self, dtype: jnp.dtype | ir.Type, idxs: tuple[Index, ...] ) -> tuple[tuple[Index, ...], state_types.Transform]: + swizzle_elems = self.swizzle_elems(dtype) if not idxs: return idxs, self if not all(isinstance(idx, (slice, mgpu.ds)) for idx in idxs[-2:]): @@ -513,14 +516,14 @@ def untransform_index( ) last_idx = idxs[-1] if isinstance(last_idx, mgpu.DynamicSlice): - if last_idx.base != 0 or last_idx.length != self.swizzle: + if last_idx.base != 0 or last_idx.length != swizzle_elems: raise ValueError("Swizzled dims cannot be sliced") else: assert isinstance(last_idx, slice) if ( (last_idx.step is not None and last_idx.step != 1) or (last_idx.start is not None and last_idx.start != 0) - or (last_idx.stop is not None and last_idx.stop != self.swizzle) + or (last_idx.stop is not None and last_idx.stop != swizzle_elems) ): raise ValueError("Swizzled dims cannot be sliced") return idxs, self diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index f8932b08b90a..2423e1c1a2a7 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1032,6 +1032,7 @@ def _handle_transforms( handle_reshapes=True, ) -> tuple[ir.Value, Sequence[gpu_core.Transform]]: transformed_ref = ref + mlir_dtype = ir.MemRefType(ref.type).element_type new_transforms = [] def _bubble_up(untransform_fn, data): nonlocal new_transforms @@ -1051,7 +1052,7 @@ def _bubble_up(untransform_fn, data): raise NotImplementedError("int_indexer_shape non-empty") indices = _ndindexer_indices(indexer) indices = _bubble_up( - lambda t, idxs: t.untransform_index(idxs), indices + lambda t, idxs: t.untransform_index(mlir_dtype, idxs), indices ) transformed_ref = mgpu.memref_slice(transformed_ref, indices) case gpu_core.TransposeRef(perm) if handle_transposes: From 35d75183c70e5f83d5df2065956547b0845c29a6 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 4 Apr 2025 10:09:44 -0700 Subject: [PATCH 390/483] `_attempt_rewriting_take_via_slice()`: canonicalize the slice index before checking it's not too long, so that e.g. `my_1d_array[:, ...]` can be treated as a slice rather than generating a gather operation. PiperOrigin-RevId: 743986126 --- jax/_src/numpy/indexing.py | 9 +++++---- tests/lax_numpy_indexing_test.py | 10 +++++++++- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/jax/_src/numpy/indexing.py b/jax/_src/numpy/indexing.py index 863f0c775ec6..05169dd541ce 100644 --- a/jax/_src/numpy/indexing.py +++ b/jax/_src/numpy/indexing.py @@ -526,8 +526,6 @@ def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: str | None) -> if not all(isinstance(i, int) for i in arr.shape): return None - if len(idx) > arr.ndim: - return None if any(i is None for i in idx): return None # TODO(jakevdp): handle newaxis case # For symbolic dimensions fallback to gather @@ -535,10 +533,13 @@ def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: str | None) -> for i in idx if isinstance(i, slice) for elt in (i.start, i.stop, i.step)): return None - if any(i is Ellipsis for i in idx): - # Remove ellipses and add trailing `slice(None)`. + # Remove ellipses and pad with trailing `slice(None)` if necessary. + # Do this before checking against rank of `arr` so that `...` can + # count as no dimensions at all (e.g. `my_1d_array[:, ...]` succeeds) idx = _canonicalize_tuple_index(arr.ndim, idx=idx) + if len(idx) > arr.ndim: + return None simple_revs = {i for i, ind in enumerate(idx) if _is_simple_reverse_slice(ind)} int_indices = {i for i, (ind, size) in enumerate(zip(idx, arr.shape)) diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index 63a725ad3643..ca9ba9c88806 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -926,12 +926,20 @@ def testSimpleIndexingUsesSlice(self): self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.slice_p) self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p) - # Indexing with `Ellipsis` is not lowered to `gather`. + # Indexing with `Ellipsis` is not lowered to `gather` ... jaxpr = jax.make_jaxpr(lambda x: x[..., 0])(jnp.ones((3, 4, 5))) self.assertLen((jaxpr.jaxpr.eqns), 2) self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.slice_p) self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p) + # ... even when the ellipsis expands to no dimensions. + jaxpr = jax.make_jaxpr(lambda x: x[..., 0:1])(jnp.ones((3,))) + self.assertLen((jaxpr.jaxpr.eqns), 1) + self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.slice_p) + jaxpr = jax.make_jaxpr(lambda x: x[0:1, ...])(jnp.ones((3,))) + self.assertLen((jaxpr.jaxpr.eqns), 1) + self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.slice_p) + # Simple reverses lower to lax.rev_p jaxpr = jax.make_jaxpr(lambda x: x[:, ::-1])(jnp.ones((3, 4))) self.assertEqual(len(jaxpr.jaxpr.eqns), 1) From e2f67e0ef1af19f4d32e02f6fb927502469b32c6 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 4 Apr 2025 10:32:47 -0700 Subject: [PATCH 391/483] Always force synchronous pipelining when we have vmem storage and trivial PiperOrigin-RevId: 743993611 --- jax/_src/pallas/mosaic/lowering.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 67cacbc8dcf9..87e06f486366 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -744,13 +744,13 @@ def dynamic_shape_replacement_fn( 1 if b is pallas_core.mapped else b for b in bm.block_shape ] - # No sense in double-buffering without any windowing pattern. - buffer_count = 0 + # Force single-buffering pipelining for trivial windowing in VMEM. + pipeline_mode = bm.pipeline_mode if ( tpu_memory_space == tpu_core.TPUMemorySpace.VMEM and bm.has_trivial_window() ): - buffer_count = 1 + pipeline_mode = pallas_core.Buffered(1) # If we have an extended dtype, we need to add the block shape for the # remaining physical dtype. @@ -769,21 +769,20 @@ def dynamic_shape_replacement_fn( block_params["window_kind"] = ir.Attribute.parse( f"#tpu.element_window<{pad_low},{pad_high}>" ) - if bm.pipeline_mode is not None: - if not isinstance(bm.pipeline_mode, pallas_core.Buffered): + if pipeline_mode is not None: + if not isinstance(pipeline_mode, pallas_core.Buffered): raise LoweringException( - f"Unsupported pipeline mode: {bm.pipeline_mode}." + f"Unsupported pipeline mode: {pipeline_mode}." ) - if buffer_count == 0: - buffer_count = bm.pipeline_mode.buffer_count + buffer_count = pipeline_mode.buffer_count if buffer_count < 1 or buffer_count > 2: raise LoweringException( "Only single (1) and double (2) buffering are supported. Got" f" {buffer_count}." ) - pipeline_mode = "synchronous" if buffer_count == 1 else "double_buffered" + pipeline_mode_str = "synchronous" if buffer_count == 1 else "double_buffered" block_params["pipeline_mode"] = ir.Attribute.parse( - f"#tpu.pipeline_mode<{pipeline_mode}>" + f"#tpu.pipeline_mode<{pipeline_mode_str}>" ) window_params.append(ir.DictAttr.get(block_params)) m.body.append(mlir_func) From e6b01bd1ed18dcbe92041c4bc7254470a11bd0b1 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 4 Apr 2025 10:52:20 -0700 Subject: [PATCH 392/483] Parameterize the random tests taking out_sharding argument in pjit_test.py PiperOrigin-RevId: 744000229 --- tests/pjit_test.py | 142 +++++++++++---------------------------------- 1 file changed, 33 insertions(+), 109 deletions(-) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index f2db913af736..2570c6090351 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7274,116 +7274,60 @@ def f(key): out = f(key) self.assertEqual(out.sharding, NamedSharding(mesh, P())) + @parameterized.named_parameters( + ("bits", partial(jax.random.bits, shape=(8, 12)), P('x', 'y')), + ("uniform", partial(jax.random.uniform, shape=(8, 12)), P('x', 'y')), + ("normal", partial(jax.random.normal, shape=(8, 12)), P('x', 'y')), + ("randint", partial(jax.random.randint, shape=(8, 12), minval=0, maxval=10), + P('x', 'y')), + ("permutation_1d", partial(jax.random.permutation, x=8), P('x')), + ("permutation_2d", partial(jax.random.permutation, + x=np.arange(8 * 12).reshape(8, 12)), + P('x', 'y')), + ) @jtu.with_user_mesh((2, 2), ('x', 'y')) - def test_random_bits(self, mesh): - @jax.jit - def f(key): - out = jax.random.bits(key, shape=(8, 12), out_sharding=P('x', 'y')) - self.assertEqual(out.aval.sharding.spec, P('x', 'y')) - return out - - key = jax.random.key(1) - out = f(key) - self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) - - lowered_text = f.lower(key).as_text() - if config.use_shardy_partitioner.value: - self.assertIn('sdy.sharding_constraint', lowered_text) - self.assertIn('<@mesh, [{"x"}, {"y"}]>', lowered_text) - else: - self.assertIn('mhlo.sharding = "{devices=[2,2]<=[4]}"}', lowered_text) - - @jtu.with_user_mesh((2, 2), ('x', 'y')) - def test_random_uniform(self, mesh): - @jax.jit - def f(key): - out = jax.random.uniform(key, shape=(8, 12), out_sharding=P('x', 'y')) - self.assertEqual(out.aval.sharding.spec, P('x', 'y')) - return out - - key = jax.random.key(1) - out = f(key) - self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) - - lowered_text = f.lower(key).as_text() - if config.use_shardy_partitioner.value: - self.assertIn('sdy.sharding_constraint', lowered_text) - self.assertIn('<@mesh, [{"x"}, {"y"}]>', lowered_text) - else: - self.assertIn('mhlo.sharding = "{devices=[2,2]<=[4]}"}', lowered_text) - - @jtu.with_user_mesh((2, 2), ('x', 'y')) - def test_random_randint(self, mesh): - @jax.jit - def f(key): - out = jax.random.randint(key, shape=(8, 12), minval=0, maxval=10, - out_sharding=P('x', 'y')) - self.assertEqual(out.aval.sharding.spec, P('x', 'y')) - return out - - key = jax.random.key(1) - out = f(key) - self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) - - lowered_text = f.lower(key).as_text() - if config.use_shardy_partitioner.value: - self.assertIn('sdy.sharding_constraint', lowered_text) - self.assertIn('<@mesh, [{"x"}, {"y"}]>', lowered_text) - else: - self.assertIn('mhlo.sharding = "{devices=[2,2]<=[4]}"}', lowered_text) - - @jtu.with_user_mesh((4,), ('x',)) - def test_random_permutation_1d(self, mesh): - @jax.jit - def f(key): - out = jax.random.permutation(key, 8, out_sharding=P('x')) - self.assertEqual(out.aval.sharding.spec, P('x')) - return out - - key = jax.random.key(1) - out = f(key) - self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) - - lowered_text = f.lower(key).as_text() - if config.use_shardy_partitioner.value: - self.assertIn('sdy.sharding_constraint', lowered_text) - self.assertIn('<@mesh, [{"x"}]>', lowered_text) - else: - self.assertIn('mhlo.sharding = "{devices=[4]<=[4]}"}', lowered_text) - - @jtu.with_user_mesh((2, 2), ('x', 'y')) - def test_random_permutation_2d(self, mesh): + def test_random_functions(self, fun, out_spec, mesh): @jax.jit def f(key): - out = jax.random.permutation(key, jnp.arange(8 * 12).reshape(8, 12), - out_sharding=P('x', 'y')) - self.assertEqual(out.aval.sharding.spec, P('x', 'y')) + out = fun(key, out_sharding=out_spec) + self.assertEqual(out.aval.sharding.spec, out_spec) return out key = jax.random.key(1) out = f(key) - self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + self.assertEqual(out.sharding, NamedSharding(mesh, out_spec)) lowered_text = f.lower(key).as_text() if config.use_shardy_partitioner.value: self.assertIn('sdy.sharding_constraint', lowered_text) - self.assertIn('<@mesh, [{"x"}, {"y"}]>', lowered_text) + if out_spec == P('x', 'y'): + self.assertIn('<@mesh, [{"x"}, {"y"}]>', lowered_text) + else: + assert out_spec == P('x') + self.assertIn('<@mesh, [{"x"}]>', lowered_text) else: - self.assertIn('mhlo.sharding = "{devices=[2,2]<=[4]}"}', lowered_text) + if out_spec == P('x', 'y'): + self.assertIn('mhlo.sharding = "{devices=[2,2]<=[4]}"}', lowered_text) + else: + assert out_spec == P('x') + self.assertIn( + 'mhlo.sharding = "{devices=[2,2]<=[4] last_tile_dim_replicate}"}', + lowered_text) @jtu.with_user_mesh((2, 2), ('x', 'y')) - def test_random_normal(self, mesh): + def test_random_truncated_normal(self, mesh): @jax.jit - def f(key): - out = jax.random.normal(key, shape=(8, 12), out_sharding=P('x', 'y')) + def f(key, lower): + out = jax.random.truncated_normal(key, lower, 2., shape=(8, 12), + out_sharding=P('x', 'y')) self.assertEqual(out.aval.sharding.spec, P('x', 'y')) return out key = jax.random.key(1) - out = f(key) + out = f(key, -1.) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) - lowered_text = f.lower(key).as_text() + lowered_text = f.lower(key, -1.).as_text() if config.use_shardy_partitioner.value: self.assertIn('sdy.sharding_constraint', lowered_text) self.assertIn('<@mesh, [{"x"}, {"y"}]>', lowered_text) @@ -7423,26 +7367,6 @@ def f(arr, key): out = f(jax.device_put(np.arange(8 * 12.).reshape(8, 12), s), key) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) - @jtu.with_user_mesh((2, 2), ('x', 'y')) - def test_random_truncated_normal(self, mesh): - @jax.jit - def f(key, lower): - out = jax.random.truncated_normal(key, lower, 2., shape=(8, 12), - out_sharding=P('x', 'y')) - self.assertEqual(out.aval.sharding.spec, P('x', 'y')) - return out - - key = jax.random.key(1) - out = f(key, -1.) - self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) - - lowered_text = f.lower(key, -1.).as_text() - if config.use_shardy_partitioner.value: - self.assertIn('sdy.sharding_constraint', lowered_text) - self.assertIn('<@mesh, [{"x"}, {"y"}]>', lowered_text) - else: - self.assertIn('mhlo.sharding = "{devices=[2,2]<=[4]}"}', lowered_text) - def test_auto_axes_no_context_mesh(self): mesh = jtu.create_mesh((2, 2), ('x', 'y'), axis_types=(AxisType.Explicit,) * 2) np_inp = np.arange(16.).reshape(8, 2) From be1a554d0bbce75a7fbc3e66fee81435f184676d Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 4 Apr 2025 11:18:27 -0700 Subject: [PATCH 393/483] Fix a possible race in pjit.cc. We need to be careful not to destroy Python objects while using a Python 3.13- critical section to protect C++ state. The critical section might be released when calling back into Python code (much as the GIL may be released in GIL mode). In this code Key is kept alive by the function already, but the Value may be deleted before the hash table updates are done. PiperOrigin-RevId: 744008939 --- jaxlib/xla/pjit.cc | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/jaxlib/xla/pjit.cc b/jaxlib/xla/pjit.cc index 508bf79f9ec0..503e8ef23f4b 100644 --- a/jaxlib/xla/pjit.cc +++ b/jaxlib/xla/pjit.cc @@ -245,9 +245,14 @@ std::shared_ptr PjitFunctionCache::DefaultCache() { std::shared_ptr cache = std::make_shared(&self->lru_list_); auto callback = nb::cpp_function([self, key{std::move(key)}](nb::handle weakref) { - nb::ft_object_guard lock(self); - auto it = self->functions_.find(key); - if (it != self->functions_.end()) { + std::unique_ptr value; + { + nb::ft_object_guard lock(self); + auto it = self->functions_.find(key); + if (it == self->functions_.end()) { + return; + } + value = std::move(it->second); self->functions_.erase(it); } }); From 5d4ac775dd210d2e5deca46c5006b07d7e08d6e9 Mon Sep 17 00:00:00 2001 From: Georg Stefan Schmid Date: Fri, 4 Apr 2025 11:28:13 -0700 Subject: [PATCH 394/483] PR #26906: [jax.distributed] Allow explicitly setting slice_index Imported from GitHub PR https://github.com/jax-ml/jax/pull/26906 Allows overriding the slice index used by XLA. More explicit control over which slice a device ends up in is desirable: - Various parts of the ecosystem equate slices with "devices communicating via fast interconnect". With the arrival of NVL72 we want devices managed by multiple hosts to form a single slice. - For debugging purposes it can be useful to allow devices on the same host (managed in separate processes) to be treated as different slices. For example, [Orbax](https://github.com/google/orbax)'s local checkpointing presumes the existence of at least two slices, so overriding the boot id will allow us to test local checkpointing on a single host. (Companion PR in XLA: https://github.com/openxla/xla/pull/23347) Copybara import of the project: -- 45aa7ce316bb05ebcc3f3ed2d888385923285e58 by Georg Stefan Schmid : [jax.distributed] Allow overriding XLA slice_index Merging this change closes #26906 COPYBARA_INTEGRATE_REVIEW=https://github.com/jax-ml/jax/pull/26906 from gspschmid:gschmid/jax-override-boot-id 45aa7ce316bb05ebcc3f3ed2d888385923285e58 PiperOrigin-RevId: 744012253 --- jax/_src/distributed.py | 16 +++++++++++++--- jax/_src/xla_bridge.py | 2 ++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/jax/_src/distributed.py b/jax/_src/distributed.py index af50e2e9e31a..a7551465c425 100644 --- a/jax/_src/distributed.py +++ b/jax/_src/distributed.py @@ -41,6 +41,7 @@ class State: client: Any | None = None preemption_sync_manager: Any | None = None coordinator_address: str | None = None + slice_index: int | None = None def initialize(self, coordinator_address: str | None = None, @@ -53,7 +54,8 @@ def initialize(self, service_heartbeat_interval_seconds: int = 10, service_max_missing_heartbeats: int = 10, client_heartbeat_interval_seconds: int = 10, - client_max_missing_heartbeats: int = 10): + client_max_missing_heartbeats: int = 10, + slice_index: int | None = None): coordinator_address = (coordinator_address or os.environ.get('JAX_COORDINATOR_ADDRESS')) if isinstance(local_device_ids, int): @@ -149,6 +151,10 @@ def initialize(self, self.initialize_preemption_sync_manager() + if slice_index is None and 'JAX_SLICE_INDEX' in os.environ: + slice_index = int(os.environ.get('JAX_SLICE_INDEX')) + self.slice_index = slice_index + def shutdown(self): if self.client: self.client.shutdown() @@ -175,7 +181,8 @@ def initialize(coordinator_address: str | None = None, local_device_ids: int | Sequence[int] | None = None, cluster_detection_method: str | None = None, initialization_timeout: int = 300, - coordinator_bind_address: str | None = None): + coordinator_bind_address: str | None = None, + slice_index: int | None = None): """Initializes the JAX distributed system. Calling :func:`~jax.distributed.initialize` prepares JAX for execution on @@ -236,6 +243,8 @@ def initialize(coordinator_address: str | None = None, all available addresses on the same port as ``coordinator_address``. On systems that have multiple network interfaces per node it may be insufficient to only have the coordinator service listen on one address/interface. + slice_index: The slice index assigned to this process' local devices. If any process sets ``slice_index``, + then all processes must do so. If ``None`` the slice indices will be chosen automatically. Raises: RuntimeError: If :func:`~jax.distributed.initialize` is called more than once @@ -261,7 +270,8 @@ def initialize(coordinator_address: str | None = None, "This includes any computation, but also calls to jax.devices, jax.device_put, and others.") global_state.initialize(coordinator_address, num_processes, process_id, local_device_ids, cluster_detection_method, - initialization_timeout, coordinator_bind_address) + initialization_timeout, coordinator_bind_address, + slice_index=slice_index) def is_initialized() -> bool: diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 227359dc4676..178ac5e6fc01 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -636,6 +636,8 @@ def factory(): 'node_id': distributed.global_state.process_id, 'num_nodes': distributed.global_state.num_processes, } + if (slice_index := distributed.global_state.slice_index) is not None: + distribute_options['slice_index'] = slice_index if options is not None: distribute_options.update(updated_options) return xla_client.make_c_api_client( From 549f1cd856bbed820c23c3cec56a084ab5c31d9e Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 4 Apr 2025 14:50:19 -0700 Subject: [PATCH 395/483] Don't set `memory_kind` to `None` if the mesh is AbstractMesh and the PiperOrigin-RevId: 744077517 --- jax/_src/distributed.py | 2 +- jaxlib/xla/sharding.cc | 21 ++++++++++++++++++++- jaxlib/xla/xla_client.py | 2 +- tests/array_test.py | 18 +++++++++++++++++- 4 files changed, 39 insertions(+), 4 deletions(-) diff --git a/jax/_src/distributed.py b/jax/_src/distributed.py index a7551465c425..fb0aebb0e642 100644 --- a/jax/_src/distributed.py +++ b/jax/_src/distributed.py @@ -152,7 +152,7 @@ def initialize(self, self.initialize_preemption_sync_manager() if slice_index is None and 'JAX_SLICE_INDEX' in os.environ: - slice_index = int(os.environ.get('JAX_SLICE_INDEX')) + slice_index = int(os.environ.get('JAX_SLICE_INDEX')) # type: ignore self.slice_index = slice_index def shutdown(self): diff --git a/jaxlib/xla/sharding.cc b/jaxlib/xla/sharding.cc index 858c025745e4..b6b58b0600ad 100644 --- a/jaxlib/xla/sharding.cc +++ b/jaxlib/xla/sharding.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include #include #include #include @@ -192,6 +193,14 @@ bool ShardingEqual(nb::handle a, nb::handle b) { return a.equal(b); } +// This list is to check for valid memory kinds when an AbstractMesh is passed +// to NamedSharding. +static const std::array valid_memory_kinds = { + "device", + "pinned_host", + "unpinned_host", +}; + NamedSharding::NamedSharding(nb::object mesh, nb::object spec, nb::object memory_kind, nb::object manual_axes, nb::object logical_device_ids) @@ -217,7 +226,17 @@ NamedSharding::NamedSharding(nb::object mesh, nb::object spec, memory_kind_ = CheckAndCanonicalizeMemoryKind(memory_kind_, *internal_device_list_); } else { - memory_kind_ = nb::none(); + if (!memory_kind_.is_none() && + (std::find(valid_memory_kinds.begin(), valid_memory_kinds.end(), + nb::cast(memory_kind_)) == + valid_memory_kinds.end())) { + throw nb::value_error( + absl::StrCat("Got invalid memory kind: ", + nb::cast(memory_kind_), + ". Valid memory kinds are: ", + absl::StrJoin(valid_memory_kinds, ", ")) + .c_str()); + } } // TODO(phawkins): this leaks a reference to the check_pspec function. diff --git a/jaxlib/xla/xla_client.py b/jaxlib/xla/xla_client.py index 523f8bb57b90..58e0cb070e29 100644 --- a/jaxlib/xla/xla_client.py +++ b/jaxlib/xla/xla_client.py @@ -50,7 +50,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version. -_version = 325 +_version = 326 # An internal increasing version number for protecting jaxlib code against # ifrt changes. diff --git a/tests/array_test.py b/tests/array_test.py index 2bdc54607473..76aa1093ede3 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -29,9 +29,10 @@ from jax._src import test_util as jtu from jax._src import xla_bridge as xb from jax._src.lib import xla_client as xc +from jax._src.lib import jaxlib_extension_version from jax._src.lib.mlir import dialects, ir from jax._src.util import safe_zip -from jax._src.mesh import AxisType +from jax._src.mesh import AxisType, AbstractMesh from jax._src.sharding import common_devices_indices_map from jax._src.sharding_impls import ( _op_sharding_to_pos_sharding, pmap_sharding_devices_indices_map, @@ -1418,6 +1419,21 @@ def test_make_mesh_axis_types(self): self.assertNotEqual(mesh1, mesh2) self.assertNotEqual(hash(mesh1), hash(mesh2)) + def test_memory_kind_with_abstract_mesh(self): + if jaxlib_extension_version < 326: + self.skipTest('Requires jaxlib_extension_version >= 326') + + abstract_mesh = AbstractMesh((2,), ('x',)) + ns = NamedSharding(abstract_mesh, P(), memory_kind='pinned_host') + self.assertEqual(ns.memory_kind, 'pinned_host') + + ns = NamedSharding(abstract_mesh, P()) + self.assertIsNone(ns.memory_kind) + + with self.assertRaisesRegex( + ValueError, 'Got invalid memory kind'): + NamedSharding(abstract_mesh, P(), memory_kind='weird_device') + @jtu.with_config(jax_use_shardy_partitioner=True) class ShardyShardingTest(jtu.JaxTestCase): From d81c0ffeb744d51e9e428fed0922f2df06dadfd5 Mon Sep 17 00:00:00 2001 From: Gleb Pobudzey Date: Fri, 4 Apr 2025 15:10:08 -0700 Subject: [PATCH 396/483] [Mosaic GPU] Limit the maximum number of registers per thread to 255. PiperOrigin-RevId: 744083257 --- jax/_src/pallas/mosaic_gpu/pipeline.py | 14 +++++++++++--- tests/pallas/mosaic_gpu_test.py | 3 ++- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index 21efbbec6630..df9c6668a51d 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -686,8 +686,16 @@ def _compute_registers( memory_registers: int, num_compute_wgs: int, ) -> int: - """Returns the number of registers to use for the compute thread.""" - # TODO(justinfu): Configure this per-platform. - n_registers = (512 - memory_registers) / num_compute_wgs + """Returns the max number of registers to use in compute threads. + + We start with the theoretical max registers per thread if one wargroup + (128 threads) used the entire SM's 64k register file (64k / 128 = 512). + Then reserve `memory_registers` for the producer warpgroup and distribute + the remaining registers evenly among the compute warpgroups. + + Note: The maximum number of registers per thread is 255, so we clamp + the value. + """ + n_registers = min(256, (512 - memory_registers) / num_compute_wgs) # Round down to the nearest multiple of 8. return int((n_registers // 8) * 8) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 0a3087f5902d..809c9c8fcaeb 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -2228,7 +2228,8 @@ def copy_kernel(_, x_smem, o_smem, o_last_block_smem, *consumed_barriers): np.testing.assert_array_equal(out, x) np.testing.assert_array_equal(out_last_block, x[-blk_m:, -blk_n:]) - def test_elementwise_add(self, m=256, n=256, num_compute_wgs=2): + @parameterized.product(m=[256], n=[256], num_compute_wgs=[1, 2]) + def test_elementwise_add(self, m, n, num_compute_wgs): self.skip_if_wg_semantics() # Crashes! blk_m = blk_n = 64 From aab6613944857c76c19e7e4732870885a1fa0a27 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 4 Apr 2025 15:33:20 -0700 Subject: [PATCH 397/483] [pallas:mosaic_gpu] Fixed a typo in `_barrier_arrive_pp_eqn` PiperOrigin-RevId: 744089477 --- jax/_src/pallas/mosaic_gpu/primitives.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index bfe1bf3fe5bc..8c04bd23cf22 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -618,7 +618,7 @@ def _barrier_arrive_pp_eqn( ): del settings barrier, *flat_transforms = eqn.invars - transforms_treedef = eqn.params["transforms_tree"] + transforms_treedef = eqn.params["transforms_treedef"] transforms = transforms_treedef.unflatten(flat_transforms) return pp.concat([ pp.text("barrier_arrive"), From fc5d9a4fcee2a4606f36d2d2bd517458afde24a3 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 4 Apr 2025 19:22:31 -0700 Subject: [PATCH 398/483] Check that memory_kind of an aval is always None PiperOrigin-RevId: 744136969 --- jax/_src/core.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jax/_src/core.py b/jax/_src/core.py index 1c35f5406543..9a5a6061cc5e 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1899,6 +1899,7 @@ def get_sharding(sharding, shape): raise ValueError("Mesh of an aval must be an AbstractMesh. " f"Got {out_s.mesh} of type {type(out_s.mesh)}") _check_divisibility(out_s, shape) + assert out_s.memory_kind is None return out_s def str_short_aval(shape, dtype, mesh, spec, vma, From 2e62693f72b4ce217dbd798830313439c8fbc1a6 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sat, 5 Apr 2025 06:36:53 -0700 Subject: [PATCH 399/483] Update XLA dependency to use revision http://github.com/openxla/xla/commit/8118a02a2d8af30563d2942818ddb7c07c373877. PiperOrigin-RevId: 744248817 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 7abf3da775d2..e632798e3132 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "ffa8cf08e295cec70a27a6b27bfaa19c5d0daeec" -XLA_SHA256 = "d5de319756b6a32748d2821f5319f831b062d0f5f22b7f0bde1d9564dc6b6f5e" +XLA_COMMIT = "8118a02a2d8af30563d2942818ddb7c07c373877" +XLA_SHA256 = "080edaa896d1537bb838428c164cab88532ab5b9609cb6b58ddaf19bad37f88b" def repo(): tf_http_archive( From b1b54f9b5ecbcd12bd09f91e7c36cfd40dd0ce15 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 6 Apr 2025 05:43:14 -0700 Subject: [PATCH 400/483] Update XLA dependency to use revision http://github.com/openxla/xla/commit/3889bec6b7f48e304953a485b713e9982dff0441. PiperOrigin-RevId: 744444688 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index e632798e3132..1a7522bda0fa 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "8118a02a2d8af30563d2942818ddb7c07c373877" -XLA_SHA256 = "080edaa896d1537bb838428c164cab88532ab5b9609cb6b58ddaf19bad37f88b" +XLA_COMMIT = "3889bec6b7f48e304953a485b713e9982dff0441" +XLA_SHA256 = "f23bb226d334f933cd5e6ebc4b20dec9ad879137763975546120ddf582a472b8" def repo(): tf_http_archive( From ad36f7f2532528c661ed27fc7c71dbc0e2e11c9d Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 6 Apr 2025 09:57:16 -0700 Subject: [PATCH 401/483] Automated Code Change PiperOrigin-RevId: 744478350 --- jaxlib/xla/BUILD | 2 ++ jaxlib/xla/custom_call_sharding.cc | 2 ++ jaxlib/xla/dlpack.cc | 1 + jaxlib/xla/dlpack.h | 1 + 4 files changed, 6 insertions(+) diff --git a/jaxlib/xla/BUILD b/jaxlib/xla/BUILD index 5b532c1dc501..c861ed06e5be 100644 --- a/jaxlib/xla/BUILD +++ b/jaxlib/xla/BUILD @@ -220,6 +220,7 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@nanobind", + "@xla//third_party/python_runtime:headers", "@xla//xla:shape_util", "@xla//xla:util", "@xla//xla/hlo/ir:hlo", @@ -264,6 +265,7 @@ cc_library( "@xla//xla:shape_util", "@xla//xla:status_macros", "@xla//xla:util", + "@xla//xla:xla_data_proto_cc", "@xla//xla/pjrt:exceptions", "@xla//xla/pjrt:pjrt_client", "@xla//xla/pjrt:pjrt_common", diff --git a/jaxlib/xla/custom_call_sharding.cc b/jaxlib/xla/custom_call_sharding.cc index 3cb53b438e09..00accd85aefd 100644 --- a/jaxlib/xla/custom_call_sharding.cc +++ b/jaxlib/xla/custom_call_sharding.cc @@ -15,6 +15,8 @@ limitations under the License. #include "jaxlib/xla/custom_call_sharding.h" +#include + #include #include #include diff --git a/jaxlib/xla/dlpack.cc b/jaxlib/xla/dlpack.cc index d1cb91114b05..c8d02e679036 100644 --- a/jaxlib/xla/dlpack.cc +++ b/jaxlib/xla/dlpack.cc @@ -58,6 +58,7 @@ limitations under the License. #include "xla/tsl/platform/logging.h" #include "xla/tsl/platform/statusor.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" namespace nb = nanobind; diff --git a/jaxlib/xla/dlpack.h b/jaxlib/xla/dlpack.h index 46b0954105f7..7fffdc345d79 100644 --- a/jaxlib/xla/dlpack.h +++ b/jaxlib/xla/dlpack.h @@ -25,6 +25,7 @@ limitations under the License. #include "jaxlib/xla/nb_class_ptr.h" #include "jaxlib/xla/py_client.h" #include "xla/python/ifrt/device.h" +#include "xla/xla_data.pb.h" namespace xla { From 477b10825ad88c25ff3730cc068ea265d5f54dfa Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 6 Apr 2025 10:10:14 -0700 Subject: [PATCH 402/483] Automated Code Change PiperOrigin-RevId: 744480338 --- examples/jax_cpp/BUILD | 2 ++ examples/jax_cpp/main.cc | 2 ++ 2 files changed, 4 insertions(+) diff --git a/examples/jax_cpp/BUILD b/examples/jax_cpp/BUILD index b3cb995aae21..86f3129c9876 100644 --- a/examples/jax_cpp/BUILD +++ b/examples/jax_cpp/BUILD @@ -21,6 +21,7 @@ cc_binary( srcs = ["main.cc"], tags = ["manual"], deps = [ + "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:platform_port", @@ -33,6 +34,7 @@ cc_binary( "@xla//xla/pjrt/plugin/xla_cpu:cpu_client_options", "@xla//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client", "@xla//xla/service:hlo_module_config", + "@xla//xla/service:hlo_proto_cc", "@xla//xla/tools:hlo_module_loader", ], ) diff --git a/examples/jax_cpp/main.cc b/examples/jax_cpp/main.cc index 5d1190ff1f2c..8deea5448fec 100644 --- a/examples/jax_cpp/main.cc +++ b/examples/jax_cpp/main.cc @@ -41,6 +41,7 @@ limitations under the License. #include #include +#include "absl/log/log.h" #include "absl/status/statusor.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_module.h" @@ -50,6 +51,7 @@ limitations under the License. #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/plugin/xla_cpu/cpu_client_options.h" #include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h" +#include "xla/service/hlo.pb.h" #include "xla/service/hlo_module_config.h" #include "xla/tools/hlo_module_loader.h" #include "tsl/platform/init_main.h" From 3f083caef59b224809e10808c4c306383307c2dc Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 6 Apr 2025 10:10:22 -0700 Subject: [PATCH 403/483] Automated Code Change PiperOrigin-RevId: 744480358 --- examples/ffi/src/jax_ffi_example/rms_norm.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/ffi/src/jax_ffi_example/rms_norm.cc b/examples/ffi/src/jax_ffi_example/rms_norm.cc index 819f3b9f868d..bcfc1eb67aa4 100644 --- a/examples/ffi/src/jax_ffi_example/rms_norm.cc +++ b/examples/ffi/src/jax_ffi_example/rms_norm.cc @@ -16,8 +16,6 @@ limitations under the License. #include #include #include -#include -#include #include #include From d1c7ba4335d09a5523e42fc8280ec94413dae31a Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 6 Apr 2025 10:10:59 -0700 Subject: [PATCH 404/483] Automated Code Change PiperOrigin-RevId: 744480452 --- jaxlib/mosaic/gpu/custom_call.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index 465551e2903b..38a388224d65 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -41,6 +41,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/CodeGen.h" From 7874d79f56f99f9366039883d304c659c84f1c47 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 6 Apr 2025 10:32:55 -0700 Subject: [PATCH 405/483] Automated Code Change PiperOrigin-RevId: 744483310 --- jaxlib/xla/BUILD | 1 + jaxlib/xla/callback.cc | 1 + jaxlib/xla/py_socket_transfer.cc | 1 + jaxlib/xla/to_ifrt_sharding.cc | 1 + 4 files changed, 4 insertions(+) diff --git a/jaxlib/xla/BUILD b/jaxlib/xla/BUILD index c861ed06e5be..35f344046828 100644 --- a/jaxlib/xla/BUILD +++ b/jaxlib/xla/BUILD @@ -732,6 +732,7 @@ cc_library( "@xla//xla/python/transfer:socket_bulk_transport", "@xla//xla/python/transfer:streaming", "@xla//xla/python/transfer:streaming_ifrt", + "@xla//xla/python/transfer:transfer_socket_proto_cc", "@xla//xla/tsl/concurrency:ref_count", "@xla//xla/tsl/platform:statusor", ], diff --git a/jaxlib/xla/callback.cc b/jaxlib/xla/callback.cc index 6f5644c3b0c7..b5519ed3bee3 100644 --- a/jaxlib/xla/callback.cc +++ b/jaxlib/xla/callback.cc @@ -41,6 +41,7 @@ limitations under the License. #include "xla/python/nb_numpy.h" #include "xla/tsl/platform/statusor.h" #include "xla/tsl/python/lib/core/numpy.h" +#include "xla/xla_data.pb.h" namespace nb = nanobind; diff --git a/jaxlib/xla/py_socket_transfer.cc b/jaxlib/xla/py_socket_transfer.cc index 55d84fd71bb7..4aa40cf66087 100644 --- a/jaxlib/xla/py_socket_transfer.cc +++ b/jaxlib/xla/py_socket_transfer.cc @@ -58,6 +58,7 @@ limitations under the License. #include "xla/python/transfer/socket_bulk_transport.h" #include "xla/python/transfer/streaming.h" #include "xla/python/transfer/streaming_ifrt.h" +#include "xla/python/transfer/transfer_socket.pb.h" #include "xla/python/types.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/tsl/platform/statusor.h" diff --git a/jaxlib/xla/to_ifrt_sharding.cc b/jaxlib/xla/to_ifrt_sharding.cc index 2a7c6707e766..52879cfa9fbe 100644 --- a/jaxlib/xla/to_ifrt_sharding.cc +++ b/jaxlib/xla/to_ifrt_sharding.cc @@ -38,6 +38,7 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tsl/platform/statusor.h" +#include "xla/xla_data.pb.h" namespace xla { From 8a6efa317d2c104ca7905a6a4d6e521a9b9ebe4c Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Sun, 6 Apr 2025 13:35:12 -0700 Subject: [PATCH 406/483] Fix deadlock when computing cached Sharding::type() values. C++ static initialization acquires an internal mutex. It is unsafe to call into Python code while holding that mutex, e.g., see the deadlock in https://gist.github.com/vfdev-5/826ef16c6cbc9f4d85466e8a348c3b5a However, in this case, there's a simpler thing we can do: eagerly initialize the ::type() values during module initialization, rather than on-demand. PiperOrigin-RevId: 744508279 --- jaxlib/xla/BUILD | 1 + jaxlib/xla/sharding.cc | 38 +++++++++++++++++++++++++++++++++++++- jaxlib/xla/sharding.h | 32 ++++++++++++++++---------------- 3 files changed, 54 insertions(+), 17 deletions(-) diff --git a/jaxlib/xla/BUILD b/jaxlib/xla/BUILD index 35f344046828..8602652cbd8a 100644 --- a/jaxlib/xla/BUILD +++ b/jaxlib/xla/BUILD @@ -565,6 +565,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", diff --git a/jaxlib/xla/sharding.cc b/jaxlib/xla/sharding.cc index b6b58b0600ad..ff1539764864 100644 --- a/jaxlib/xla/sharding.cc +++ b/jaxlib/xla/sharding.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include #include #include #include @@ -28,6 +29,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" #include "nanobind/nanobind.h" #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/string_view.h" // IWYU pragma: keep @@ -242,7 +244,7 @@ NamedSharding::NamedSharding(nb::object mesh, nb::object spec, // TODO(phawkins): this leaks a reference to the check_pspec function. // A better way to fix this would be to move PartitionSpec and this check into // C++. - nb::object* check_pspec = [](){ + nb::object* check_pspec = []() { static absl::Mutex mu; static nb::object* output = nullptr; { @@ -262,6 +264,13 @@ NamedSharding::NamedSharding(nb::object mesh, nb::object spec, (*check_pspec)(mesh_, spec_, manual_axes_); } +/*static*/ PyObject* NamedSharding::type_ = nullptr; + +/*static*/ void NamedSharding::InitializeType() { + // Intentionally leaks a reference. + type_ = nanobind::type().inc_ref().ptr(); +} + SingleDeviceSharding::SingleDeviceSharding(nb::object device, nb::object memory_kind) : Sharding(/*num_devices=*/1), @@ -273,6 +282,13 @@ SingleDeviceSharding::SingleDeviceSharding(nb::object device, CheckAndCanonicalizeMemoryKind(memory_kind_, internal_device_list_); } +/*static*/ PyObject* SingleDeviceSharding::type_ = nullptr; + +/*static*/ void SingleDeviceSharding::InitializeType() { + // Intentionally leaks a reference. + type_ = nanobind::type().inc_ref().ptr(); +} + SingleDeviceSharding::SingleDeviceSharding( xla::nb_class_ptr client, xla::ifrt::DeviceListRef device_list, nb::object memory_kind) @@ -295,6 +311,15 @@ PmapSharding::PmapSharding(xla::nb_numpy_ndarray devices, xla::make_nb_class(nb::tuple(flat_devices)); } +/*static*/ PyObject* PmapSharding::type_ = nullptr; + +// /*static*/ nanobind::handle PmapSharding::type() { return type_; } + +/*static*/ void PmapSharding::InitializeType() { + // Intentionally leaks a reference. + type_ = nanobind::type().inc_ref().ptr(); +} + GSPMDSharding::GSPMDSharding(nb::sequence devices, xla::HloSharding op_sharding, nb::object memory_kind, nb::object device_list) : Sharding(/*num_devices=*/nb::len(devices.ptr())), @@ -316,6 +341,13 @@ GSPMDSharding::GSPMDSharding(nb::sequence devices, xla::HloSharding op_sharding, CheckAndCanonicalizeMemoryKind(memory_kind_, internal_device_list_); } +/*static*/ PyObject* GSPMDSharding::type_ = nullptr; + +/*static*/ void GSPMDSharding::InitializeType() { + // Intentionally leaks a reference. + type_ = nanobind::type().inc_ref().ptr(); +} + void RegisterSharding(nb::module_& m) { nb::class_(m, "Sharding").def(nb::init<>()); @@ -334,6 +366,7 @@ void RegisterSharding(nb::module_& m) { .def_prop_ro("_internal_device_list", [](const NamedSharding& s) { return xla::ValueOrThrow(s.internal_device_list()); }); + NamedSharding::InitializeType(); nb::class_(m, "SingleDeviceSharding", nb::dynamic_attr()) @@ -343,6 +376,7 @@ void RegisterSharding(nb::module_& m) { .def_prop_ro("_memory_kind", &SingleDeviceSharding::memory_kind) .def_prop_ro("_internal_device_list", &SingleDeviceSharding::internal_device_list); + SingleDeviceSharding::InitializeType(); nb::class_(m, "PmapSharding", nb::dynamic_attr()) .def( @@ -357,6 +391,7 @@ void RegisterSharding(nb::module_& m) { .def_prop_ro("sharding_spec", &PmapSharding::sharding_spec) .def_prop_ro("_internal_device_list", &PmapSharding::internal_device_list); + PmapSharding::InitializeType(); nb::class_(m, "GSPMDSharding", nb::dynamic_attr()) .def(nb::init(), @@ -372,6 +407,7 @@ void RegisterSharding(nb::module_& m) { .def_prop_ro("_memory_kind", &GSPMDSharding::memory_kind) .def_prop_ro("_internal_device_list", &GSPMDSharding::internal_device_list); + GSPMDSharding::InitializeType(); } } // namespace jax diff --git a/jaxlib/xla/sharding.h b/jaxlib/xla/sharding.h index 698ff2ca9ca8..4b602bd14324 100644 --- a/jaxlib/xla/sharding.h +++ b/jaxlib/xla/sharding.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef JAXLIB_XLA_SHARDING_H_ #define JAXLIB_XLA_SHARDING_H_ +#include + #include #include #include @@ -84,10 +86,8 @@ class NamedSharding : public Sharding { return logical_device_ids_; } - static nanobind::handle type() { - static auto type = nanobind::type(); - return type; - } + static nanobind::handle type() { return type_; } + static void InitializeType(); absl::StatusOr> internal_device_list() const { if (internal_device_list_) { @@ -105,6 +105,7 @@ class NamedSharding : public Sharding { nanobind::object manual_axes_; nanobind::object logical_device_ids_; std::optional> internal_device_list_; + static PyObject* type_; }; class SingleDeviceSharding : public Sharding { @@ -120,10 +121,8 @@ class SingleDeviceSharding : public Sharding { const nanobind::object& device() const { return device_; } const nanobind::object& memory_kind() const { return memory_kind_; } - static nanobind::handle type() { - static auto type = nanobind::type(); - return type; - } + static nanobind::handle type() { return type_; } + static void InitializeType(); xla::nb_class_ptr internal_device_list() const { return internal_device_list_; @@ -133,6 +132,8 @@ class SingleDeviceSharding : public Sharding { nanobind::object device_; nanobind::object memory_kind_; xla::nb_class_ptr internal_device_list_; + + static PyObject* type_; }; // The C++ implementation of jax.PmapSharding in python. It contains a few key @@ -147,10 +148,8 @@ class PmapSharding : public Sharding { const ShardingSpec& sharding_spec() const { return sharding_spec_; } - static nanobind::handle type() { - static auto type = nanobind::type(); - return type; - } + static nanobind::handle type() { return type_; } + static void InitializeType(); xla::nb_class_ptr internal_device_list() const { return internal_device_list_; @@ -160,6 +159,7 @@ class PmapSharding : public Sharding { xla::nb_numpy_ndarray devices_; ShardingSpec sharding_spec_; xla::nb_class_ptr internal_device_list_; + static PyObject* type_; }; class GSPMDSharding : public Sharding { @@ -184,10 +184,8 @@ class GSPMDSharding : public Sharding { return *hash_; } - static nanobind::handle type() { - static auto type = nanobind::type(); - return type; - } + static nanobind::handle type() { return type_; } + static void InitializeType(); const xla::HloSharding& hlo_sharding() const { return hlo_sharding_; } @@ -234,6 +232,8 @@ class GSPMDSharding : public Sharding { nanobind::object memory_kind_; std::optional hash_; xla::nb_class_ptr internal_device_list_; + + static PyObject* type_; }; void RegisterSharding(nanobind::module_& m); From cccc34dc2334040e58eeb6131f2ac7f1470a8f62 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Sun, 6 Apr 2025 23:37:20 -0700 Subject: [PATCH 407/483] Raise an error if the type passed to `axis_types` argument of `Mesh` and `AbstractMesh` is not `jax.sharding.AxisType`. PiperOrigin-RevId: 744602037 --- jax/_src/mesh.py | 13 +++++++++---- tests/array_test.py | 10 ++++++++++ 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index 00859f9b3d74..8db4445542d0 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -111,12 +111,16 @@ class AxisType(enum.Enum): def __repr__(self): return self.name -def _normalize_axis_types(axis_names, axis_types): +def _normalize_axis_types(axis_names, axis_types, name): axis_types = ((AxisType.Auto,) * len(axis_names) if axis_types is None else axis_types) if not isinstance(axis_types, tuple): - assert isinstance(axis_types, AxisType), axis_types axis_types = (axis_types,) + + if not all(isinstance(a, AxisType) for a in axis_types): + raise TypeError( + f"axis_types passed to {name} must be of type `jax.sharding.AxisType`." + f" Got {axis_types} of type {tuple(type(a) for a in axis_types)}") if len(axis_names) != len(axis_types): raise ValueError( "Number of axis names should match the number of axis_types. Got" @@ -256,7 +260,7 @@ def __new__(cls, devices: np.ndarray | Sequence[xc.Device], f"devices.ndim == {devices.ndim} and " f"len(axis_names) == {len(axis_names)}.") - axis_types = _normalize_axis_types(axis_names, axis_types) + axis_types = _normalize_axis_types(axis_names, axis_types, 'Mesh') key = (axis_names, devices.shape, tuple(devices.flat), axis_types) val = _mesh_object_dict.get(key, None) @@ -440,7 +444,8 @@ def __init__(self, axis_sizes: tuple[int, ...], axis_names: tuple[str, ...], self.axis_sizes = axis_sizes self.axis_names = axis_names self._size = math.prod(self.axis_sizes) if self.axis_sizes else 0 - self._axis_types = _normalize_axis_types(self.axis_names, axis_types) + self._axis_types = _normalize_axis_types( + self.axis_names, axis_types, 'AbstractMesh') self._hash = hash((self.axis_sizes, self.axis_names, self._axis_types)) def __hash__(self): diff --git a/tests/array_test.py b/tests/array_test.py index 76aa1093ede3..f097497cef51 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -1377,6 +1377,16 @@ def test_mesh_axis_types_mismatch(self): jax.sharding.AbstractMesh((2, 1), ('x', 'y'), axis_types=jax.sharding.AxisType.Auto) + with self.assertRaisesRegex(TypeError, "axis_types.*must be of type"): + AbstractMesh((2,), ('x',), axis_types=("explicit",)) + + with self.assertRaisesRegex(TypeError, "axis_types.*must be of type"): + AbstractMesh((2,), ('x',), axis_types="explicit") + + with self.assertRaisesRegex(TypeError, "axis_types.*must be of type"): + AbstractMesh((2, 2), ('x', 'y'), + axis_types=("explicit", AxisType.Explicit)) + def test_make_mesh_axis_types(self): Auto, Explicit, Manual = AxisType.Auto, AxisType.Explicit, AxisType.Manual From 90cfa99a6868df89ea96923aae4338c123bfd242 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Mon, 7 Apr 2025 00:51:13 -0700 Subject: [PATCH 408/483] [Mosaic GPU] Support Slice and Transpose in the Pallas WGMMA lowering This change also fixes the transpose handling in the lowering and completely removes the use of the TransposeTransform. Instead we rely on strides. If we don't discover any issues with this, we will remove the transpose transform also from the mlir dialect. PiperOrigin-RevId: 744618241 --- jax/_src/pallas/mosaic_gpu/primitives.py | 54 ++++++++--- .../mosaic/gpu/dialect_lowering.py | 94 ++++++++++++++----- .../mosaic/gpu/transform_inference.py | 81 ++++++++++++---- tests/mosaic/gpu_dialect_test.py | 22 +++++ tests/pallas/mosaic_gpu_test.py | 47 +++++++--- 5 files changed, 229 insertions(+), 69 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 8c04bd23cf22..8bd67e705cf0 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -903,15 +903,25 @@ def _wgmma_lowering( transforms_leaves, [a_transforms_tree.num_leaves] ) a_transforms = a_transforms_tree.unflatten(a_transforms_leaves) - a, a_transforms = lowering._handle_transforms(a, a_transforms) + a, a_transforms = lowering._handle_transforms( + a, a_transforms, handle_transposes=False, handle_reshapes=False + ) match a_transforms: case (gpu_core.UnswizzleRef(lhs_swizzle), gpu_core.UntileRef(tiling)): - swizzle_elems = lhs_swizzle // a_aval.dtype.itemsize - if tiling != (8, swizzle_elems): - raise NotImplementedError("WGMMA lhs tiling does not fit swizzle") + lhs_transpose = False + case ( + gpu_core.UnswizzleRef(lhs_swizzle), + gpu_core.UntileRef(tiling), + gpu_core.TransposeRef((1, 0)), + ): + lhs_transpose = True case _: raise ValueError(f"WGMMA lhs has unsupported transforms: {a_transforms}.") + swizzle_elems = lhs_swizzle // a_aval.dtype.itemsize + if tiling != (8, swizzle_elems): + raise NotImplementedError("WGMMA lhs tiling does not fit swizzle") else: + lhs_transpose = False b_transforms_leaves = transforms_leaves # type: ignore if not isinstance(a, mgpu.FragmentedArray): raise ValueError( @@ -949,8 +959,6 @@ def _wgmma_lowering( f" {rhs_tiling}." ) - # TODO(cperivol): Find a generic way to move this reshape into - # _handle_transforms. high_dims = [d // t for d, t in util.safe_zip(new_shape, rhs_tiling)] b = mgpu.memref_reshape(b, (*high_dims, *rhs_tiling)) rhs_transpose = False @@ -964,6 +972,8 @@ def _wgmma_lowering( if rhs_tiling != (8, swizzle_elems): raise NotImplementedError("WGMMA rhs tiling does not fit swizzle") + if lhs_transpose: + a = mgpu.memref_transpose(a, (1, 0, 3, 2)) if rhs_transpose: b = mgpu.memref_transpose(b, (1, 0, 3, 2)) new_acc = mgpu.wgmma(acc, a, b, swizzle=rhs_swizzle) @@ -981,23 +991,37 @@ def _wgmma_warpgroup_lowering( a_transforms_tree, b_transforms_tree, ): - del ctx, transforms_leaves # Unused. + del ctx # Unused. + if a_transforms_tree is not None: - match a_transforms_tree: - case gpu_core.TransposeRef((1, 0)): - raise NotImplementedError("WGMMA lhs transpose not supported.") + a_transforms_leaves, b_transforms_leaves = util.split_list( + transforms_leaves, [a_transforms_tree.num_leaves] + ) + a_transforms = a_transforms_tree.unflatten(a_transforms_leaves) + a, a_transforms = lowering._handle_transforms(a, a_transforms) + match a_transforms: + case (gpu_core.TransposeRef((1, 0)),): + a = mgpu.memref_transpose(a, (1, 0)) + case (): + pass case _: raise ValueError( - f"WGMMA lhs has unsupported transforms: {a_transforms_tree}." + f"WGMMA lhs has unsupported transforms: {a_transforms}." ) + else: + b_transforms_leaves = transforms_leaves # type: ignore if b_transforms_tree is not None: - match b_transforms_tree: - case gpu_core.TransposeRef((1, 0)): - raise NotImplementedError("WGMMA rhs transpose not supported.") + b_transforms = b_transforms_tree.unflatten(b_transforms_leaves) + b, b_transforms = lowering._handle_transforms(b, b_transforms) + match b_transforms: + case (gpu_core.TransposeRef((1, 0)),): + b = mgpu.memref_transpose(b, (1, 0)) + case (): + pass case _: raise ValueError( - f"WGMMA rhs has unsupported transforms: {b_transforms_tree}." + f"WGMMA rhs has unsupported transforms: {b_transforms}." ) new_acc = mgpu.dialect.wgmma(acc, a, b) diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index f00cff9a500c..3deb53646ce4 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -355,7 +355,7 @@ def _vector_load_op_lowering_rule( ) ref_ty = ir.MemRefType(vector_load_op.base.type) _check_transforms_and_swizzle_are_supported(ref_ty, transforms, swizzle) - transformed_ref = transform_memref(vector_load_op.base, transforms) + transformed_ref = reinterpret_smem_ref(vector_load_op.base, transforms) fragmented_array = fa.FragmentedArray.load_tiled( transformed_ref, swizzle=swizzle, @@ -397,7 +397,7 @@ def _vector_store_op_lowering_rule( ref_ty = ir.MemRefType(vector_store_op.base.type) _check_transforms_and_swizzle_are_supported(ref_ty, transforms, swizzle) fragmented_array.store_tiled( - transform_memref(vector_store_op.base, transforms), swizzle + reinterpret_smem_ref(vector_store_op.base, transforms), swizzle ) elif (isinstance(fragmented_array.layout, fa.WGStridedFragLayout) or isinstance(fragmented_array.layout, fa.WGSplatFragLayout)): @@ -510,32 +510,78 @@ def swizzle_and_transforms_from_transforms_attr( return swizzle or mgpu.SwizzlingMode.kNoSwizzle, tuple(gmem_transforms) -def transform_memref( - mem_ref: ir.Value, transforms: tuple[launch_context.MemRefTransform, ...] +def _is_memref_transposed(mem_ref_type: ir.MemRefType) -> bool: + strides, _ = mem_ref_type.get_strides_and_offset() + prev_stride = math.inf + for stride in strides: + if stride > prev_stride: + return True + prev_stride = stride + return False + + +def reinterpret_smem_ref( + ref: ir.Value, + transforms: tuple[launch_context.MemRefTransform, ...], ) -> ir.Value: - """Reinterprets the memref to one where the shape is transformed as given.""" - if not transforms: - return mem_ref + """Applies transforms on the ref, and makes sure that their effect is + propagated appropriately on the strides. - mem_ref_type = ir.MemRefType(mem_ref.type) - if mem_ref_type.memory_space != ir.Attribute.parse( - "#gpu.address_space" - ): - raise ValueError(f"Only workgroup memory is supported but got {mem_ref}.") + This function is used any time we lower from a dialect SMEM ref (2D for wgmma) + with given transforms to a "physical" SMEM ref (4D for wgmma) that is fully + transformed and transposed as needed. + """ + ref_ty = ir.MemRefType(ref.type) + transposed = _is_memref_transposed(ref_ty) + if not transforms and not transposed: + return ref + + if ref_ty.memory_space != ir.Attribute.parse("#gpu.address_space"): + raise ValueError(f"Only workgroup memory is supported but got {ref}.") + + shape = ref_ty.shape + if transposed: + if len(shape) != 2: + raise NotImplementedError( + f"Only 2D shapes can be transposed, but got {shape}" + ) + strides, _ = ref_ty.get_strides_and_offset() + if strides[0] != 1 or strides[1] != shape[0]: + raise NotImplementedError( + f"Only contiguous 2D memrefs can be transposed, but got {ref_ty}" + ) - shape = mem_ref_type.shape for t in transforms: - shape = t.transform_shape(shape) + shape = list(t.transform_shape(shape)) + + if transposed: + # The expected output is a transposed ref and `shape` is already transposed. + # We need to compute the correct strides to match the shape. + if len(shape) == 2: + minor_to_major_stride_order = (1, 0) + elif len(shape) == 4: + minor_to_major_stride_order = (2, 3, 0, 1) + else: + raise NotImplementedError( + f"Expected a 2D or 4D shape after transforms, but got {shape}" + ) + strides = [1]*len(shape) + for i in minor_to_major_stride_order[1:]: + strides[i] = strides[i-1] * shape[i-1] + layout = ir.StridedLayoutAttr.get(0, strides) + else: + layout = None - memref_new_type = ir.MemRefType.get( + new_ref_ty = ir.MemRefType.get( shape, - mem_ref_type.element_type, - memory_space=mem_ref_type.memory_space, + ref_ty.element_type, + memory_space=ref_ty.memory_space, + layout=layout, ) - ms = utils.WORKGROUP_NVPTX_ADDRESS_SPACE - ptr = utils.memref_ptr(mem_ref, memory_space=ms) - return utils.ptr_as_memref(ptr, memref_new_type, ptr_memory_space=ms) + ptr = utils.memref_ptr(ref, memory_space=ms) + ref = utils.ptr_as_memref(ptr, new_ref_ty, ptr_memory_space=ms) + return ref @_register_lowering(mgpu.AsyncLoadOp) @@ -569,7 +615,7 @@ def _mgpu_async_load_op_lowering_rule( # TODO(dasenov): Add support for the remaining op properties. ctx.launch_context.async_copy( src_ref=load_op.source, - dst_ref=transform_memref(load_op.destination, transforms), + dst_ref=reinterpret_smem_ref(load_op.destination, transforms), gmem_slice=tuple(gmem_slice), barrier=barrier, arrive=False, @@ -610,7 +656,7 @@ def _mgpu_async_store_op_lowering_rule( # TODO(dasenov): Add support for the remaining op properties. ctx.launch_context.async_copy( - src_ref=transform_memref(store_op.source, transforms), + src_ref=reinterpret_smem_ref(store_op.source, transforms), dst_ref=store_op.destination, gmem_slice=tuple(gmem_slice), swizzle=swizzle, @@ -840,7 +886,7 @@ def _mgpu_wgmma_op_lowering_rule( _check_transforms_and_swizzle_are_supported( ref_ty, b_transforms, b_swizzle, minimum_swizzle ) - b_operand = transform_memref(wgmma_op.b, b_transforms) + b_operand = reinterpret_smem_ref(wgmma_op.b, b_transforms) if ir.VectorType.isinstance(wgmma_op.a.type): a_operand = _fragmented_array_from_ir(wgmma_op.a, wgmma_layout) @@ -857,7 +903,7 @@ def _mgpu_wgmma_op_lowering_rule( f"Non-matching swizzles of operands a and b in WGMMA: {a_swizzle} !=" f" {b_swizzle}" ) - a_operand = transform_memref(wgmma_op.a, a_transforms) + a_operand = reinterpret_smem_ref(wgmma_op.a, a_transforms) new_acc = wgmma.wgmma(acc, a_operand, b_operand, swizzle=b_swizzle) diff --git a/jax/experimental/mosaic/gpu/transform_inference.py b/jax/experimental/mosaic/gpu/transform_inference.py index 3438a654f90a..6026cb216166 100644 --- a/jax/experimental/mosaic/gpu/transform_inference.py +++ b/jax/experimental/mosaic/gpu/transform_inference.py @@ -98,11 +98,8 @@ def infer_transforms_for_wgmma_ref(ref_ty: ir.MemRefType) -> ir.ArrayAttr: element_bytewidth = utils.bytewidth(ref_ty.element_type) strides, _ = ref_ty.get_strides_and_offset() - - if strides[0] < strides[1]: - raise NotImplementedError("Transpositions aren't handled yet.") - - minor_dim = ref_ty.shape[1] + transposed = strides[0] < strides[1] + minor_dim = ref_ty.shape[0 if transposed else 1] major_tiling = 8 # Try tiling with all swizzling modes starting from the largest one. @@ -118,12 +115,14 @@ def infer_transforms_for_wgmma_ref(ref_ty: ir.MemRefType) -> ir.ArrayAttr: break else: # No valid tile transform can be inferred. - raise ValueError( - f"{ref_ty.shape} is not a valid WGMMA shape" - ) + raise ValueError(f"{ref_ty.shape} is not a valid WGMMA shape") + if transposed: + tiling = (minor_tiling, major_tiling) + else: + tiling = (major_tiling, minor_tiling) return ir.ArrayAttr.get([ - mgpu.TileTransformAttr.get((major_tiling, minor_tiling)), + mgpu.TileTransformAttr.get(tiling), mgpu.SwizzleTransformAttr.get(minor_tiling * element_bytewidth), ]) @@ -255,6 +254,24 @@ def _infer_dynamic_smem_transforms( return None +def _get_tile_and_swizzle_transforms( + transforms: ir.ArrayAttr | None, +) -> tuple[ir.Attribute, ir.Attribute]: + if transforms is None: + return + + if len(transforms) == 2: + tile_transform, swizzle_transform = transforms + if not ( + mgpu.TileTransformAttr.isinstance(tile_transform) + and mgpu.SwizzleTransformAttr.isinstance(swizzle_transform) + ): + raise NotImplementedError(f"Unsupported transforms {transforms}.") + return tile_transform, swizzle_transform + else: + raise NotImplementedError(f"Unsupported transforms {transforms}.") + + # This is used by Pallas' "_handle_indexing" memory transform. @partial(_add_transform_inference_rule, memref.SubViewOp) def _infer_memref_subview_transforms( @@ -285,15 +302,7 @@ def _infer_memref_subview_transforms( # - We only propagate transforms if they consist of a single tile transform # and a single swizzle transform. # TODO(bchetioui): implement more complex propagation rules. - if len(transforms) == 2: - tile_transform, swizzle_transform = transforms - if not ( - mgpu.TileTransformAttr.isinstance(tile_transform) - and mgpu.SwizzleTransformAttr.isinstance(swizzle_transform) - ): - raise NotImplementedError(f"Can't propagate transforms {transforms}.") - else: - raise NotImplementedError(f"Can't propagate transforms {transforms}.") + tile_transform, _ = _get_tile_and_swizzle_transforms(transforms) # Check swizzle transform propagation. strides, _ = ir.MemRefType.get_strides_and_offset(op.source.type) @@ -318,6 +327,42 @@ def _infer_memref_subview_transforms( return [transforms], [transforms] +@partial(_add_transform_inference_rule, memref.TransposeOp) +def _infer_memref_transpose_transforms( + op: memref.TransposeOp, +) -> OptionalTransforms: + in_ty = ir.MemRefType(op.in_.type) + if len(in_ty.shape) != 2: + raise NotImplementedError(f"Only 2D memrefs are supported, got {in_ty}") + in_strides, _ = in_ty.get_strides_and_offset() + out_strides, _ = ir.MemRefType(op.result.type).get_strides_and_offset() + transpose = in_strides != out_strides + + users = list(op.result.uses) + if len(users) != 1: + raise NotImplementedError( + f"Only memref.transpose with a single use are supported, got {op}" + ) + + op_operand_use = users[0] + consumer = op_operand_use.owner + op_user = consumer.operands[op_operand_use.operand_number] + out_transforms = inference_utils.in_transforms_for_operand(consumer, op_user) + + in_transforms = [] + if not transpose: + in_transforms = out_transforms + else: + tile_transform, swizzle_transform = _get_tile_and_swizzle_transforms( + out_transforms + ) + transposed_tiling = mgpu.TileTransformAttr(tile_transform).tiling[::-1] + in_transforms.append(mgpu.TileTransformAttr.get(transposed_tiling)) + in_transforms.append(swizzle_transform) + + return [ir.ArrayAttr.get(in_transforms)], [out_transforms] + + # `memref.load` is used to load barrier phases---the rule needn't do anything # interesting, but we need to have it in order to avoid crashing on it. @partial(_add_transform_inference_rule, memref.LoadOp) diff --git a/tests/mosaic/gpu_dialect_test.py b/tests/mosaic/gpu_dialect_test.py index bc94d72dc0d8..7e211abb955a 100644 --- a/tests/mosaic/gpu_dialect_test.py +++ b/tests/mosaic/gpu_dialect_test.py @@ -34,6 +34,7 @@ from jax.experimental.mosaic import gpu as mgpu from jax.experimental.mosaic.gpu import layouts from jax.experimental.mosaic.gpu import utils as mgpu_utils +from jax.experimental.mosaic.gpu import dialect_lowering as lowering _cext = mgpu.dialect._cext if mgpu.dialect is not None else None @@ -906,6 +907,27 @@ def body(vec1, vec2, ref): with self.assertRaisesRegex(ir.MLIRError, error): self.module.operation.verify() + def test_memref_transforms_with_transpose(self): + with ir.InsertionPoint(self.module.body): + ty_in = ir.MemRefType.get( + (64, 128), + ir.BF16Type.get(), + memory_space=ir.Attribute.parse("#gpu.address_space"), + ) + ref = memref.alloc(ty_in, [], []) + + ref = mgpu_utils.memref_transpose(ref, (1, 0)) + # This tiling is applied to the transposed memref. + transforms = [mgpu.TileTransform(tiling=(16, 32))] + + ref_transformed = lowering.reinterpret_smem_ref(ref, transforms) + ty_transformed = ir.MemRefType(ref_transformed.type) + self.assertEqual(ty_transformed.shape, [8, 2, 16, 32]) + strides, _ = ty_transformed.get_strides_and_offset() + self.assertEqual(strides, [512, 4096, 1, 16]) + + + if __name__ == "__main__": parameterized.absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 809c9c8fcaeb..10b4f3de60ad 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1551,7 +1551,8 @@ def scope(acc_ref): acc_ini = jnp.ones((64, 64), dtype=jnp.float16) np.testing.assert_array_equal(kernel(acc_ini), jnp.full((64, 64), 5, dtype=jnp.float16)) - def test_realistic_matmul(self): + @parameterized.product(lhs_transpose=[False, True], rhs_transpose=[False, True]) + def test_realistic_matmul(self, lhs_transpose, rhs_transpose): dtype = jnp.float16 swizzle = 128 elems_128b = swizzle // jnp.dtype(dtype).itemsize @@ -1561,7 +1562,11 @@ def test_realistic_matmul(self): m, k, n = grid_m * tile_m, grid_k * tile_k, grid_n * tile_n def kernel(a_ref, b_ref, o_ref, acc_ref): # Make sure tiling does not alter the shape of references + if lhs_transpose: + a_ref = plgpu.transpose_ref(a_ref, (1, 0)) assert a_ref.shape == (tile_m, tile_k) + if rhs_transpose: + b_ref = plgpu.transpose_ref(b_ref, (1, 0)) assert b_ref.shape == (tile_k, tile_n) assert o_ref.shape == acc_ref.shape == (tile_m, tile_n) plgpu.wgmma(acc_ref, a_ref, b_ref) @@ -1572,17 +1577,31 @@ def _epilogue(): plgpu.wgmma_wait(1) # We don't await the last WGMMA, hence delay_release=1 key1, key2 = jax.random.split(jax.random.key(42), 2) - a = jax.random.uniform(key1, shape=(m, k), dtype=dtype) - b = jax.random.uniform(key2, shape=(k, n), dtype=dtype) + a_shape = (k, m) if lhs_transpose else (m, k) + a = jax.random.uniform(key1, shape=a_shape, dtype=dtype) + b_shape = (n, k) if rhs_transpose else (k, n) + b = jax.random.uniform(key2, shape=b_shape, dtype=dtype) - lhs_spec = pl.BlockSpec( - (tile_m, tile_k), - lambda m, n, k: (m, k), - ) - rhs_spec = pl.BlockSpec( - (tile_k, tile_n), - lambda m, n, k: (k, n), - ) + if lhs_transpose: + lhs_spec = pl.BlockSpec( + (tile_k, tile_m), + lambda m, n, k: (k, m), + ) + else: + lhs_spec = pl.BlockSpec( + (tile_m, tile_k), + lambda m, n, k: (m, k), + ) + if rhs_transpose: + rhs_spec = pl.BlockSpec( + (tile_n, tile_k), + lambda m, n, k: (n, k), + ) + else: + rhs_spec = pl.BlockSpec( + (tile_k, tile_n), + lambda m, n, k: (k, n), + ) out_spec = pl.BlockSpec( (tile_m, tile_n), lambda m, n, k: (m, n), @@ -1627,7 +1646,11 @@ def _epilogue(): delay_release=1, ), )(a, b) - np.testing.assert_allclose(res, a @ b, rtol=1e-3) + np.testing.assert_allclose( + res, + (a.T if lhs_transpose else a) @ (b.T if rhs_transpose else b), + rtol=1e-3, + ) @parameterized.parameters(jnp.float16, jnp.float32) def test_wgmma(self, dtype): From 245194ffa13f4a7f38e7ab1a30aa46a2d29af3f5 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 7 Apr 2025 02:40:01 -0700 Subject: [PATCH 409/483] Use `contextlib.nullcontext` instead of `trivial_ctx` I removed `trivial_ctx` from the public `jax.interpreters.partial_eval` submodule without going through a deprecation cycle, because it is highly unlikely anyone is using it. PiperOrigin-RevId: 744645764 --- jax/_src/interpreters/partial_eval.py | 6 ++---- jax/interpreters/partial_eval.py | 1 - 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 0a8e3b7824ff..532eb0f80029 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -15,7 +15,7 @@ from collections import namedtuple from collections.abc import Callable, Sequence, Hashable -from contextlib import contextmanager +import contextlib from functools import partial import itertools as it import operator as op @@ -1236,14 +1236,12 @@ def _default_res_aval_updater( params: dict[str, Any], aval: AbstractValue) -> AbstractValue: return aval -@contextmanager -def trivial_ctx(_): yield def call_partial_eval_custom_rule( jaxpr_param_name: str, params_updater: ParamsUpdater, saveable: Callable[..., RematCases_], unks_in: list[bool], inst_in: list[bool], eqn: JaxprEqn, *, res_aval: ResAvalUpdater = _default_res_aval_updater, - ctx = trivial_ctx, + ctx = contextlib.nullcontext, ) -> tuple[JaxprEqn, JaxprEqn, Sequence[bool], Sequence[bool], list[Var]]: jaxpr = eqn.params[jaxpr_param_name] with ctx(eqn.params): diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index b546d774a2e9..a2d988f6bea3 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -81,7 +81,6 @@ trace_to_subjaxpr_nounits as trace_to_subjaxpr_nounits, trace_to_subjaxpr_nounits_fwd as trace_to_subjaxpr_nounits_fwd, tracers_to_jaxpr as tracers_to_jaxpr, - trivial_ctx as trivial_ctx, ) From 695ee8f3d1cc47cb1da286e900b1434d1f0951a2 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 7 Apr 2025 05:48:16 -0400 Subject: [PATCH 410/483] Fix a race in pjit under free threading. Fixes https://github.com/jax-ml/jax/issues/27767 --- .github/workflows/tsan.yaml | 2 +- jaxlib/xla/pjit.cc | 19 ++++++++++--------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/.github/workflows/tsan.yaml b/.github/workflows/tsan.yaml index 1bdb36b2cd03..4c28eab528e4 100644 --- a/.github/workflows/tsan.yaml +++ b/.github/workflows/tsan.yaml @@ -13,7 +13,7 @@ on: - main paths: - '**/workflows/tsan.yaml' - - '**/workflows/tsan-suppressions.txt' + - '**/workflows/tsan-suppressions*.txt' jobs: tsan: diff --git a/jaxlib/xla/pjit.cc b/jaxlib/xla/pjit.cc index 503e8ef23f4b..50bdc750d3a4 100644 --- a/jaxlib/xla/pjit.cc +++ b/jaxlib/xla/pjit.cc @@ -245,16 +245,17 @@ std::shared_ptr PjitFunctionCache::DefaultCache() { std::shared_ptr cache = std::make_shared(&self->lru_list_); auto callback = nb::cpp_function([self, key{std::move(key)}](nb::handle weakref) { - std::unique_ptr value; - { - nb::ft_object_guard lock(self); - auto it = self->functions_.find(key); - if (it == self->functions_.end()) { - return; - } - value = std::move(it->second); - self->functions_.erase(it); + nb::ft_object_guard lock(self); + auto it = self->functions_.find(key); + if (it == self->functions_.end()) { + return; } + // Remove the value from the map before destroying it. Destroying + // the value may release `lock` since it may call arbitrary Python + // code. + std::unique_ptr value = std::move(it->second); + self->functions_.erase(it); + value.reset(); }); PyObject* weakref = PyWeakref_NewRef(function.ptr(), callback.ptr()); if (weakref) { From ce7dc85104813f153e42f546d698f4147a00795d Mon Sep 17 00:00:00 2001 From: George Necula Date: Mon, 7 Apr 2025 11:53:20 +0200 Subject: [PATCH 411/483] [export] Add support for serializing functions with PRNG keys as inputs/outputs This introduces version 4 of serialization, fully backwards compatible with versions 2 and 3. Fixes: #24143 --- jax/_src/export/serialization.fbs | 6 +++++- jax/_src/export/serialization.py | 7 +++++++ jax/_src/export/serialization_generated.py | 7 +++++-- jax/_src/prng.py | 2 +- tests/export_test.py | 12 ++++++++++++ 5 files changed, 30 insertions(+), 4 deletions(-) diff --git a/jax/_src/export/serialization.fbs b/jax/_src/export/serialization.fbs index 7d3e342f1879..01cfa9944dfd 100644 --- a/jax/_src/export/serialization.fbs +++ b/jax/_src/export/serialization.fbs @@ -45,7 +45,7 @@ enum AbstractValueKind: byte { } enum DType: byte { - // Last used id: 22 + // Last used id: 29 bool = 0, i8 = 1, i16 = 2, @@ -76,6 +76,10 @@ enum DType: byte { f8_e5m2fnuz = 21, f8_e8m0fnu = 25, f4_e2m1fn = 26, + + key_fry = 27, + key_rbg = 28, + key_unsafe_rbg = 29, } table AbstractValue { diff --git a/jax/_src/export/serialization.py b/jax/_src/export/serialization.py index 94c0baf642b6..3d878cccc701 100644 --- a/jax/_src/export/serialization.py +++ b/jax/_src/export/serialization.py @@ -31,6 +31,7 @@ from jax._src import core from jax._src import dtypes from jax._src import effects +from jax._src import prng from jax._src import tree_util from jax._src.export import serialization_generated as ser_flatbuf from jax._src.export import _export @@ -48,6 +49,8 @@ # Version 2, Dec 16th, 2023, adds the f0 dtype. # Version 3, October 16th, 2024, adds serialization for namedtuple and custom types # This version is backwards compatible with Version 2. +# Version 4, April 7th, 2025, adds serialization for PRNGs key types. +# This version is backwards compatible with Version 2 and 3. _SERIALIZATION_VERSION = 2 def serialize(exp: _export.Exported, vjp_order: int = 0) -> bytearray: @@ -361,6 +364,10 @@ def _deserialize_pytreedef_to_pytree(p: ser_flatbuf.PyTreeDef): dtypes._float8_e4m3_dtype: ser_flatbuf.DType.f8_e4m3, dtypes._float8_e8m0fnu_dtype: ser_flatbuf.DType.f8_e8m0fnu, dtypes._float4_e2m1fn_dtype: ser_flatbuf.DType.f4_e2m1fn, + + prng.KeyTy(prng.prngs["threefry2x32"]): ser_flatbuf.DType.key_fry, + prng.KeyTy(prng.prngs["rbg"]): ser_flatbuf.DType.key_rbg, + prng.KeyTy(prng.prngs["unsafe_rbg"]): ser_flatbuf.DType.key_unsafe_rbg, } _dtype_kind_to_dtype = { diff --git a/jax/_src/export/serialization_generated.py b/jax/_src/export/serialization_generated.py index b1fc13333777..34211c1ebe54 100644 --- a/jax/_src/export/serialization_generated.py +++ b/jax/_src/export/serialization_generated.py @@ -53,16 +53,19 @@ class DType(object): bf16 = 14 i4 = 15 ui4 = 16 - f8_e3m4 = 24 - f8_e4m3 = 23 f8_e4m3b11fnuz = 17 f8_e4m3fn = 18 f8_e4m3fnuz = 19 f8_e5m2 = 20 f8_e5m2fnuz = 21 f0 = 22 + f8_e4m3 = 23 + f8_e3m4 = 24 f8_e8m0fnu = 25 f4_e2m1fn = 26 + key_fry = 27 + key_rbg = 28 + key_unsafe_rbg = 29 class ShardingKind(object): diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 35282cb716cb..1dc7e9c0df0e 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -113,7 +113,7 @@ def pprint(self): ])))) -prngs = {} +prngs: dict[str, PRNGImpl] = {} def register_prng(impl: PRNGImpl): if impl.name in prngs: diff --git a/tests/export_test.py b/tests/export_test.py index 2264fbdd997b..26157f2f6a79 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -421,6 +421,18 @@ def f(x1, x2): self.assertEqual(tree_util.tree_structure(res2), tree_util.tree_structure(res)) + @jtu.parameterized_filterable( + kwargs=[dict(impl=p) + for p in ("rbg", "unsafe_rbg", "threefry2x32")]) + def test_prng_keys(self, *, impl): + + key = jax.random.key(42, impl=impl) + @jax.jit + def f(key): + return key + exp_f = get_exported(jax.jit(f))(key) + self.assertEqual(f(key), exp_f.call(key)) + def test_error_wrong_intree(self): def f(a_b_pair, *, c): return jnp.sin(a_b_pair[0]) + jnp.cos(a_b_pair[1]) + c From 075d88febc58913ead503c452e0eeafd317fee6f Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 7 Apr 2025 03:08:09 -0700 Subject: [PATCH 412/483] Fix some test timeouts PiperOrigin-RevId: 744652508 --- tests/BUILD | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/BUILD b/tests/BUILD index 23d59e8d549a..63969bc935da 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -947,7 +947,7 @@ jax_multiplatform_test( }, shard_count = { "cpu": 40, - "gpu": 40, + "gpu": 50, "tpu": 40, }, tags = ["noasan"], # Times out @@ -1132,10 +1132,11 @@ jax_multiplatform_test( backend_tags = { "cpu": [ "noasan", # Times out under asan - "notsan", # Times out under asan + "notsan", # Times out under tsan ], "tpu": [ - "noasan", # Times out under asan. + "noasan", # Times out under asan + "notsan", # Times out under tsan ], }, shard_count = { From 6e93fa34f32c2e57bc3b65948f26eb27d180b9bd Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 7 Apr 2025 03:38:43 -0700 Subject: [PATCH 413/483] Removed unused deprecations PiperOrigin-RevId: 744659794 --- jax/_src/deprecations.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/jax/_src/deprecations.py b/jax/_src/deprecations.py index 37f2f0264782..6c39c893a111 100644 --- a/jax/_src/deprecations.py +++ b/jax/_src/deprecations.py @@ -127,12 +127,10 @@ def warn(deprecation_id: str, message: str, stacklevel: int) -> None: register('jax-dlpack-import-legacy') register('jax-nn-one-hot-float-input') register("jax-numpy-astype-complex-to-real") -register("jax-numpy-array-none") register('jax-numpy-clip-args') register('jax-numpy-linalg-matrix_rank-tol') register('jax-numpy-linalg-pinv-rcond') register('jax-numpy-quantile-interpolation') register('jax-numpy-reduction-non-boolean-where') register('jax-numpy-trimzeros-not-1d-array') -register('pallas-gpu-triton') register('jax-scipy-special-sph-harm') From 9b850a9e9413db077ef74ef6672b9eb36c388fb4 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Mon, 7 Apr 2025 03:39:29 -0700 Subject: [PATCH 414/483] [Mosaic GPU] Delete mentions of `WGMMARowFragLayout` in `layouts.py`. PiperOrigin-RevId: 744659986 --- jax/experimental/mosaic/gpu/layouts.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/jax/experimental/mosaic/gpu/layouts.py b/jax/experimental/mosaic/gpu/layouts.py index cb94c3eaf749..78c9f670881b 100644 --- a/jax/experimental/mosaic/gpu/layouts.py +++ b/jax/experimental/mosaic/gpu/layouts.py @@ -171,15 +171,6 @@ def to_layout_attr( ) -_wgmma_row_fragmented_layout_attr_pattern = re.compile( - r"^#mosaic_gpu.WGMMARowFragLayout$" -) - - -def is_wgmma_row_fragmented_layout(attr: ir.Attribute) -> bool: - return bool(_wgmma_row_fragmented_layout_attr_pattern.search(str(attr))) - - def from_layout_attr( attr: ir.Attribute, ) -> ( @@ -194,8 +185,6 @@ def from_layout_attr( return from_strided_fragmented_layout_attr(attr) elif is_tiled_layout(attr): return from_tiled_layout_attr(attr) - elif is_wgmma_row_fragmented_layout(attr): - return fa.WGMMARowFragLayout() else: raise NotImplementedError( f"Unsupported layout for conversion from MLIR attribute: {attr}" From 4596ee3cc5e5970c9f250ad0136a55c6caa3ded0 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 7 Apr 2025 04:12:51 -0700 Subject: [PATCH 415/483] Add a missing jaxlib version check in Pallas TPU lowering PiperOrigin-RevId: 744668747 --- jax/_src/pallas/mosaic/lowering.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 87e06f486366..2669d73691c1 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -45,6 +45,7 @@ from jax._src.interpreters import partial_eval as pe from jax._src.lax import lax as lax_internal from jax._src.lax.control_flow import for_loop +from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith from jax._src.lib.mlir.dialects import func @@ -3579,13 +3580,16 @@ def _dma_start_lowering_rule( sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, sem_transforms) if device_id is not None: device_id = _device_id_to_logical(ctx, device_id, device_id_type) + priority_kwarg = {"priority": priority} + if jaxlib_version < (0, 5, 4): + priority_kwarg = {} tpu.enqueue_dma( src_ref, dst_ref, sem, source_semaphore=src_sem, device_id=device_id, - priority=priority, + **priority_kwarg, ) return [] From 153fa228943bedb2420c3e7961a5781e5eef6319 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 7 Apr 2025 04:14:17 -0700 Subject: [PATCH 416/483] Add more TSAN skips to avoid timeouts PiperOrigin-RevId: 744669093 --- tests/BUILD | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/BUILD b/tests/BUILD index 63969bc935da..eb6ff81f5d68 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -597,7 +597,10 @@ jax_multiplatform_test( srcs = ["lax_scipy_special_functions_test.py"], backend_tags = { "gpu": ["noasan"], # Times out. - "cpu": ["noasan"], # Times out. + "cpu": [ + "noasan", + "notsan", + ], # Times out. }, shard_count = { "cpu": 20, From c2aa811cd6e196a64c3572194e5aa86e4b65f7da Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 7 Apr 2025 05:49:47 -0700 Subject: [PATCH 417/483] `jex.core.Var` is no longer ordered This behavior was only needed for kfac_jax which has been updated *not* to rely on variable ordering. PiperOrigin-RevId: 744691114 --- jax/_src/core.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index 9a5a6061cc5e..14ed19d4d441 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -412,7 +412,6 @@ def new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info=None, _var_counter = it.count() -@total_ordering class Var: __slots__ = ["count", "suffix", "aval"] @@ -425,11 +424,6 @@ def __init__(self, suffix: str, aval: AbstractValue): self.suffix = suffix self.aval = aval - # TODO(phawkins, mattjj): remove ordering of variables. JAX itself does not - # care about variable ordering, but the downstream package kfac_jax does. - def __lt__(self, other): - return self.count < other.count - def __repr__(self): return f'Var(id={id(self)}){self.suffix}:{self.aval.str_short()}' From 5c0f8858466d31e3678a9bbd16b6c305a38b0aec Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 7 Apr 2025 06:28:35 -0700 Subject: [PATCH 418/483] Update XLA dependency to use revision http://github.com/openxla/xla/commit/145f836bd5175dc5dd262f716a0c59af2b0297a0. PiperOrigin-RevId: 744700775 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 1a7522bda0fa..a8a93026378f 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "3889bec6b7f48e304953a485b713e9982dff0441" -XLA_SHA256 = "f23bb226d334f933cd5e6ebc4b20dec9ad879137763975546120ddf582a472b8" +XLA_COMMIT = "145f836bd5175dc5dd262f716a0c59af2b0297a0" +XLA_SHA256 = "bd19d8a1d25468696809a69ef3984bb00ef432e3fe9c05116b9c114dc7c83fa2" def repo(): tf_http_archive( From 83572e17bd7ac833d1e346d91bad27dc4572aad8 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Mon, 7 Apr 2025 07:03:04 -0700 Subject: [PATCH 419/483] [Mosaic GPU] Add missing to/from tiled layout attributes with replicated lane dimensions. PiperOrigin-RevId: 744708476 --- jax/experimental/mosaic/gpu/layouts.py | 24 +++++++++++++++++++++--- jaxlib/mosaic/dialect/gpu/mosaic_gpu.td | 11 +++++++++++ 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/jax/experimental/mosaic/gpu/layouts.py b/jax/experimental/mosaic/gpu/layouts.py index 78c9f670881b..d9b1a01a24b5 100644 --- a/jax/experimental/mosaic/gpu/layouts.py +++ b/jax/experimental/mosaic/gpu/layouts.py @@ -107,15 +107,25 @@ def to_tiled_layout_attr( ) -> ir.Attribute: """Constructs a #mosaic_gpu.TiledLayout attribute from a TiledLayout.""" + def _lane_dim_str(d: int | fa.Replicated) -> str: + if isinstance(d, fa.Replicated): + return f"#mosaic_gpu.Replicated" + return str(d) + tile_str = lambda tile: "[" + ", ".join(str(d) for d in tile) + "]" tiling = "[" + ", ".join(tile_str(tile) for tile in layout.tiling.tiles) + "]" + lane_dims = "[" + ",".join(_lane_dim_str(d) for d in layout.lane_dims) + "]" + return ir.Attribute.parse( f"#mosaic_gpu.TiledLayout<{tiling}, warp_dim={layout.warp_dim}," - f" lane_dims={list(layout.lane_dims)}, vector_dim={layout.vector_dim}>" + f" lane_dims={lane_dims}, vector_dim={layout.vector_dim}>" ) _list_of_lists_delimiter = re.compile(r"\]\s*,\s*\[") +_replicated_pattern = re.compile( + r"^#mosaic_gpu.Replicated<\s*times\s*=\s*(?P\d+)\s*>\s*$" +) def from_tiled_layout_attr( @@ -133,6 +143,12 @@ def from_tiled_layout_attr( f"Expected a #mosaic_gpu.TiledLayout attribute, got {attr}" ) + def _lane_dim(lane_dim_str: str) -> int | fa.Replicated: + match = _replicated_pattern.fullmatch(lane_dim_str) + if match: + return fa.Replicated(int(match.group("times"))) + return int(lane_dim_str) + tiling_str = match.group("tiling") tile_strings = [] if len(tiling_str) > 2: @@ -141,8 +157,10 @@ def from_tiled_layout_attr( return fa.TiledLayout( tiling=fa.Tiling(tiles), warp_dim=int(match.group("warp_dim")), - lane_dims=tuple(int(s) for s in match.group("lane_dims").split(",")), - vector_dim=int(match.group("vector_dim")) + lane_dims=tuple( + _lane_dim(s) for s in match.group("lane_dims").split(",") + ), + vector_dim=int(match.group("vector_dim")), ) diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td index 6b934b951d93..36f9f6f374e5 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td @@ -142,6 +142,17 @@ def MosaicGPU_WGSplatFragLayout : AttrDef { + let summary = "Indicates a replicated dimension in a tiled layout."; + let description = [{ + See mosaic/gpu/fragmented_array.py -> Replicated for more details. + }]; + + let parameters = (ins "int":$times); + let mnemonic = "Replicated"; + let assemblyFormat = "`<` `times` `=` $times `>`"; +} + def MosaicGPU_TiledLayout : AttrDef { let summary = "A layout derived from a tiling expression."; let description = [{ From 70485e31b96a395a58de765b9b6a9260feb9d775 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 7 Apr 2025 07:19:23 -0700 Subject: [PATCH 420/483] Remove accidental exports jax.interpreters.mlir.{hlo,func_dialect}. These are available via jax.extend.mlir.dialects. No deprecation period because jax.interpreters.mlir is not a stable API. PiperOrigin-RevId: 744712537 --- CHANGELOG.md | 3 +++ jax/_src/cudnn/fused_attention_stablehlo.py | 4 ++-- jax/_src/cudnn/fusion.py | 4 ++-- jax/interpreters/mlir.py | 2 -- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ffd197b390d0..7f19fcb189ed 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * Implemented host callback handlers for CPU and GPU devices using XLA's FFI and removed existing CPU/GPU handlers using XLA's custom call. * All APIs in `jax.lib.xla_extension` are now deprecated. + * `jax.interpreters.mlir.hlo` and `jax.interpreters.mlir.func_dialect`, + which were accidental exports, have been removed. If needed, they are + available from `jax.extend.mlir`. * Several previously-deprecated APIs have been removed, including: * From `jax.lib.xla_client`: `FftType`, `PaddingType`, `dtype_to_etype`, and `shape_from_pyval`. diff --git a/jax/_src/cudnn/fused_attention_stablehlo.py b/jax/_src/cudnn/fused_attention_stablehlo.py index c7e7c83f30f8..d901ed875ceb 100644 --- a/jax/_src/cudnn/fused_attention_stablehlo.py +++ b/jax/_src/cudnn/fused_attention_stablehlo.py @@ -28,8 +28,8 @@ from jax._src import xla_bridge from jax.interpreters import mlir from jax.interpreters import xla -from jax.interpreters.mlir import hlo -from jax.interpreters.mlir import ir +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import hlo import jax.numpy as jnp from jax.sharding import NamedSharding, PartitionSpec diff --git a/jax/_src/cudnn/fusion.py b/jax/_src/cudnn/fusion.py index f320672463cb..355b33e1509c 100644 --- a/jax/_src/cudnn/fusion.py +++ b/jax/_src/cudnn/fusion.py @@ -16,8 +16,8 @@ import jax from jax._src import core as jax_core from jax.interpreters import mlir -from jax.interpreters.mlir import hlo -from jax.interpreters.mlir import ir +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import hlo diff --git a/jax/interpreters/mlir.py b/jax/interpreters/mlir.py index 8a615be968a6..a0505c74f883 100644 --- a/jax/interpreters/mlir.py +++ b/jax/interpreters/mlir.py @@ -43,8 +43,6 @@ flatten_ir_values as flatten_lowering_ir_args, # TODO(phawkins): remove me # noqa: F401 flatten_ir_values as flatten_ir_values, unflatten_ir_values_like_types as unflatten_ir_values_like_types, - func_dialect as func_dialect, - hlo as hlo, i32_attr as i32_attr, i64_attr as i64_attr, ir as ir, From 412f88e2234c4b82f18e43bdad7bf64a32ff94a5 Mon Sep 17 00:00:00 2001 From: Ayaka Date: Mon, 7 Apr 2025 07:20:00 -0700 Subject: [PATCH 421/483] Temporarily skip JaxNumpyErrorTests in multi-thread environments PiperOrigin-RevId: 744712701 --- tests/jax_numpy_error_test.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/jax_numpy_error_test.py b/tests/jax_numpy_error_test.py index a38e7d5509f9..dba277289f42 100644 --- a/tests/jax_numpy_error_test.py +++ b/tests/jax_numpy_error_test.py @@ -30,6 +30,12 @@ class JaxNumpyErrorTests(jtu.JaxTestCase): + def setUp(self): + # TODO(b/408148001): Fix thread safety issue. + if jtu.TEST_NUM_THREADS.value > 1: + self.skipTest("Test does not work with multiple threads") + super().setUp() + @parameterized.product(jit=[True, False]) def test_set_error_if_nan(self, jit): def f(x): From a099b285307508efad12a015d6f6d9d13ae49077 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Mon, 7 Apr 2025 07:39:14 -0700 Subject: [PATCH 422/483] Reverts 735cec18cb2f8dff2aea5e503fd886a37aee094e PiperOrigin-RevId: 744717457 --- jaxlib/cuda/BUILD | 1 - jaxlib/gpu/py_client_gpu.cc | 88 ++++++++++++-------------------- jaxlib/rocm/BUILD | 1 - jaxlib/xla/BUILD | 1 - jaxlib/xla/py_client_cpu.cc | 87 +++++++++----------------------- tests/python_callback_test.py | 94 +++++++++++++++-------------------- 6 files changed, 95 insertions(+), 177 deletions(-) diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index d7035c92b24a..be7ac6116d2f 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -637,7 +637,6 @@ cc_library( "@xla//third_party/python_runtime:headers", # buildcleaner: keep "@xla//xla:comparison_util", "@xla//xla:shape_util", - "@xla//xla:util", "@xla//xla:xla_data_proto_cc", "@xla//xla/ffi:ffi_api", "@xla//xla/ffi/api:ffi", diff --git a/jaxlib/gpu/py_client_gpu.cc b/jaxlib/gpu/py_client_gpu.cc index e3aec51d8d25..861ffce3e749 100644 --- a/jaxlib/gpu/py_client_gpu.cc +++ b/jaxlib/gpu/py_client_gpu.cc @@ -44,7 +44,6 @@ limitations under the License. #include "xla/python/types.h" #include "xla/shape_util.h" #include "xla/xla_data.pb.h" -#include "xla/util.h" namespace nb = nanobind; @@ -82,7 +81,8 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, auto arg = args.get(i); auto ptype = static_cast(arg->element_type()); // TODO(b/395428868): Remove this check once we support subbyte types. - if (ptype == xla::S1 || ptype == xla::U1) { + if (ptype == xla::S1 || ptype == xla::S2 || ptype == xla::S4 || + ptype == xla::U1 || ptype == xla::U2 || ptype == xla::U4) { return xla::ffi::Error(xla::ffi::ErrorCode::kUnimplemented, absl::StrFormat("Unsupported primitive type: %s", PrimitiveType_Name(ptype))); @@ -112,6 +112,9 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, PyTuple_SET_ITEM(host_input_arrays.ptr(), i, nb::none().inc_ref().ptr()); continue; } + nb::capsule base(host_input_buffers[i], [](void* ptr) noexcept { + delete[] static_cast(ptr); + }); auto maybe_dtype = PrimitiveTypeToNbDtype(ptype); if (!maybe_dtype.ok()) { return xla::ffi::Error::Internal(maybe_dtype.status().ToString()); @@ -119,23 +122,6 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, auto dtype = maybe_dtype.value(); auto dims = absl::Span(arg->dimensions().begin(), arg->dimensions().size()); - // TODO(b/402422886): Remove this once we form Jax arrays directly instead - // of packing/unpacking to/from numpy arrays. - // We pass in data using default numpy layout i.e., std::nullopt. - size_t bits_per_element = xla::primitive_util::BitWidth(ptype); - if (bits_per_element == 2 || bits_per_element == 4) { - // NOTE(dsuo): FFI argument and return buffers are sized assuming - // minimum 1-byte element sizes, even if the data itself is packed. - size_t packed_size = arg->size_bytes() * bits_per_element / 8; - auto buffer = xla::UnpackIntN( - bits_per_element, static_cast(host_input_buffers[i]), - packed_size); - delete[] static_cast(host_input_buffers[i]); - host_input_buffers[i] = buffer.release(); - } - nb::capsule base(host_input_buffers[i], [](void* ptr) noexcept { - delete[] static_cast(ptr); - }); auto array = xla::nb_numpy_ndarray(dtype, dims, std::nullopt, host_input_buffers[i], base); array.attr("flags").attr("writeable") = nb::bool_(false); @@ -160,7 +146,8 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, auto ret = rets.get(i).value(); auto ptype = static_cast(ret->element_type()); // TODO(b/395428868): Remove this check once we support subbyte types. - if (ptype == xla::S1 || ptype == xla::U1) { + if (ptype == xla::S1 || ptype == xla::S2 || ptype == xla::S4 || + ptype == xla::U1 || ptype == xla::U2 || ptype == xla::U4) { return xla::ffi::Error(xla::ffi::ErrorCode::kUnimplemented, absl::StrFormat("Unsupported primitive type: %s", PrimitiveType_Name(ptype))); @@ -181,45 +168,32 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, } auto expected_shape = maybe_expected_shape.value(); auto expected_strides = xla::ByteStridesForShape(expected_shape); - - const void* data = array.data(); - size_t size_bytes = array.size() * array.itemsize(); - if (strides != expected_strides) { - xla::TransposePlan::Options options; - options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); - options.dims = absl::Span( - reinterpret_cast(array.shape()), array.ndim()); - absl::InlinedVector reversed_layout; - reversed_layout.resize(expected_shape.dimensions().size()); - absl::c_reverse_copy(expected_shape.layout().minor_to_major(), - reversed_layout.begin()); - options.permutation = reversed_layout; - options.input_layout = xla::TransposePlan::Striding{strides}; - auto maybe_plan = transpose_cache->cache.GetOrCreate(options); - if (!maybe_plan.ok()) { - return xla::ffi::Error::Internal(maybe_plan.status().ToString()); - } - auto plan = maybe_plan.value(); - void* temp = new char[size_bytes]; - temp_buffers.push_back(temp); - plan->Execute(data, temp); - data = temp; + if (strides == expected_strides) { + auto gpu_res = + gpuMemcpyAsync(ret->untyped_data(), array.data(), ret->size_bytes(), + gpuMemcpyHostToDevice, stream); + CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; + continue; } - - // TODO(b/402422886): Remove this once we form Jax arrays directly instead - // of packing/unpacking to/from numpy arrays. - std::unique_ptr buffer; - size_t bits_per_element = xla::primitive_util::BitWidth(ptype); - if (bits_per_element == 2 || bits_per_element == 4) { - // NOTE(dsuo): FFI arguments and return buffers are sized assuming - // minimum 1-byte element sizes, even if the data itself is packed. - buffer = xla::PackIntN(bits_per_element, static_cast(data), - size_bytes); - data = buffer.get(); - size_bytes = (size_bytes * bits_per_element) / 8; + void* temp = new char[ret->size_bytes()]; + temp_buffers.push_back(temp); + xla::TransposePlan::Options options; + options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); + options.dims = absl::Span( + reinterpret_cast(array.shape()), array.ndim()); + absl::InlinedVector reversed_layout; + reversed_layout.resize(expected_shape.dimensions().size()); + absl::c_reverse_copy(expected_shape.layout().minor_to_major(), + reversed_layout.begin()); + options.permutation = reversed_layout; + options.input_layout = xla::TransposePlan::Striding{strides}; + auto maybe_plan = transpose_cache->cache.GetOrCreate(options); + if (!maybe_plan.ok()) { + return xla::ffi::Error::Internal(maybe_plan.status().ToString()); } - - auto gpu_res = gpuMemcpyAsync(ret->untyped_data(), data, size_bytes, + auto plan = maybe_plan.value(); + plan->Execute(array.data(), temp); + auto gpu_res = gpuMemcpyAsync(ret->untyped_data(), temp, ret->size_bytes(), gpuMemcpyHostToDevice, stream); CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; } diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index 5893af26de85..94d75d9c19ae 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -539,7 +539,6 @@ cc_library( "@xla//third_party/python_runtime:headers", # buildcleaner: keep "@xla//xla:comparison_util", "@xla//xla:shape_util", - "@xla//xla:util", "@xla//xla:xla_data_proto_cc", "@xla//xla/ffi:ffi_api", "@xla//xla/ffi/api:ffi", diff --git a/jaxlib/xla/BUILD b/jaxlib/xla/BUILD index 8602652cbd8a..c5d151b2cd3b 100644 --- a/jaxlib/xla/BUILD +++ b/jaxlib/xla/BUILD @@ -640,7 +640,6 @@ cc_library( "@nanobind", "@xla//third_party/python_runtime:headers", # buildcleaner: keep "@xla//xla:shape_util", - "@xla//xla:util", "@xla//xla:xla_data_proto_cc", "@xla//xla/ffi:ffi_api", "@xla//xla/ffi/api:ffi", diff --git a/jaxlib/xla/py_client_cpu.cc b/jaxlib/xla/py_client_cpu.cc index fef6a54aab2d..ac4e7bee5680 100644 --- a/jaxlib/xla/py_client_cpu.cc +++ b/jaxlib/xla/py_client_cpu.cc @@ -41,7 +41,6 @@ limitations under the License. #include "xla/python/nb_numpy.h" #include "xla/python/types.h" #include "xla/shape_util.h" -#include "xla/util.h" #include "xla/xla_data.pb.h" namespace nb = nanobind; @@ -80,7 +79,8 @@ ffi::Error XlaFfiPythonCpuCallback(FfiLoadedHostCallbacks* callbacks, auto arg = args.get(i); auto ptype = static_cast(arg->element_type()); // TODO(b/395428868): Remove this check once we support subbyte types. - if (ptype == S1 || ptype == U1) { + if (ptype == S1 || ptype == S2 || ptype == S4 || ptype == U1 || + ptype == U2 || ptype == U4) { return ffi::Error(ffi::ErrorCode::kUnimplemented, absl::StrFormat("Unsupported primitive type: %s", PrimitiveType_Name(ptype))); @@ -96,20 +96,9 @@ ffi::Error XlaFfiPythonCpuCallback(FfiLoadedHostCallbacks* callbacks, auto dtype = maybe_dtype.value(); auto dims = absl::Span(arg->dimensions().begin(), arg->dimensions().size()); - // TODO(b/402422886): Remove this once we form Jax arrays directly instead - std::unique_ptr buffer; - const void* data = arg->untyped_data(); - size_t bits_per_element = xla::primitive_util::BitWidth(ptype); - if (bits_per_element == 2 || bits_per_element == 4) { - // NOTE(dsuo): FFI argument and return buffers are sized assuming - size_t packed_size = arg->size_bytes() * bits_per_element / 8; - buffer = xla::UnpackIntN(bits_per_element, static_cast(data), - packed_size); - data = buffer.get(); - } // We pass in data using default numpy layout i.e., std::nullopt. auto array = - nb_numpy_ndarray(dtype, dims, std::nullopt, data); + nb_numpy_ndarray(dtype, dims, std::nullopt, arg.value().untyped_data()); array.attr("flags").attr("writeable") = nb::bool_(false); PyTuple_SET_ITEM(nb_args.ptr(), i, array.release().ptr()); } @@ -130,8 +119,9 @@ ffi::Error XlaFfiPythonCpuCallback(FfiLoadedHostCallbacks* callbacks, for (size_t i = 0; i < rets.size(); ++i) { auto ret = rets.get(i).value(); auto ptype = static_cast(ret->element_type()); - // TODO(b/402422886): Remove this once we form Jax arrays directly instead - if (ptype == S1 || ptype == U1) { + // TODO(b/395428868): Remove this check once we support subbyte types. + if (ptype == S1 || ptype == S2 || ptype == S4 || ptype == U1 || + ptype == U2 || ptype == U4) { return ffi::Error(ffi::ErrorCode::kUnimplemented, absl::StrFormat("Unsupported primitive type: %s", PrimitiveType_Name(ptype))); @@ -151,55 +141,26 @@ ffi::Error XlaFfiPythonCpuCallback(FfiLoadedHostCallbacks* callbacks, } auto expected_shape = maybe_expected_shape.value(); auto expected_strides = ByteStridesForShape(expected_shape); - - const void* data = array.data(); - std::unique_ptr buffer; - size_t bits_per_element = xla::primitive_util::BitWidth(ptype); - size_t size_bytes = array.size() * array.itemsize(); - if (strides != expected_strides) { - xla::TransposePlan::Options options; - options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); - options.dims = absl::Span( - reinterpret_cast(array.shape()), array.ndim()); - absl::InlinedVector reversed_layout; - reversed_layout.resize(expected_shape.dimensions().size()); - absl::c_reverse_copy(expected_shape.layout().minor_to_major(), - reversed_layout.begin()); - options.permutation = reversed_layout; - options.input_layout = xla::TransposePlan::Striding{strides}; - auto maybe_plan = transpose_cache->cache.GetOrCreate(options); - if (!maybe_plan.ok()) { - return ffi::Error::Internal(maybe_plan.status().ToString()); - } - auto plan = maybe_plan.value(); - if (bits_per_element == 2 || bits_per_element == 4) { - // NOTE(dsuo): If the data needs to be unpacked, don't use return buffer - // supplied by FFI directly. - buffer = std::make_unique(size_bytes); - plan->Execute(data, buffer.get()); - data = buffer.get(); - } else { - plan->Execute(data, ret->untyped_data()); - data = ret->untyped_data(); - } - } - - // TODO(b/402422886): Remove this once we form Jax arrays directly instead - // of packing/unpacking to/from numpy arrays. - if (bits_per_element == 2 || bits_per_element == 4) { - // NOTE(dsuo): FFI arguments and return buffers are sized assuming - // minimum 1-byte element sizes, even if the data itself is packed. - buffer = xla::PackIntN(bits_per_element, static_cast(data), - size_bytes); - data = buffer.get(); - size_bytes = (size_bytes * bits_per_element) / 8; + if (strides == expected_strides) { + std::memcpy(ret->untyped_data(), array.data(), ret->size_bytes()); + continue; } - - // Copy data to output buffer if haven't already or modified the data to - // write back. - if (data != ret->untyped_data()) { - std::memcpy(ret->untyped_data(), data, size_bytes); + xla::TransposePlan::Options options; + options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); + options.dims = absl::Span( + reinterpret_cast(array.shape()), array.ndim()); + absl::InlinedVector reversed_layout; + reversed_layout.resize(expected_shape.dimensions_size()); + absl::c_reverse_copy(expected_shape.layout().minor_to_major(), + reversed_layout.begin()); + options.permutation = reversed_layout; + options.input_layout = xla::TransposePlan::Striding{strides}; + auto maybe_plan = transpose_cache->cache.GetOrCreate(options); + if (!maybe_plan.ok()) { + return ffi::Error::Internal(maybe_plan.status().ToString()); } + auto plan = maybe_plan.value(); + plan->Execute(array.data(), ret->untyped_data()); } return ffi::Error::Success(); diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index 34ab20c05644..a8442b4a1356 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -586,15 +586,10 @@ def fun(x): self.assertAllClose(2 * x, fun(x)) self.assertEqual(count(), 1) - @parameterized.parameters("int2", "int4", "uint2", "uint4", "float4_e2m1fn") + @parameterized.parameters("int2", "int4", "uint2", "uint4") def test_subbyte_operands(self, dtype: str): - if jaxlib_extension_version < 323: - self.skipTest("Requires jaxlib_extension_version >= 323.") - if "2" in dtype and jtu.test_device_matches(["tpu"]): - self.skipTest( - "TODO(dsuo): TPU callbacks send SIGABRT for int2, uint2, and" - " float4_e2m1fn." - ) + if jaxlib_extension_version <= 321: + self.skipTest("Requires jaxlib_extension_version >= 322.") def get(x): return x def f(x): @@ -605,17 +600,21 @@ def f(x): ) return y x = np.arange(8, dtype=dtype) - np.testing.assert_array_equal(jax.jit(f)(x), np.arange(8, dtype=dtype)) + # TODO(b/395428868): Remove this check once we support subbyte types. + if jtu.test_device_matches(["tpu"]): + if "2" in dtype: + self.skipTest("TODO(dsuo): TPU callbacks send SIGABRT for int2/uint2.") + np.testing.assert_array_equal(jax.jit(f)(x), np.arange(8, dtype=dtype)) + else: + with self.assertRaisesRegex( + Exception, "Unsupported primitive type" + ): + _ = jax.jit(f)(x) - @parameterized.parameters("int2", "int4", "uint2", "uint4", "float4_e2m1fn") + @parameterized.parameters("int2", "int4", "uint2", "uint4") def test_subbyte_results(self, dtype: str): - if jaxlib_extension_version < 323: - self.skipTest("Requires jaxlib_extension_version >= 323.") - if "2" in dtype and jtu.test_device_matches(["tpu"]): - self.skipTest( - "TODO(dsuo): TPU callbacks send SIGABRT for int2, uint2, and" - " float4_e2m1fn." - ) + if jaxlib_extension_version <= 321: + self.skipTest("Requires jaxlib_extension_version >= 322.") def get(): return np.arange(8, dtype=dtype) @@ -626,43 +625,16 @@ def f(): ) return y - np.testing.assert_array_equal(jax.jit(f)(), np.arange(8, dtype=dtype)) - - @parameterized.parameters("int2", "int4", "uint2", "uint4", "float4_e2m1fn") - def test_non_default_stride_subbyte_results(self, dtype: str): - if jaxlib_extension_version < 323: - self.skipTest("Requires jaxlib_extension_version >= 323.") - if "2" in dtype and jtu.test_device_matches(["tpu"]): - self.skipTest( - "TODO(dsuo): TPU callbacks send SIGABRT for int2, uint2, and" - " float4_e2m1fn." - ) - x = jnp.arange(24, dtype=dtype).reshape(2, 3, 4) - def callback(x): - return np.asfortranarray(x) - - @jax.jit - def f(x): - return jax.pure_callback( - callback, jax.ShapeDtypeStruct(x.shape, x.dtype), x - ) - - result = f(x) - np.testing.assert_array_equal(x, result) - - def test_non_default_stride(self): - x = jnp.arange(24, dtype=jnp.float32).reshape(2, 3, 4) - def callback(x): - return np.asfortranarray(x) - - @jax.jit - def f(x): - return jax.pure_callback( - callback, jax.ShapeDtypeStruct(x.shape, x.dtype), x - ) - - result = f(x) - np.testing.assert_array_equal(x, result) + # TODO(b/395428868): Remove this check once we support subbyte types. + if jtu.test_device_matches(["tpu"]): + if "2" in dtype: + self.skipTest("TODO(dsuo): TPU callbacks send SIGABRT for int2/uint2.") + np.testing.assert_array_equal(jax.jit(f)(), np.arange(8, dtype=dtype)) + else: + with self.assertRaisesRegex( + Exception, "Unsupported primitive type" + ): + _ = jax.jit(f)() class PureCallbackTest(jtu.JaxTestCase): @@ -1136,6 +1108,20 @@ def fun(x): result += fun(jnp.ones((500, 500), jnp.complex64))[1] jax.block_until_ready(result) # doesn't deadlock + def test_non_default_stride(self): + x = jnp.arange(24, dtype=jnp.float32).reshape(2, 3, 4) + def callback(x): + return np.asfortranarray(x) + + @jax.jit + def f(x): + return jax.pure_callback( + callback, jax.ShapeDtypeStruct(x.shape, x.dtype), x + ) + + result = f(x) + np.testing.assert_array_equal(x, result) + class IOCallbackTest(jtu.JaxTestCase): From 5a3fc606d47148cf47a96172ff67b1535182968b Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 7 Apr 2025 07:57:17 -0700 Subject: [PATCH 423/483] Deprecate public export of mlir.custom_call. PiperOrigin-RevId: 744722183 --- CHANGELOG.md | 2 ++ jax/_src/cudnn/fused_attention_stablehlo.py | 11 +++++------ jax/interpreters/mlir.py | 22 ++++++++++++++++++++- 3 files changed, 28 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7f19fcb189ed..21398b31cafb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * `jax.interpreters.mlir.hlo` and `jax.interpreters.mlir.func_dialect`, which were accidental exports, have been removed. If needed, they are available from `jax.extend.mlir`. + * `jax.interpreters.mlir.custom_call` is deprecated. The APIs provided by + {mod}`jax.ffi` should be used instead. * Several previously-deprecated APIs have been removed, including: * From `jax.lib.xla_client`: `FftType`, `PaddingType`, `dtype_to_etype`, and `shape_from_pyval`. diff --git a/jax/_src/cudnn/fused_attention_stablehlo.py b/jax/_src/cudnn/fused_attention_stablehlo.py index d901ed875ceb..818bc018cdf5 100644 --- a/jax/_src/cudnn/fused_attention_stablehlo.py +++ b/jax/_src/cudnn/fused_attention_stablehlo.py @@ -24,10 +24,9 @@ from jax._src import dispatch from jax._src.custom_partitioning import custom_partitioning from jax._src.interpreters import batching +from jax._src.interpreters import mlir from jax._src.lib import cuda_versions from jax._src import xla_bridge -from jax.interpreters import mlir -from jax.interpreters import xla from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo import jax.numpy as jnp @@ -1018,7 +1017,7 @@ def sharded_impl(*args): _dot_product_attention_fwd_p = core.Primitive("dot_product_attention_fwd") _dot_product_attention_fwd_p.multiple_results = True _dot_product_attention_fwd_p.def_impl( - functools.partial(xla.apply_primitive, _dot_product_attention_fwd_p) + functools.partial(dispatch.apply_primitive, _dot_product_attention_fwd_p) ) _dot_product_attention_fwd_p.def_abstract_eval( _dot_product_attention_fwd_abstract @@ -1043,7 +1042,7 @@ def sharded_impl(*args): _dot_product_attention_bwd_p = core.Primitive("dot_product_attention_bwd") _dot_product_attention_bwd_p.multiple_results = True _dot_product_attention_bwd_p.def_impl( - functools.partial(xla.apply_primitive, _dot_product_attention_bwd_p) + functools.partial(dispatch.apply_primitive, _dot_product_attention_bwd_p) ) _dot_product_attention_bwd_p.def_abstract_eval( _dot_product_attention_bwd_abstract @@ -1604,7 +1603,7 @@ def _dot_product_attention_fp8_bwd_partition( _dot_product_attention_fp8_fwd_p = core.Primitive("dot_product_attention_fp8_fwd") _dot_product_attention_fp8_fwd_p.multiple_results = True _dot_product_attention_fp8_fwd_p.def_impl( - functools.partial(xla.apply_primitive, _dot_product_attention_fp8_fwd_p) + functools.partial(dispatch.apply_primitive, _dot_product_attention_fp8_fwd_p) ) _dot_product_attention_fp8_fwd_p.def_abstract_eval( _dot_product_attention_fp8_fwd_abstract @@ -1629,7 +1628,7 @@ def _dot_product_attention_fp8_bwd_partition( _dot_product_attention_fp8_bwd_p = core.Primitive("dot_product_attention_fp8_bwd") _dot_product_attention_fp8_bwd_p.multiple_results = True _dot_product_attention_fp8_bwd_p.def_impl( - functools.partial(xla.apply_primitive, _dot_product_attention_fp8_bwd_p) + functools.partial(dispatch.apply_primitive, _dot_product_attention_fp8_bwd_p) ) _dot_product_attention_fp8_bwd_p.def_abstract_eval( _dot_product_attention_fp8_bwd_abstract diff --git a/jax/interpreters/mlir.py b/jax/interpreters/mlir.py index a0505c74f883..10c8d1e9e671 100644 --- a/jax/interpreters/mlir.py +++ b/jax/interpreters/mlir.py @@ -33,7 +33,7 @@ aval_to_ir_type as aval_to_ir_type, aval_to_ir_types as aval_to_ir_types, core_call_lowering as core_call_lowering, - custom_call as custom_call, + custom_call as _custom_call, dense_bool_elements as dense_bool_elements, dense_bool_array as dense_bool_array, dense_int_array as dense_int_array, @@ -77,3 +77,23 @@ from jax._src.callback import ( emit_python_callback as emit_python_callback, ) + +_deprecations = { + # Added Apr 7 2025 + "custom_call": ( + "mlir.custom_call is deprecated; use the APIs provided by jax.ffi instead.", + _custom_call, + ) +} + +import typing as _typing + +if _typing.TYPE_CHECKING: + custom_call = _custom_call +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del _typing +del _custom_call From dbc3bcd3cebdacc3e0ef8ef717807cac170635eb Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Fri, 4 Apr 2025 08:47:59 -0400 Subject: [PATCH 424/483] Apply forwarding in pjit linearization rule to avoid intermediate copies. --- jax/_src/interpreters/ad.py | 9 +++++++++ jax/_src/pjit.py | 18 +++++++++++++----- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index e47e518a11f2..1824c39f03fe 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -194,6 +194,11 @@ def new_arg(trace, primal_aval, nz): tangent_trace.invalidate() if attrs_tracked: raise NotImplementedError("TODO: attrs") + tangent_jaxpr, used_consts, _ = pe.dce_jaxpr_consts( + tangent_jaxpr, [True] * len(tangent_jaxpr.outvars), + [False] * len(tangent_jaxpr.constvars) + [True] * len(tangent_jaxpr.invars)) + tangent_consts = [c for c, used in zip(tangent_consts, used_consts) if used] + residuals_and_primals = (*tangent_consts, *out_primals) residuals_and_primals = map(primal_trace.to_jaxpr_tracer, residuals_and_primals) primal_jaxpr, primal_consts, attrs_tracked = primal_trace.to_jaxpr(residuals_and_primals, debug_info) @@ -871,6 +876,10 @@ def make_zero(aval): for (r, nz) in zip(out_tangents, out_nzs) if nz] in_tracers = [t for t, nz in zip(tangent_args, nonzeros) if nz] jaxpr, out_consts, _ = pe.tracers_to_jaxpr(in_tracers, out_nz_tracers, jvp.debug_info) + jaxpr, used_consts, _ = pe.dce_jaxpr_consts( + jaxpr, [True] * len(jaxpr.outvars), + [False] * len(jaxpr.constvars) + [True] * len(jaxpr.invars)) + out_consts = [c for used, c in zip(used_consts, out_consts) if used] def linearized(residuals, *tangents): nz_tangents_in = [t for (t, nz) in zip(tangents, nonzeros) if nz] diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index af744ae5db96..8c3c5101eb51 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2076,14 +2076,22 @@ def _pjit_linearization(nzs, *primals_in, jaxpr, donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs): primal_jaxpr, num_residuals, nzs_out, tangent_jaxpr = ad.linearize_jaxpr(jaxpr, nzs) - # constvars will become residuals. Move them to the end of the ordinary args. res_shardings = (UNSPECIFIED,) * num_residuals res_layouts = (None,) * num_residuals res_donated = (False,) * num_residuals + + in_fwd = pe._jaxpr_forwarding(primal_jaxpr.jaxpr) + in_fwd, _ = split_list(in_fwd, [num_residuals]) + keep = tuple(f is None for f in in_fwd) + (True,) * len(out_shardings) + primal_jaxpr = pe.prune_closed_jaxpr_outputs(primal_jaxpr, keep) + num_residuals = sum(f is None for f in in_fwd) + def tangent_fun(consts_, *tangents): + consts_it = iter(consts_) + res = [next(consts_it) if f is None else primals_in[f] for f in in_fwd] + assert next(consts_it, None) is None tangents_nz = _filter_zeros(nzs, tangents) - assert len(consts_) == num_residuals - nz_tangents_out = pjit_p.bind(*(*tangents_nz, *consts_), + nz_tangents_out = pjit_p.bind(*(*tangents_nz, *res), jaxpr=tangent_jaxpr, in_shardings=_filter_zeros(nzs, in_shardings) + res_shardings, out_shardings=_filter_zeros(nzs_out, out_shardings), @@ -2106,9 +2114,9 @@ def _filter_zeros(is_nz_l, l): ans = pjit_p.bind(*primals_in, jaxpr=primal_jaxpr, in_shardings=in_shardings, - out_shardings=(*res_shardings, *out_shardings), + out_shardings=(*res_shardings[:num_residuals], *out_shardings), in_layouts=in_layouts, - out_layouts=(*res_layouts, *out_layouts), + out_layouts=(*res_layouts[:num_residuals], *out_layouts), donated_invars=donated_invars, ctx_mesh=ctx_mesh, name=name, From ff00fa91cecd2b21f866559c5fd07061a335899a Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 7 Apr 2025 09:48:06 -0700 Subject: [PATCH 425/483] Removed unused `jax_remat_opt_barrier` config option It defaults to True and is not flipped to False by any internal JAX users. PiperOrigin-RevId: 744754343 --- jax/_src/ad_checkpoint.py | 86 ++++---------------- jax/_src/config.py | 7 -- jax/experimental/jax2tf/jax2tf.py | 11 ++- jax/experimental/jax2tf/tests/jax2tf_test.py | 6 +- 4 files changed, 21 insertions(+), 89 deletions(-) diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index c2868cf7c078..e5390be4cfe0 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -757,89 +757,34 @@ def _has_effects(effects) -> bool: return bool({e for e in effects if not isinstance(e, core.NamedAxisEffect)}) -def remat_expansion(*args, jaxpr: core.Jaxpr, prevent_cse: bool, - differentiated: bool, is_gpu_platform: bool = False, - **_): +def remat_expansion( + *args, jaxpr: core.Jaxpr, prevent_cse: bool, differentiated: bool, **_ +): assert not jaxpr.constvars if differentiated and prevent_cse: - if config.remat_opt_barrier.value: - translation_rule = _remat_translation_using_opt_barrier - elif is_gpu_platform: - translation_rule = _remat_translation_using_while - else: - translation_rule = _remat_translation_using_cond + translation_rule = _remat_translation_using_opt_barrier else: translation_rule = lambda *args, jaxpr: core.eval_jaxpr(jaxpr, (), *args) return api.named_call(translation_rule, name="checkpoint")(*args, jaxpr=jaxpr) + def _remat_translation_using_opt_barrier(*args, jaxpr: core.Jaxpr): args = lax_internal.optimization_barrier(args) return core.eval_jaxpr(jaxpr, (), *args) -# TODO(mattjj): add core utility for 'create dummy value for this type'? -def _dummy_like(aval: core.AbstractValue) -> Any: - if aval is core.abstract_token: - return lax_internal.create_token() - elif isinstance(aval, (core.ShapedArray, core.DShapedArray)): - return lax_internal.broadcast(lax_internal.empty(aval.dtype), aval.shape) # type: ignore - else: - raise ValueError(aval) - -def _remat_translation_using_while(*args, jaxpr: core.Jaxpr): - # Implements: - # for(counter=0, result=0; counter < rng(1, 2); counter ++) { - # result = eval_jaxpr(*args) - # } - # The loop carry is a tuple: (counter, result, args) - from jax._src.lax import control_flow as lax_control_flow - - avals_out = tuple(v.aval for v in jaxpr.outvars) - carry_init = (np.int32(0), tuple(map(_dummy_like, avals_out)), args) - def cond(carry): - counter, _, _ = carry - unif = lax_internal.rng_uniform(np.int32(1), np.int32(2), shape=()) - return counter < unif - - def body(carry): - counter, _, args = carry - results = core.eval_jaxpr(jaxpr, (), *args) - return (counter + 1, tuple(results), args) - - carry_res = lax_control_flow.while_loop(cond, body, carry_init) - return carry_res[1] - -def _remat_translation_using_cond(*args, jaxpr: core.Jaxpr): - # Implements: - # if(rng(0, 1) < 2) - # return eval_jaxpr(*args) - # else: - # return 0 - from jax._src.lax import control_flow as lax_control_flow - - avals_out = tuple(v.aval for v in jaxpr.outvars) - - def remat_comp(*args): - return tuple(core.eval_jaxpr(jaxpr, (), *args)) - def dummy_comp(*args): - return tuple(map(_dummy_like, avals_out)) - - unif = lax_internal.rng_uniform(np.float32(0), np.float32(1), shape=()) - return lax_control_flow.cond(unif < np.float32(2), remat_comp, dummy_comp, *args) - -def _remat_lowering(ctx, *args, jaxpr: core.Jaxpr, prevent_cse: bool, - differentiated: bool, policy, is_gpu_platform=False): + +def _remat_lowering( + ctx, + *args, + jaxpr: core.Jaxpr, + prevent_cse: bool, + differentiated: bool, + policy, +): jaxpr_args: Sequence[mlir.IrValues] if differentiated and prevent_cse: - # If we're using the loop or cond lowerings, use the slower lower_fun - # based path. - if not config.remat_opt_barrier.value: - return mlir.lower_fun(remat_expansion, multiple_results=True)( - ctx, *args, jaxpr=jaxpr, prevent_cse=prevent_cse, - differentiated=differentiated, policy=policy, - is_gpu_platform=is_gpu_platform) - arg_types = map(mlir.aval_to_ir_type, ctx.avals_in) flat_args = mlir.flatten_ir_values(args) barrier_op = hlo.OptimizationBarrierOp(flat_args) @@ -853,9 +798,8 @@ def _remat_lowering(ctx, *args, jaxpr: core.Jaxpr, prevent_cse: bool, ctx.set_tokens_out(tokens_out) return outs + mlir.register_lowering(remat_p, _remat_lowering) -mlir.register_lowering(remat_p, partial(_remat_lowering, is_gpu_platform=True), - platform="gpu") def checkpoint_name(x, name): diff --git a/jax/_src/config.py b/jax/_src/config.py index b4a12dcc1762..aca6d8e2c938 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1512,13 +1512,6 @@ def _update_disable_jit_thread_local(val): help=('Attempt constant folding during staging.'), include_in_jit_key=True) -# This flag is temporary during rollout of the remat barrier. -# TODO(parkers): Remove if there are no complaints. -remat_opt_barrier = bool_state( - name='jax_remat_opt_barrier', - default=True, - help=('Enables using optimization-barrier op for lowering remat.')) - enable_remat_opt_pass = bool_state( name='jax_compiler_enable_remat_pass', default=True, diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 492e070de1af..ce57bdad5311 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -3173,12 +3173,11 @@ def select_one_carry(new_c: TfVal, c: TfVal, c_aval: core.ShapedArray) -> TfVal: lax_control_flow._scan_impl, extra_name_stack="scan") -tf_impl_with_avals[ad_checkpoint.remat_p] = \ - _convert_jax_impl(partial(ad_checkpoint.remat_expansion, - # TODO: jax2tf cannot discriminate by platform - is_gpu_platform=False), - multiple_results=True, - extra_name_stack="checkpoint") +tf_impl_with_avals[ad_checkpoint.remat_p] = _convert_jax_impl( + ad_checkpoint.remat_expansion, + multiple_results=True, + extra_name_stack="checkpoint", +) tf_impl[ad_checkpoint.name_p] = lambda x, *, name: x diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index bea2b76cb7cf..b40b1a6d5571 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -832,11 +832,7 @@ def f(x1): arg = np.array(3.) f_tf = jax2tf.convert(jax.grad(remat_f)) f_tf_hlo = self.TfToHlo(f_tf, arg) - if config.remat_opt_barrier.value: - self.assertRegex(f_tf_hlo, r"opt-barrier") - else: - self.assertRegex(f_tf_hlo, - r'transpose/jax2tf_f_/jvp/checkpoint/cond/branch_1_fun/Sin') + self.assertRegex(f_tf_hlo, r"opt-barrier") def test_remat_free_var(self): def f(x): From 51c224c446df943b565058144b23eeaf8966009d Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 7 Apr 2025 09:49:28 -0700 Subject: [PATCH 426/483] Removed deprecated `jax.core.{full_lower,jaxpr_as_fun,lattice_join}` PiperOrigin-RevId: 744754730 --- CHANGELOG.md | 4 ++-- jax/_src/core.py | 5 ----- jax/core.py | 13 ++++--------- 3 files changed, 6 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 21398b31cafb..beacd477390f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -49,8 +49,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * From `jax.core`: `AxisSize`, `ClosedJaxpr`, `EvalTrace`, `InDBIdx`, `InputType`, `Jaxpr`, `JaxprEqn`, `Literal`, `MapPrimitive`, `OpaqueTraceState`, `OutDBIdx`, `Primitive`, `Token`, `TRACER_LEAK_DEBUGGER_WARNING`, `Var`, `concrete_aval`, - `dedup_referents`, `escaped_tracer_error`, `extend_axis_env_nd`, `get_referent`, - `join_effects`, `leaked_tracer_error`, `maybe_find_leaked_tracers`, `raise_to_shaped`, + `dedup_referents`, `escaped_tracer_error`, `extend_axis_env_nd`, `full_lower`, `get_referent`, `jaxpr_as_fun`, `join_effects`, `lattice_join`, + `leaked_tracer_error`, `maybe_find_leaked_tracers`, `raise_to_shaped`, `raise_to_shaped_mappings`, `reset_trace_state`, `str_eqn_compact`, `substitute_vars_in_output_ty`, `typecompat`, and `used_axis_names_jaxpr`. Most have no public replacement, though a few are available at {mod}`jax.extend.core`. diff --git a/jax/_src/core.py b/jax/_src/core.py index 14ed19d4d441..9f80842a38ff 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1496,11 +1496,6 @@ def _jaxpr_type_to_callable_annotation(jaxpr: Jaxpr) -> InputType: for v in jaxpr.invars] return tuple(out) -# TODO(dougalm): Deprecate. This is here for backwards compat. -def lattice_join(x, y): - assert typematch(x, y) - return x - # For use in typing annotations to denote either a Tracer or a `valid_jaxtype`. Value = Any diff --git a/jax/core.py b/jax/core.py index 688fa14d9ccf..9702798d9af9 100644 --- a/jax/core.py +++ b/jax/core.py @@ -97,13 +97,11 @@ "typecheck": ("jax.core.typecheck is deprecated.", _src_core.typecheck), "typematch": ("jax.core.typematch is deprecated.", _src_core.typematch), # Added 2024-12-10 - "full_lower": ("jax.core.full_lower is deprecated. It is a no-op as of JAX v0.4.36.", - _src_core.full_lower), - "jaxpr_as_fun": ("jax.core.jaxpr_as_fun is deprecated. Use jax.extend.core.jaxpr_as_fun instead, " + "full_lower": ("jax.core.full_lower is deprecated. It is a no-op as of JAX v0.4.36.", None), + "jaxpr_as_fun": ("jax.core.jaxpr_as_fun was removed in JAX v0.6.0. Use jax.extend.core.jaxpr_as_fun instead, " "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", - _src_core.jaxpr_as_fun), - "lattice_join": ("jax.core.lattice_join is deprecated. It is a no-op as of JAX v0.4.36.", - _src_core.lattice_join), + None), + "lattice_join": ("jax.core.lattice_join is deprecated. It is a no-op as of JAX v0.4.36.", None), # Finalized 2025-03-25 for JAX v0.6.0; remove after 2025-06-25 "AxisSize": ("jax.core.AxisSize was removed in JAX v0.6.0.", None), "ClosedJaxpr": ("jax.core.ClosedJaxpr was removed in JAX v0.6.0. Use jax.extend.core.ClosedJaxpr instead, " @@ -152,10 +150,7 @@ axis_frame = _src_core.axis_frame call_p = _src_core.call_p closed_call_p = _src_core.closed_call_p - full_lower = _src_core.full_lower get_type = _src_core.get_aval - jaxpr_as_fun = _src_core.jaxpr_as_fun - lattice_join = _src_core.lattice_join trace_state_clean = _src_core.trace_state_clean typecheck = _src_core.typecheck typematch = _src_core.typematch From 855829e1bcf2fbdbe183469350108de50b4cf872 Mon Sep 17 00:00:00 2001 From: Jacob Burnim Date: Mon, 7 Apr 2025 10:51:46 -0700 Subject: [PATCH 427/483] Add int4, uint4 to test_util.suppported_types To increase test coverage for these types. PiperOrigin-RevId: 744777880 --- jax/_src/test_util.py | 14 ++++++++++---- jaxlib/xla/py_values.cc | 2 ++ jaxlib/xla/xla_client.py | 2 +- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index c3f7fb4c4139..5a2eaabd0f02 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -57,6 +57,7 @@ from jax._src import mesh as mesh_lib from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm from jax._src.interpreters import mlir +from jax._src.lib import jaxlib_extension_version from jax._src.numpy.util import promote_dtypes, promote_dtypes_inexact from jax._src.public_test_util import ( # noqa: F401 _assert_numpy_allclose, _check_dtypes_match, _default_tolerance, _dtype, check_close, check_grads, @@ -376,10 +377,13 @@ def device_under_test(): def supported_dtypes(): if device_under_test() == "tpu": - types = {np.bool_, np.int8, np.int16, np.int32, np.uint8, np.uint16, - np.uint32, _dtypes.bfloat16, np.float16, np.float32, np.complex64, + types = {np.bool_, _dtypes.int4, np.int8, np.int16, np.int32, + _dtypes.uint4, np.uint8, np.uint16, np.uint32, + _dtypes.bfloat16, np.float16, np.float32, np.complex64, _dtypes.float8_e4m3fn, _dtypes.float8_e4m3b11fnuz, _dtypes.float8_e5m2} + if jaxlib_extension_version < 327: + types -= {_dtypes.int4, _dtypes.uint4} elif device_under_test() == "gpu": types = {np.bool_, np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, @@ -389,10 +393,12 @@ def supported_dtypes(): elif device_under_test() == "METAL": types = {np.int32, np.uint32, np.float32} else: - types = {np.bool_, np.int8, np.int16, np.int32, np.int64, - np.uint8, np.uint16, np.uint32, np.uint64, + types = {np.bool_, _dtypes.int4, np.int8, np.int16, np.int32, np.int64, + _dtypes.uint4, np.uint8, np.uint16, np.uint32, np.uint64, _dtypes.bfloat16, np.float16, np.float32, np.float64, np.complex64, np.complex128} + if jaxlib_extension_version < 327: + types -= {_dtypes.int4, _dtypes.uint4} if not config.enable_x64.value: types -= {np.uint64, np.int64, np.float64, np.complex128} return types diff --git a/jaxlib/xla/py_values.cc b/jaxlib/xla/py_values.cc index e13a38197c0a..709f3cb3b2ef 100644 --- a/jaxlib/xla/py_values.cc +++ b/jaxlib/xla/py_values.cc @@ -684,10 +684,12 @@ absl::StatusOr PyArgSignatureOfValue(nb::handle arg, // float64_dt and complex128_dt which are taken care of in previous if // blocks. (*p)[dtypes.np_bool.ptr()] = numpy_array_handler; + (*p)[dtypes.np_int4.ptr()] = numpy_array_handler; (*p)[dtypes.np_int8.ptr()] = numpy_array_handler; (*p)[dtypes.np_int16.ptr()] = numpy_array_handler; (*p)[dtypes.np_int32.ptr()] = numpy_array_handler; (*p)[dtypes.np_int64.ptr()] = np_int_handler; + (*p)[dtypes.np_uint4.ptr()] = numpy_array_handler; (*p)[dtypes.np_uint8.ptr()] = numpy_array_handler; (*p)[dtypes.np_uint16.ptr()] = numpy_array_handler; (*p)[dtypes.np_uint32.ptr()] = numpy_array_handler; diff --git a/jaxlib/xla/xla_client.py b/jaxlib/xla/xla_client.py index 58e0cb070e29..fa31d1764de2 100644 --- a/jaxlib/xla/xla_client.py +++ b/jaxlib/xla/xla_client.py @@ -50,7 +50,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version. -_version = 326 +_version = 327 # An internal increasing version number for protecting jaxlib code against # ifrt changes. From 7239487ccc635e2374073c167b773d0627b070a9 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 7 Apr 2025 18:06:56 +0000 Subject: [PATCH 428/483] Bump medyagh/setup-minikube from 0.0.18 to 0.0.19 Bumps [medyagh/setup-minikube](https://github.com/medyagh/setup-minikube) from 0.0.18 to 0.0.19. - [Release notes](https://github.com/medyagh/setup-minikube/releases) - [Commits](https://github.com/medyagh/setup-minikube/compare/d8c0eb871f6f455542491d86a574477bd3894533...cea33675329b799adccc9526aa5daccc26cd5052) --- updated-dependencies: - dependency-name: medyagh/setup-minikube dependency-version: 0.0.19 dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- .github/workflows/k8s.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/k8s.yaml b/.github/workflows/k8s.yaml index a96ce1ead26c..1042388fe9c6 100644 --- a/.github/workflows/k8s.yaml +++ b/.github/workflows/k8s.yaml @@ -38,7 +38,7 @@ jobs: path: jax - name: Start Minikube cluster - uses: medyagh/setup-minikube@d8c0eb871f6f455542491d86a574477bd3894533 # ratchet:medyagh/setup-minikube@v0.0.18 + uses: medyagh/setup-minikube@cea33675329b799adccc9526aa5daccc26cd5052 # ratchet:medyagh/setup-minikube@v0.0.19 - name: Install K8s Jobset run: | From fcf5115fdbd216a44c2daf4860fc241cb8ac4f8b Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Mon, 7 Apr 2025 11:10:52 -0700 Subject: [PATCH 429/483] [Pallas Fuser] Add output_fusion_mask support Currently, the fusion API assumes by default that all of the outputs of a @fuse-decorated function are computed jointly in one big output fusion. For example, in the following snippet ```python @fuse def f(x, y): z1, z2 = fusable_f(x, y) return g(z1, z2) ``` it assumes that `g` is a single function that operates on z1 and z2 jointly. However, in practice, the fusable may want two separate output fusions: ```python @fuse def f(x, y): z1, z2 = fusable_f(x, y) return g1(z1), g2(z2) ``` This is a special case of the general function but the fusable may not be materializing z1 and z2 at the same time so may not be able to compute this efficiently with a single function g. By decorating a fusable with an output fusion prefix (in the above example `(True, True)`), the fusable will now be given a pair of functions `g1` and `g2` if the output fusion is "separable". For example, we'd error for the following example: ```python @fuse def f(x, y): z1, z2 = fusable_f(x, y) return z1 + z2 ``` because z1 and z2 interact with each other in the output fusion. The rationale for providing a PyTree prefix (as opposed to a more general mechanism) is that the fusable can group its outputs into subtrees that it can identify with the output prefix. This does restrict the types of output groups that are possible (outputs must be part of the same shared subtree, as opposed to arbitrarily scattered throughput the output pytree), but this is an okay restriction because the fusable author is responsible for the grouping and can always construct it that way. PiperOrigin-RevId: 744784770 --- jax/_src/pallas/fuser/BUILD | 1 + jax/_src/pallas/fuser/fusable.py | 59 +++--- jax/_src/pallas/fuser/jaxpr_fusion.py | 203 ++++++++++++++++++--- tests/pallas/BUILD | 21 +++ tests/pallas/fusion_test.py | 232 ++++++++++++++++++++++++ tests/pallas/tpu_fusable_matmul_test.py | 9 +- 6 files changed, 469 insertions(+), 56 deletions(-) create mode 100644 tests/pallas/fusion_test.py diff --git a/jax/_src/pallas/fuser/BUILD b/jax/_src/pallas/fuser/BUILD index 66bbac33aabb..8339ad6705ff 100644 --- a/jax/_src/pallas/fuser/BUILD +++ b/jax/_src/pallas/fuser/BUILD @@ -99,6 +99,7 @@ pytype_strict_library( "//jax:core", "//jax:partial_eval", "//jax:tree_util", + "//jax:util", ], ) diff --git a/jax/_src/pallas/fuser/fusable.py b/jax/_src/pallas/fuser/fusable.py index b075c6d136c9..aa2ea0843c0a 100644 --- a/jax/_src/pallas/fuser/fusable.py +++ b/jax/_src/pallas/fuser/fusable.py @@ -13,6 +13,7 @@ # limitations under the License. """Fusable primitive.""" +from typing import Any import jax from jax._src import api_util @@ -40,32 +41,38 @@ def _make_trivial_fusion(x: jax.Array) -> fusion_lib.Fusion: ) -def fusable(f): - def wrapper(*args): - def wrapped(*args): - in_fusions = tree_util.tree_map(_make_trivial_fusion, args) - return f(*in_fusions, None) - - flat_args, in_tree = tree_util.tree_flatten(args) - debug_info = api_util.debug_info('fusable', wrapped, args, {}) - flat_fun, out_tree_thunk = api_util.flatten_fun_nokwargs( - lu.wrap_init(wrapped, debug_info=debug_info), in_tree - ) - flat_avals = [_get_aval(x) for x in flat_args] - jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals) - out_tree = out_tree_thunk() - out = fusable_p.bind( - *consts, - *flat_args, - jaxpr=jaxpr, - num_consts=len(consts), - in_tree=in_tree, - out_tree=out_tree, - func=f, - ) - return tree_util.tree_unflatten(out_tree, out) - - return wrapper +def fusable(f=None, *, output_fusion_prefix: Any = True): + def decorator(f): + def wrapper(*args): + def wrapped(*args): + in_fusions = tree_util.tree_map(_make_trivial_fusion, args) + return f(*in_fusions, None) + + flat_args, in_tree = tree_util.tree_flatten(args) + debug_info = api_util.debug_info('fusable', wrapped, args, {}) + flat_fun, out_tree_thunk = api_util.flatten_fun_nokwargs( + lu.wrap_init(wrapped, debug_info=debug_info), in_tree + ) + flat_avals = [_get_aval(x) for x in flat_args] + jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals) + out_tree = out_tree_thunk() + out = fusable_p.bind( + *consts, + *flat_args, + jaxpr=jaxpr, + num_consts=len(consts), + in_tree=in_tree, + out_tree=out_tree, + func=f, + output_fusion_prefix=output_fusion_prefix, + ) + return tree_util.tree_unflatten(out_tree, out) + + return wrapper + + if f is not None: + return decorator(f) + return decorator @fusable_p.def_impl diff --git a/jax/_src/pallas/fuser/jaxpr_fusion.py b/jax/_src/pallas/fuser/jaxpr_fusion.py index 3d36b8f3e2fd..649037e18092 100644 --- a/jax/_src/pallas/fuser/jaxpr_fusion.py +++ b/jax/_src/pallas/fuser/jaxpr_fusion.py @@ -14,15 +14,15 @@ """Fuses a function.""" +from collections.abc import Sequence +import functools from typing import Any - import jax from jax._src import api_util from jax._src import core as jax_core from jax._src import linear_util as lu from jax._src import tree_util from jax._src.interpreters import partial_eval as pe - from jax._src.pallas.fuser import fusable_dtype from jax._src.pallas.fuser import fusion as fusion_lib from jax._src.pallas.fuser.fusable import fusable_p @@ -73,9 +73,9 @@ def wrapper(*args, **kwargs): _fusable: dict[jax_core.Primitive, Any] = {} -def construct_fusion( +def _construct_fusion_jaxpr( candidate_values, jaxpr: jax_core.Jaxpr, outvars, *invars, **kwargs -) -> fusion_lib.Fusion: +): flat_outvars, out_tree = tree_util.tree_flatten(outvars) flat_invars, in_tree = tree_util.tree_flatten((invars, kwargs)) new_jaxpr_no_dce = jaxpr.replace( @@ -94,12 +94,6 @@ def construct_fusion( c for used, c in zip(used_consts, candidate_values, strict=True) if used ) kernel_in_tree = tree_util.tree_structure((invars, kwargs)) - - def _fn(*args, **kwargs): - flat_args, _ = tree_util.tree_flatten((args, kwargs)) - out_flat = jax_core.eval_jaxpr(new_jaxpr, new_values, *flat_args) - return tree_util.tree_unflatten(out_tree, out_flat) - flat_in_type = [ jax.ShapeDtypeStruct(x.aval.shape, x.aval.dtype) for x in flat_invars ] @@ -108,9 +102,158 @@ def _fn(*args, **kwargs): out_tree, [jax.ShapeDtypeStruct(x.aval.shape, x.aval.dtype) for x in flat_outvars], ) + return new_jaxpr, new_values, in_type, out_type, out_tree + + +def construct_fusion( + candidate_values, jaxpr: jax_core.Jaxpr, outvars, *invars, **kwargs +) -> fusion_lib.Fusion: + new_jaxpr, new_values, in_type, out_type, out_tree = _construct_fusion_jaxpr( + candidate_values, jaxpr, outvars, *invars, **kwargs + ) + + def _fn(*args, **kwargs): + flat_args, _ = tree_util.tree_flatten((args, kwargs)) + out_flat = jax_core.eval_jaxpr(new_jaxpr, new_values, *flat_args) + return tree_util.tree_unflatten(out_tree, out_flat) + return fusion_lib.Fusion(_fn, in_type, out_type) +def _find_downstream( + jaxpr: jax_core.Jaxpr, in_used: Sequence[bool] +) -> tuple[bool, ...]: + # TODO(sharadmv): We use partial_eval to query downstream dependencies which + # is not an officially sanctioned way to do so, since PE is really used for + # AD. In the future, we should have a special Jaxpr API that queries this. + _, _, out_used, *_ = pe.partial_eval_jaxpr_custom( + jaxpr, + in_unknowns=in_used, + in_inst=in_used, + ensure_out_unknowns=False, + ensure_out_inst=False, + saveable=lambda *_, **__: False, + ) + return tuple(out_used) + + +def _construct_output_permutation( + used: list[tuple[bool, ...]], +) -> list[int]: + order = [] + for u in used: + true_vals = [i for i in range(len(u)) if u[i]] + order.extend(true_vals) + return [order.index(i) for i in range(len(order))] + + +def _construct_output_fusions( + candidate_values, + jaxpr, + out_tree, + fusion_eqn_index, + fusion_eqn_outvars, # Flat list of vars output by the fusable eqn + fusion_eqn_out_tree, # Tree structure of the fusable eqn outputs + output_fusion_prefix, # Pytree defining output groups +): + # 1. Create jaxpr_out: represents computation *after* the fusable + # Inputs: fusion_eqn_outvars + # Outputs: jaxpr.outvars + jaxpr_out, all_values, _, _, _ = _construct_fusion_jaxpr( + candidate_values, + jaxpr.replace( + eqns=jaxpr.eqns[:fusion_eqn_index] + + jaxpr.eqns[fusion_eqn_index + 1 :] + ), + tree_util.tree_unflatten(out_tree, jaxpr.outvars), # Original outputs + tree_util.tree_unflatten( + fusion_eqn_out_tree, fusion_eqn_outvars + ), # Fusable outputs as inputs + ) + + # 2. Group fusable outputs based on the mask + unflat_fusable_outvars = jax.tree.unflatten( + fusion_eqn_out_tree, fusion_eqn_outvars + ) + partial_flat = jax.tree.structure(output_fusion_prefix).flatten_up_to( + unflat_fusable_outvars + ) + + # 3. Calculate dependencies and check disjointness + downstream_outputs_used_masks = [] # List of bool tuples, one per group + already_used_final_outputs = set() # Indices of final outputs already claimed + for outvars_group in partial_flat: + # Identify vars in this group + used_fusable_outvars = set(jax.tree.leaves(outvars_group)) + # Create mask for jaxpr_out inputs corresponding to this group + in_used_mask = [ + True if v in used_fusable_outvars else False for v in jaxpr_out.invars + ] + # Trace dependencies through jaxpr_out to find which final outputs are affected + downstream_used_mask = _find_downstream( + jaxpr_out, in_used_mask + ) # Mask for jaxpr_out.outvars (== jaxpr.outvars) + + # Check for overlap in final output usage across groups + for i, used in enumerate(downstream_used_mask): + if used: + if i in already_used_final_outputs: + raise ValueError( + "Outputs must be disjoint in order to use separate output fusions" + ) + already_used_final_outputs.add(i) + downstream_outputs_used_masks.append(downstream_used_mask) + + # 4. Construct output permutation needed to restore original output order + output_permutation = _construct_output_permutation( + downstream_outputs_used_masks + ) + + # Construct fusions for each group by DCEing the jaxpr_out + output_fusions = [] + for i, outvars_group in enumerate(partial_flat): + flat_group_vars, _ = tree_util.tree_flatten(outvars_group) + downstream_used_mask = downstream_outputs_used_masks[i] + + used_jaxpr_invars = [False] * len(all_values) + [ + v in flat_group_vars for v in jaxpr_out.invars + ] + jaxpr_out_for_group, used_consts, _ = pe.dce_jaxpr_consts( + jaxpr_out, downstream_used_mask, instantiate=used_jaxpr_invars + ) + values_for_jaxpr = tuple( + c for used, c in zip(used_consts, all_values, strict=True) if used + ) + + def _fn(jaxpr, vals, *args, **kwargs): + flat_args, _ = tree_util.tree_flatten((args, kwargs)) + out_flat = jax_core.eval_jaxpr(jaxpr, vals, *flat_args) + return tuple(out_flat) + + fn = functools.partial(_fn, jaxpr_out_for_group, values_for_jaxpr) + in_type = jax.tree.map( + lambda v: jax.ShapeDtypeStruct(v.aval.shape, v.aval.dtype), # pytype: disable=attribute-error + outvars_group, + ) + out_type = tuple( + jax.ShapeDtypeStruct(v.aval.shape, v.aval.dtype) # pytype: disable=attribute-error + for v in jaxpr_out_for_group.outvars + ) + fusion = fusion_lib.Fusion( + fn, + (in_type, {}), + out_type, + ) + output_fusions.append(fusion) + + return ( + tree_util.tree_unflatten( + tree_util.tree_structure(output_fusion_prefix), output_fusions + ), + output_permutation, + ) + + def fuse_jaxpr( jaxpr: jax_core.Jaxpr, out_tree: tree_util.PyTreeDef, consts, *args ): @@ -125,6 +268,15 @@ def fuse_jaxpr( raise ValueError("No fusable eqn found") fusion_eqn = jaxpr.eqns[fusion_eqn_index] + # Now let's check if we need to do any fusion at all, e.g. do the outputs of + # the jaxpr have any dependence on the fusion at all? We can DCE the jaxpr + # with all the inputs and outputs to check if there is a dependence. + dced_jaxpr, _ = pe.dce_jaxpr(jaxpr, [True] * len(jaxpr.outvars), + instantiate=True) + if not any(eqn.primitive is fusable_p for eqn in dced_jaxpr.eqns): + # Short circuit if there is nothing to fuse. + return jax_core.eval_jaxpr(dced_jaxpr, consts, *args) + candidate_values = [*consts, *args] # Construct fusions for non-constant inputs to the fusable. @@ -141,21 +293,20 @@ def fuse_jaxpr( in_fusions = tree_util.tree_unflatten( fusion_eqn.params["in_tree"], in_fusions_flat ) - out_fusion = construct_fusion( + output_fusions, output_permutation = _construct_output_fusions( candidate_values, - jaxpr.replace( - eqns=jaxpr.eqns[:fusion_eqn_index] - + jaxpr.eqns[fusion_eqn_index + 1 :] - ), - tree_util.tree_unflatten(out_tree, jaxpr.outvars), - tree_util.tree_unflatten( - fusion_eqn.params["out_tree"], fusion_eqn.outvars - ), + jaxpr, + out_tree, + fusion_eqn_index, + fusion_eqn.outvars, + fusion_eqn.params["out_tree"], + fusion_eqn.params["output_fusion_prefix"], ) - # Run the fusable. - out = fusion_eqn.params["func"](*in_fusions, out_fusion) - - # Now return the flattened output (the fuse_jaxpr caller should unflatten). - out_flat = tree_util.tree_leaves(out) - assert len(out_flat) == len(jaxpr.outvars) - return out_flat + out = fusion_eqn.params["func"](*in_fusions, output_fusions) + flat_out = jax.tree.leaves(out) + permuted_out = [flat_out[i] for i in output_permutation] + assert len(permuted_out) == len(jaxpr.outvars), ( + len(permuted_out), + len(jaxpr.outvars), + ) + return permuted_out diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 1ea05c700938..ba5d9d5f4ae7 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -680,6 +680,27 @@ jax_multiplatform_test( ] + py_deps("absl/testing") + py_deps("numpy"), ) +jax_multiplatform_test( + name = "fusion_test", + srcs = [ + "fusion_test.py", + ], + disable_configs = [ + "cpu", + "cpu_shardy", + ], + enable_backends = ["cpu"], + tags = [ + "noasan", + "nomsan", + "notsan", + ], + deps = [ + "//jax:pallas", + "//jax:pallas_fuser", + ] + py_deps("absl/testing") + py_deps("numpy"), +) + jax_multiplatform_test( name = "tpu_fusable_matmul_test", srcs = ["tpu_fusable_matmul_test.py"], diff --git a/tests/pallas/fusion_test.py b/tests/pallas/fusion_test.py new file mode 100644 index 000000000000..2edcf78f1aba --- /dev/null +++ b/tests/pallas/fusion_test.py @@ -0,0 +1,232 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +import jax +from jax._src import test_util as jtu +from jax.experimental.pallas import fuser +import jax.numpy as jnp +import numpy as np + +jax.config.parse_flags_with_absl() + + +class FusionTest(jtu.JaxTestCase): + + def test_basic_fusion(self): + + @jax.jit + @fuser.fuse + @fuser.fusable + def f(x_fn, y_fn): + x = x_fn() + if y_fn is None: + y_fn = lambda x: x + return y_fn(x) + + x = jax.random.normal(jax.random.key(0), (128, 128), dtype=jnp.float32) + np.testing.assert_array_equal(f(x), x) + + def test_separate_output_fusions_trivial(self): + + @fuser.fusable(output_fusion_prefix=(True, True)) + def f(x_fn, y_fn, z_fns): + x = x_fn() + y = y_fn() + if z_fns is None: + z_fns = lambda x: x, lambda x: x + z_fn1, z_fn2 = z_fns + return z_fn1(x), z_fn2(y) + + @jax.jit + @fuser.fuse + def g(x, y): + x, y = f(x, y) + return x, y * 2 + + x = jax.random.normal(jax.random.key(0), (128, 128), dtype=jnp.float32) + y = jax.random.normal(jax.random.key(1), (1, 128), dtype=jnp.float32) + x_out, y_out = g(x, y) + np.testing.assert_array_equal(x_out, x) + np.testing.assert_array_equal(y_out, y * 2) + + def test_separate_output_fusions_should_error_if_not_disjoint(self): + + @fuser.fusable(output_fusion_prefix=(True, True)) + def f(x_fn, y_fn, z_fns): + x = x_fn() + y = y_fn() + if z_fns is None: + z_fns = lambda x: x, lambda x: x + z_fn1, z_fn2 = z_fns + return z_fn1(x), z_fn2(y) + + @jax.jit + @fuser.fuse + def g(x, y): + x_res, y_res = f(x, y) + return x_res + y_res + + x = jax.random.normal(jax.random.key(0), (128, 128), dtype=jnp.float32) + y = jax.random.normal(jax.random.key(1), (128, 128), dtype=jnp.float32) + + with self.assertRaisesRegex( + ValueError, + "Outputs must be disjoint in order to use separate output fusions", + ): + g(x, y) + + def test_separate_output_fusions_allows_permute(self): + + @fuser.fusable(output_fusion_prefix=(True, True)) + def f(x_fn, y_fn, z_fns): + x = x_fn() + y = y_fn() + if z_fns is None: + z_fns = lambda x: x, lambda x: x + z_fn1, z_fn2 = z_fns + return z_fn1(x), z_fn2(y) + + @jax.jit + @fuser.fuse + def g(x, y): + x_res, y_res = f(x, y) + return y_res * 2, x_res + + x = jax.random.normal(jax.random.key(0), (128, 128), dtype=jnp.float32) + y = jax.random.normal(jax.random.key(1), (1, 128), dtype=jnp.float32) + y_out, x_out = g(x, y) + np.testing.assert_array_equal(x_out, x) + np.testing.assert_array_equal(y_out, y * 2) + + def test_separate_output_fusions_with_nesting(self): + + @fuser.fusable(output_fusion_prefix=(True, True)) + def f(x_fn, y_fn, z_fns): + x = x_fn() + y = y_fn() + if z_fns is None: + z_fns = lambda x: x, lambda x: x + z_fn1, z_fn2 = z_fns + return z_fn1(x), z_fn2(y) + + @jax.jit + @fuser.fuse + def g(x, y): + x_res, y_res = f(x, y) + return (x_res * 2, x_res + x_res), y_res + + x = jax.random.normal(jax.random.key(0), (128, 128), dtype=jnp.float32) + y = jax.random.normal(jax.random.key(1), (1, 128), dtype=jnp.float32) + (x1_out, x2_out), y_out = g(x, y) + np.testing.assert_array_equal(x1_out, x * 2) + np.testing.assert_array_equal(x2_out, x + x) + np.testing.assert_array_equal(y_out, y) + + def test_separate_output_fusions_with_nesting_and_permutation(self): + + @fuser.fusable(output_fusion_prefix=(True, True)) + def f(x_fn, y_fn, z_fns): + x = x_fn() + y = y_fn() + if z_fns is None: + z_fns = lambda x: x, lambda x: x + z_fn1, z_fn2 = z_fns + return z_fn1(x), z_fn2(y) + + @jax.jit + @fuser.fuse + def g(x, y): + x_res, y_res = f(x, y) + return y_res, (x_res * 2, x_res + x_res) + + x = jax.random.normal(jax.random.key(0), (128, 128), dtype=jnp.float32) + y = jax.random.normal(jax.random.key(1), (1, 128), dtype=jnp.float32) + y_out, (x1_out, x2_out) = g(x, y) + np.testing.assert_array_equal(x1_out, x * 2) + np.testing.assert_array_equal(x2_out, x + x) + np.testing.assert_array_equal(y_out, y) + + def test_separate_output_fusions_with_deep_output_mask(self): + + @fuser.fusable(output_fusion_prefix=(True, (True, True))) + def f(x_fn, y_fn, z_fn, o_fns): + x = x_fn() + y = y_fn() + z = z_fn() + if o_fns is None: + o_fns = lambda x: x, (lambda x: x, lambda x: x) + o_fn1, (o_fn2, o_fn3) = o_fns + return o_fn1(x), (o_fn2(y), o_fn3(z)) + + @jax.jit + @fuser.fuse + def g(x, y, z): + x_res, (y_res, z_res) = f(x, y, z) + return (x_res * 2, (y_res, z_res + z_res)) + + x = jax.random.normal(jax.random.key(0), (128, 128), dtype=jnp.float32) + y = jax.random.normal(jax.random.key(1), (1, 128), dtype=jnp.float32) + z = jax.random.normal(jax.random.key(1), (128, 1), dtype=jnp.float32) + x_out, (y_out, z_out) = g(x, y, z) + np.testing.assert_array_equal(x_out, x * 2) + np.testing.assert_array_equal(y_out, y) + np.testing.assert_array_equal(z_out, z + z) + + def test_separate_output_fusions_with_reused_value(self): + @fuser.fusable(output_fusion_prefix=(True, True)) + def f(x_fn, y_fn, z_fns): + x = x_fn() + y = y_fn() + if z_fns is None: + z_fns = lambda x: x, lambda x: x + z_fn1, z_fn2 = z_fns + return z_fn1(x), z_fn2(y) + + @jax.jit + @fuser.fuse + def g(x, y, a): + x_res, y_res = f(x, y) + return y_res + a, (x_res * 2, x_res + x_res + a) + + x = jax.random.normal(jax.random.key(0), (128, 128), dtype=jnp.float32) + y = jax.random.normal(jax.random.key(1), (1, 128), dtype=jnp.float32) + a = jax.random.normal(jax.random.key(1), (1, 128), dtype=jnp.float32) + y_out, (x1_out, x2_out) = g(x, y, a) + np.testing.assert_array_equal(x1_out, x * 2) + np.testing.assert_array_equal(x2_out, x + x + a) + np.testing.assert_array_equal(y_out, y + a) + + def test_empty_fusion(self): + @fuser.fusable + def f(x_fn, y_fn): + x = x_fn() + if y_fn is None: + y_fn = lambda x: x + return y_fn(x) + + @jax.jit + @fuser.fuse + def g(x, a): + _ = f(x) + return a + + x = jax.random.normal(jax.random.key(0), (128, 128), dtype=jnp.float32) + a = jax.random.normal(jax.random.key(1), (128, 128), dtype=jnp.float32) + y_out = g(x, a) + np.testing.assert_array_equal(y_out, a) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/tpu_fusable_matmul_test.py b/tests/pallas/tpu_fusable_matmul_test.py index 5ee372ce92ab..93523b174774 100644 --- a/tests/pallas/tpu_fusable_matmul_test.py +++ b/tests/pallas/tpu_fusable_matmul_test.py @@ -71,7 +71,8 @@ def _(): def _(): acc = acc_ref[...].astype(out_dtype) z_values = jax.tree.map(lambda ref: ref.get(), z_value_refs) - o_ref[...] = z_fn(pids, scalar_prefetch, z_values, acc) + out = z_fn(pids, scalar_prefetch, z_values, acc) + jax.tree.map(lambda ref, x: ref.set(x), o_ref, out) def _fusable_matmul( @@ -174,12 +175,12 @@ def z_index_map(i, j, k, *_): y_value_block_specs, z_value_block_specs, ], - out_specs=z_out_block_spec, + out_specs=[z_out_block_spec], ), compiler_params=pltpu.TPUCompilerParams( dimension_semantics=dimension_semantics, ), - out_shape=z_out_type, + out_shape=[z_out_type], interpret=interpret, debug=debug, )( @@ -187,7 +188,7 @@ def z_index_map(i, j, k, *_): x_values, y_values, z_values, - ) + )[0] def fusable_matmul( From e1e37f8d5e80597dbe7a3b447f8c29ba1575ee55 Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Mon, 7 Apr 2025 11:43:26 -0700 Subject: [PATCH 430/483] [Mosaic TPU] FWD compatibility needs to keep previous version at least one month. PiperOrigin-RevId: 744796256 --- jax/_src/tpu_custom_call.py | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index f84db206f4d1..cbec7f873156 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -64,15 +64,22 @@ ) -# This tracks the latest Mosaic IR version with a monthly delay. -FWD_COMPAT_IR_VERSION = 4 -DEFAULT_IR_VERSION = None -# TODO(jevinjiang): Remove this once both jaxlib and libtpu are up to date. -if is_cloud_tpu_older_than(2025, 4, 5) or jax.version._version_as_tuple( - jax.lib.__version__ -) < (0, 5, 4): - FWD_COMPAT_IR_VERSION = 3 - DEFAULT_IR_VERSION = 3 +# Controls the IR serialization version. Upon incrementing the +# default version in jaxlib/mosaic/dialect/tpu/transforms/serde.cc we must +# continue to use the old serialization version when in forward compatibility +# mode: for 1 month when exporting, or when using old cloud TPU. +# +# This can be achieved by adding: +# if ctx.is_forward_compat() or is_cloud_tpu_older_than(): +# return +# return None +# +# We should also add a TODO to remove the conditional one month later. +def get_ir_version(ctx: mlir.LoweringRuleContext) -> int | None: + # TODO(jevinjiang): remove the forward compatibility check after 2025-05-05. + if ctx.is_forward_compat() or is_cloud_tpu_older_than(2025, 4, 5): + return 3 + return None tpu_custom_call_p = core.Primitive("tpu_custom_call") @@ -679,9 +686,7 @@ def lower_module_to_custom_call( serialization_format=serialization_format, output_memory_spaces=output_memory_spaces, kernel_name=kernel_name, - ir_version=FWD_COMPAT_IR_VERSION - if ctx.is_forward_compat() - else DEFAULT_IR_VERSION, + ir_version=get_ir_version(ctx), ) return _tpu_custom_call_lowering( ctx, From 05ca0233914b429381aeb1758e96ed73c6c434c1 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Mon, 7 Apr 2025 19:12:32 +0000 Subject: [PATCH 431/483] [shard-map] in eager shmap, handle all rep rule output cases By convention, rep_rules can return three kinds of thing: 1. a sequence (tuple or list), 2. a single set, or 3. a single None. Even rules for primitives with multiple results can return single objects rather than sequences; the reason is that it's convenient not ot have to infer the number of outputs for higher-order primitives. In the latter two cases we rely on the caller (in this case, ShardMapTrace.process_primitive) to 'broadcast' the singleton result to a list of results equal to the number of outputs. Previously, the code was checking `if type(out_rep) is set`, which doesn't handle case 3. (We briefly tried another fix direction where we don't allow case 3, because we don't have case 3 in the upcoming VMA type system which replaces this stuff. But until that lands the easiest fix is just to handle all cases correctly.) fixes #26148, fixes #27673 Co-authored-by: Justin Fu --- jax/experimental/shard_map.py | 3 ++- tests/shard_map_test.py | 13 +++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index ef3751c96901..3a46f444fb1b 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -958,7 +958,8 @@ def process_primitive(self, prim, tracers, params): rep_rule = _check_rules.get(prim, partial(_rule_missing, prim)) out_rep = rep_rule(self.mesh, *in_rep, **params) if self.check else set() if prim.multiple_results: - out_rep = [out_rep] * len(out_vals) if type(out_rep) is set else out_rep + out_rep = (out_rep if isinstance(out_rep, (list, tuple)) + else [out_rep] * len(out_vals)) return map(partial(ShardMapTracer, self), out_rep, out_vals) return ShardMapTracer(self, out_rep, out_vals) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 520cc02638df..3a4c3ea9779c 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -685,6 +685,19 @@ def f3(): f3() jax.jit(f3)() + def test_multiple_result_primitive_with_none_sharding(self): + # https://github.com/jax-ml/jax/issues/27673 + xs = jnp.arange(20).reshape(2, 10) + mesh = jtu.create_mesh((2,), ("i",)) + y = shard_map( + lambda x: jnp.split(x.squeeze(), 2), + mesh=mesh, + in_specs=(None,), + out_specs=P("i"), + )(xs) + expected = jnp.repeat(xs, 2, axis=0).reshape(2, 2, 10) + self.assertArraysEqual(y, expected) + def test_vmap_spmd_axis_name(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) From dc00f9bdaea77094cb6cdf959d99e61efbd87268 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 7 Apr 2025 13:56:07 -0400 Subject: [PATCH 432/483] Apply output forwarding in lin rule for pjit. --- jax/_src/interpreters/partial_eval.py | 1 + jax/_src/pjit.py | 45 +++++++++++++++++++++------ 2 files changed, 37 insertions(+), 9 deletions(-) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 532eb0f80029..21be93ee485e 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -1876,6 +1876,7 @@ def invalidate(self): # avoid cyclic refs self.frame.tracers = [] self.frame.constid_to_tracer = {} + self.frame.constvar_to_val = {} def to_jaxpr_tracer(self, x): as_local_var = self.frame.tracer_to_var.get(id(x)) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 8c3c5101eb51..641456eca15b 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2079,19 +2079,44 @@ def _pjit_linearization(nzs, *primals_in, jaxpr, res_shardings = (UNSPECIFIED,) * num_residuals res_layouts = (None,) * num_residuals res_donated = (False,) * num_residuals + primal_out_shardings = res_shardings + tuple(out_shardings) + primal_out_layouts = res_layouts + tuple(out_layouts) + def keep_where(l, should_keep): + return tuple(x for x, keep in zip(l, should_keep) if keep) + + # Input-to-output forwarding. in_fwd = pe._jaxpr_forwarding(primal_jaxpr.jaxpr) - in_fwd, _ = split_list(in_fwd, [num_residuals]) - keep = tuple(f is None for f in in_fwd) + (True,) * len(out_shardings) + in_fwd_res, in_fwd_primal = split_list(in_fwd, [num_residuals]) + in_fwd = in_fwd_res + [ + fwd if isinstance(os, UnspecifiedValue) and ol is None else None + for os, ol, fwd in zip(out_shardings, out_layouts, in_fwd_primal) + ] + del in_fwd_res, in_fwd_primal + keep = [f is None for f in in_fwd] primal_jaxpr = pe.prune_closed_jaxpr_outputs(primal_jaxpr, keep) - num_residuals = sum(f is None for f in in_fwd) + primal_out_shardings = keep_where(primal_out_shardings, keep) + primal_out_layouts = keep_where(primal_out_layouts, keep) + kept_res, _ = split_list(keep, [num_residuals]) + num_kept_residuals = sum(kept_res) + del keep, kept_res + + # Output-to-output forwarding. + num_out_primals = len(primal_jaxpr.jaxpr.outvars) - num_kept_residuals + res_vars, out_vars = split_list(primal_jaxpr.jaxpr.outvars, [num_kept_residuals]) + idx_map = {id(v): i for i, v in enumerate(out_vars)} + offset = sum(id(v) not in idx_map for v in res_vars) + idx_map = {k: v + offset for k, v in idx_map.items()} + out_fwd = [idx_map.get(id(v)) for v in res_vars] + [None] * num_out_primals + keep = [f is None for f in out_fwd] + primal_jaxpr = pe.prune_closed_jaxpr_outputs(primal_jaxpr, keep) + primal_out_shardings = keep_where(primal_out_shardings, keep) + primal_out_layouts = keep_where(primal_out_layouts, keep) + del keep def tangent_fun(consts_, *tangents): - consts_it = iter(consts_) - res = [next(consts_it) if f is None else primals_in[f] for f in in_fwd] - assert next(consts_it, None) is None tangents_nz = _filter_zeros(nzs, tangents) - nz_tangents_out = pjit_p.bind(*(*tangents_nz, *res), + nz_tangents_out = pjit_p.bind(*tangents_nz, *consts_, jaxpr=tangent_jaxpr, in_shardings=_filter_zeros(nzs, in_shardings) + res_shardings, out_shardings=_filter_zeros(nzs_out, out_shardings), @@ -2114,15 +2139,17 @@ def _filter_zeros(is_nz_l, l): ans = pjit_p.bind(*primals_in, jaxpr=primal_jaxpr, in_shardings=in_shardings, - out_shardings=(*res_shardings[:num_residuals], *out_shardings), + out_shardings=primal_out_shardings, in_layouts=in_layouts, - out_layouts=(*res_layouts[:num_residuals], *out_layouts), + out_layouts=primal_out_layouts, donated_invars=donated_invars, ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused, inline=inline, compiler_options_kvs=compiler_options_kvs) + ans = subs_list(out_fwd, ans, ans) + ans = subs_list(in_fwd, primals_in, ans) residuals_ans, primal_ans = split_list(ans, [num_residuals]) return primal_ans, nzs_out, residuals_ans, tangent_fun From 23b63cd5e0c8f7ab337443ff18d7069d0b8b1afb Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 7 Apr 2025 12:50:03 -0700 Subject: [PATCH 433/483] Update XLA dependency to use revision http://github.com/openxla/xla/commit/77635006f6a898f71f19db360e9b4485aa5106da. PiperOrigin-RevId: 744819336 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index a8a93026378f..d4df9ee38034 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "145f836bd5175dc5dd262f716a0c59af2b0297a0" -XLA_SHA256 = "bd19d8a1d25468696809a69ef3984bb00ef432e3fe9c05116b9c114dc7c83fa2" +XLA_COMMIT = "77635006f6a898f71f19db360e9b4485aa5106da" +XLA_SHA256 = "d2a63a3cd2f354cd07699f30e7b5c16c7513e686e498b8ad712fb577ab677121" def repo(): tf_http_archive( From 522add2cccf1dc17cb0bb874468b7d87aebf32ef Mon Sep 17 00:00:00 2001 From: Michael Hudgins Date: Mon, 7 Apr 2025 13:10:44 -0700 Subject: [PATCH 434/483] [CI] Temporarily disable TPU v6 due to runner issues PiperOrigin-RevId: 744825924 --- .github/workflows/cloud-tpu-ci-nightly.yml | 11 +---------- .github/workflows/wheel_tests_nightly_release.yml | 11 ++--------- 2 files changed, 3 insertions(+), 19 deletions(-) diff --git a/.github/workflows/cloud-tpu-ci-nightly.yml b/.github/workflows/cloud-tpu-ci-nightly.yml index fd799a3f70b5..b50b07d5cc4a 100644 --- a/.github/workflows/cloud-tpu-ci-nightly.yml +++ b/.github/workflows/cloud-tpu-ci-nightly.yml @@ -27,18 +27,9 @@ jobs: jaxlib-version: ["head", "pypi_latest", "nightly", "nightly+oldest_supported_libtpu"] tpu: [ {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"}, - {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}, - {type: "v6e-8", cores: "8", runner: "linux-x86-ct6e-180-8tpu"} + {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"} ] python-version: ["3.10"] - # Exclude v6e-8 tests for nightly+oldest_supported_libtpu and pypi_latest for resource constraints. - exclude: - - tpu: - type: "v6e-8" - jaxlib-version: "nightly+oldest_supported_libtpu" - - tpu: - type: "v6e-8" - jaxlib-version: "pypi_latest" name: "TPU test (jaxlib=${{ matrix.jaxlib-version }}, ${{ matrix.tpu.type }})" env: LIBTPU_OLDEST_VERSION_DATE: 20241205 diff --git a/.github/workflows/wheel_tests_nightly_release.yml b/.github/workflows/wheel_tests_nightly_release.yml index 6fd48d016bd0..132aad577d50 100644 --- a/.github/workflows/wheel_tests_nightly_release.yml +++ b/.github/workflows/wheel_tests_nightly_release.yml @@ -80,26 +80,19 @@ jobs: tpu-specs: [ # {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"}, - {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}, - {type: "v6e-8", cores: "8", runner: "linux-x86-ct6e-180-8tpu"} + {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"} ] libtpu-version-type: ["pypi_latest", "nightly", "oldest_supported_libtpu"] exclude: - libtpu-version-type: ${{ startsWith(github.ref_name, 'release/') && 'nightly' }} - libtpu-version-type: ${{ !startsWith(github.ref_name, 'release/') && 'pypi_latest' }} - # Run a single Python version for v4-8 and v6e-8 + # Run a single Python version for v4-8 - tpu-specs: type: "v4-8" python: "3.10" - tpu-specs: type: "v4-8" python: "3.11" - - tpu-specs: - type: "v6e-8" - python: "3.10" - - tpu-specs: - type: "v6e-8" - python: "3.11" # Run min and max Python versions for v5e-8 - tpu-specs: type: "v5e-8" From e1b057287967c477834fb9b62006f1e03dc763b1 Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Mon, 7 Apr 2025 13:23:45 -0700 Subject: [PATCH 435/483] [mgpu] Allow bf16 printing PiperOrigin-RevId: 744830111 --- jax/experimental/mosaic/gpu/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 28534cf4025b..47401440fac2 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -135,7 +135,7 @@ def _debug_scalar_ty_format(arg): return "%llu", arg if ir.F32Type.isinstance(arg.type): return "%f", arg - if ir.F16Type.isinstance(arg.type): + if ir.BF16Type.isinstance(arg.type) or ir.F16Type.isinstance(arg.type): arg = arith.extf(ir.F32Type.get(), arg) return "%f", arg raise NotImplementedError(f"Can't print the type {arg.type}") From b6e4b93851c75b0eea375cbf43771ddb094c547b Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Mon, 7 Apr 2025 13:49:28 -0700 Subject: [PATCH 436/483] Add jaxlib_extension_version guard against explicit copying in jax.device_put. PiperOrigin-RevId: 744838237 --- jax/_src/dispatch.py | 3 ++- tests/pjit_test.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index d205f860b214..baab6d519291 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -44,6 +44,7 @@ from jax._src.interpreters import pxla from jax._src.interpreters import xla from jax._src.layout import DeviceLocalLayout, Layout +from jax._src.lib import jaxlib_extension_version from jax._src.lib import xla_client as xc from jax._src.mesh import AbstractMesh, Mesh from jax._src.monitoring import record_event_duration_secs, record_event_time_span @@ -495,7 +496,7 @@ def _device_put_sharding_impl(x, aval, device, copy): return _DeferredShardArg(x, x.sharding, aval, x.committed, copy) elif is_single_device_sharding(x.sharding): device = x.sharding._device_assignment[0] if device is None else device - if copy == CopySemantics.COPY: + if copy == CopySemantics.COPY and jaxlib_extension_version >= 327: return xc.batched_device_put(aval, SingleDeviceSharding(device), [x], [device], True, True) return pxla.batched_device_put(aval, SingleDeviceSharding(device), [x], diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 2570c6090351..2db75be18475 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -1401,7 +1401,7 @@ def test_zero_literal_equality(self): self.assertIn("stablehlo.constant dense<-0.000000e+00>", ir) def test_device_put_copy_donate(self): - if jaxlib_extension_version < 323: + if jaxlib_extension_version < 327: raise unittest.SkipTest("Copy not supported in device put.") x = np.arange(1000) y = jax.device_put(x, device=jax.devices()[0], may_alias=False, donate=False) From 9a3e94dec519c4a5dbf4549be9cee983d9b63cb8 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Mon, 7 Apr 2025 21:00:59 +0000 Subject: [PATCH 437/483] [shard-map] add while_map rep rule fixes #27664 --- jax/_src/util.py | 9 ++----- jax/experimental/shard_map.py | 38 ++++++++++++++++++++++++++ tests/shard_map_test.py | 51 +++++++++++++++++++++++++++++++++++ 3 files changed, 91 insertions(+), 7 deletions(-) diff --git a/jax/_src/util.py b/jax/_src/util.py index b3f7becee7eb..30da28522840 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -497,13 +497,8 @@ def __eq__(self, other): self.args == other.args and self.kwargs == other.kwargs) def __hash__(self): - return hash( - ( - self.f.__code__, - self.args, - tuple(sorted(self.kwargs.items(), key=lambda kv: kv[0])), - ), - ) + kwargs = tuple(sorted(self.kwargs.items(), key=lambda kv: kv[0])) + return hash((self.f.__code__, self.args, kwargs)) def __call__(self, *args, **kwargs): return self.f(*self.args, *args, **self.kwargs, **kwargs) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 3a46f444fb1b..7c4a8c6e2542 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -1403,6 +1403,44 @@ def _scan_rewrite(mesh, in_rep, *args, jaxpr, num_consts, num_carry, **params): *args, jaxpr=jaxpr_, num_consts=num_consts, num_carry=num_carry, **params) return out_vals, out_rep +@register_check(control_flow.loops.while_p) +def _while_check(mesh, *in_rep, body_jaxpr, cond_nconsts, body_nconsts, **_): + _, bconst_rep, carry_rep_in = split_list(in_rep, [cond_nconsts, body_nconsts]) + carry_rep_out = _check_rep(mesh, body_jaxpr.jaxpr, [*bconst_rep, *carry_rep_in]) + if tuple(carry_rep_in) != tuple(carry_rep_out): + raise Exception("Scanwhile_loopcarry input and output got mismatched " + "replication types {carry_rep_in} and {carry_rep_out}. " + "Please open an issue at " + "https://github.com/jax-ml/jax/issues, and as a temporary " + "workaround pass the check_rep=False argument to shard_map") + return carry_rep_out + +@register_rewrite(control_flow.loops.while_p) +def _while_rewrite(mesh, in_rep, *args, cond_jaxpr, body_jaxpr, cond_nconsts, + body_nconsts): + # while while isn't transposable, we insert pbroadcasts for consistent carry + cconst_rep, bconst_rep, carry_rep_in = split_list(in_rep, [cond_nconsts, body_nconsts]) + num_carry = len(args) - cond_nconsts - body_nconsts + for _ in range(1 + num_carry): + in_rep_ = [*bconst_rep, *carry_rep_in] + _, carry_rep_out = _replication_rewrite_nomatch(mesh, body_jaxpr, in_rep_) + if tuple(carry_rep_in) == tuple(carry_rep_out): + break + carry_rep_in = map(op.and_, carry_rep_in, carry_rep_out) + else: + assert False, "Fixpoint not reached" + + cond_jaxpr_, _ = _replication_rewrite_nomatch( + mesh, cond_jaxpr, (*cconst_rep, *carry_rep_in)) + body_jaxpr_ = _replication_rewrite_match( + mesh, body_jaxpr, (*bconst_rep, *carry_rep_in), carry_rep_out) + args_ = [pbroadcast(x, tuple(n for n in src if n not in dst)) + if src - dst else x for x, src, dst in zip(args, in_rep, in_rep_)] + out_vals = control_flow.loops.while_p.bind( + *args_, cond_jaxpr=cond_jaxpr_, body_jaxpr=body_jaxpr_, + cond_nconsts=cond_nconsts, body_nconsts=body_nconsts) + return out_vals, carry_rep_out + @register_check(control_flow.conditionals.cond_p) def _cond_rule(mesh, *in_rep, branches): _, *args_rep = in_rep diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 3a4c3ea9779c..daf95ebbd50b 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -1016,6 +1016,57 @@ def body(c, _): shard_map(g, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), out_specs=[P(None), P(None), P(('x', 'y'))])(x, x, x) + def test_while_rep_rule(self): + mesh = jtu.create_mesh((2, 2,), ('x', 'y')) + + def f(x, y, z): + x, y, z = x.sum(), y.sum(), z.sum() + def cond(c): + i, *_ = c + return i < 5 + def body(c): + i, c, *cs = c + return (i + 1, *cs, c) + _, *out = jax.lax.while_loop(cond, body, (0, x, y, z)) + return [jnp.expand_dims(a, 0) for a in out] + + x = jnp.arange(4) + + # doesn't crash, because out_spec assumes no replication (and there is none) + shard_map(f, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + out_specs=P(('x', 'y')))(x, x, x) + + # does crash, because output incorrectly promises replication + with self.assertRaisesRegex(ValueError, "require replication"): + shard_map(f, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + out_specs=P('x'))(x, x, x) + with self.assertRaisesRegex(ValueError, "require replication"): + shard_map(f, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + out_specs=P('y'))(x, x, x) + with self.assertRaisesRegex(ValueError, "require replication"): + shard_map(f, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + out_specs=P(None))(x, x, x) + + def g(x, y, z): + x, y, z = x.sum(), y.sum(), z.sum() + def cond(c): + i, *_ = c + return i < 1 + def body(c): + i, *cs = c + return (i + 1, *cs) + _, *out = jax.lax.while_loop(cond, body, (0, x, y, z)) + return [jnp.expand_dims(a, 0) for a in out] + + # doesn't crash, because everything matches + shard_map(g, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + out_specs=[P(None), P('x'), P(('x', 'y'))])(x, x, x) + + # does crash, because the second guy is wrong + with self.assertRaisesRegex(ValueError, "require replication"): + shard_map(g, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + out_specs=[P(None), P(None), P(('x', 'y'))])(x, x, x) + def test_cond_rep_rule(self): mesh = jtu.create_mesh((2, 2,), ('x', 'y')) x = jnp.arange(4) From d3cfff057fadfe173bc9410300ef37de09031d3a Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 7 Apr 2025 14:08:35 -0700 Subject: [PATCH 438/483] jax.numpy: support __jax_array__ in remaining APIs --- jax/_src/numpy/lax_numpy.py | 25 +++++++++++++++---------- tests/array_extensibility_test.py | 25 +++++++++++++------------ 2 files changed, 28 insertions(+), 22 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index d4226617030b..503dca1784c4 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -911,11 +911,11 @@ def histogram(a: ArrayLike, bins: ArrayLike = 10, Array(True, dtype=bool) """ if weights is None: - util.check_arraylike("histogram", a, bins) + a, _ = util.ensure_arraylike("histogram", a, bins) a, = util.promote_dtypes_inexact(a) weights = ones_like(a) else: - util.check_arraylike("histogram", a, bins, weights) + a, _, weights = util.ensure_arraylike("histogram", a, bins, weights) if np.shape(a) != np.shape(weights): raise ValueError("weights should have the same shape as a.") a, weights = util.promote_dtypes_inexact(a, weights) @@ -1005,7 +1005,7 @@ def histogram2d(x: ArrayLike, y: ArrayLike, bins: ArrayLike | list[ArrayLike] = >>> jnp.allclose(normed_sum, 1.0) Array(True, dtype=bool) """ - util.check_arraylike("histogram2d", x, y) + x, y = util.ensure_arraylike("histogram2d", x, y) try: N = len(bins) # type: ignore[arg-type] except TypeError: @@ -1077,10 +1077,10 @@ def histogramdd(sample: ArrayLike, bins: ArrayLike | list[ArrayLike] = 10, Array(True, dtype=bool) """ if weights is None: - util.check_arraylike("histogramdd", sample) + sample = util.ensure_arraylike("histogramdd", sample) sample, = util.promote_dtypes_inexact(sample) else: - util.check_arraylike("histogramdd", sample, weights) + sample, weights = util.ensure_arraylike("histogramdd", sample, weights) if np.shape(weights) != np.shape(sample)[:1]: raise ValueError("should have one weight for each sample.") sample, weights = util.promote_dtypes_inexact(sample, weights) @@ -2424,7 +2424,7 @@ def expand_dims(a: ArrayLike, axis: int | Sequence[int]) -> Array: [2], [3]]]], dtype=int32) """ - util.check_arraylike("expand_dims", a) + a = util.ensure_arraylike("expand_dims", a) axis = _ensure_index_tuple(axis) return lax.expand_dims(a, axis) @@ -4371,7 +4371,7 @@ def pad_func(row: Array, pad_width: tuple[int, int], Array([-10, -10, 2, 3, 4, 10, 10], dtype=int32) """ - util.check_arraylike("pad", array) + array = util.ensure_arraylike("pad", array) pad_width = _broadcast_to_pairs(pad_width, np.ndim(array), "pad_width") if pad_width and not all(core.is_dim(p[0]) and core.is_dim(p[1]) for p in pad_width): @@ -6988,8 +6988,10 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *, Array([[1, 1, 2, 2, 2, 2, 2], [3, 3, 4, 4, 4, 4, 4]], dtype=int32) """ - arr = util.ensure_arraylike("repeat", a) - core.is_dim(repeats) or util.check_arraylike("repeat", repeats) + if core.is_dim(repeats): + arr = util.ensure_arraylike("repeat", a) + else: + arr, repeats = util.ensure_arraylike("repeat", a, repeats) if axis is None: arr = arr.ravel() @@ -7828,7 +7830,7 @@ def diag_indices_from(arr: ArrayLike) -> tuple[Array, ...]: Array([0, 1], dtype=int32), Array([0, 1], dtype=int32)) """ - util.check_arraylike("diag_indices_from", arr) + arr = util.ensure_arraylike("diag_indices_from", arr) nd = np.ndim(arr) if not np.ndim(arr) >= 2: raise ValueError("input array must be at least 2-d") @@ -8244,6 +8246,9 @@ def delete( # Case 3: obj is an array # NB: pass both arrays to check for appropriate error message. util.check_arraylike("delete", a, obj) + # Can't use ensure_arraylike here because obj may be static. + if hasattr(obj, "__jax_array__"): + obj = obj.__jax_array__() # Case 3a: unique integer indices; delete in a JIT-compatible way if issubdtype(_dtype(obj), np.integer) and assume_unique_indices: diff --git a/tests/array_extensibility_test.py b/tests/array_extensibility_test.py index 14fcc18ca7a5..7c2681a2100e 100644 --- a/tests/array_extensibility_test.py +++ b/tests/array_extensibility_test.py @@ -81,6 +81,7 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: Bool = ShapeDtype(bool) Int = ShapeDtype(int) +UInt = ShapeDtype('uint32') Uint8 = ShapeDtype('uint8') Float = ShapeDtype(float) Complex = ShapeDtype(complex) @@ -280,18 +281,18 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: NumPyAPI.sig(jnp.correlate, Float[7], Float[3]), NumPyAPI.sig(jnp.cos, Float[5]), NumPyAPI.sig(jnp.cosh, Float[5]), - # NumPyAPI.sig(np.count_nonzero, [float], [(10,)]), - # NumPyAPI.sig(np.cov, [float], [(10,)]), - # NumPyAPI.sig(np.cross, [float, float], [(3,), (3,)]), + NumPyAPI.sig(jnp.count_nonzero, Float[10]), + NumPyAPI.sig(jnp.cov, Float[10]), + NumPyAPI.sig(jnp.cross, Float[3], Float[3]), NumPyAPI.sig(jnp.cumprod, Float[5]), NumPyAPI.sig(jnp.cumsum, Float[5]), NumPyAPI.sig(jnp.cumulative_prod, Float[5]), NumPyAPI.sig(jnp.cumulative_sum, Float[5]), NumPyAPI.sig(jnp.deg2rad, Float[5]), NumPyAPI.sig(jnp.degrees, Float[5]), - # NumPyAPI.sig(jnp.delete, Float[5], Int[()]), + NumPyAPI.sig(jnp.delete, Float[5], Int[()]), NumPyAPI.sig(jnp.diag, Float[5]), - # NumPyAPI.sig(jnp.diag_indices_from, Float[5, 5]), + NumPyAPI.sig(jnp.diag_indices_from, Float[5, 5]), NumPyAPI.sig(jnp.diagflat, Float[5]), NumPyAPI.sig(jnp.diagonal, Float[5, 5]), NumPyAPI.sig(jnp.diff, Float[5]), @@ -306,7 +307,7 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: NumPyAPI.sig(jnp.equal, Float[5], Float[5]), NumPyAPI.sig(jnp.exp, Float[5]), NumPyAPI.sig(jnp.exp2, Float[5]), - # NumPyAPI.sig(jnp.expand_dims, Float[5], axis=0), + NumPyAPI.sig(jnp.expand_dims, Float[5], axis=0), NumPyAPI.sig(jnp.expm1, Float[5]), NumPyAPI.sig(jnp.extract, Bool[5], Float[5]), NumPyAPI.sig(jnp.fabs, Float[5]), @@ -332,11 +333,11 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: NumPyAPI.sig(jnp.greater, Float[5], Float[5]), NumPyAPI.sig(jnp.greater_equal, Float[5], Float[5]), NumPyAPI.sig(jnp.heaviside, Float[5], Float[5]), - # NumPyAPI.sig(jnp.histogram, Float[5]), + NumPyAPI.sig(jnp.histogram, Float[5]), NumPyAPI.sig(jnp.histogram2d, Float[5], Float[5]), NumPyAPI.sig(jnp.histogram_bin_edges, Float[5]), - # NumPyAPI.sig(jnp.histogramdd, Float[5, 3]), - # NumPyAPI.sig(jnp.hsplit, Float[3, 5], Int[1]), + NumPyAPI.sig(jnp.histogramdd, Float[5, 3]), + NumPyAPI.sig(jnp.hsplit, Float[3, 6], indices_or_sections=2), NumPyAPI.sig(jnp.hstack, (Float[5], Float[5])), NumPyAPI.sig(jnp.hypot, Float[5], Float[5]), NumPyAPI.sig(jnp.i0, Float[5]), @@ -411,7 +412,7 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: NumPyAPI.sig(jnp.ones_like, Float[5]), NumPyAPI.sig(jnp.outer, Float[5], Float[5]), NumPyAPI.sig(jnp.packbits, Int[5]), - # NumPyAPI.sig(jnp.pad, Float[5], pad_width=2), + NumPyAPI.sig(jnp.pad, Float[5], pad_width=2), NumPyAPI.sig(jnp.partition, Float[5], kth=3), NumPyAPI.sig(jnp.percentile, Float[5], q=75), NumPyAPI.sig(jnp.permute_dims, Float[3, 5], axes=(1, 0)), @@ -437,11 +438,11 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct: NumPyAPI.sig(jnp.rad2deg, Float[5]), NumPyAPI.sig(jnp.radians, Float[5]), NumPyAPI.sig(jnp.ravel, Float[5]), - # NumPyAPI.sig(jnp.ravel_multi_index, Int[2, 5], dims=(2, 3)), + NumPyAPI.sig(jnp.ravel_multi_index, [Uint8[5], Uint8[5]], dims=(8, 9)), NumPyAPI.sig(jnp.real, Complex[5]), NumPyAPI.sig(jnp.reciprocal, Float[5]), NumPyAPI.sig(jnp.remainder, Float[5], Float[5]), - # NumPyAPI.sig(jnp.repeat, Float[5], Int[5]), + NumPyAPI.sig(jnp.repeat, Float[5], Int[5]), NumPyAPI.sig(jnp.reshape, Float[6], shape=(2, 3)), NumPyAPI.sig(jnp.resize, Float[6], new_shape=(2, 3)), NumPyAPI.sig(jnp.right_shift, Int[5], Int[5]), From db11efab3be59105b2ac2ccce7281fda30438f1d Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 7 Apr 2025 14:43:34 -0700 Subject: [PATCH 439/483] Migrate jaxlib to use a single common .so file for all C++ dependencies. The idea is to move all of the jaxlib contents into a single .so file, and have all of the other Python extensions be tiny stubs that reexport part of the larger .so file. This has two main benefits: * it reduces the size of the jaxlib wheel, by about 70-80MB when installed. The benefit of the change is that it avoid duplication between the MLIR CAPI code and the copy of MLIR in XLA. * it gives us flexibility to split and merge Python extensions as we see fit. Issue https://github.com/jax-ml/jax/issues/11225 PiperOrigin-RevId: 744855997 --- .bazelrc | 1 + jaxlib/BUILD | 48 +++++++++- jaxlib/jax_common.json | 8 ++ jaxlib/libjax_common.lds | 7 ++ jaxlib/libjax_common_darwin.lds | 1 + jaxlib/mlir/_mlir_libs/BUILD.bazel | 136 ++++++++++++--------------- jaxlib/mlir/_mlir_libs/triton_ext.cc | 10 ++ jaxlib/pyinit_stub.c | 28 ++++++ jaxlib/pywrap.bzl | 89 ++++++++++++++++++ jaxlib/setup.py | 2 + jaxlib/tools/BUILD.bazel | 3 +- jaxlib/tools/build_wheel.py | 44 ++++----- jaxlib/triton/BUILD | 5 +- jaxlib/xla/BUILD | 3 +- 14 files changed, 281 insertions(+), 104 deletions(-) create mode 100644 jaxlib/jax_common.json create mode 100644 jaxlib/libjax_common.lds create mode 100644 jaxlib/libjax_common_darwin.lds create mode 100644 jaxlib/pyinit_stub.c create mode 100644 jaxlib/pywrap.bzl diff --git a/.bazelrc b/.bazelrc index 422363644578..0c359e039c89 100644 --- a/.bazelrc +++ b/.bazelrc @@ -98,6 +98,7 @@ build:windows --incompatible_strict_action_env=true # ############################################################################# build:nonccl --define=no_nccl_support=true +build --repo_env USE_PYWRAP_RULES=1 build:posix --copt=-fvisibility=hidden build:posix --copt=-Wno-sign-compare build:posix --cxxopt=-std=c++17 diff --git a/jaxlib/BUILD b/jaxlib/BUILD index c8114b48835f..d195bda41f32 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -20,6 +20,12 @@ load( "py_library_providing_imports_info", "pytype_library", ) +load( + "//jaxlib:pywrap.bzl", + "nanobind_pywrap_extension", + "pywrap_binaries", + "pywrap_library", +) load("//jaxlib:symlink_files.bzl", "symlink_files") licenses(["notice"]) @@ -51,6 +57,7 @@ py_library_providing_imports_info( lib_rule = pytype_library, deps = [ ":cpu_feature_guard", + ":jax", ":utils", "//jaxlib/cpu:_lapack", "//jaxlib/mlir", @@ -98,6 +105,44 @@ exports_files([ "setup.py", ]) +pywrap_library( + name = "jax", + common_lib_def_files_or_filters = { + "jaxlib/jax_common": "jax_common.json", + }, + common_lib_version_scripts = { + "jaxlib/jax_common": select({ + "@bazel_tools//src/conditions:windows": None, + "@bazel_tools//src/conditions:darwin": "libjax_common_darwin.lds", + "//conditions:default": "libjax_common.lds", + }), + }, + deps = [ + ":utils", + "//jaxlib/mlir/_mlir_libs:_chlo", + "//jaxlib/mlir/_mlir_libs:_mlir", + "//jaxlib/mlir/_mlir_libs:_mlirDialectsGPU", + "//jaxlib/mlir/_mlir_libs:_mlirDialectsLLVM", + "//jaxlib/mlir/_mlir_libs:_mlirDialectsNVGPU", + "//jaxlib/mlir/_mlir_libs:_mlirDialectsSparseTensor", + "//jaxlib/mlir/_mlir_libs:_mlirGPUPasses", + "//jaxlib/mlir/_mlir_libs:_mlirHlo", + "//jaxlib/mlir/_mlir_libs:_mlirSparseTensorPasses", + "//jaxlib/mlir/_mlir_libs:_mosaic_gpu_ext", + "//jaxlib/mlir/_mlir_libs:_sdy", + "//jaxlib/mlir/_mlir_libs:_stablehlo", + "//jaxlib/mlir/_mlir_libs:_tpu_ext", + "//jaxlib/mlir/_mlir_libs:_triton_ext", + "//jaxlib/mlir/_mlir_libs:register_jax_dialects", + "//jaxlib/xla:xla_extension", + ], +) + +pywrap_binaries( + name = "jaxlib_binaries", + dep = ":jax", +) + cc_library( name = "absl_status_casters", hdrs = ["absl_status_casters.h"], @@ -170,10 +215,9 @@ nanobind_extension( ], ) -nanobind_extension( +nanobind_pywrap_extension( name = "utils", srcs = ["utils.cc"], - module_name = "utils", deps = [ "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", diff --git a/jaxlib/jax_common.json b/jaxlib/jax_common.json new file mode 100644 index 000000000000..61a2c9313897 --- /dev/null +++ b/jaxlib/jax_common.json @@ -0,0 +1,8 @@ +{ + "global": [ + "Wrapped_PyInit_*" + ], + "local": [ + "*" + ] +} diff --git a/jaxlib/libjax_common.lds b/jaxlib/libjax_common.lds new file mode 100644 index 000000000000..6130415a8d26 --- /dev/null +++ b/jaxlib/libjax_common.lds @@ -0,0 +1,7 @@ +{ + global: + Wrapped_PyInit_*; + + local: + *; +}; diff --git a/jaxlib/libjax_common_darwin.lds b/jaxlib/libjax_common_darwin.lds new file mode 100644 index 000000000000..aed9a1d7512a --- /dev/null +++ b/jaxlib/libjax_common_darwin.lds @@ -0,0 +1 @@ +*Wrapped_PyInit_* diff --git a/jaxlib/mlir/_mlir_libs/BUILD.bazel b/jaxlib/mlir/_mlir_libs/BUILD.bazel index fb94837cff37..6599e50695d4 100644 --- a/jaxlib/mlir/_mlir_libs/BUILD.bazel +++ b/jaxlib/mlir/_mlir_libs/BUILD.bazel @@ -15,10 +15,9 @@ load( "//jaxlib:jax.bzl", "if_windows", - "nanobind_extension", - "py_extension", "windows_cc_shared_mlir_library", ) +load("//jaxlib:pywrap.bzl", "nanobind_pywrap_extension") load("//jaxlib:symlink_files.bzl", "symlink_inputs") package( @@ -44,7 +43,7 @@ LINKOPTS = select({ ], }) -py_extension( +nanobind_pywrap_extension( name = "_mlir", srcs = [ "@llvm-project//mlir:lib/Bindings/Python/MainModule.cpp", @@ -52,14 +51,13 @@ py_extension( copts = COPTS, linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:MLIRBindingsPythonCoreNoCAPI", + "@llvm-project//mlir:MLIRBindingsPythonCore", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@nanobind", ], ) -py_extension( +nanobind_pywrap_extension( name = "_mlirDialectsGPU", srcs = [ "@llvm-project//mlir:lib/Bindings/Python/DialectGPU.cpp", @@ -67,15 +65,14 @@ py_extension( copts = COPTS, linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:CAPIGPUHeaders", - "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:CAPIGPU", + "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@nanobind", ], ) -py_extension( +nanobind_pywrap_extension( name = "_mlirGPUPasses", srcs = [ "@llvm-project//mlir:lib/Bindings/Python/GPUPasses.cpp", @@ -83,14 +80,13 @@ py_extension( copts = COPTS, linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:CAPIGPUHeaders", + "@llvm-project//mlir:CAPIGPU", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@nanobind", ], ) -py_extension( +nanobind_pywrap_extension( name = "_mlirDialectsNVGPU", srcs = [ "@llvm-project//mlir:lib/Bindings/Python/DialectNVGPU.cpp", @@ -98,15 +94,14 @@ py_extension( copts = COPTS, linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:CAPIIRHeaders", - "@llvm-project//mlir:CAPINVGPUHeaders", + "@llvm-project//mlir:CAPIIR", + "@llvm-project//mlir:CAPINVGPU", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@nanobind", ], ) -py_extension( +nanobind_pywrap_extension( name = "_mlirDialectsLLVM", srcs = [ "@llvm-project//mlir:lib/Bindings/Python/DialectLLVM.cpp", @@ -114,15 +109,14 @@ py_extension( copts = COPTS, linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:CAPIIRHeaders", - "@llvm-project//mlir:CAPILLVMHeaders", + "@llvm-project//mlir:CAPIIR", + "@llvm-project//mlir:CAPILLVM", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@nanobind", ], ) -py_extension( +nanobind_pywrap_extension( name = "_mlirDialectsSparseTensor", srcs = [ "@llvm-project//mlir:lib/Bindings/Python/DialectSparseTensor.cpp", @@ -130,14 +124,13 @@ py_extension( copts = COPTS, linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:CAPISparseTensorHeaders", + "@llvm-project//mlir:CAPISparseTensor", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@nanobind", ], ) -py_extension( +nanobind_pywrap_extension( name = "_mlirSparseTensorPasses", srcs = [ "@llvm-project//mlir:lib/Bindings/Python/SparseTensorPasses.cpp", @@ -145,22 +138,20 @@ py_extension( copts = COPTS, linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:CAPISparseTensorHeaders", + "@llvm-project//mlir:CAPISparseTensor", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@nanobind", ], ) -py_extension( +nanobind_pywrap_extension( name = "_mosaic_gpu_ext", srcs = ["mosaic_gpu_ext.cc"], copts = COPTS, linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "//jaxlib/mosaic/dialect/gpu:gpu_dialect_capi_headers", - "@llvm-project//mlir:CAPIIRHeaders", + "//jaxlib/mosaic/dialect/gpu:gpu_dialect_capi", + "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeadersAndDeps", "@nanobind", ], @@ -171,17 +162,16 @@ py_extension( # :jaxlib_mlir_capi_shared_library). This ensures that the RPATH works correctly # across platforms. It's not clear if Windows supports RPATH-like functionality # across different directories at all. -py_extension( +nanobind_pywrap_extension( name = "_tpu_ext", srcs = ["tpu_ext.cc"], copts = COPTS, linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "//jaxlib/mosaic:tpu_dialect_capi_headers", + "//jaxlib/mosaic:tpu_dialect_capi", "@com_google_absl//absl/log:check", "@llvm-project//llvm:Support", - "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeadersAndDeps", "@nanobind", "@xla//xla/python:nb_numpy", @@ -190,7 +180,7 @@ py_extension( ) # This target contains the extension and it's Python dependencies, which are not -# supported by the `py_extension`/`nanobind_extension` macros. +# supported by the `nanobind_pywrap_extension`/`nanobind_extension` macros. py_library( name = "_tpu_ext_lib", deps = [ @@ -200,19 +190,22 @@ py_library( ], ) -nanobind_extension( +nanobind_pywrap_extension( name = "_triton_ext", srcs = ["triton_ext.cc"], copts = COPTS, linkopts = LINKOPTS, pytype_srcs = ["_triton_ext.pyi"], deps = [ - ":jaxlib_mlir_capi_shared_library", - "//jaxlib/triton:triton_dialect_capi_headers", - "@llvm-project//mlir:CAPIIRHeaders", - "@llvm-project//mlir:MLIRBindingsPythonNanobindHeadersAndDeps", "@nanobind", - ], + ] + if_windows( + [], + [ + "//jaxlib/triton:triton_dialect_capi", + "@llvm-project//mlir:CAPIIR", + "@llvm-project//mlir:MLIRBindingsPythonNanobindHeadersAndDeps", + ], + ), ) symlink_inputs( @@ -235,7 +228,7 @@ cc_library( hdrs = ["jaxlib_mlir_capi_shims.h"], deps = [ "@llvm-project//mlir:BuiltinToLLVMIRTranslation", - "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:GPUPipelines", "@llvm-project//mlir:GPUToLLVMIRTranslation", "@llvm-project//mlir:LLVMToLLVMIRTranslation", @@ -250,34 +243,33 @@ cc_library( name = "jaxlib_mlir_capi_shims_hdrs", hdrs = ["jaxlib_mlir_capi_shims.h"], deps = [ - "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:CAPIIR", ], ) # JAX-specific registrations. -py_extension( +nanobind_pywrap_extension( name = "register_jax_dialects", srcs = ["register_jax_dialects.cc"], copts = COPTS, linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "//jaxlib/mosaic/gpu:mlir_capi_headers", - "@llvm-project//mlir:CAPIArithHeaders", - "@llvm-project//mlir:CAPIGPUHeaders", - "@llvm-project//mlir:CAPIIRHeaders", - "@llvm-project//mlir:CAPILLVMHeaders", - "@llvm-project//mlir:CAPIMathHeaders", - "@llvm-project//mlir:CAPIMemRefHeaders", - "@llvm-project//mlir:CAPINVGPUHeaders", - "@llvm-project//mlir:CAPINVVMHeaders", - "@llvm-project//mlir:CAPISCFHeaders", - "@llvm-project//mlir:CAPITransformsHeaders", - "@llvm-project//mlir:CAPIVectorHeaders", + "//jaxlib/mosaic/gpu:mlir_capi", + "@llvm-project//mlir:CAPIArith", + "@llvm-project//mlir:CAPIGPU", + "@llvm-project//mlir:CAPIIR", + "@llvm-project//mlir:CAPILLVM", + "@llvm-project//mlir:CAPIMath", + "@llvm-project//mlir:CAPIMemRef", + "@llvm-project//mlir:CAPINVGPU", + "@llvm-project//mlir:CAPINVVM", + "@llvm-project//mlir:CAPISCF", + "@llvm-project//mlir:CAPITransforms", + "@llvm-project//mlir:CAPIVector", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@local_config_python//:headers", "@nanobind", - "@shardy//shardy/integrations/c:sdy_capi_headers", + "@shardy//shardy/integrations/c:sdy_capi", ], ) @@ -285,7 +277,7 @@ py_extension( # MHLO Extensions ##---------------------------------------------------------------------------## -py_extension( +nanobind_pywrap_extension( name = "_mlirHlo", srcs = [ "@xla//xla/mlir_hlo:bindings/python/MlirHloModule.cc", @@ -293,12 +285,11 @@ py_extension( copts = COPTS, linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@local_config_python//:headers", "@nanobind", - "@xla//xla/mlir_hlo:CAPIHeaders", + "@xla//xla/mlir_hlo:CAPI", ], ) @@ -306,7 +297,7 @@ py_extension( # Shardy Extensions ##---------------------------------------------------------------------------## -py_extension( +nanobind_pywrap_extension( name = "_sdy", srcs = [ "@shardy//shardy/integrations/python/ir:sdy_module.cc", @@ -314,13 +305,12 @@ py_extension( copts = COPTS, linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:IR", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@local_config_python//:headers", "@nanobind", - "@shardy//shardy/integrations/c:sdy_capi_headers", + "@shardy//shardy/integrations/c:sdy_capi", ], ) @@ -328,7 +318,7 @@ py_extension( # Stablehlo Extensions ##---------------------------------------------------------------------------## -py_extension( +nanobind_pywrap_extension( name = "_chlo", srcs = [ "@stablehlo//:chlo_py_api_files", @@ -336,16 +326,15 @@ py_extension( copts = COPTS, linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@local_config_python//:headers", "@nanobind", - "@stablehlo//:chlo_capi_headers", + "@stablehlo//:chlo_capi", ], ) -py_extension( +nanobind_pywrap_extension( name = "_stablehlo", srcs = [ "@stablehlo//:stablehlo_py_api_files", @@ -353,13 +342,12 @@ py_extension( copts = COPTS, linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", "@llvm-project//llvm:Support", - "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@local_config_python//:headers", "@nanobind", - "@stablehlo//:stablehlo_capi_headers", + "@stablehlo//:stablehlo_capi", ], ) diff --git a/jaxlib/mlir/_mlir_libs/triton_ext.cc b/jaxlib/mlir/_mlir_libs/triton_ext.cc index 7fba7e1dfe80..687ceec4cd33 100644 --- a/jaxlib/mlir/_mlir_libs/triton_ext.cc +++ b/jaxlib/mlir/_mlir_libs/triton_ext.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#ifndef _WIN32 + #include #include @@ -74,3 +76,11 @@ NB_MODULE(_triton_ext, m) { return encoding; }); } + +#else // _WIN32 + +#include "nanobind/nanobind.h" + +NB_MODULE(_triton_ext, m) {} + +#endif // _WIN32 diff --git a/jaxlib/pyinit_stub.c b/jaxlib/pyinit_stub.c new file mode 100644 index 000000000000..7fc873d9ae0e --- /dev/null +++ b/jaxlib/pyinit_stub.c @@ -0,0 +1,28 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Stub that reexports Wrapped_PyInit_module as PyInit_module. + +extern void* Wrapped_PyInit_@MODULE_NAME@(); + +#if defined(WIN32) || defined(_WIN32) +#define EXPORT_SYMBOL __declspec(dllexport) +#else +#define EXPORT_SYMBOL __attribute__ ((visibility("default"))) +#endif + +EXPORT_SYMBOL void* PyInit_@MODULE_NAME@() { + return Wrapped_PyInit_@MODULE_NAME@(); +} diff --git a/jaxlib/pywrap.bzl b/jaxlib/pywrap.bzl new file mode 100644 index 000000000000..75324e01907a --- /dev/null +++ b/jaxlib/pywrap.bzl @@ -0,0 +1,89 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Wrappers around pywrap rules for JAX.""" + +load("@bazel_skylib//rules:expand_template.bzl", "expand_template") +load( + "@xla//third_party/py/rules_pywrap:pywrap.impl.bzl", + "pybind_extension", + _pywrap_binaries = "pywrap_binaries", + _pywrap_library = "pywrap_library", +) + +pywrap_library = _pywrap_library +pywrap_binaries = _pywrap_binaries + +def nanobind_pywrap_extension( + name, + srcs = [], + deps = [], + pytype_srcs = [], + pytype_deps = [], + copts = [], + linkopts = [], + visibility = None): + # buildifier: disable=function-docstring-args + "Python extension rule using nanobind and the pywrap rules." + module_name = name + lib_name = name + "_pywrap_library" + src_cc_name = name + "_pywrap_stub.c" + + # We put the entire contents of the extension in a single cc_library, which will become part of + # the common pywrap library. All the contents of all extensions will end up in the common + # library. + native.cc_library( + name = lib_name, + srcs = srcs, + copts = copts, + deps = deps, + local_defines = [ + "PyInit_{}=Wrapped_PyInit_{}".format(module_name, module_name), + ], + visibility = ["//visibility:private"], + ) + + # We build a small stub library as the extension that forwards to the PyInit_... symbol from the + # common pywrap library. + expand_template( + name = name + "_pywrap_stub", + testonly = True, + out = src_cc_name, + substitutions = { + "@MODULE_NAME@": module_name, + }, + template = "//jaxlib:pyinit_stub.c", + visibility = ["//visibility:private"], + ) + + # Despite its name "pybind_extension" has nothing to do with pybind. It is the Python extension + # rule from the pywrap rules. + pybind_extension( + name = name, + srcs = [src_cc_name], + deps = [":" + lib_name], + linkopts = linkopts, + visibility = visibility, + default_deps = [], + common_lib_packages = [ + "jaxlib", + ], + ) + + # Create a py_library with the type stubs as data, on which wheel builds can depend. + native.py_library( + name = name + "_type_stubs", + data = pytype_srcs, + deps = pytype_deps, + ) diff --git a/jaxlib/setup.py b/jaxlib/setup.py index b3a37a25f1b2..5bd010525c96 100644 --- a/jaxlib/setup.py +++ b/jaxlib/setup.py @@ -76,6 +76,8 @@ def has_ext_modules(self): package_data={ 'jaxlib': [ '*.so', + '*.dylib', + '*.dll', '*.pyd*', 'py.typed', 'cpu/*', diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index 79a1f7e7089d..3ea09802dfaf 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -64,9 +64,10 @@ py_binary( "LICENSE.txt", "//jaxlib", "//jaxlib:README.md", + "//jaxlib:jaxlib_binaries", "//jaxlib:setup.py", "//jaxlib/xla:xla_client.py", - "//jaxlib/xla:xla_extension", + "//jaxlib/xla:xla_extension_type_stubs", "@xla//xla/ffi/api:api.h", "@xla//xla/ffi/api:c_api.h", "@xla//xla/ffi/api:ffi.h", diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index fcc811789c19..bab1c6014ff4 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -60,11 +60,11 @@ r = runfiles.Create() - def _is_mac(): return platform.system() == "Darwin" +soext = "dll" if build_utils.is_windows() else ("dylib" if _is_mac() else "so") pyext = "pyd" if build_utils.is_windows() else "so" @@ -135,7 +135,7 @@ def verify_mac_libraries_dont_reference_chkstack(): if not _is_mac(): return nm = subprocess.run( - ["nm", "-g", r.Rlocation("__main__/jaxlib/xla/xla_extension.so")], + ["nm", "-g", r.Rlocation(f"__main__/jaxlib/xla_extension.{pyext}")], capture_output=True, text=True, check=False, @@ -186,6 +186,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu): src_files=[ f"__main__/jaxlib/cpu_feature_guard.{pyext}", f"__main__/jaxlib/utils.{pyext}", + "__main__/jaxlib/jax_common.dll" if build_utils.is_windows() else f"__main__/jaxlib/libjax_common.{soext}", "__main__/jaxlib/lapack.py", "__main__/jaxlib/hlo_helpers.py", "__main__/jaxlib/gpu_prng.py", @@ -198,7 +199,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu): "__main__/jaxlib/plugin_support.py", "__main__/jaxlib/version.py", "__main__/jaxlib/xla/xla_client.py", - f"__main__/jaxlib/xla/xla_extension.{pyext}", + f"__main__/jaxlib/xla_extension.{pyext}", ], ) # This file is required by PEP-561. It marks jaxlib as package containing @@ -311,38 +312,31 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu): ) - if build_utils.is_windows(): - capi_so = "__main__/jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi.dll" - else: - so_ext = "dylib" if _is_mac() else "so" - capi_so = f"__main__/jaxlib/mlir/_mlir_libs/libjaxlib_mlir_capi.{so_ext}" - mlir_libs_dir = jaxlib_dir / "mlir" / "_mlir_libs" copy_runfiles( dst_dir=mlir_libs_dir, src_files=[ - capi_so, "__main__/jaxlib/mlir/_mlir_libs/__init__.py", - f"__main__/jaxlib/mlir/_mlir_libs/_mlir.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_chlo.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_mlirHlo.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsSparseTensor.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_mlirSparseTensorPasses.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_mosaic_gpu_ext.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_tpu_ext.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_sdy.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_stablehlo.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/register_jax_dialects.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsGPU.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsLLVM.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsNVGPU.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_mlirGPUPasses.{pyext}", + f"__main__/jaxlib/_mlir.{pyext}", + f"__main__/jaxlib/_chlo.{pyext}", + f"__main__/jaxlib/_mlirHlo.{pyext}", + f"__main__/jaxlib/_mlirDialectsSparseTensor.{pyext}", + f"__main__/jaxlib/_mlirSparseTensorPasses.{pyext}", + f"__main__/jaxlib/_mosaic_gpu_ext.{pyext}", + f"__main__/jaxlib/_tpu_ext.{pyext}", + f"__main__/jaxlib/_sdy.{pyext}", + f"__main__/jaxlib/_stablehlo.{pyext}", + f"__main__/jaxlib/register_jax_dialects.{pyext}", + f"__main__/jaxlib/_mlirDialectsGPU.{pyext}", + f"__main__/jaxlib/_mlirDialectsLLVM.{pyext}", + f"__main__/jaxlib/_mlirDialectsNVGPU.{pyext}", + f"__main__/jaxlib/_mlirGPUPasses.{pyext}", ] + ( [] if build_utils.is_windows() else [ - f"__main__/jaxlib/mlir/_mlir_libs/_triton_ext.{pyext}", + f"__main__/jaxlib/_triton_ext.{pyext}", "__main__/jaxlib/mlir/_mlir_libs/_triton_ext.pyi", ] ), diff --git a/jaxlib/triton/BUILD b/jaxlib/triton/BUILD index 99cddd9e6381..64410fdfeb00 100644 --- a/jaxlib/triton/BUILD +++ b/jaxlib/triton/BUILD @@ -35,7 +35,10 @@ pytype_strict_library( "//jaxlib/mlir:ir", ] + if_windows( [], - ["//jaxlib/mlir/_mlir_libs:_triton_ext"], + [ + "//jaxlib/mlir/_mlir_libs:_triton_ext", + "//jaxlib/mlir/_mlir_libs:_triton_ext_type_stubs", + ], ), ) diff --git a/jaxlib/xla/BUILD b/jaxlib/xla/BUILD index c5d151b2cd3b..a299629c3ba5 100644 --- a/jaxlib/xla/BUILD +++ b/jaxlib/xla/BUILD @@ -24,6 +24,7 @@ load( "py_strict_test", "pytype_strict_library", ) +load("//jaxlib:pywrap.bzl", "nanobind_pywrap_extension") licenses(["notice"]) @@ -39,7 +40,7 @@ package_group( ], ) -nanobind_extension( +nanobind_pywrap_extension( name = "xla_extension", srcs = ["xla.cc"], pytype_deps = py_deps(["numpy"]), From 96e63eaee8a4f741eca6e30ebbed805df825e6bf Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 7 Apr 2025 14:46:38 -0700 Subject: [PATCH 440/483] jnp.linalg: add symmetrize_input argument & docs --- jax/_src/numpy/linalg.py | 25 ++++++++++++++++++------- tests/linalg_test.py | 2 +- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index 23f2a58b09f6..146bbbda0213 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -72,8 +72,8 @@ def _symmetrize(x: Array) -> Array: return (x + _H(x)) / 2 @export -@partial(jit, static_argnames=['upper']) -def cholesky(a: ArrayLike, *, upper: bool = False) -> Array: +@partial(jit, static_argnames=['upper', 'symmetrize_input']) +def cholesky(a: ArrayLike, *, upper: bool = False, symmetrize_input: bool = True) -> Array: """Compute the Cholesky decomposition of a matrix. JAX implementation of :func:`numpy.linalg.cholesky`. @@ -98,6 +98,10 @@ def cholesky(a: ArrayLike, *, upper: bool = False) -> Array: Must have shape ``(..., N, N)``. upper: if True, compute the upper Cholesky decomposition `U`. if False (default), compute the lower Cholesky decomposition `L`. + symmetrize_input: if True (default) then input is symmetrized, which leads + to better behavior under automatic differentiation. Note that when this + is set to True, both the upper and lower triangles of the input will + be used in computing the decomposition. Returns: array of shape ``(..., N, N)`` representing the Cholesky decomposition @@ -135,7 +139,7 @@ def cholesky(a: ArrayLike, *, upper: bool = False) -> Array: """ a = ensure_arraylike("jnp.linalg.cholesky", a) a, = promote_dtypes_inexact(a) - L = lax_linalg.cholesky(a) + L = lax_linalg.cholesky(a, symmetrize_input=symmetrize_input) return L.mT.conj() if upper else L @@ -821,7 +825,9 @@ def eigh(a: ArrayLike, UPLO: str | None = None, UPLO: specifies whether the calculation is done with the lower triangular part of ``a`` (``'L'``, default) or the upper triangular part (``'U'``). symmetrize_input: if True (default) then input is symmetrized, which leads - to better behavior under automatic differentiation. + to better behavior under automatic differentiation. Note that when this + is set to True, both the upper and lower triangles of the input will + be used in computing the decomposition. Returns: A namedtuple ``(eigenvalues, eigenvectors)`` where @@ -863,8 +869,9 @@ def eigh(a: ArrayLike, UPLO: str | None = None, @export -@partial(jit, static_argnames=('UPLO',)) -def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array: +@partial(jit, static_argnames=('UPLO', 'symmetrize_input')) +def eigvalsh(a: ArrayLike, UPLO: str | None = 'L', *, + symmetrize_input: bool = True) -> Array: """ Compute the eigenvalues of a Hermitian matrix. @@ -875,6 +882,10 @@ def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array: or symmetric (if real) matrix. UPLO: specifies whether the calculation is done with the lower triangular part of ``a`` (``'L'``, default) or the upper triangular part (``'U'``). + symmetrize_input: if True (default) then input is symmetrized, which leads + to better behavior under automatic differentiation. Note that when this + is set to True, both the upper and lower triangles of the input will + be used in computing the decomposition. Returns: An array of shape ``(..., M)`` containing the eigenvalues, sorted in @@ -894,7 +905,7 @@ def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array: """ a = ensure_arraylike("jnp.linalg.eigvalsh", a) a, = promote_dtypes_inexact(a) - w, _ = eigh(a, UPLO) + w, _ = eigh(a, UPLO, symmetrize_input=symmetrize_input) return w diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 20c998d6a685..1670f1ee4abd 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -96,7 +96,7 @@ def args_maker(): a = rng(factor_shape, dtype) return [np.matmul(a, jnp.conj(T(a)))] - jnp_fun = partial(jnp.linalg.cholesky, upper=upper) + jnp_fun = partial(jnp.linalg.cholesky, upper=upper, symmetrize_input=True) def np_fun(x, upper=upper): # Upper argument added in NumPy 2.0.0 From b18dc1dfd7668859e07c5c823d14154564647708 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Mon, 7 Apr 2025 14:46:53 -0700 Subject: [PATCH 441/483] [Mosaic GPU] Add scaffolding for a new lowering "axis" (UserThreadSemantics), in addition to the existing ThreadSemantics (renamed to LoweringSemantics). UserThreadSemantics controls the thread semantics of the Pallas user's code, whereas LoweringSemantics controls the level at which Mosaic GPU emits code. PiperOrigin-RevId: 744857085 --- jax/_src/pallas/mosaic_gpu/core.py | 16 +- jax/_src/pallas/mosaic_gpu/lowering.py | 199 ++++++++++-------- .../mosaic_gpu/pallas_call_registration.py | 4 +- jax/_src/pallas/mosaic_gpu/primitives.py | 72 ++++--- jax/experimental/mosaic/gpu/__init__.py | 2 +- jax/experimental/mosaic/gpu/core.py | 10 +- jax/experimental/pallas/mosaic_gpu.py | 2 +- tests/mosaic/gpu_test.py | 8 +- tests/pallas/mosaic_gpu_test.py | 53 ++--- tests/pallas/ops_test.py | 2 +- 10 files changed, 205 insertions(+), 163 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 444fe6e50f88..d964a8a90144 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -91,7 +91,7 @@ class GPUCompilerParams(pallas_core.CompilerParams): delay_release: int = 0 profile_space: int = 0 profile_dir: str = "" - thread_semantics: mgpu.core.ThreadSemantics = mgpu.core.ThreadSemantics.Lane + lowering_semantics: mgpu.core.LoweringSemantics = mgpu.core.LoweringSemantics.Lane def __post_init__(self): if bool(self.profile_space) ^ bool(self.profile_dir): @@ -142,6 +142,20 @@ def get_ref_aval(self) -> pallas_core.TransformedRef | AbstractMemoryRef: return self(()).get_ref_aval() +class PrimitiveSemantics(enum.Enum): + """Thread semantics for a primitives at the Pallas user-level.""" + + Warp = enum.auto() + Warpgroup = enum.auto() + + +# Convenience constants for (lowering, primitive) thread semantics pairs. +LANExWG_SEMANTICS = ( + mgpu.LoweringSemantics.Lane, PrimitiveSemantics.Warpgroup) +WGxWG_SEMANTICS = ( + mgpu.LoweringSemantics.Warpgroup, PrimitiveSemantics.Warpgroup) + + def kernel( body: Callable[..., None], out_shape: object, diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 2423e1c1a2a7..f7bdbccc1ad6 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -88,13 +88,13 @@ def _align_to(x: int, alignment: int): @dataclasses.dataclass(frozen=True, kw_only=True) class ResourceEstimatorContext: axis_names: _AxisNames - thread_semantics: mgpu.ThreadSemantics + lowering_semantics: mgpu.LoweringSemantics @property def arrival_multiplier(self) -> int: return ( WARPGROUP_SIZE - if self.thread_semantics == mgpu.ThreadSemantics.Lane + if self.lowering_semantics == mgpu.LoweringSemantics.Lane else 1 ) @@ -308,7 +308,8 @@ class ModuleContext: name_stack: source_info_util.NameStack traceback_caches: mlir.TracebackCaches squashed_dims: tuple[int, ...] - thread_semantics: mgpu.ThreadSemantics + lowering_semantics: mgpu.LoweringSemantics + primitive_semantics: gpu_core.PrimitiveSemantics def reserve_barrier(self, barrier: mgpu.Barrier) -> mgpu.BarrierRef: """Reserves a barrier. @@ -367,7 +368,7 @@ def scratch_view( smem = ir.Attribute.parse("#gpu.address_space") i8 = ir.IntegerType.get_signless(8) i32 = ir.IntegerType.get_signless(32) - if self.thread_semantics == mgpu.ThreadSemantics.Lane: + if self.lowering_semantics == mgpu.LoweringSemantics.Lane: smem_base = gpu_dialect.dynamic_shared_memory( ir.MemRefType.get((mgpu_utils.DYNAMIC,), i8, memory_space=smem) ) @@ -383,7 +384,7 @@ def scratch_view( # The below code emission relies on the assumption that the first scratch # operand provided by Mosaic GPU always begins at the beginning of # dynamic SMEM. Mosaic GPU is expected to uphold that invariant. - if self.thread_semantics == mgpu.ThreadSemantics.Lane: + if self.lowering_semantics == mgpu.LoweringSemantics.Lane: view = memref_dialect.view( scratch_ty, smem_base, _as_index(off), [] ) @@ -416,7 +417,7 @@ class LoweringRuleContext: def estimator_ctx(self) -> ResourceEstimatorContext: return ResourceEstimatorContext( axis_names=self.module_ctx.axis_names, - thread_semantics=self.module_ctx.thread_semantics, + lowering_semantics=self.module_ctx.lowering_semantics, ) @@ -703,8 +704,8 @@ def lower_jaxpr_to_module( debug_info = jaxpr.debug_info params = compiler_params.get("mosaic_gpu", {}) approx_math = params.get("approx_math", False) - thread_semantics = params.get( - "thread_semantics", mgpu_core.ThreadSemantics.Lane + lowering_semantics = params.get( + "lowering_semantics", mgpu_core.LoweringSemantics.Lane ) if len(cluster) < 3: @@ -732,7 +733,7 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): else: tmem_cols = 0 - if thread_semantics == mgpu.ThreadSemantics.Lane: + if lowering_semantics == mgpu.LoweringSemantics.Lane: single_lane_predicate = mgpu.single_thread_predicate(per_block=False) else: # Warpgroup semantics do not have a single lane predicate. single_lane_predicate = None @@ -752,7 +753,8 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): name_stack=source_info_util.NameStack(), traceback_caches=mlir.TracebackCaches(), squashed_dims=squashed_dims, - thread_semantics=thread_semantics, + lowering_semantics=lowering_semantics, + primitive_semantics=gpu_core.PrimitiveSemantics.Warpgroup, ) del runtime_smem, grouped_barriers, runtime_barriers @@ -762,7 +764,7 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): rs = _estimate_resources( ResourceEstimatorContext( - axis_names=axis_names, thread_semantics=thread_semantics + axis_names=axis_names, lowering_semantics=lowering_semantics ), jaxpr, ) @@ -801,7 +803,7 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): ) ) - if thread_semantics == mgpu.ThreadSemantics.Warpgroup: + if lowering_semantics == mgpu.LoweringSemantics.Warpgroup: # Run Python lowering passes. The remaining passes will be run in C++ in # jax/jaxlib/mosaic/gpu/custom_call.cc mgpu.infer_layout(module) # pytype: disable=attribute-error @@ -820,17 +822,21 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): mosaic_lowering_rules = { # Lowering rules when using Mosaic GPU lane semantics. - mgpu.ThreadSemantics.Lane: {} , + (mgpu.LoweringSemantics.Lane, gpu_core.PrimitiveSemantics.Warpgroup): {} , # Lowering rules when using Mosaic GPU warpgroup semantics. - mgpu.ThreadSemantics.Warpgroup: {}, + (mgpu.LoweringSemantics.Warpgroup, + gpu_core.PrimitiveSemantics.Warpgroup): {}, } def register_lowering_rule( - primitive: jax_core.Primitive, thread_semantics: mgpu.ThreadSemantics + primitive: jax_core.Primitive, + lowering_semantics: mgpu.LoweringSemantics, + primitive_semantics: gpu_core.PrimitiveSemantics = gpu_core.PrimitiveSemantics.Warpgroup, ): def deco(fn): - mosaic_lowering_rules[thread_semantics][primitive] = fn + mosaic_lowering_rules[ + (lowering_semantics, primitive_semantics)][primitive] = fn return fn return deco @@ -866,7 +872,7 @@ def write_env(var: jax_core.Var, val, require_value: bool = True): # TODO(apaszke): Handle other avals (refs, etc.). if isinstance(aval := var.aval, jax_core.ShapedArray): # TODO(apaszke): Clarify the type invariants for lane semantics? - if module_ctx.thread_semantics == mgpu.ThreadSemantics.Warpgroup: + if module_ctx.lowering_semantics == mgpu.LoweringSemantics.Warpgroup: # Shaped arrays must be vectors if and only if their shape is non-empty. # Those with empty shapes should be represented by their scalar type. mlir_dtype = mgpu_utils.dtype_to_ir_type(aval.dtype) @@ -903,10 +909,13 @@ def write_env(var: jax_core.Var, val, require_value: bool = True): ) loc = mlir._source_info_to_location(module_ctx, eqn.primitive, source_info) with source_info_util.user_context(eqn.source_info.traceback), loc: - if eqn.primitive not in mosaic_lowering_rules[module_ctx.thread_semantics]: + if eqn.primitive not in mosaic_lowering_rules[ + (module_ctx.lowering_semantics, module_ctx.primitive_semantics)]: raise NotImplementedError( "Unimplemented primitive in Pallas Mosaic GPU lowering: " - f"{eqn.primitive.name}. " + f"{eqn.primitive.name} for lowering semantics " + f"{module_ctx.lowering_semantics} and user thread semantics " + f"{module_ctx.primitive_semantics}. " "Please file an issue on https://github.com/jax-ml/jax/issues." ) new_local_name_stack = [scope.name for scope in eqn.source_info.name_stack.stack] @@ -918,7 +927,9 @@ def write_env(var: jax_core.Var, val, require_value: bool = True): wrapper_stack = contextlib.ExitStack() wrapper_stack.enter_context(launch_ctx.named_region(name)) named_regions.append(wrapper_stack) - rule = mosaic_lowering_rules[module_ctx.thread_semantics][eqn.primitive] + rule = mosaic_lowering_rules[ + (module_ctx.lowering_semantics, module_ctx.primitive_semantics) + ][eqn.primitive] rule_ctx = LoweringRuleContext( module_ctx, launch_ctx, @@ -947,8 +958,9 @@ def write_env(var: jax_core.Var, val, require_value: bool = True): return map(read_env, jaxpr.outvars) -@register_lowering_rule(primitives.program_id_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(primitives.program_id_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(primitives.program_id_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule( + primitives.program_id_p, mgpu.LoweringSemantics.Warpgroup) def _program_id_lowering_rule(ctx: LoweringRuleContext, axis): if ctx.module_ctx.program_ids is None: raise NotImplementedError("pl.program_id() is not supported in this context") @@ -1015,8 +1027,9 @@ def lowering_rule(ctx: LoweringRuleContext, *args, **params): return lowering_rule -@register_lowering_rule(primitives.num_programs_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(primitives.num_programs_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(primitives.num_programs_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule( + primitives.num_programs_p, mgpu.LoweringSemantics.Warpgroup) def _num_programs_lowering_rule(ctx: LoweringRuleContext, axis): del ctx # Unused. return arith_dialect.index_cast( @@ -1089,7 +1102,7 @@ def _ndindexer_indices(indexer: indexing.NDIndexer) -> tuple[gpu_core.Index, ... return tuple(indices) -@register_lowering_rule(sp.get_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(sp.get_p, mgpu.LoweringSemantics.Lane) def _get_lowering_rule(ctx: LoweringRuleContext, x_ref, *leaves, tree): if isinstance(x_ref, tcgen05.TMEMRef): transforms = jax.tree.unflatten(tree, leaves) @@ -1132,7 +1145,7 @@ def _get_lowering_rule(ctx: LoweringRuleContext, x_ref, *leaves, tree): raise NotImplementedError(f"Unsupported transforms: {transforms}") -@register_lowering_rule(sp.get_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(sp.get_p, mgpu.LoweringSemantics.Warpgroup) def _get_lowering_rule_wg(ctx: LoweringRuleContext, x_smem, *leaves, tree): if not isinstance(x_smem, ir.Value) and ir.MemRefType.isinstance(x_smem): raise TypeError(f"Can only load from references (got {x_smem}).") @@ -1157,7 +1170,7 @@ def _get_lowering_rule_wg(ctx: LoweringRuleContext, x_smem, *leaves, tree): return memref_dialect.load(x_smem, []) -@register_lowering_rule(sp.swap_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(sp.swap_p, mgpu.LoweringSemantics.Lane) def _swap_lowering_rule( ctx: LoweringRuleContext, x_smem, value, *leaves, tree ): @@ -1222,7 +1235,7 @@ def _swap_lowering_rule( raise NotImplementedError(f"Unsupported transforms: {transforms}") -@register_lowering_rule(sp.swap_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(sp.swap_p, mgpu.LoweringSemantics.Warpgroup) def _swap_lowering_rule_wg( ctx: LoweringRuleContext, x_smem, value, *leaves, tree ): @@ -1253,8 +1266,8 @@ def _swap_lowering_rule_wg( return old_value -@register_lowering_rule(pjit.pjit_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(pjit.pjit_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(pjit.pjit_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(pjit.pjit_p, mgpu.LoweringSemantics.Warpgroup) def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **kwargs): if jaxpr.consts: raise NotImplementedError @@ -1263,7 +1276,7 @@ def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **kwargs): ) -@register_lowering_rule(lax.slice_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.slice_p, mgpu.LoweringSemantics.Lane) def _slice_lowering_rule( ctx: LoweringRuleContext, x, limit_indices, start_indices, strides ): @@ -1273,8 +1286,8 @@ def _slice_lowering_rule( return x[tuple(slice(b, e) for b, e in zip(start_indices, limit_indices))] -@register_lowering_rule(lax.select_n_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.select_n_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(lax.select_n_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.select_n_p, mgpu.LoweringSemantics.Warpgroup) def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, *cases): if len(cases) != 2: raise NotImplementedError( @@ -1283,7 +1296,7 @@ def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, *cases): ) pred_aval, *cases_avals = ctx.avals_in [out_aval] = ctx.avals_out - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: pred = _ensure_fa(pred, pred_aval.dtype) cases = _bcast(*cases, *cases_avals, out_aval) # ``select`` expects the first case to be the true branch, but ``select_n`` @@ -1301,7 +1314,7 @@ def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, *cases): return arith_dialect.select(pred, *reversed(cases)) -@register_lowering_rule(lax.broadcast_in_dim_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.broadcast_in_dim_p, mgpu.LoweringSemantics.Lane) def _broadcast_in_dim_lowering_rule( ctx: LoweringRuleContext, x: mgpu.FragmentedArray, @@ -1331,7 +1344,8 @@ def _broadcast_in_dim_lowering_rule( return x.broadcast(shape) -@register_lowering_rule(lax.broadcast_in_dim_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule( + lax.broadcast_in_dim_p, mgpu.LoweringSemantics.Warpgroup) def _broadcast_in_dim_lowering_rule_wg( ctx: LoweringRuleContext, x: ir.Value, @@ -1351,7 +1365,7 @@ def _broadcast_in_dim_lowering_rule_wg( ) -@register_lowering_rule(lax.convert_element_type_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.convert_element_type_p, mgpu.LoweringSemantics.Lane) def _convert_element_type_lowering_rule( ctx: LoweringRuleContext, x, *, new_dtype, weak_type, sharding ): @@ -1362,7 +1376,8 @@ def _convert_element_type_lowering_rule( ) -@register_lowering_rule(lax.convert_element_type_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule( + lax.convert_element_type_p, mgpu.LoweringSemantics.Warpgroup) def _convert_element_type_lowering_rule_wg( ctx: LoweringRuleContext, x, *, new_dtype, weak_type, sharding ): @@ -1454,12 +1469,12 @@ def convert(ty, x): return convert(ty, x) -mosaic_lowering_rules[mgpu.ThreadSemantics.Lane].update({ +mosaic_lowering_rules[gpu_core.LANExWG_SEMANTICS].update({ lax.neg_p: lambda ctx, x: -x, lax.not_p: lambda ctx, x: ~x, }) -mosaic_lowering_rules[mgpu.ThreadSemantics.Warpgroup].update({ +mosaic_lowering_rules[gpu_core.WGxWG_SEMANTICS].update({ lax.neg_p: _lower_fun(lambda x: jnp.subtract(0, x), multiple_results=False), lax.not_p: _lower_fun( lambda x: jnp.astype(jnp.bitwise_xor(jnp.astype(x, int), -1), jnp.dtype(x)), multiple_results=False, @@ -1472,7 +1487,7 @@ def _binary_op_lowering_rule(ctx: LoweringRuleContext, x, y, *, impl): return impl(x, y) -mosaic_lowering_rules[mgpu.ThreadSemantics.Lane].update({ +mosaic_lowering_rules[gpu_core.LANExWG_SEMANTICS].update({ lax.add_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x + y), lax.sub_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x - y), lax.mul_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x * y), @@ -1533,7 +1548,7 @@ def _binary_op_lowering_rule_wg( arith_dialect.minimumf, ), ]: - mosaic_lowering_rules[mgpu.ThreadSemantics.Warpgroup][op] = partial( + mosaic_lowering_rules[gpu_core.WGxWG_SEMANTICS][op] = partial( _binary_op_lowering_rule_wg, si_impl=si_impl, ui_impl=ui_impl, @@ -1552,7 +1567,7 @@ def _binary_boolean_op_lowering_rule_wg( (lax.or_p, arith_dialect.ori), (lax.xor_p, arith_dialect.xori), ]: - mosaic_lowering_rules[mgpu.ThreadSemantics.Warpgroup][op] = partial( + mosaic_lowering_rules[gpu_core.WGxWG_SEMANTICS][op] = partial( _binary_boolean_op_lowering_rule_wg, impl=impl, ) @@ -1585,7 +1600,7 @@ def _comparison_lowering_rule_wg( (lax.gt_p, CmpIPred.sgt, CmpIPred.ugt, CmpFPred.OGT), (lax.ge_p, CmpIPred.sge, CmpIPred.uge, CmpFPred.OGE), ]: - mosaic_lowering_rules[mgpu.ThreadSemantics.Warpgroup][op] = partial( + mosaic_lowering_rules[gpu_core.WGxWG_SEMANTICS][op] = partial( _comparison_lowering_rule_wg, si_pred=si_pred, ui_pred=ui_pred, @@ -1593,7 +1608,7 @@ def _comparison_lowering_rule_wg( ) -@register_lowering_rule(lax.div_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.div_p, mgpu.LoweringSemantics.Lane) def _div_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out) if ir.FloatType.isinstance(x.mlir_dtype): @@ -1601,19 +1616,19 @@ def _div_lowering_rule(ctx: LoweringRuleContext, x, y): return x // y -@register_lowering_rule(lax.integer_pow_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.integer_pow_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(lax.integer_pow_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.integer_pow_p, mgpu.LoweringSemantics.Warpgroup) def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, y): if y != 2: raise NotImplementedError return _square_lowering_rule(ctx, x) -@register_lowering_rule(lax.square_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.square_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(lax.square_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.square_p, mgpu.LoweringSemantics.Warpgroup) def _square_lowering_rule(ctx: LoweringRuleContext, x): [x_aval] = ctx.avals_in - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: x = _ensure_fa(x, x_aval.dtype) return x * x if jnp.issubdtype(x_aval.dtype, jnp.integer): @@ -1623,13 +1638,13 @@ def _square_lowering_rule(ctx: LoweringRuleContext, x): raise NotImplementedError(f"Unsupported dtype {x_aval.dtype}") -@register_lowering_rule(lax.rsqrt_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.rsqrt_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(lax.rsqrt_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.rsqrt_p, mgpu.LoweringSemantics.Warpgroup) def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x, accuracy): if accuracy is not None: raise NotImplementedError("Not implemented: accuracy") [x_aval] = ctx.avals_in - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: return _ensure_fa(x, x_aval.dtype).rsqrt(approx=ctx.module_ctx.approx_math) fastmath = ( arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None @@ -1639,13 +1654,13 @@ def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x, accuracy): ) -@register_lowering_rule(lax.tanh_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.tanh_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(lax.tanh_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.tanh_p, mgpu.LoweringSemantics.Warpgroup) def _tanh_lowering_rule(ctx: LoweringRuleContext, x, accuracy): if accuracy is not None: raise NotImplementedError("Not implemented: accuracy") [x_aval] = ctx.avals_in - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: return _ensure_fa(x, x_aval.dtype).tanh(approx=ctx.module_ctx.approx_math) fastmath = ( arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None @@ -1659,21 +1674,21 @@ def _logistic(x, accuracy): return 1.0 / (1 + lax.exp(-x)) -mosaic_lowering_rules[mgpu.ThreadSemantics.Lane][lax.logistic_p] = _lower_fun( +mosaic_lowering_rules[gpu_core.LANExWG_SEMANTICS][lax.logistic_p] = _lower_fun( _logistic, multiple_results=False ) -mosaic_lowering_rules[mgpu.ThreadSemantics.Warpgroup][lax.logistic_p] = ( +mosaic_lowering_rules[gpu_core.WGxWG_SEMANTICS][lax.logistic_p] = ( _lower_fun(_logistic, multiple_results=False) ) -@register_lowering_rule(lax.exp_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.exp_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(lax.exp_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.exp_p, mgpu.LoweringSemantics.Warpgroup) def _exp_lowering_rule(ctx: LoweringRuleContext, x, accuracy): if accuracy is not None: raise NotImplementedError("Not implemented: accuracy") [x_aval] = ctx.avals_in - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: return _ensure_fa(x, x_aval.dtype).exp(approx=ctx.module_ctx.approx_math) fastmath = ( arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None @@ -1681,13 +1696,13 @@ def _exp_lowering_rule(ctx: LoweringRuleContext, x, accuracy): return math_dialect.exp(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath) -@register_lowering_rule(lax.exp2_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.exp2_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(lax.exp2_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.exp2_p, mgpu.LoweringSemantics.Warpgroup) def _exp2_lowering_rule(ctx: LoweringRuleContext, x, accuracy): if accuracy is not None: raise NotImplementedError("Not implemented: accuracy") [x_aval] = ctx.avals_in - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: return _ensure_fa(x, x_aval.dtype).exp2(approx=ctx.module_ctx.approx_math) fastmath = ( arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None @@ -1695,13 +1710,13 @@ def _exp2_lowering_rule(ctx: LoweringRuleContext, x, accuracy): return math_dialect.exp2(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath) -@register_lowering_rule(lax.log_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.log_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(lax.log_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.log_p, mgpu.LoweringSemantics.Warpgroup) def _log_lowering_rule(ctx: LoweringRuleContext, x, accuracy): if accuracy is not None: raise NotImplementedError("Not implemented: accuracy") [x_aval] = ctx.avals_in - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: return _ensure_fa(x, x_aval.dtype).log(approx=ctx.module_ctx.approx_math) fastmath = ( arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None @@ -1709,7 +1724,7 @@ def _log_lowering_rule(ctx: LoweringRuleContext, x, accuracy): return math_dialect.log(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath) -@register_lowering_rule(lax.reduce_sum_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.reduce_sum_p, mgpu.LoweringSemantics.Lane) def _reduce_sum_lowering_rule(ctx: LoweringRuleContext, x, *, axes): [x_aval] = ctx.avals_in match x.layout: @@ -1729,7 +1744,7 @@ def _reduce_sum_lowering_rule(ctx: LoweringRuleContext, x, *, axes): raise NotImplementedError(f"Unsupported layout {x.layout}") -@register_lowering_rule(lax.reduce_max_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.reduce_max_p, mgpu.LoweringSemantics.Lane) def _reduce_max_lowering_rule(ctx: LoweringRuleContext, x, *, axes): [x_aval] = ctx.avals_in match x.layout: @@ -1770,7 +1785,7 @@ def _reduce_lowering_rule_wg( return vector_dialect.MultiDimReductionOp(kind, x, acc, axes) -@register_lowering_rule(lax.reduce_sum_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(lax.reduce_sum_p, mgpu.LoweringSemantics.Warpgroup) def _reduce_sum_lowering_rule_wg(ctx: LoweringRuleContext, x, *, axes): op = _reduce_lowering_rule_wg( vector_dialect.CombiningKind.ADD, 0, ctx, x, axes=axes @@ -1781,7 +1796,7 @@ def _reduce_sum_lowering_rule_wg(ctx: LoweringRuleContext, x, *, axes): return op.result -@register_lowering_rule(lax.reduce_max_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(lax.reduce_max_p, mgpu.LoweringSemantics.Warpgroup) def _reduce_max_lowering_rule_wg(ctx: LoweringRuleContext, x, *, axes): [x_aval] = ctx.avals_in if jnp.issubdtype(x_aval.dtype, jnp.floating): @@ -1822,8 +1837,8 @@ def _resolve_cluster_axis(axis_names: _AxisNames | None, axis_name: str): return gpu_dialect.Dimension(axis_names.cluster.index(axis_name)) -@register_lowering_rule(lax.axis_index_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.axis_index_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(lax.axis_index_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.axis_index_p, mgpu.LoweringSemantics.Warpgroup) def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable): axis_names = ctx.module_ctx.axis_names if not axis_names: @@ -1883,7 +1898,7 @@ def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable): ) -@register_lowering_rule(primitives.debug_print_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(primitives.debug_print_p, mgpu.LoweringSemantics.Lane) def _debug_print_lowering_rule( ctx: LoweringRuleContext, *args, @@ -1911,7 +1926,7 @@ def _debug_print_lowering_rule( return () -@register_lowering_rule(primitives.debug_print_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(primitives.debug_print_p, mgpu.LoweringSemantics.Warpgroup) def _debug_print_lowering_rule_wg( ctx: LoweringRuleContext, *args, @@ -1925,8 +1940,8 @@ def _debug_print_lowering_rule_wg( return () -@register_lowering_rule(primitives.run_scoped_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(primitives.run_scoped_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(primitives.run_scoped_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(primitives.run_scoped_p, mgpu.LoweringSemantics.Warpgroup) def _run_scoped_lowering_rule( ctx: LoweringRuleContext, *consts, jaxpr: jax_core.Jaxpr ): @@ -1937,7 +1952,7 @@ def _run_scoped_lowering_rule( aval = v.aval if isinstance(aval, gpu_core.WGMMAAbstractAccumulatorRef): dtype = mlir.dtype_to_ir_type(aval.dtype) - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: input_refs.append(mgpu.WGMMAAccumulator.zero(*aval.shape, dtype)) else: zero = arith_dialect.constant(dtype, ir.FloatAttr.get(dtype, 0.0)) @@ -2018,7 +2033,7 @@ def _run_scoped_lowering_rule( return outs -@register_lowering_rule(discharge.run_state_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(discharge.run_state_p, mgpu.LoweringSemantics.Lane) def _run_state_lowering_rule( ctx: LoweringRuleContext, *args, @@ -2090,7 +2105,7 @@ def as_values(vals, avals): _ensure = ( _ensure_fa - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane else _ensure_ir_value ) return [v if a else _ensure(v, av) for a, v, av in zip(is_acc, vals, avals)] @@ -2110,8 +2125,8 @@ def loop(loop_index, body_args): return loop.results -@register_lowering_rule(lax.scan_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.scan_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(lax.scan_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.scan_p, mgpu.LoweringSemantics.Warpgroup) def _scan_lowering_rule( ctx: LoweringRuleContext, *args, @@ -2199,8 +2214,8 @@ def _lower_while_via_fori( return ub, ub, *for_out -@register_lowering_rule(lax.while_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.while_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(lax.while_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.while_p, mgpu.LoweringSemantics.Warpgroup) def _while_lowering_rule( ctx: LoweringRuleContext, *args, @@ -2224,7 +2239,7 @@ def _while_lowering_rule( _is_acc = lambda x: isinstance(x, mgpu.WGMMAAccumulator) _ensure = _ensure_ir_value - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: _ensure = lambda v, aval: v if _is_acc(v) else _ensure_fa(v, aval.dtype) # If we fail conversion to fori, fallback to an ordinary while loop. @@ -2276,8 +2291,8 @@ def _while_lowering_rule( return carry_treedef.unflatten(list(while_op.results)) -@register_lowering_rule(lax.cond_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.cond_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(lax.cond_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.cond_p, mgpu.LoweringSemantics.Warpgroup) def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches): index_aval, *_arg_avals = ctx.avals_in @@ -2334,9 +2349,9 @@ def _yielded_values(outs, avals): return treedef.unflatten(list(switch_op.results)) -@register_lowering_rule(lax.bitcast_convert_type_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.bitcast_convert_type_p, mgpu.LoweringSemantics.Lane) @register_lowering_rule( - lax.bitcast_convert_type_p, mgpu.ThreadSemantics.Warpgroup + lax.bitcast_convert_type_p, mgpu.LoweringSemantics.Warpgroup ) def _bitcast_convert_type_lowering_rule( ctx: LoweringRuleContext, x, *, new_dtype @@ -2352,7 +2367,7 @@ def _bitcast_convert_type_lowering_rule( " have different widths" ) - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Warpgroup: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Warpgroup: x = _ensure_ir_value(x, x_aval.dtype) return arith_dialect.bitcast( ir.VectorType.get(x_aval.shape, dst_elem_type), x @@ -2368,7 +2383,7 @@ def _bitcast_convert_type_lowering_rule( ) -@register_lowering_rule(lax.optimization_barrier_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.optimization_barrier_p, mgpu.LoweringSemantics.Lane) def _optimization_barrier_lowering(ctx: LoweringRuleContext, *args): result = mgpu.optimization_barrier( *(_ensure_fa(arg, aval.dtype) for arg, aval in zip(args, ctx.avals_in)) @@ -2377,7 +2392,7 @@ def _optimization_barrier_lowering(ctx: LoweringRuleContext, *args): @register_lowering_rule( - lax.optimization_barrier_p, mgpu.ThreadSemantics.Warpgroup + lax.optimization_barrier_p, mgpu.LoweringSemantics.Warpgroup ) def _optimization_barrier_lowering_wg(ctx: LoweringRuleContext, *args): result = mgpu.dialect.optimization_barrier([ diff --git a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py index ff3c4f89d30c..eb15aff21235 100644 --- a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py @@ -58,8 +58,8 @@ def pallas_call_lowering( print(f"The grid mapping for pallas_call {debug_info.func_src_info}:") print(grid_mapping) - thread_semantics = compiler_params.get("mosaic_gpu", {}).get( - "thread_semantics", mgpu.ThreadSemantics.Lane + lowering_semantics = compiler_params.get("mosaic_gpu", {}).get( + "lowering_semantics", mgpu.LoweringSemantics.Lane ) mgpu.dialect.register_dialect(ctx.module_context.context) # pytype: disable=attribute-error diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 8bd67e705cf0..a37b018860d7 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -75,7 +75,7 @@ def _load_abstract_eval(src, *avals_flat, args_tree, layout, optimized): {state.ReadEffect(0)}, ) -@lowering.register_lowering_rule(load_p, mgpu.ThreadSemantics.Lane) +@lowering.register_lowering_rule(load_p, mgpu.LoweringSemantics.Lane) def _load_p_lowering_rule( ctx: lowering.LoweringRuleContext, x_ref, *leaves, args_tree, layout, optimized ): @@ -215,9 +215,10 @@ def _copy_smem_to_gmem_pp_eqn( jax_core.pp_eqn_rules[copy_smem_to_gmem_p] = _copy_smem_to_gmem_pp_eqn -@lowering.register_lowering_rule(copy_smem_to_gmem_p, mgpu.ThreadSemantics.Lane) @lowering.register_lowering_rule( - copy_smem_to_gmem_p, mgpu.ThreadSemantics.Warpgroup + copy_smem_to_gmem_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule( + copy_smem_to_gmem_p, mgpu.LoweringSemantics.Warpgroup ) def _copy_smem_to_gmem_lowering( ctx: lowering.LoweringRuleContext, @@ -236,7 +237,7 @@ def _copy_smem_to_gmem_lowering( else: predicate = None - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: if predicate is not None: assert ctx.module_ctx.single_wg_lane_predicate is not None predicate = arith_dialect.andi( @@ -253,7 +254,7 @@ def _copy_smem_to_gmem_lowering( dst_transforms = dst_transforms_treedef.unflatten(flat_dst_transforms) src, src_transforms = lowering._handle_transforms(src, src_transforms, handle_transposes=False) copy_params = _extract_gmem_copy_params(dst_transforms) | _extract_smem_copy_params(src_transforms) - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: ctx.launch_ctx.async_copy( src_ref=src, dst_ref=dst, @@ -439,9 +440,10 @@ def _copy_gmem_to_smem_pp_eqn( jax_core.pp_eqn_rules[copy_gmem_to_smem_p] = _copy_gmem_to_smem_pp_eqn -@lowering.register_lowering_rule(copy_gmem_to_smem_p, mgpu.ThreadSemantics.Lane) @lowering.register_lowering_rule( - copy_gmem_to_smem_p, mgpu.ThreadSemantics.Warpgroup + copy_gmem_to_smem_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule( + copy_gmem_to_smem_p, mgpu.LoweringSemantics.Warpgroup ) def _copy_gmem_to_smem_lowering( ctx: lowering.LoweringRuleContext, @@ -488,7 +490,7 @@ def _copy_gmem_to_smem_lowering( f" dtype={dst_ty.element_type})" ) bytes = bits // 8 - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: if bytes % WARPGROUP_SIZE: raise NotImplementedError("Only aligned copies are supported") # We arrive uniformly from each thread in the WG, so we need to divide the @@ -630,8 +632,8 @@ def _barrier_arrive_pp_eqn( jax_core.pp_eqn_rules[barrier_arrive_p] = _barrier_arrive_pp_eqn -@lowering.register_lowering_rule(barrier_arrive_p, mgpu.ThreadSemantics.Lane) -@lowering.register_lowering_rule(barrier_arrive_p, mgpu.ThreadSemantics.Warpgroup) +@lowering.register_lowering_rule(barrier_arrive_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule(barrier_arrive_p, mgpu.LoweringSemantics.Warpgroup) def _barrier_arrive_lowering( ctx: lowering.LoweringRuleContext, barrier, @@ -688,8 +690,8 @@ def _barrier_wait_pp_eqn( jax_core.pp_eqn_rules[barrier_wait_p] = _barrier_wait_pp_eqn -@lowering.register_lowering_rule(barrier_wait_p, mgpu.ThreadSemantics.Lane) -@lowering.register_lowering_rule(barrier_wait_p, mgpu.ThreadSemantics.Warpgroup) +@lowering.register_lowering_rule(barrier_wait_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule(barrier_wait_p, mgpu.LoweringSemantics.Warpgroup) def _barrier_wait_lowering( ctx: lowering.LoweringRuleContext, barrier, @@ -726,9 +728,10 @@ def _wait_smem_to_gmem_abstract_eval(n, *, wait_read_only): return (), {gpu_core._memory_effect} -@lowering.register_lowering_rule(wait_smem_to_gmem_p, mgpu.ThreadSemantics.Lane) @lowering.register_lowering_rule( - wait_smem_to_gmem_p, mgpu.ThreadSemantics.Warpgroup + wait_smem_to_gmem_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule( + wait_smem_to_gmem_p, mgpu.LoweringSemantics.Warpgroup ) def _wait_smem_to_gmem_lowering( ctx: lowering.LoweringRuleContext, n, *, wait_read_only @@ -759,8 +762,9 @@ def _commit_group_abstract_eval(): return (), {gpu_core._memory_effect} -@lowering.register_lowering_rule(commit_group_p, mgpu.ThreadSemantics.Lane) -@lowering.register_lowering_rule(commit_group_p, mgpu.ThreadSemantics.Warpgroup) +@lowering.register_lowering_rule(commit_group_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule( + commit_group_p, mgpu.LoweringSemantics.Warpgroup) def _commit_group_lowering(ctx: lowering.LoweringRuleContext): del ctx # Unused. nvvm_dialect.cp_async_bulk_commit_group() @@ -886,7 +890,7 @@ def _wgmma_ref_discharge(in_avals, out_avals, *args, **kwargs): wgmma_p = jax_core.Primitive("wgmma") -@lowering.register_lowering_rule(wgmma_p, mgpu.ThreadSemantics.Lane) +@lowering.register_lowering_rule(wgmma_p, mgpu.LoweringSemantics.Lane) def _wgmma_lowering( ctx: lowering.LoweringRuleContext, acc, @@ -981,7 +985,7 @@ def _wgmma_lowering( return new_acc -@lowering.register_lowering_rule(wgmma_p, mgpu.ThreadSemantics.Warpgroup) +@lowering.register_lowering_rule(wgmma_p, mgpu.LoweringSemantics.Warpgroup) def _wgmma_warpgroup_lowering( ctx: lowering.LoweringRuleContext, acc, @@ -1052,8 +1056,8 @@ def wgmma_wait_effectful_abstract_eval(_): return [], {gpu_core._wgmma_pipeline_effect} -@lowering.register_lowering_rule(wgmma_wait_p, mgpu.ThreadSemantics.Lane) -@lowering.register_lowering_rule(wgmma_wait_p, mgpu.ThreadSemantics.Warpgroup) +@lowering.register_lowering_rule(wgmma_wait_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule(wgmma_wait_p, mgpu.LoweringSemantics.Warpgroup) def _wgmma_wait_lowering(ctx: lowering.LoweringRuleContext, allow_groups): del ctx nvvm_dialect.wgmma_wait_group_sync_aligned(allow_groups) @@ -1085,16 +1089,16 @@ def _wgmma_accumulator_deref_discharge(in_avals, out_avals, acc): @lowering.register_lowering_rule( - wgmma_accumulator_deref_p, mgpu.ThreadSemantics.Lane + wgmma_accumulator_deref_p, mgpu.LoweringSemantics.Lane ) @lowering.register_lowering_rule( - wgmma_accumulator_deref_p, mgpu.ThreadSemantics.Warpgroup + wgmma_accumulator_deref_p, mgpu.LoweringSemantics.Warpgroup ) def _wgmma_accumulator_deref_lowering(ctx: lowering.LoweringRuleContext, acc): nvvm_dialect.wgmma_wait_group_sync_aligned(0) return ( acc.value - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane else acc ) @@ -1156,7 +1160,7 @@ def _layout_cast_abstract_eval(x, new_layout): return x -@lowering.register_lowering_rule(layout_cast_p, mgpu.ThreadSemantics.Lane) +@lowering.register_lowering_rule(layout_cast_p, mgpu.LoweringSemantics.Lane) def _layout_cast_lowering(ctx: lowering.LoweringRuleContext, x, *, new_layout): del ctx # Unused. return x.to_layout(new_layout.to_mgpu()) @@ -1177,8 +1181,10 @@ def _set_max_registers_abstract_eval(n, *, action): return (), {gpu_core._memory_effect} -@lowering.register_lowering_rule(set_max_registers_p, mgpu.ThreadSemantics.Lane) -@lowering.register_lowering_rule(set_max_registers_p, mgpu.ThreadSemantics.Warpgroup) +@lowering.register_lowering_rule( + set_max_registers_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule( + set_max_registers_p, mgpu.LoweringSemantics.Warpgroup) def _set_max_registers_lowering( ctx: lowering.LoweringRuleContext, n, *, action ): @@ -1206,8 +1212,9 @@ def _commit_smem_abstract_eval(): return (), {gpu_core._memory_effect} -@lowering.register_lowering_rule(commit_smem_p, mgpu.ThreadSemantics.Lane) -@lowering.register_lowering_rule(commit_smem_p, mgpu.ThreadSemantics.Warpgroup) +@lowering.register_lowering_rule(commit_smem_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule( + commit_smem_p, mgpu.LoweringSemantics.Warpgroup) def _commit_smem_lowering(ctx: lowering.LoweringRuleContext): # TODO(bchetioui): add primitive for commit smem to mosaic_gpu dialect. mgpu.commit_shared() @@ -1227,7 +1234,8 @@ def _broadcasted_iota_abstract_eval(dtype, shape, dimension, layout): return jax_core.ShapedArray(shape, dtype) -@lowering.register_lowering_rule(broadcasted_iota_p, mgpu.ThreadSemantics.Lane) +@lowering.register_lowering_rule( + broadcasted_iota_p, mgpu.LoweringSemantics.Lane) def _broadcasted_iota_lowering( ctx: lowering.LoweringRuleContext, dtype, shape, dimension, layout ): @@ -1309,8 +1317,8 @@ def _jaxpr_call_pp_eqn( jax_core.pp_eqn_rules[jaxpr_call_p] = _jaxpr_call_pp_eqn -@lowering.register_lowering_rule(jaxpr_call_p, mgpu.ThreadSemantics.Lane) -@lowering.register_lowering_rule(jaxpr_call_p, mgpu.ThreadSemantics.Warpgroup) +@lowering.register_lowering_rule(jaxpr_call_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule(jaxpr_call_p, mgpu.LoweringSemantics.Warpgroup) def _jaxpr_call_lowering_rule( ctx: lowering.LoweringRuleContext, *flat_args, @@ -1507,7 +1515,7 @@ def _inline_mgpu_abstract_eval( def _inline_mgpu_discharge(*args, **kwargs): raise NotImplementedError("inline_mgpu_p does not support discharge.") -@lowering.register_lowering_rule(inline_mgpu_p, mgpu.ThreadSemantics.Lane) +@lowering.register_lowering_rule(inline_mgpu_p, mgpu.LoweringSemantics.Lane) def _inline_mgpu_lowering_rule( ctx: lowering.LoweringRuleContext, *flat_args, diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index a4f1e0a9cfe0..8ecc5b9fd8da 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -23,7 +23,7 @@ Barrier as Barrier, ClusterBarrier as ClusterBarrier, TMABarrier as TMABarrier, - ThreadSemantics as ThreadSemantics, + LoweringSemantics as LoweringSemantics, TMEM as TMEM, Union as Union, as_gpu_kernel as as_gpu_kernel, diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index 43b93e7da023..e822ea5f3ebf 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -209,7 +209,7 @@ def _count_buffer_bytes(shape_dtype: jax.ShapeDtypeStruct) -> int: return math.prod(shape_dtype.shape) * np.dtype(shape_dtype.dtype).itemsize -class ThreadSemantics(enum.Enum): +class LoweringSemantics(enum.Enum): """Semantics for the kernel's instruction stream.""" Lane = enum.auto() @@ -595,7 +595,7 @@ def as_gpu_kernel( module_name: str = "unknown", kernel_name: str | None = None, ir_version: int | None = None, - thread_semantics: ThreadSemantics = ThreadSemantics.Lane, + thread_semantics: LoweringSemantics = LoweringSemantics.Lane, ): if isinstance(in_shape, list): in_shape = tuple(in_shape) @@ -609,7 +609,7 @@ def as_gpu_kernel( ) ) - if thread_semantics == ThreadSemantics.Warpgroup and dialect is not None: + if thread_semantics == LoweringSemantics.Warpgroup and dialect is not None: # Run Python lowering passes. The remaining passes will be run in C++ in # jax/jaxlib/mosaic/gpu/custom_call.cc layout_inference.infer_layout(module) # pytype: disable=attribute-error @@ -669,7 +669,7 @@ def as_torch_gpu_kernel( cluster: tuple[int, int, int] = (1, 1, 1), module_name: str = "unknown", kernel_name: str | None = None, - thread_semantics: ThreadSemantics = ThreadSemantics.Lane, + lowering_semantics: LoweringSemantics = LoweringSemantics.Lane, ): try: import torch @@ -692,7 +692,7 @@ def as_torch_gpu_kernel( ) ) - if thread_semantics == ThreadSemantics.Warpgroup and dialect is not None: + if lowering_semantics == LoweringSemantics.Warpgroup and dialect is not None: # Run Python lowering passes. The remaining passes will be run in C++ in # jax/jaxlib/mosaic/gpu/custom_call.cc layout_inference.infer_layout(module) # pytype: disable=attribute-error diff --git a/jax/experimental/pallas/mosaic_gpu.py b/jax/experimental/pallas/mosaic_gpu.py index 1d3bebbc3757..85e512d03290 100644 --- a/jax/experimental/pallas/mosaic_gpu.py +++ b/jax/experimental/pallas/mosaic_gpu.py @@ -52,7 +52,7 @@ from jax._src.pallas.mosaic_gpu.primitives import wait_smem_to_gmem as wait_smem_to_gmem from jax._src.pallas.mosaic_gpu.primitives import wgmma as wgmma from jax._src.pallas.mosaic_gpu.primitives import wgmma_wait as wgmma_wait -from jax.experimental.mosaic.gpu.core import ThreadSemantics as ThreadSemantics +from jax.experimental.mosaic.gpu.core import LoweringSemantics as LoweringSemantics #: Alias of :data:`jax.experimental.pallas.mosaic_gpu.GPUMemorySpace.GMEM`. diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 80e5380d165b..f0930f5de8cc 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -2602,7 +2602,7 @@ def add(ctx, a, b, result, smem): in_shape=(jax_shape, jax_shape), out_shape=jax_shape, smem_scratch_shape=[], - thread_semantics=mgpu.ThreadSemantics.Warpgroup, + thread_semantics=mgpu.LoweringSemantics.Warpgroup, ) x = self.prng.uniform(-1, 1, shape).astype(dtype) @@ -2747,7 +2747,7 @@ def add( jax_shape_sliced, core.TMABarrier(1), ], - thread_semantics=mgpu.ThreadSemantics.Warpgroup, + thread_semantics=mgpu.LoweringSemantics.Warpgroup, ) x = self.prng.uniform(-1, 1, test_case.shape).astype(dtype) @@ -2846,7 +2846,7 @@ def add( spec, core.TMABarrier(1), ], - thread_semantics=mgpu.ThreadSemantics.Warpgroup, + thread_semantics=mgpu.LoweringSemantics.Warpgroup, ) x = self.prng.uniform(-1, 1, spec.shape).astype(dtype) @@ -2994,7 +2994,7 @@ def matmul( result_jax_shape, core.TMABarrier(1), ], - thread_semantics=mgpu.ThreadSemantics.Warpgroup, + thread_semantics=mgpu.LoweringSemantics.Warpgroup, ) prng_key = jax.random.key(1234) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 10b4f3de60ad..f0f3bdf41c32 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -28,6 +28,7 @@ from jax import lax from jax._src import test_util as jtu from jax._src.pallas import pallas_call +from jax._src.pallas.mosaic_gpu import core as gpu_core from jax._src.pallas.mosaic_gpu import lowering as mgpu_lowering from jax._src.pallas.mosaic_gpu import pipeline as mgpu_pipeline from jax._src.pallas.mosaic_gpu import primitives as mgpu_primitives @@ -64,14 +65,14 @@ def _sum_same_dtype(x): class PallasTestMetaclass(parameterized.TestGeneratorMetaclass): - def __new__(mcs, *args, thread_semantics=plgpu.ThreadSemantics.Lane): + def __new__(mcs, *args, lowering_semantics=plgpu.LoweringSemantics.Lane): cls = super().__new__(mcs, *args) - cls.THREAD_SEMANTICS = thread_semantics + cls.LOWERING_SEMANTICS = lowering_semantics return cls class PallasTest(jtu.JaxTestCase, metaclass=PallasTestMetaclass): - THREAD_SEMANTICS: ClassVar[plgpu.ThreadSemantics] + LOWERING_SEMANTICS: ClassVar[plgpu.LoweringSemantics] def setUp(self): if not jtu.is_cuda_compute_capability_at_least("9.0"): @@ -83,20 +84,20 @@ def setUp(self): super().setUp() def skip_if_wg_semantics(self): - if self.THREAD_SEMANTICS == plgpu.ThreadSemantics.Warpgroup: + if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Warpgroup: self.skipTest("Not supported under WG semantics") def kernel(self, *args, **kwargs): compiler_params = dataclasses.replace( kwargs.pop("compiler_params", plgpu.GPUCompilerParams()), - thread_semantics=self.THREAD_SEMANTICS, + lowering_semantics=self.LOWERING_SEMANTICS, ) return plgpu.kernel(*args, compiler_params=compiler_params, **kwargs) def pallas_call(self, *args, **kwargs): compiler_params = dataclasses.replace( kwargs.pop("compiler_params", plgpu.GPUCompilerParams()), - thread_semantics=self.THREAD_SEMANTICS, + lowering_semantics=self.LOWERING_SEMANTICS, ) return pl.pallas_call(*args, compiler_params=compiler_params, **kwargs) @@ -1503,7 +1504,7 @@ def kernel(x_ref, y_ref, o_ref): class PallasCallWGTest( - PallasCallTest, thread_semantics=plgpu.ThreadSemantics.Warpgroup + PallasCallTest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup ): ... @@ -1513,10 +1514,14 @@ def test_missing_primitive_lowerings_are_tracked(self): # enable warpgroup semantics by default (assuming we haven't overspecialized # lowerings). rules = mgpu_lowering.mosaic_lowering_rules - wg_lowered_primitives = set(rules[plgpu.ThreadSemantics.Warpgroup]) - lane_lowered_primitives = set(rules[plgpu.ThreadSemantics.Lane]) - - actual_missing_primitives = lane_lowered_primitives - wg_lowered_primitives + wg_wg_lowered_primitives = set( + rules[(plgpu.LoweringSemantics.Warpgroup, + gpu_core.PrimitiveSemantics.Warpgroup)]) + lane_wg_lowered_primitives = set(rules[ + (plgpu.LoweringSemantics.Lane, gpu_core.PrimitiveSemantics.Warpgroup)]) + + actual_missing_primitives = (lane_wg_lowered_primitives - + wg_wg_lowered_primitives) expected_missing_primitives = { mgpu_primitives.inline_mgpu_p, mgpu_primitives.broadcasted_iota_p, @@ -1607,7 +1612,7 @@ def _epilogue(): lambda m, n, k: (m, n), ) - if self.THREAD_SEMANTICS == plgpu.ThreadSemantics.Lane: + if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Lane: lhs_spec = plgpu.GPUBlockSpec( lhs_spec.block_shape, lhs_spec.index_map, @@ -1715,7 +1720,7 @@ def scope(acc_ref): b = jax.random.uniform(key2, shape=(128, 192), dtype=jnp.float16) transforms = () - if self.THREAD_SEMANTICS == plgpu.ThreadSemantics.Lane: + if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Lane: transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) res = self.pallas_call( kernel, @@ -1768,7 +1773,7 @@ def scope(acc_ref): b = jax.random.uniform(key2, shape=(2, 128, 192), dtype=jnp.float16) transforms = () - if self.THREAD_SEMANTICS == plgpu.ThreadSemantics.Lane: + if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Lane: transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) res = self.pallas_call( @@ -1797,7 +1802,7 @@ def scope(acc_ref): a = jax.random.uniform(key1, shape=(64, 128), dtype=jnp.float16) b = jax.random.uniform(key2, shape=(128, 128), dtype=jnp.float16) transforms = () - if self.THREAD_SEMANTICS == plgpu.ThreadSemantics.Lane: + if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Lane: transforms = ( plgpu.TilingTransform((8, elems_128b)), plgpu.SwizzleTransform(128), @@ -1820,7 +1825,7 @@ def scope(acc_ref): class PallasCallSm90AWGTest( - PallasCallSm90ATest, thread_semantics=plgpu.ThreadSemantics.Warpgroup + PallasCallSm90ATest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup ): ... @@ -1851,7 +1856,7 @@ def kernel(y_ref, tmem_ref, smem_ref): class PallasCallSm100AWGTest( - PallasCallSm100ATest, thread_semantics=plgpu.ThreadSemantics.Warpgroup + PallasCallSm100ATest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup ): ... @@ -2117,7 +2122,7 @@ def kernel_body(_, x_smem, o_smem): class PipelineWGTest( - PipelineTest, thread_semantics=plgpu.ThreadSemantics.Warpgroup + PipelineTest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup ): ... @@ -2137,7 +2142,7 @@ def test_realistic_matmul(self): m, k, n = grid_m * tile_m, grid_k * tile_k, grid_n * tile_n transforms = () - if self.THREAD_SEMANTICS == plgpu.ThreadSemantics.Lane: + if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Lane: transforms = ( plgpu.TilingTransform((8, elems_128b)), plgpu.SwizzleTransform(128), @@ -2190,7 +2195,7 @@ def kernel_body(_, a_smem, b_smem): class PipelineSm90AWGTest( - PipelineSm90ATest, thread_semantics=plgpu.ThreadSemantics.Warpgroup + PipelineSm90ATest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup ): ... @@ -2354,7 +2359,7 @@ def tiled_acc_kernel(_, x_smem, carry): class WarpSpecializedPipelineWGTest( WarpSpecializedPipelineTest, - thread_semantics=plgpu.ThreadSemantics.Warpgroup, + lowering_semantics=plgpu.LoweringSemantics.Warpgroup, ): ... @@ -2612,7 +2617,7 @@ def body(step, _): class CoreMapWGTest( - CoreMapTest, thread_semantics=plgpu.ThreadSemantics.Warpgroup + CoreMapTest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup ): ... @@ -2747,7 +2752,7 @@ def body(i_ref1, i_ref2, o_ref, sem_ref): class ExamplesWGTest( - ExamplesTest, thread_semantics=plgpu.ThreadSemantics.Warpgroup + ExamplesTest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup ): ... @@ -2789,7 +2794,7 @@ def do_wgmma(acc_ref): class ExamplesSm90AWGTest( - ExamplesSm90ATest, thread_semantics=plgpu.ThreadSemantics.Warpgroup + ExamplesSm90ATest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup ): ... diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index aeb0ba1cca1a..ff02c334f45c 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -296,7 +296,7 @@ def pallas_call(cls, *args, **kwargs): if jtu.test_device_matches(["cuda"]) and use_mosaic_gpu: assert plgpu_mgpu is not None compiler_params = plgpu_mgpu.GPUCompilerParams( - thread_semantics=plgpu_mgpu.ThreadSemantics.Warpgroup + lowering_semantics=plgpu_mgpu.LoweringSemantics.Warpgroup ) kwargs["compiler_params"] = compiler_params From 64e4bf26324ddfeb957233f6370b74acdbb80e5f Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 7 Apr 2025 14:49:09 -0700 Subject: [PATCH 442/483] Relax jax dependency constraints to be able to install RC wheels Also, add a job to the release test workflow that verifies that the release wheels can be installed. TESTED: 1. Full release: https://github.com/jax-ml/jax/actions/runs/14315832784 2. jax only release: https://github.com/jax-ml/jax/actions/runs/14316157252 PiperOrigin-RevId: 744857804 --- .../workflows/wheel_tests_nightly_release.yml | 98 +++++++++++++++++-- jax/version.py | 6 ++ setup.py | 19 ++-- tests/version_test.py | 16 +++ 4 files changed, 127 insertions(+), 12 deletions(-) diff --git a/.github/workflows/wheel_tests_nightly_release.yml b/.github/workflows/wheel_tests_nightly_release.yml index 132aad577d50..f6d2aa9b97c6 100644 --- a/.github/workflows/wheel_tests_nightly_release.yml +++ b/.github/workflows/wheel_tests_nightly_release.yml @@ -1,12 +1,14 @@ # CI - Wheel Tests (Nightly/Release) # -# This workflow builds JAX artifacts and runs CPU/CUDA tests. +# This workflow is used to test the JAX wheels that was built by internal CI jobs. # -# It orchestrates the following: -# 1. run-pytest-cpu: Calls the `pytest_cpu.yml` workflow which downloads the jaxlib wheel that was +# 1. run-pytest-cpu: Calls the `pytest_cpu.yml` workflow which downloads the JAX wheels that was # built by internal CI jobs and runs CPU tests. -# 2. run-pytest-cuda: Calls the `pytest_cuda.yml` workflow which downloads the jaxlib and CUDA -# artifacts that were built by internal CI jobs and runs the CUDA tests. +# 2. run-pytest-cuda: Calls the `pytest_cuda.yml` workflow which downloads the JAX wheels that was +# built by internal CI jobs and runs CUDA tests. +# 3. run-pytest-tpu: Calls the `pytest_tpu.yml` workflow which downloads the JAX wheels that was +# built by internal CI jobs and runs TPU tests. +# 4. verify-release-wheels-install: Verifies that JAX's release wheels can be installed. name: CI - Wheel Tests (Nightly/Release) on: @@ -106,4 +108,88 @@ jobs: run-full-tpu-test-suite: "1" libtpu-version-type: ${{ matrix.libtpu-version-type }} download-jax-only-from-gcs: ${{inputs.download-jax-only-from-gcs}} - gcs_download_uri: ${{inputs.gcs_download_uri}} \ No newline at end of file + gcs_download_uri: ${{inputs.gcs_download_uri}} + + verify-release-wheels-install: + if: ${{ startsWith(github.ref_name, 'release/')}} + defaults: + run: + # Set the shell to bash as GitHub actions runs with /bin/sh by default + shell: bash + runs-on: linux-x86-n2-16 + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + python: ["3.10", "3.13", "3.13-nogil"] + container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" + + # Verifies that JAX's release wheels can be installed + name: "Verify release wheels install (Python ${{ matrix.python }})" + + env: + PYTHON: "python${{ matrix.python }}" + + steps: + - name: Download release wheels from GCS + run: | + mkdir -p $(pwd)/dist + final_gcs_download_uri=${{ inputs.gcs_download_uri }} + + # Get the major and minor version of Python. + # E.g if python=3.10, then python_major_minor=310 + # E.g if python=3.13-nogil, then python_major_minor=313t + python_major_minor=${{ matrix.python }} + python_major_minor=$(echo "${python_major_minor//-nogil/t}" | tr -d '.') + python_major_minor="cp${python_major_minor%t}-cp${python_major_minor}-" + + gsutil -m cp -r "${final_gcs_download_uri}"/jax*py3*none*any.whl $(pwd)/dist/ + + jax_wheel=$(ls dist/jax*py3*none*any.whl 2>/dev/null) + echo "JAX_WHEEL=$jax_wheel" >> $GITHUB_ENV + + if [[ "${{ inputs.download-jax-only-from-gcs }}" != "1" ]]; then + gsutil -m cp -r "${final_gcs_download_uri}/jaxlib*${python_major_minor}*linux*x86_64*.whl" $(pwd)/dist/ + gsutil -m cp -r "${final_gcs_download_uri}/jax*cuda*plugin*${python_major_minor}*linux*x86_64*.whl" $(pwd)/dist/ + gsutil -m cp -r "${final_gcs_download_uri}/jax*cuda*pjrt*linux*x86_64*.whl" $(pwd)/dist/ + + jaxlib_wheel=$(ls dist/jaxlib*${python_major_minor}*linux*x86_64*.whl 2>/dev/null) + jax_cuda_plugin_wheel=$(ls dist/jax*cuda*plugin*${python_major_minor}*linux*x86_64*.whl 2>/dev/null) + jax_cuda_pjrt_wheel=$(ls dist/jax*cuda*pjrt*linux*x86_64*.whl 2>/dev/null) + + echo "JAXLIB_WHEEL=$jaxlib_wheel" >> $GITHUB_ENV + echo "JAX_CUDA_PLUGIN_WHEEL=$jax_cuda_plugin_wheel" >> $GITHUB_ENV + echo "JAX_CUDA_PJRT_WHEEL=$jax_cuda_pjrt_wheel" >> $GITHUB_ENV + fi + - name: Verify JAX CPU packages can be installed + run: | + $PYTHON -m uv venv ~/test_cpu && source ~/test_cpu/bin/activate + if [[ "${{ inputs.download-jax-only-from-gcs }}" == "1" ]]; then + uv pip install $JAX_WHEEL + else + uv pip install $JAX_WHEEL $JAXLIB_WHEEL + fi + - name: Verify JAX TPU packages can be installed + run: | + $PYTHON -m uv venv ~/test_tpu && source ~/test_tpu/bin/activate + + if [[ "${{ inputs.download-jax-only-from-gcs }}" == "1" ]]; then + uv pip install $JAX_WHEEL[tpu] + else + uv pip install $JAX_WHEEL[tpu] $JAXLIB_WHEEL + fi + - name: Verify JAX CUDA packages can be installed (Nvidia Pip Packages) + run: | + $PYTHON -m uv venv ~/test_cuda_pip && source ~/test_cuda_pip/bin/activate + if [[ "${{ inputs.download-jax-only-from-gcs }}" == "1" ]]; then + uv pip install $JAX_WHEEL[cuda] + else + uv pip install $JAX_WHEEL[cuda] $JAXLIB_WHEEL $JAX_CUDA_PJRT_WHEEL $JAX_CUDA_PLUGIN_WHEEL[with-cuda] + fi + - name: Verify JAX CUDA packages can be installed (CUDA local) + run: | + $PYTHON -m uv venv ~/test_cuda_local && source ~/test_cuda_local/bin/activate + if [[ "${{ inputs.download-jax-only-from-gcs }}" == "1" ]]; then + uv pip install $JAX_WHEEL[cuda12-local] + else + uv pip install $JAX_WHEEL $JAXLIB_WHEEL $JAX_CUDA_PJRT_WHEEL $JAX_CUDA_PLUGIN_WHEEL + fi \ No newline at end of file diff --git a/jax/version.py b/jax/version.py index 6ed6a5fda600..21662d078f7f 100644 --- a/jax/version.py +++ b/jax/version.py @@ -93,6 +93,12 @@ def _get_version_for_build() -> str: return _version_from_git_tree(_version) or _version_from_todays_date(_version) +def _is_prerelease() -> bool: + """Determine if this is a pre-release ("rc" wheels) build.""" + rc_version = os.getenv("WHEEL_VERSION_SUFFIX", "") + return True if rc_version.startswith("rc") else False + + def _write_version(fname: str) -> None: """Used by setup.py to write the specified version info into the source tree.""" release_version = _get_version_for_build() diff --git a/setup.py b/setup.py index bdaeb624bf38..629836b30862 100644 --- a/setup.py +++ b/setup.py @@ -38,6 +38,13 @@ def load_version_module(pkg_path): _cmdclass = _version_module._get_cmdclass(project_name) _minimum_jaxlib_version = _version_module._minimum_jaxlib_version +# If this is a pre-release ("rc" wheels), append "rc0" to +# _minimum_jaxlib_version and _current_jaxlib_version so that we are able to +# install the rc wheels. +if _version_module._is_prerelease(): + _minimum_jaxlib_version += "rc0" + _current_jaxlib_version += "rc0" + with open('README.md', encoding='utf-8') as f: _long_description = f.read() @@ -81,32 +88,32 @@ def load_version_module(pkg_path): ], 'cuda': [ - f"jaxlib=={_current_jaxlib_version}", + f"jaxlib>={_current_jaxlib_version},<={_jax_version}", f"jax-cuda12-plugin[with_cuda]>={_current_jaxlib_version},<={_jax_version}", ], 'cuda12': [ - f"jaxlib=={_current_jaxlib_version}", + f"jaxlib>={_current_jaxlib_version},<={_jax_version}", f"jax-cuda12-plugin[with_cuda]>={_current_jaxlib_version},<={_jax_version}", ], # Deprecated alias for cuda12, kept to avoid breaking users who wrote # cuda12_pip in their CI. 'cuda12_pip': [ - f"jaxlib=={_current_jaxlib_version}", + f"jaxlib>={_current_jaxlib_version},<={_jax_version}", f"jax-cuda12-plugin[with_cuda]>={_current_jaxlib_version},<={_jax_version}", ], # Target that does not depend on the CUDA pip wheels, for those who want # to use a preinstalled CUDA. 'cuda12_local': [ - f"jaxlib=={_current_jaxlib_version}", - f"jax-cuda12-plugin=={_current_jaxlib_version}", + f"jaxlib>={_current_jaxlib_version},<={_jax_version}", + f"jax-cuda12-plugin>={_current_jaxlib_version},<={_jax_version}", ], # ROCm support for ROCm 6.0 and above. 'rocm': [ - f"jaxlib=={_current_jaxlib_version}", + f"jaxlib>={_current_jaxlib_version},<={_jax_version}", f"jax-rocm60-plugin>={_current_jaxlib_version},<={_jax_version}", ], diff --git a/tests/version_test.py b/tests/version_test.py index b78e61ae024c..14da82df2e3e 100644 --- a/tests/version_test.py +++ b/tests/version_test.py @@ -143,6 +143,7 @@ def testBuildVersionFromEnvironment(self): JAX_NIGHTLY=None, JAXLIB_NIGHTLY=None): with assert_no_subprocess_call(): version = jax.version._get_version_for_build() + self.assertFalse(jax.version._is_prerelease()) self.assertEqual(version, base_version) self.assertValidVersion(version) @@ -150,6 +151,7 @@ def testBuildVersionFromEnvironment(self): JAX_NIGHTLY=None, JAXLIB_NIGHTLY=None): with assert_no_subprocess_call(): version = jax.version._get_version_for_build() + self.assertFalse(jax.version._is_prerelease()) self.assertEqual(version, base_version) self.assertValidVersion(version) @@ -183,6 +185,20 @@ def testBuildVersionFromEnvironment(self): ): with assert_no_subprocess_call(): version = jax.version._get_version_for_build() + self.assertTrue(jax.version._is_prerelease()) + self.assertEqual(version, f"{base_version}rc0") + self.assertValidVersion(version) + + with jtu.set_env( + JAX_RELEASE=None, + JAXLIB_RELEASE="1", + JAX_NIGHTLY=None, + JAXLIB_NIGHTLY=None, + WHEEL_VERSION_SUFFIX="rc0", + ): + with assert_no_subprocess_call(): + version = jax.version._get_version_for_build() + self.assertTrue(jax.version._is_prerelease()) self.assertEqual(version, f"{base_version}rc0") self.assertValidVersion(version) From 3a3c145039c8d1b41946b0683cdcf601b29bd3f9 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Mon, 7 Apr 2025 21:30:25 +0000 Subject: [PATCH 443/483] [shard-map] canonicalize rep=None to be rep={all possible axes} None is meant to represent the same thing as {replicated over all possible axes}. But without this canonicalization, we could compare None as not equal to {all possible axes}. fixes #26621 Unrelated: in several places, including the _check_rep path, we don't handle partial auto correctly, since we treat {all possible axes} as {all mesh axes}, but actually it should be more like {all mesh axes} - auto. We'll leave that fix for a follow-up... --- jax/experimental/shard_map.py | 7 ++++--- tests/shard_map_test.py | 22 ++++++++++++++++++++++ 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 3a46f444fb1b..17d909b5629c 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -646,7 +646,7 @@ def _check_rep(mesh: Mesh, jaxpr: core.Jaxpr, in_rep: Sequence[RepType] env: dict[core.Var, RepType] = {} def read(x: core.Atom) -> RepType: - return env[x] if type(x) is core.Var else None + return env[x] if type(x) is core.Var else set(mesh.axis_names) def write(v: core.Var, val: RepType) -> None: env[v] = val @@ -942,7 +942,7 @@ def to_val_rep_pair(self, val): raise Exception(f"Shouldn't have any non-shard_map tracers: {val}") else: val_ = _unmatch_spec(self.mesh, {}, val, self.context_mesh) - return val_, None + return val_, set(self.mesh.axis_names) - set(self.auto) def process_primitive(self, prim, tracers, params): in_vals, in_rep = unzip2(map(self.to_val_rep_pair, tracers)) @@ -1008,6 +1008,7 @@ class ShardMapTracer(core.Tracer): val: JaxType def __init__(self, trace, rep, val): + rep = set(trace.mesh.axis_names) - set(trace.auto) if rep is None else rep self._trace = trace self.rep = rep self.val = val @@ -2151,7 +2152,7 @@ def _efficient_transpose_rewrite(fun, mesh, in_names, out_names_thunk): def _efficient_transpose_rewrite_nomatch(f, store, mesh, in_reps, *args): with core.take_current_trace() as parent: tag = core.TraceTag() - t = RewriteTrace(parent_trace = parent, tag = tag, mesh=mesh) + t = RewriteTrace(parent_trace=parent, tag=tag, mesh=mesh) in_tracers = map(partial(RewriteTracer, t), in_reps, args) with core.set_current_trace(t): ans = f(*in_tracers) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 3a4c3ea9779c..62395a8750ab 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -2747,6 +2747,28 @@ def f(x): f(x) # doesn't crash + def test_rep_none_canonicalization(self): + # https://github.com/jax-ml/jax/issues/26621 + N = 8 + xs = jnp.ones((8, N), dtype=jnp.int32) + variables = jax.random.normal(jax.random.key(1), (N, N), jnp.complex64) + mesh = jtu.create_mesh((2,), ('i',)) + in_specs = (P(), P("i"),) + out_specs = P("i") + + variables = jax.lax.with_sharding_constraint(variables, NamedSharding(mesh, P())) + xs = jax.lax.with_sharding_constraint(xs, NamedSharding(mesh, P('i'))) + + def fun(v, xs): + # Commenting this single line below makes everything work + v = jax.scipy.linalg.expm(v) + v = v.sum() + return v * xs.sum(axis=-1).astype(v.dtype) + + res = fun(variables, xs) + fun_shard_map = shard_map(fun, mesh=mesh, in_specs=in_specs, out_specs=out_specs) + res = fun_shard_map(variables, xs) # don't crash + class FunSpec(NamedTuple): name: str From 48a9ad07968795357814c0b02e7abb52cae10786 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 7 Apr 2025 15:07:30 -0700 Subject: [PATCH 444/483] Reverts 006a6a63feb64bf9984526030ba008186d69d2b4 PiperOrigin-RevId: 744864022 --- jax/_src/lax/parallel.py | 2 +- jax/_src/pallas/mosaic/lowering.py | 16 +++++----------- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index e533672a1d9b..6ed4dddfcc21 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -1906,8 +1906,8 @@ def _axis_index_lowering(ctx, *, axis_name): ctx.module_context.axis_env)] def _axis_index_effectful_abstract_eval(*, axis_name): - axis_name = (axis_name,) if not isinstance(axis_name, tuple) else axis_name effect = {core.NamedAxisEffect(axis_name)} + axis_name = (axis_name,) if not isinstance(axis_name, tuple) else axis_name _check_axis_names(axis_name) mesh = get_abstract_mesh() sharding = NamedSharding(mesh, P()) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 2669d73691c1..98ff98759c8c 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -518,17 +518,11 @@ def has_communication(self) -> bool: nonlocal_axis_names = set() def _get_nonlocal_axis_names(jaxpr: jax_core.Jaxpr): return { - e.name - for e in jaxpr.effects - if isinstance(e, jax_core.NamedAxisEffect) - and ( - not self.grid_names - or all( - name not in self.grid_names - for name in tree_util.tree_leaves(e.name) - ) - ) - } + e.name + for e in jaxpr.effects + if isinstance(e, jax_core.NamedAxisEffect) + and (not self.grid_names or e.name not in self.grid_names) + } nonlocal_axis_names.update(_get_nonlocal_axis_names(self.jaxpr)) for bm in self.block_mappings: if bm is not None: From 2944e3b2a64d26f3cadbda4694486c21979a7229 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 7 Apr 2025 15:27:10 -0700 Subject: [PATCH 445/483] Removed `data_dependent_tracing_fallback` config option No internal code needs it any more. PiperOrigin-RevId: 744870756 --- CHANGELOG.md | 3 +++ jax/_src/config.py | 6 ------ jax/_src/core.py | 8 +------- jax/_src/pjit.py | 4 +--- 4 files changed, 5 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index beacd477390f..3aae0f432121 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * {func}`jax.numpy.array` no longer accepts `None`. This behavior was deprecated since November 2023 and is now removed. + * Removed the `config.jax_data_dependent_tracing_fallback` config option, + which was added temporarily in v0.4.36 to allow users to opt out of the + new "stackless" tracing machinery. * Changes * The minimum CuDNN version is v9.8. diff --git a/jax/_src/config.py b/jax/_src/config.py index aca6d8e2c938..8aa4ee343664 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1099,12 +1099,6 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: ' transpose rewrite machinery in shard_map'), include_in_jit_key=True) -data_dependent_tracing_fallback = bool_state( - name='jax_data_dependent_tracing_fallback', - default=False, - help=('When True, falls back to trace dispatch based on data dependence ' - 'instead of throwing an escaped tracer error.')) - softmax_custom_jvp = bool_state( name='jax_softmax_custom_jvp', default=False, diff --git a/jax/_src/core.py b/jax/_src/core.py index 9f80842a38ff..8d32b9370091 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -497,9 +497,7 @@ def bind(self, *args, **params): def _true_bind(self, *args, **params): for arg in args: - if (isinstance(arg, Tracer) - and not arg._trace.is_valid() - and not config.data_dependent_tracing_fallback.value): + if isinstance(arg, Tracer) and not arg._trace.is_valid(): raise escaped_tracer_error(arg) # TODO: figure out how to handle function arguments # assert (not config.enable_checks.value or @@ -1015,10 +1013,6 @@ def process_primitive(self, primitive, args, params): else: # TODO(dougalm): delete. this shouldn't be necessary args = map(full_lower, args) - if config.data_dependent_tracing_fallback.value: - for arg in args: - if isinstance(arg, Tracer): - return primitive.bind_with_trace(arg._trace, args, params) check_eval_args(args) return primitive.impl(*args, **params) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 8c3c5101eb51..cf4d13530b74 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -186,9 +186,7 @@ def _python_pjit_helper(fun: Callable, jit_info: PjitInfo, *args, **kwargs): args_flat = [*init_states, *args_flat] try: - if (core.trace_state_clean() and - not config.debug_key_reuse.value and - not config.data_dependent_tracing_fallback.value): + if core.trace_state_clean() and not config.debug_key_reuse.value: args_flat = map(core.full_lower, args_flat) core.check_eval_args(args_flat) out_flat, compiled, profiler = _pjit_call_impl_python(*args_flat, **p.params) From 0a72e856cfb8984ba4883d9449bc8928adebe535 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 7 Apr 2025 16:20:58 -0700 Subject: [PATCH 446/483] Add **experimental** `with_dll_constraint` API. This is for cases when the users wants to let SPMD decide the sharding. But this is a contradiction since layouts apply to device local shape and without knowing the sharding, you can't decide the layout. But there are cases where you don't care what the sharding is, you just want to force a row-major layout (for example). **This API should only be used for those cases**. PiperOrigin-RevId: 744888557 --- jax/_src/layout.py | 4 ++ jax/_src/pjit.py | 44 +++++++++++++++++++ .../jax2tf/tests/primitives_test.py | 2 + jax/experimental/layout.py | 5 ++- tests/BUILD | 1 + tests/layout_test.py | 33 +++++++++++++- 6 files changed, 87 insertions(+), 2 deletions(-) diff --git a/jax/_src/layout.py b/jax/_src/layout.py index 5309f0b1fd9c..8d4f8acd5327 100644 --- a/jax/_src/layout.py +++ b/jax/_src/layout.py @@ -127,6 +127,10 @@ def __init__(self, device_local_layout: LayoutOptions = None, self.device_local_layout = device_local_layout self.sharding = sharding + @property + def dll(self): + return self.device_local_layout + def __repr__(self): return (f'Layout(device_local_layout={self.device_local_layout},' f' sharding={self.sharding})') diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index cf4d13530b74..0c8d7393b98c 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2942,6 +2942,50 @@ def use_explicit_axes(*axes): with mesh_lib.use_abstract_mesh(new_mesh): yield +# -------------------- with_dll_constraint -------------------- + +def with_dll_constraint(x, layouts): + x_flat, tree = tree_flatten(x) + layouts_flat = tuple(flatten_axes("with_dll_constraint layouts", tree, + layouts)) + if any(not isinstance(l, DeviceLocalLayout) for l in layouts_flat): + raise ValueError( + 'layouts passed to `with_dll_constraint` must be of type' + f' `DeviceLocalLayout`. Got {[type(l) for l in layouts_flat]}') + check_aval_layout_compatibility( + layouts_flat, x_flat, ("",) * len(layouts_flat), + "with_dll_constraint arguments") + outs = [dll_constraint_p.bind(xf, layout=l) + for xf, l in zip(x_flat, layouts_flat)] + return tree_unflatten(tree, outs) + +dll_constraint_p = core.Primitive('dll_constraint') +dll_constraint_p.def_abstract_eval(lambda x, **_: x) +ad.deflinear2(dll_constraint_p, + lambda ct, _, **params: (dll_constraint_p.bind(ct, **params),)) + +def _dll_constraint_impl(x, *, layout): + if not isinstance(x, xc.ArrayImpl): + raise ValueError( + 'with_dll_constraint in eager mode can only be applied to' + f' jax.Arrays. Got {type(x)}') + if x.layout.device_local_layout == layout: # type: ignore + return x + return api.jit(_identity_fn, out_shardings=Layout(layout, x.sharding))(x) +dll_constraint_p.def_impl(_dll_constraint_impl) + +def _dll_constraint_hlo_lowering(ctx, x_node, *, layout): + aval, = ctx.avals_in + out_aval, = ctx.avals_out + return [mlir.wrap_with_layout_op(ctx, x_node, out_aval, layout, aval)] +mlir.register_lowering(dll_constraint_p, + _dll_constraint_hlo_lowering) + +def _dll_constraint_batcher(axis_data, vals_in, dims_in, layout): + raise NotImplementedError +batching.fancy_primitive_batchers[dll_constraint_p] = _dll_constraint_batcher +batching.skippable_batchers[dll_constraint_p] = lambda _: () + # -------------------- helpers -------------------- def get_unconstrained_dims(sharding: NamedSharding): diff --git a/jax/experimental/jax2tf/tests/primitives_test.py b/jax/experimental/jax2tf/tests/primitives_test.py index 1ccd009f157c..0156465e339a 100644 --- a/jax/experimental/jax2tf/tests/primitives_test.py +++ b/jax/experimental/jax2tf/tests/primitives_test.py @@ -174,6 +174,8 @@ def test_primitive_coverage(self): continue if p.name == "sharding_constraint": continue + if p.name == "dll_constraint": + continue if p.name == "mesh_cast": continue if p.name == "reshard": diff --git a/jax/experimental/layout.py b/jax/experimental/layout.py index ed9f8931938e..aa114a2803e8 100644 --- a/jax/experimental/layout.py +++ b/jax/experimental/layout.py @@ -14,5 +14,8 @@ from jax._src.layout import ( DeviceLocalLayout as DeviceLocalLayout, - Layout as Layout + Layout as Layout, +) +from jax._src.pjit import ( + with_dll_constraint as with_dll_constraint, ) diff --git a/tests/BUILD b/tests/BUILD index eb6ff81f5d68..0a58ee52d88c 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -325,6 +325,7 @@ jax_multiplatform_test( }, enable_configs = [ "tpu_v3_2x2_shardy", + "tpu_v3_2x2", ], tags = ["multiaccelerator"], deps = [ diff --git a/tests/layout_test.py b/tests/layout_test.py index b9062b8d21dc..ae10013a5f60 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -21,9 +21,10 @@ import jax.numpy as jnp from jax.sharding import NamedSharding, PartitionSpec as P, SingleDeviceSharding from jax._src import config -from jax._src.layout import Layout, DeviceLocalLayout as DLL from jax._src import test_util as jtu from jax._src.util import safe_zip +from jax.experimental.layout import (with_dll_constraint, Layout, + DeviceLocalLayout as DLL) from jax.experimental.compute_on import compute_on config.parse_flags_with_absl() @@ -744,6 +745,36 @@ def f(x): self.assertArraysEqual(out, np_inp * 2) self.assertEqual(out.layout, out_layout) + def test_with_dll_constraint(self): + if not jtu.test_device_matches(['tpu']): + self.skipTest('Only works for TPU') + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + shape = (16, 128) + s = NamedSharding(mesh, P('x')) + np_inp = np.arange(math.prod(shape)).reshape(shape) + arr = jax.device_put(np_inp, s) + + # Create a custom layout instead of using `arr.layout` to test the API. + custom_dll = DLL(major_to_minor=arr.layout.dll.major_to_minor[::-1]) + + def f(x): + y = x.T + # Constrain `y` to the original layout of `arr` because without it, + # the layout of `y` would be the transpose of `arr`. + y = with_dll_constraint(y, custom_dll) + return y * 2 + + f(arr) # doesn't crash + + f = jax.jit(f) + out = f(arr) + self.assertEqual(out.layout.device_local_layout.major_to_minor, + custom_dll.major_to_minor) + self.assertArraysEqual(out, np_inp.T * 2) + + lowered_text = f.lower(arr).as_text() + self.assertIn('LayoutConstraint', lowered_text) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From 84e04fe60838db9deb4bd040f8bb678424fd439d Mon Sep 17 00:00:00 2001 From: Rachel Han Date: Mon, 7 Apr 2025 16:24:17 -0700 Subject: [PATCH 447/483] Add custom pretty print rule for the unary ops with accuracy s.t. accuracy is not printed if it's None. PiperOrigin-RevId: 744889524 --- jax/_src/api.py | 13 ++- jax/_src/lax/lax.py | 19 +++- tests/api_test.py | 147 ++++++++++++++++--------------- tests/core_test.py | 8 +- tests/pmap_test.py | 16 ++-- tests/unary_ops_accuracy_test.py | 24 ++++- 6 files changed, 132 insertions(+), 95 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index fb10245c30e9..d338e2d70700 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2268,16 +2268,13 @@ def make_jaxpr( >>> print(f(3.0)) -0.83602 >>> jax.make_jaxpr(f)(3.0) - { lambda ; a:f32[]. let - b:f32[] = cos[accuracy=None] a - c:f32[] = sin[accuracy=None] b - in (c,) } + { lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) } >>> jax.make_jaxpr(jax.grad(f))(3.0) { lambda ; a:f32[]. let - b:f32[] = cos[accuracy=None] a - c:f32[] = sin[accuracy=None] a - _:f32[] = sin[accuracy=None] b - d:f32[] = cos[accuracy=None] b + b:f32[] = cos a + c:f32[] = sin a + _:f32[] = sin b + d:f32[] = cos b e:f32[] = mul 1.0 d f:f32[] = neg e g:f32[] = mul f c diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 13511641558c..7ca73603c14b 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -4066,6 +4066,11 @@ def _nary_lower_hlo( out = op(*args, result_accuracy=accuracy_attr(accuracy)) return [mlir.lower_with_sharding_in_types(ctx, out, aval_out)] +def _unary_with_accuracy_pp_rule(eqn, context, settings): + params = dict(eqn.params) + if 'accuracy' in params and params['accuracy'] is None: + del params['accuracy'] + return core._pp_eqn(eqn.replace(params=params), context, settings) _float = {np.floating} _complex = {np.complexfloating} @@ -4128,6 +4133,7 @@ def _round_lower(ctx, x, *, rounding_method): ad.defjvp2(exp_p, lambda g, ans, x, **kwargs: mul(g, ans)) mlir.register_lowering(exp_p, partial(_nary_lower_hlo, hlo.exponential)) batching.ragged_prop_rules[exp_p] = batching.ragged_mask_elementwise_rule +core.pp_eqn_rules[exp_p] = _unary_with_accuracy_pp_rule exp2_p = standard_unop(_float | _complex, 'exp2') ad.defjvp2( @@ -4145,19 +4151,23 @@ def _exp2_lower(ctx, x, accuracy): ] mlir.register_lowering(exp2_p, _exp2_lower) +core.pp_eqn_rules[exp2_p] = _unary_with_accuracy_pp_rule log_p = standard_unop(_float | _complex, 'log') ad.defjvp(log_p, lambda g, x, **kwargs: div(g, x)) mlir.register_lowering(log_p, partial(_nary_lower_hlo, hlo.log)) +core.pp_eqn_rules[log_p] = _unary_with_accuracy_pp_rule expm1_p = standard_unop(_float | _complex, 'expm1') ad.defjvp2(expm1_p, lambda g, ans, x, **kwargs: mul(g, add(ans, _one(ans)))) mlir.register_lowering(expm1_p, partial(_nary_lower_hlo, hlo.exponential_minus_one)) +core.pp_eqn_rules[expm1_p] = _unary_with_accuracy_pp_rule log1p_p = standard_unop(_float | _complex, 'log1p') ad.defjvp(log1p_p, lambda g, x, **kwargs: div(g, add(x, _one(x)))) mlir.register_lowering(log1p_p, partial(_nary_lower_hlo, hlo.log_plus_one)) +core.pp_eqn_rules[log1p_p] = _unary_with_accuracy_pp_rule tanh_p = standard_unop(_float | _complex, 'tanh') ad.defjvp2( @@ -4165,6 +4175,7 @@ def _exp2_lower(ctx, x, accuracy): lambda g, ans, x, **kwargs: mul(add(g, mul(g, ans)), sub(_one(x), ans)), ) mlir.register_lowering(tanh_p, partial(_nary_lower_hlo, hlo.tanh)) +core.pp_eqn_rules[tanh_p] = _unary_with_accuracy_pp_rule logistic_p = standard_unop(_float | _complex, 'logistic') ad.defjvp2( @@ -4174,13 +4185,13 @@ def _exp2_lower(ctx, x, accuracy): # TODO(phawkins): switch to LogisticOp lowering; debug numerical problems. # mlir.register_lowering(logistic_p, partial(_nary_lower_hlo, hlo.logistic)) - def logistic_impl(x, accuracy): one = _const(x, 1) return div(one, add(one, exp(neg(x)))) mlir.register_lowering(logistic_p, mlir.lower_fun(logistic_impl, multiple_results=False)) +core.pp_eqn_rules[logistic_p] = _unary_with_accuracy_pp_rule def _sin_complex(x): # use expm1 instead of exp to avoid cancellation when abs(x) is small @@ -4219,6 +4230,7 @@ def _sin_p_lin(nzs, x, accuracy): ad.defjvp(sin_p, lambda g, x, accuracy: mul(g, cos(x, accuracy=accuracy))) ad.primitive_linearizations[sin_p] = _sin_p_lin mlir.register_lowering(sin_p, _sin_lowering) +core.pp_eqn_rules[sin_p] = _unary_with_accuracy_pp_rule batching.ragged_prop_rules[sin_p] = batching.ragged_mask_elementwise_rule def _cos_complex(x): @@ -4244,10 +4256,12 @@ def _cos_lowering(ctx, x, accuracy): cos_p, lambda g, x, accuracy: neg(mul(g, sin(x, accuracy=accuracy))) ) mlir.register_lowering(cos_p, _cos_lowering) +core.pp_eqn_rules[cos_p] = _unary_with_accuracy_pp_rule tan_p = standard_unop(_float | _complex, 'tan') ad.defjvp2(tan_p, lambda g, ans, x, **kwargs: mul(g, add(_const(x, 1), square(ans)))) mlir.register_lowering(tan_p, partial(_nary_lower_hlo, hlo.tan)) +core.pp_eqn_rules[tan_p] = _unary_with_accuracy_pp_rule asin_p = standard_unop(_float | _complex, 'asin') ad.defjvp(asin_p, lambda g, x: mul(g, rsqrt(sub(_const(x, 1), square(x))))) @@ -4365,6 +4379,7 @@ def _abs_jvp_rule(g, ans, x): sqrt_p = standard_unop(_float | _complex, 'sqrt') ad.defjvp2(sqrt_p, lambda g, ans, x, **kwargs: mul(g, div(_const(x, 0.5), ans))) mlir.register_lowering(sqrt_p, partial(_nary_lower_hlo, hlo.sqrt)) +core.pp_eqn_rules[sqrt_p] = _unary_with_accuracy_pp_rule rsqrt_p = standard_unop(_float | _complex, 'rsqrt') ad.defjvp2( @@ -4372,6 +4387,7 @@ def _abs_jvp_rule(g, ans, x): lambda g, ans, x, **kwargs: mul(g, mul(_const(x, -0.5), div(ans, x))), ) mlir.register_lowering(rsqrt_p, partial(_nary_lower_hlo, hlo.rsqrt)) +core.pp_eqn_rules[rsqrt_p] = _unary_with_accuracy_pp_rule cbrt_p = standard_unop(_float, 'cbrt') ad.defjvp2( @@ -4381,6 +4397,7 @@ def _abs_jvp_rule(g, ans, x): ), ) mlir.register_lowering(cbrt_p, partial(_nary_lower_hlo, hlo.cbrt)) +core.pp_eqn_rules[cbrt_p] = _unary_with_accuracy_pp_rule square_p = standard_unop(_int | _float | _complex, 'square') diff --git a/tests/api_test.py b/tests/api_test.py index 83264f10e033..440fea1b059c 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -5093,11 +5093,11 @@ def f_yesremat(x): jaxpr = api.make_jaxpr(api.linearize(f_yesremat, 4.)[1])(1.) scan_eqn, = jaxpr.jaxpr.eqns - self.assertIn(' cos[', str(scan_eqn.params['jaxpr'])) + self.assertIn(' cos ', str(scan_eqn.params['jaxpr'])) jaxpr = api.make_jaxpr(api.vjp(f_yesremat, 4.)[1])(1.) scan_eqn, = jaxpr.jaxpr.eqns - self.assertIn(' cos[', str(scan_eqn.params['jaxpr'])) + self.assertIn(' cos ', str(scan_eqn.params['jaxpr'])) @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} @@ -5450,9 +5450,9 @@ def f(x): ('new_remat', new_checkpoint), ] for policy_name, policy, in_jaxpr2, not_in_jaxpr2 in [ - ('save_anything', lambda *_, **__: True, [], [' sin[', ' cos[[ ']), - ('save_nothing', lambda *_, **__: False, [' sin[', ' cos['], []), - ('save_sin', lambda p, *_, **__: str(p) == 'sin', [' cos['], [' sin[']), + ('save_anything', lambda *_, **__: True, [], [' sin ', ' cos ']), + ('save_nothing', lambda *_, **__: False, [' sin ', ' cos '], []), + ('save_sin', lambda p, *_, **__: str(p) == 'sin', [' cos '], [' sin ']), ]) def test_remat_custom_policy(self, remat, policy, in_jaxpr2, not_in_jaxpr2): for square in [lambda x: x * x, api.jit(lambda x: x * x)]: @@ -5482,8 +5482,8 @@ def test_remat_custom_policy_save_cos(self, remat): policy=save_cos) _, f_lin = api.linearize(f, 1.) jaxpr_text = str(f_lin.func.args[0]) - self.assertNotIn(' sin[', jaxpr_text) - self.assertNotIn(' cos[', jaxpr_text) + self.assertNotIn(' sin ', jaxpr_text) + self.assertNotIn(' cos ', jaxpr_text) jtu.check_grads(f, (3.,), order=2, modes=['fwd', 'rev']) @parameterized.named_parameters( @@ -5505,7 +5505,7 @@ def f(x): _, f_lin = api.linearize(f, jnp.ones((2, 2))) jaxpr_text = str(f_lin.func.args[0]) - self.assertEqual(jaxpr_text.count(' sin['), 2) + self.assertEqual(jaxpr_text.count(' sin '), 2) self.assertEqual(jaxpr_text.count(' dot_'), 6) jtu.check_grads(f, (jnp.ones((2, 2)),), order=2, modes=['fwd', 'rev']) @@ -5528,7 +5528,7 @@ def f(x): _, f_lin = api.linearize(f, jnp.ones((2, 2))) jaxpr_text = str(f_lin.func.args[0]) - self.assertEqual(jaxpr_text.count(' sin['), 2) + self.assertEqual(jaxpr_text.count(' sin '), 2) self.assertEqual(jaxpr_text.count(' dot_general'), 6) jtu.check_grads(f, (jnp.ones((2, 2)),), order=2, modes=['fwd', 'rev']) @@ -5551,7 +5551,7 @@ def f(x): _, f_lin = api.linearize(f, jnp.ones((3, 2, 2))) jaxpr_text = str(f_lin.func.args[0]) - self.assertEqual(jaxpr_text.count(' sin['), 2) + self.assertEqual(jaxpr_text.count(' sin '), 2) self.assertEqual(jaxpr_text.count(' dot_general'), 9) jtu.check_grads(f, (jnp.ones((3, 2, 2)),), order=2, modes=['fwd', 'rev']) @@ -5575,7 +5575,7 @@ def f(x): _, f_lin = api.linearize(f, jnp.ones((2, 2))) jaxpr_text = str(f_lin.func.args[0]) - self.assertEqual(jaxpr_text.count(' sin['), 2) + self.assertEqual(jaxpr_text.count(' sin '), 2) self.assertEqual(jaxpr_text.count(' dot_'), 6) jtu.check_grads(f, (jnp.ones((2, 2)),), order=2, modes=['fwd', 'rev']) @@ -5599,8 +5599,8 @@ def body(x, _): return f(x), None # Two sine calls in the backward pass because while we don't save sines # within the (rematted) body function, we can save the scan carry, which # effectively saves one sine. Three cosines for the Jacobian coefficients. - self.assertEqual(jaxpr_text.count(' sin['), 2) - self.assertEqual(jaxpr_text.count(' cos['), 3) + self.assertEqual(jaxpr_text.count(' sin '), 2) + self.assertEqual(jaxpr_text.count(' cos '), 3) # Six calls to dot_general in the backward pass because we save the primal # matmuls and only compure the backward pass ones (two for each primal one). self.assertEqual(jaxpr_text.count(' dot_'), 6) @@ -5906,8 +5906,9 @@ def test_remat_of_scan(self, remat): jtu.check_grads(remat(f), (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(remat(f), 4.)[1])(1.) - self.assertIn(' sin[', str(jaxpr)) - self.assertIn(' cos[', str(jaxpr)) + print("debug jaxpr: ", str(jaxpr)) + self.assertIn(' sin ', str(jaxpr)) + self.assertIn(' cos ', str(jaxpr)) @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} @@ -5952,8 +5953,8 @@ def body(x, _): return f(x), None jaxpr = f_vjp.args[0].func.args[1] jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 3) - self.assertEqual(jaxpr_text.count(' cos['), 3) + self.assertEqual(jaxpr_text.count(' sin '), 3) + self.assertEqual(jaxpr_text.count(' cos '), 3) # Six calls to dot_general in the backward pass because we save the primal # matmuls and only compute the backward pass ones (two for each primal one). self.assertEqual(jaxpr_text.count(' dot_'), 6) @@ -5970,8 +5971,8 @@ def test_remat_of_scan_policy(self): jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 0) - self.assertEqual(jaxpr_text.count(' cos['), 0) + self.assertEqual(jaxpr_text.count(' sin '), 0) + self.assertEqual(jaxpr_text.count(' cos '), 0) def test_remat_of_scan_funky_custom_jvp(self): def scan_apply(f, x): @@ -5994,40 +5995,40 @@ def sin_jvp(primals, tangents): jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 0) - self.assertEqual(jaxpr_text.count(' cos['), 0) + self.assertEqual(jaxpr_text.count(' sin '), 0) + self.assertEqual(jaxpr_text.count(' cos '), 0) save_sin = lambda prim, *_, **__: str(prim) == 'sin' f = new_checkpoint(partial(scan_apply, sin), policy=save_sin) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 0) - self.assertEqual(jaxpr_text.count(' cos['), 1) + self.assertEqual(jaxpr_text.count(' sin '), 0) + self.assertEqual(jaxpr_text.count(' cos '), 1) f = new_checkpoint(partial(scan_apply, sin), policy=jax.checkpoint_policies.everything_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 0) - self.assertEqual(jaxpr_text.count(' cos['), 0) + self.assertEqual(jaxpr_text.count(' sin '), 0) + self.assertEqual(jaxpr_text.count(' cos '), 0) f = new_checkpoint(partial(scan_apply, sin), policy=jax.checkpoint_policies.nothing_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 1) # +1 b/c dce fixed point - self.assertEqual(jaxpr_text.count(' cos['), 1) + self.assertEqual(jaxpr_text.count(' sin '), 1) # +1 b/c dce fixed point + self.assertEqual(jaxpr_text.count(' cos '), 1) f = new_checkpoint(lambda x: scan_apply(sin, scan_apply(sin, x)), policy=jax.checkpoint_policies.nothing_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 2) # +1 b/c dce fixed point - self.assertEqual(jaxpr_text.count(' cos['), 2) + self.assertEqual(jaxpr_text.count(' sin '), 2) # +1 b/c dce fixed point + self.assertEqual(jaxpr_text.count(' cos '), 2) def test_remat_of_scan_funky_custom_jvp2(self): # Like the above test but instead of using jit inside custom_jvp, use scan. @@ -6052,40 +6053,40 @@ def sin_jvp(primals, tangents): jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 1) # +1 b/c dce fixed point - self.assertEqual(jaxpr_text.count(' cos['), 0) + self.assertEqual(jaxpr_text.count(' sin '), 1) # +1 b/c dce fixed point + self.assertEqual(jaxpr_text.count(' cos '), 0) save_sin = lambda prim, *_, **__: str(prim) == 'sin' f = new_checkpoint(partial(scan_apply, sin), policy=save_sin) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 0) - self.assertEqual(jaxpr_text.count(' cos['), 1) + self.assertEqual(jaxpr_text.count(' sin '), 0) + self.assertEqual(jaxpr_text.count(' cos '), 1) f = new_checkpoint(partial(scan_apply, sin), policy=jax.checkpoint_policies.everything_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 0) - self.assertEqual(jaxpr_text.count(' cos['), 0) + self.assertEqual(jaxpr_text.count(' sin '), 0) + self.assertEqual(jaxpr_text.count(' cos '), 0) f = new_checkpoint(partial(scan_apply, sin), policy=jax.checkpoint_policies.nothing_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 1) # +1 b/c dce fixed point - self.assertEqual(jaxpr_text.count(' cos['), 1) + self.assertEqual(jaxpr_text.count(' sin '), 1) # +1 b/c dce fixed point + self.assertEqual(jaxpr_text.count(' cos '), 1) f = new_checkpoint(lambda x: scan_apply(sin, scan_apply(sin, x)), policy=jax.checkpoint_policies.nothing_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 2) # +1 b/c dce fixed point - self.assertEqual(jaxpr_text.count(' cos['), 2) + self.assertEqual(jaxpr_text.count(' sin '), 2) # +1 b/c dce fixed point + self.assertEqual(jaxpr_text.count(' cos '), 2) @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} @@ -6100,8 +6101,8 @@ def test_remat_of_cond(self, remat): jtu.check_grads(remat(f), (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(remat(f), 4.)[1])(1.) - self.assertNotIn(' sin[', str(jaxpr)) - self.assertIn(' cos[', str(jaxpr)) + self.assertNotIn(' sin ', str(jaxpr)) + self.assertIn(' cos ', str(jaxpr)) true_fn = lambda c: jnp.sin(jnp.sin(c)) false_fn = lambda c: c @@ -6109,8 +6110,8 @@ def test_remat_of_cond(self, remat): jtu.check_grads(remat(f), (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(remat(f), 4.)[1])(1.) - self.assertIn(' sin[', str(jaxpr)) - self.assertIn(' cos[', str(jaxpr)) + self.assertIn(' sin ', str(jaxpr)) + self.assertIn(' cos ', str(jaxpr)) @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} @@ -6150,8 +6151,8 @@ def f(x): _, f_vjp = api.vjp(f, jnp.ones((5, 5))) jaxpr_text = str(f_vjp.args[0].func.args[1]) - self.assertEqual(jaxpr_text.count(' sin['), 2) - self.assertEqual(jaxpr_text.count(' cos['), 3) + self.assertEqual(jaxpr_text.count(' sin '), 2) + self.assertEqual(jaxpr_text.count(' cos '), 3) # Five calls to dot_general in the backward pass because we have two for # each forward-pass dot, except for the first which only has one (as we are # differentiating with respect to only W and not x). @@ -6181,8 +6182,8 @@ def f(x): jaxpr = f_vjp.args[0].func.args[1] jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 2) - self.assertEqual(jaxpr_text.count(' cos['), 3) + self.assertEqual(jaxpr_text.count(' sin '), 2) + self.assertEqual(jaxpr_text.count(' cos '), 3) self.assertEqual(jaxpr_text.count(' dot_'), 5) jtu.check_grads(api.jit(f), (jnp.ones((5, 5)),), order=2, @@ -6196,8 +6197,8 @@ def test_remat_of_cond_policy(self): jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 0) - self.assertEqual(jaxpr_text.count(' cos['), 0) + self.assertEqual(jaxpr_text.count(' sin '), 0) + self.assertEqual(jaxpr_text.count(' cos '), 0) def test_remat_of_cond_funky_custom_jvp(self): def cond_apply(f, x): @@ -6219,40 +6220,40 @@ def sin_jvp(primals, tangents): jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 0) - self.assertEqual(jaxpr_text.count(' cos['), 0) + self.assertEqual(jaxpr_text.count(' sin '), 0) + self.assertEqual(jaxpr_text.count(' cos '), 0) save_sin = lambda prim, *_, **__: str(prim) == 'sin' f = new_checkpoint(partial(cond_apply, sin), policy=save_sin) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 0) - self.assertEqual(jaxpr_text.count(' cos['), 1) + self.assertEqual(jaxpr_text.count(' sin '), 0) + self.assertEqual(jaxpr_text.count(' cos '), 1) f = new_checkpoint(partial(cond_apply, sin), policy=jax.checkpoint_policies.everything_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 0) - self.assertEqual(jaxpr_text.count(' cos['), 0) + self.assertEqual(jaxpr_text.count(' sin '), 0) + self.assertEqual(jaxpr_text.count(' cos '), 0) f = new_checkpoint(partial(cond_apply, sin), policy=jax.checkpoint_policies.nothing_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 0) - self.assertEqual(jaxpr_text.count(' cos['), 1) + self.assertEqual(jaxpr_text.count(' sin '), 0) + self.assertEqual(jaxpr_text.count(' cos '), 1) f = new_checkpoint(lambda x: cond_apply(sin, cond_apply(sin, x)), policy=jax.checkpoint_policies.nothing_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 1) - self.assertEqual(jaxpr_text.count(' cos['), 2) + self.assertEqual(jaxpr_text.count(' sin '), 1) + self.assertEqual(jaxpr_text.count(' cos '), 2) def test_remat_of_cond_funky_custom_jvp2(self): # Like the above test but instead of using jit inside custom_jvp, use cond. @@ -6276,40 +6277,40 @@ def sin_jvp(primals, tangents): jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 0) - self.assertEqual(jaxpr_text.count(' cos['), 0) + self.assertEqual(jaxpr_text.count(' sin '), 0) + self.assertEqual(jaxpr_text.count(' cos '), 0) save_sin = lambda prim, *_, **__: str(prim) == 'sin' f = new_checkpoint(partial(cond_apply, sin), policy=save_sin) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 0) - self.assertEqual(jaxpr_text.count(' cos['), 1) + self.assertEqual(jaxpr_text.count(' sin '), 0) + self.assertEqual(jaxpr_text.count(' cos '), 1) f = new_checkpoint(partial(cond_apply, sin), policy=jax.checkpoint_policies.everything_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 0) - self.assertEqual(jaxpr_text.count(' cos['), 0) + self.assertEqual(jaxpr_text.count(' sin '), 0) + self.assertEqual(jaxpr_text.count(' cos '), 0) f = new_checkpoint(partial(cond_apply, sin), policy=jax.checkpoint_policies.nothing_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 0) - self.assertEqual(jaxpr_text.count(' cos['), 1) + self.assertEqual(jaxpr_text.count(' sin '), 0) + self.assertEqual(jaxpr_text.count(' cos '), 1) f = new_checkpoint(lambda x: cond_apply(sin, cond_apply(sin, x)), policy=jax.checkpoint_policies.nothing_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 1) - self.assertEqual(jaxpr_text.count(' cos['), 2) + self.assertEqual(jaxpr_text.count(' sin '), 1) + self.assertEqual(jaxpr_text.count(' cos '), 2) @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} @@ -6334,8 +6335,8 @@ def f(x): self.assertArraysAllClose(y_dot, expected, check_dtypes=False) jaxpr = api.make_jaxpr(jax.linearize(remat(f), 4.)[1])(1.) - self.assertIn(' sin[', str(jaxpr)) - self.assertIn(' cos[', str(jaxpr)) + self.assertIn(' sin ', str(jaxpr)) + self.assertIn(' cos ', str(jaxpr)) def test_remat_of_while_loop_policy(self): def cond_fn(carry): @@ -6352,8 +6353,8 @@ def f(x): save_cos = lambda prim, *_, **__: str(prim) == 'cos' g = new_checkpoint(f, policy=save_cos) jaxpr = api.make_jaxpr(jax.linearize(g, 4.)[1])(1.) - self.assertIn(' sin[', str(jaxpr)) - self.assertIn(' cos[', str(jaxpr)) + self.assertIn(' sin ', str(jaxpr)) + self.assertIn(' cos ', str(jaxpr)) @jtu.thread_unsafe_test() # logging isn't thread-safe def test_remat_residual_logging(self): diff --git a/tests/core_test.py b/tests/core_test.py index 03d6355cb257..8ab24dbe51f6 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -474,8 +474,8 @@ def new_jaxpr(): # jaxpr is: # # { lambda ; a. - # let b = sin[accuracy=None] a - # c = cos[accuracy=None] a + # let b = sin a + # c = cos a # d = add b c # in (d,) } # @@ -487,7 +487,7 @@ def new_jaxpr(): self.assertRaisesRegex( core.JaxprTypeError, r"Value for variable 'b' inconsistently typed as f32\[\] " - r"for let-binder of type i32\[\]\n\nin equation:\n\nb:i32\[\] = sin\[accuracy=None] a", + r"for let-binder of type i32\[\]\n\nin equation:\n\nb:i32\[\] = sin\ a", lambda: core.check_jaxpr(jaxpr)) jaxpr = new_jaxpr() @@ -496,7 +496,7 @@ def new_jaxpr(): self.assertRaisesRegex( core.JaxprTypeError, r"Value for variable 'b' inconsistently typed as f32\[\] " - r"for let-binder of type f32\[2,3\]\n\nin equation:\n\nb:f32\[2,3\] = sin\[accuracy=None] a", + r"for let-binder of type f32\[2,3\]\n\nin equation:\n\nb:f32\[2,3\] = sin\ a", lambda: core.check_jaxpr(jaxpr)) def test_jaxpr_dropvar_from_jit_call(self): diff --git a/tests/pmap_test.py b/tests/pmap_test.py index d40293501edf..af2d03e2945d 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -2082,8 +2082,8 @@ def test_remat_of_pmap(self, remat): x = jnp.arange(1.) jaxpr = jax.make_jaxpr(jax.linearize(f, x)[1])(x) - self.assertIn(' sin[', str(jaxpr)) - self.assertIn(' cos[', str(jaxpr)) + self.assertIn(' sin ', str(jaxpr)) + self.assertIn(' cos ', str(jaxpr)) @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} @@ -2100,24 +2100,24 @@ def test_remat_of_pmap_policy(self, remat): _, f_vjp = jax.vjp(f, x) jaxpr = f_vjp.args[0].func.args[1] jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 0) - self.assertEqual(jaxpr_text.count(' cos['), 0) + self.assertEqual(jaxpr_text.count(' sin '), 0) + self.assertEqual(jaxpr_text.count(' cos '), 0) save_sin = lambda prim, *_, **__: str(prim) == 'sin' f = remat(g, policy=save_sin) _, f_vjp = jax.vjp(f, x) jaxpr = f_vjp.args[0].func.args[1] jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 0) - self.assertEqual(jaxpr_text.count(' cos['), 2) + self.assertEqual(jaxpr_text.count(' sin '), 0) + self.assertEqual(jaxpr_text.count(' cos '), 2) save_nothing = lambda prim, *_, **__: False f = remat(g, policy=save_nothing) _, f_vjp = jax.vjp(f, x) jaxpr = f_vjp.args[0].func.args[1] jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin['), 1) - self.assertEqual(jaxpr_text.count(' cos['), 2) + self.assertEqual(jaxpr_text.count(' sin '), 1) + self.assertEqual(jaxpr_text.count(' cos '), 2) def test_axis_name_shadowing_with_vmap(self): # vmap-of-pmap with mismatched axis sizes diff --git a/tests/unary_ops_accuracy_test.py b/tests/unary_ops_accuracy_test.py index fb370ab96923..289e33a404f2 100644 --- a/tests/unary_ops_accuracy_test.py +++ b/tests/unary_ops_accuracy_test.py @@ -253,7 +253,7 @@ def f(x, y): @parameterized.named_parameters( *generate_test_cases(["exp", "expm1", "exp2"]) ) - def test_diff_grad(self, op, x, tp, **kwargs): + def test_diff_grad(self, op, x, tp, **kwargs): @jax.jit def f_default(x): default_op = op(x, accuracy=tp.low) @@ -368,6 +368,28 @@ def test_low_tol(self, op, x, **kwargs): ): op(x, accuracy=lax.Tolerance(atol=1e-60, rtol=1e-60, ulps=0)) + def test_accuracy_jaxpr(self): + # Since accuracy is not set, the jaxpr should not contain "accuracy". + self.assertNotIn( + "accuracy", + str( + jax.make_jaxpr(lambda x: lax.exp(x, accuracy=None))( + np.arange(4.0, dtype=np.float32) + ) + ), + ) + # Set accuracy. + self.assertIn( + "accuracy", + str( + jax.make_jaxpr( + lambda x: lax.exp( + x, accuracy=lax.Tolerance(atol=1e-60, rtol=1e-60, ulps=0) + ) + )(np.arange(4.0, dtype=np.float32)) + ), + ) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From ca6e470d2f4c9e583532a1c413277e79ea2b7852 Mon Sep 17 00:00:00 2001 From: Zac Cranko Date: Mon, 7 Apr 2025 23:30:31 +0000 Subject: [PATCH 448/483] harden cache against jaxlib ver --- jax/_src/cache_key.py | 11 +++++++---- tests/cache_key_test.py | 4 ++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/jax/_src/cache_key.py b/jax/_src/cache_key.py index e4b6e7a2669c..6fe3d8819d3c 100644 --- a/jax/_src/cache_key.py +++ b/jax/_src/cache_key.py @@ -110,6 +110,10 @@ def get( bytes(jaxlib_version_str.encode("utf-8")) ), ), + ( + "backend version", + lambda hash_obj: _hash_platform(hash_obj, backend) + ), ( "XLA flags", lambda hash_obj: _hash_xla_flags(hash_obj, get_flag_prefixes()), @@ -126,7 +130,7 @@ def get( ), ( "accelerator_config", - lambda hash_obj: _hash_accelerator_config(hash_obj, devices, backend), + lambda hash_obj: _hash_accelerator_config(hash_obj, devices), ), ( "compression", @@ -220,7 +224,7 @@ def _hash_devices(hash_obj, devices: np.ndarray) -> None: _hash_string(hash_obj, device.device_kind) -def _hash_accelerator_config(hash_obj, accelerators: np.ndarray, backend): +def _hash_accelerator_config(hash_obj, accelerators: np.ndarray): accelerator_devices = [] for accelerator in accelerators.flat: accelerator_devices.append(accelerator) @@ -233,9 +237,8 @@ def _hash_accelerator_config(hash_obj, accelerators: np.ndarray, backend): # PjRtTopologyDescription as yet. logger.info("get (_hash_accelerator_config): unable to hash " "accelerator config, falling back to hashing " - "devices + platform: %s (type %s)", ex, type(ex)) + "devices %s (type %s)", ex, type(ex)) _hash_devices(hash_obj, accelerators) - _hash_platform(hash_obj, backend) # LINT.IfChange(xla_flags) xla_flags_to_exclude_from_cache_key = [ diff --git a/tests/cache_key_test.py b/tests/cache_key_test.py index a908d260d560..fd3e7706260a 100644 --- a/tests/cache_key_test.py +++ b/tests/cache_key_test.py @@ -83,9 +83,9 @@ def test_hash_accelerator_devices(self): self.assertEqual(dev_hash1, dev_hash2) acc_hash1 = self.get_hashed_value( - cache_key._hash_accelerator_config, devices, xla_bridge.get_backend()) + cache_key._hash_accelerator_config, devices) acc_hash2 = self.get_hashed_value( - cache_key._hash_accelerator_config, devices, xla_bridge.get_backend()) + cache_key._hash_accelerator_config, devices) self.assertEqual(acc_hash1, acc_hash2) def test_hash_platform(self): From 31589960ff30816e57977f7aaa7c04f97ba9cac6 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 7 Apr 2025 17:31:04 -0700 Subject: [PATCH 449/483] Migrate custom_call filecheck to use internal custom_call since the external one is deprecated. PiperOrigin-RevId: 744908555 --- tests/filecheck/custom_call.filecheck.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/filecheck/custom_call.filecheck.py b/tests/filecheck/custom_call.filecheck.py index c6af4235ebb4..27cc904e59d8 100644 --- a/tests/filecheck/custom_call.filecheck.py +++ b/tests/filecheck/custom_call.filecheck.py @@ -19,7 +19,7 @@ from absl import app import jax -from jax.interpreters import mlir +from jax._src.interpreters import mlir from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import func as func_dialect import numpy as np From 86de4783bb472af6a2ef17e61bd926097aa525eb Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 7 Apr 2025 19:25:34 -0700 Subject: [PATCH 450/483] Remove unused function jax._src.interpreters.mlir.xla_computation_to_mlir_module. PiperOrigin-RevId: 744934776 --- jax/_src/interpreters/mlir.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index a112063ce3ae..65d9dbe5791f 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -2757,11 +2757,6 @@ def cached_lowering(ctx, *args, **params): return cached_lowering -def xla_computation_to_mlir_module(xla_computation: xc.XlaComputation - ) -> ir.Module: - module_str = xc._xla.mlir.xla_computation_to_mlir_module(xla_computation) - return ir.Module.parse(module_str) - def merge_mlir_modules(dst_module: ir.Module, sym_name: str, src_module: ir.Module, From bb515aa74f24b7688b5a8f612990893d7da3654e Mon Sep 17 00:00:00 2001 From: Alex Pivovarov Date: Mon, 7 Apr 2025 20:00:46 -0700 Subject: [PATCH 451/483] Address previous FP8-related TODOs in jaxlib/XLA. The ml_dtype requirement in JAX was updated to version 0.5.0+ (on Mar 20, 2025) - commit 4b7ead4 This update allows us to address previous FP8-related TODOs in jaxlib/XLA. PiperOrigin-RevId: 744943824 --- jaxlib/xla/py_values.cc | 21 +++++++++++++++------ jaxlib/xla/xla.cc | 7 +++---- jaxlib/xla/xla_client.py | 11 +++++------ jaxlib/xla/xla_client.pyi | 9 ++++----- jaxlib/xla/xla_client_test.py | 32 +++++++++++++++++++++++++++----- 5 files changed, 54 insertions(+), 26 deletions(-) diff --git a/jaxlib/xla/py_values.cc b/jaxlib/xla/py_values.cc index 709f3cb3b2ef..90dd77209694 100644 --- a/jaxlib/xla/py_values.cc +++ b/jaxlib/xla/py_values.cc @@ -694,16 +694,25 @@ absl::StatusOr PyArgSignatureOfValue(nb::handle arg, (*p)[dtypes.np_uint16.ptr()] = numpy_array_handler; (*p)[dtypes.np_uint32.ptr()] = numpy_array_handler; (*p)[dtypes.np_uint64.ptr()] = np_uint64_handler; - // TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. - // (*p)[dtypes.np_float4_e2m1fn.ptr()] = numpy_array_handler; - // (*p)[dtypes.np_float8_e3m4.ptr()] = numpy_array_handler; - // (*p)[dtypes.np_float8_e4m3.ptr()] = numpy_array_handler; - // (*p)[dtypes.np_float8_e8m0fnu.ptr()] = numpy_array_handler; + // TODO(upwind): Explore if we can remove std::optional for these types + // in xla/python/types.h and xla/python/types.cc + if (dtypes.np_float4_e2m1fn.has_value()) { + (*p)[dtypes.np_float4_e2m1fn->ptr()] = numpy_array_handler; + } + if (dtypes.np_float8_e3m4.has_value()) { + (*p)[dtypes.np_float8_e3m4->ptr()] = numpy_array_handler; + } + if (dtypes.np_float8_e4m3.has_value()) { + (*p)[dtypes.np_float8_e4m3->ptr()] = numpy_array_handler; + } (*p)[dtypes.np_float8_e4m3fn.ptr()] = numpy_array_handler; (*p)[dtypes.np_float8_e4m3b11fnuz.ptr()] = numpy_array_handler; - (*p)[dtypes.np_float8_e5m2.ptr()] = numpy_array_handler; (*p)[dtypes.np_float8_e4m3fnuz.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float8_e5m2.ptr()] = numpy_array_handler; (*p)[dtypes.np_float8_e5m2fnuz.ptr()] = numpy_array_handler; + if (dtypes.np_float8_e8m0fnu.has_value()) { + (*p)[dtypes.np_float8_e8m0fnu->ptr()] = numpy_array_handler; + } (*p)[dtypes.np_float16.ptr()] = numpy_array_handler; (*p)[dtypes.np_bfloat16.ptr()] = numpy_array_handler; (*p)[dtypes.np_float32.ptr()] = numpy_array_handler; diff --git a/jaxlib/xla/xla.cc b/jaxlib/xla/xla.cc index e460a1773e94..660e62bd8019 100644 --- a/jaxlib/xla/xla.cc +++ b/jaxlib/xla/xla.cc @@ -208,15 +208,14 @@ NB_MODULE(xla_extension, m) { .value("U64", U64) .value("F16", F16) .value("F4E2M1FN", F4E2M1FN) - // TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. - // .value("F8E3M4", F8E3M4) - // .value("F8E4M3", F8E4M3) - .value("F8E8M0FNU", F8E8M0FNU) + .value("F8E3M4", F8E3M4) + .value("F8E4M3", F8E4M3) .value("F8E4M3FN", F8E4M3FN) .value("F8E4M3B11FNUZ", F8E4M3B11FNUZ) .value("F8E4M3FNUZ", F8E4M3FNUZ) .value("F8E5M2", F8E5M2) .value("F8E5M2FNUZ", F8E5M2FNUZ) + .value("F8E8M0FNU", F8E8M0FNU) .value("BF16", BF16) .value("F32", F32) .value("F64", F64) diff --git a/jaxlib/xla/xla_client.py b/jaxlib/xla/xla_client.py index fa31d1764de2..637d7d060aa2 100644 --- a/jaxlib/xla/xla_client.py +++ b/jaxlib/xla/xla_client.py @@ -260,16 +260,15 @@ def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): PrimitiveType.U16: np.dtype('uint16'), PrimitiveType.U32: np.dtype('uint32'), PrimitiveType.U64: np.dtype('uint64'), - # TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. - # PrimitiveType.F4E2M1FN: np.dtype(float4_e2m1fn), - # PrimitiveType.F8E3M4: np.dtype(float8_e3m4), - # PrimitiveType.F8E4M3: np.dtype(float8_e4m3), - # PrimitiveType.F8E8M0FNU: np.dtype(float8_e8m0fnu), + PrimitiveType.F4E2M1FN: np.dtype(float4_e2m1fn), + PrimitiveType.F8E3M4: np.dtype(float8_e3m4), + PrimitiveType.F8E4M3: np.dtype(float8_e4m3), PrimitiveType.F8E4M3FN: np.dtype(float8_e4m3fn), PrimitiveType.F8E4M3B11FNUZ: np.dtype(float8_e4m3b11fnuz), - PrimitiveType.F8E5M2: np.dtype(float8_e5m2), PrimitiveType.F8E4M3FNUZ: np.dtype(float8_e4m3fnuz), + PrimitiveType.F8E5M2: np.dtype(float8_e5m2), PrimitiveType.F8E5M2FNUZ: np.dtype(float8_e5m2fnuz), + PrimitiveType.F8E8M0FNU: np.dtype(float8_e8m0fnu), PrimitiveType.BF16: np.dtype(bfloat16), PrimitiveType.F16: np.dtype('float16'), PrimitiveType.F32: np.dtype('float32'), diff --git a/jaxlib/xla/xla_client.pyi b/jaxlib/xla/xla_client.pyi index b182eb65ba60..382858d2a6d0 100644 --- a/jaxlib/xla/xla_client.pyi +++ b/jaxlib/xla/xla_client.pyi @@ -63,16 +63,15 @@ _ifrt_version: int mlir_api_version: int bfloat16: type[numpy.generic] -# TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. -# float4_e2m1fn: type[numpy.generic] -# float8_e3m4: type[numpy.generic] -# float8_e4m3: type[numpy.generic] -# float8_e8m0fnu: type[numpy.generic] +float4_e2m1fn: type[numpy.generic] +float8_e3m4: type[numpy.generic] +float8_e4m3: type[numpy.generic] float8_e4m3fn: type[numpy.generic] float8_e4m3b11fnuz: type[numpy.generic] float8_e4m3fnuz: type[numpy.generic] float8_e5m2: type[numpy.generic] float8_e5m2fnuz: type[numpy.generic] +float8_e8m0fnu: type[numpy.generic] XLA_ELEMENT_TYPE_TO_DTYPE: dict[PrimitiveType, numpy.dtype] _NameValueMapping = Mapping[str, Union[str, int, list[int], float, bool]] diff --git a/jaxlib/xla/xla_client_test.py b/jaxlib/xla/xla_client_test.py index 7de905d9ec41..9c6625610ca6 100644 --- a/jaxlib/xla/xla_client_test.py +++ b/jaxlib/xla/xla_client_test.py @@ -48,12 +48,12 @@ float4_e2m1fn = ml_dtypes.float4_e2m1fn float8_e3m4 = ml_dtypes.float8_e3m4 float8_e4m3 = ml_dtypes.float8_e4m3 -float8_e8m0fnu = ml_dtypes.float8_e8m0fnu float8_e4m3fn = ml_dtypes.float8_e4m3fn float8_e4m3fnuz = ml_dtypes.float8_e4m3fnuz float8_e4m3b11fnuz = ml_dtypes.float8_e4m3b11fnuz float8_e5m2 = ml_dtypes.float8_e5m2 float8_e5m2fnuz = ml_dtypes.float8_e5m2fnuz +float8_e8m0fnu = ml_dtypes.float8_e8m0fnu ops = xla_client.ops xla_computation_to_mlir_module = ( xla_client._xla.mlir.xla_computation_to_mlir_module) @@ -178,10 +178,17 @@ def TestFactory(xla_backend, # TODO(zhangqiaorjc): test fp8 types when XLA support is complete. # standard_dtypes is only used for BufferProtocolTest so we only test fp8 # round trip tests. - fp8_dtypes = [float8_e4m3b11fnuz, float8_e4m3fn, float8_e5m2] + fp8_dtypes = [ + float8_e3m4, + float8_e4m3, + float8_e4m3fn, + float8_e4m3b11fnuz, + float8_e5m2, + float8_e8m0fnu, + ] standard_dtypes += fp8_dtypes - # TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. - # standard_dtypes += [float4_e2m1fn, float8_e3m4, float8_e4m3, float8_e8m0fnu] + # TODO(upwind): testRoundTrip and testLiveBuffers fail for float4_e2m1fn type + # standard_dtypes += [float4_e2m1fn] dlpack_dtypes = int_dtypes + float_dtypes + [np.bool_] + complex_dtypes class ComputationTest(parameterized.TestCase): @@ -1228,9 +1235,19 @@ def testStandardTypes(self): for dtype in standard_dtypes: if dtype == np.complex128: continue + # float8_e8m0fnu is not supported on TPU. + if dtype == float8_e8m0fnu and self.backend.platform == "tpu": + continue # float8_e4m3b11fnuz not supported on some TPU backends. if ( - dtype in [float8_e5m2fnuz, float8_e4m3fnuz, float8_e4m3b11fnuz] + dtype + in [ + float8_e3m4, + float8_e4m3, + float8_e4m3fnuz, + float8_e4m3b11fnuz, + float8_e5m2fnuz, + ] and self.backend.platform == "tpu" ): if self.backend.platform_version.find("TPU") == -1: @@ -2253,6 +2270,11 @@ def testFft(self): "dtype": dtype, } for dtype in float_dtypes + fp8_dtypes) def testNextAfter(self, dtype): + if dtype == float8_e8m0fnu: + # TODO(b/409114865): Test fails with Mismatched elements error. + self.skipTest("b/409114865: Test fails with Mismatched elements error") + if dtype in [float8_e3m4, float8_e4m3] and self.backend.platform == "tpu": + self.skipTest("TPU doesn't support float8_e3m4 or float8_e4m3") if dtype == np.float64 and self.backend.platform == "tpu": self.skipTest("TPU doesn't support float64") if dtype == bfloat16 and self.backend.platform == "tpu": From 51dbcd4dad3acf6f83943d1febbb7d5c773c7f59 Mon Sep 17 00:00:00 2001 From: George Necula Date: Tue, 8 Apr 2025 00:09:27 -0700 Subject: [PATCH 452/483] [export] Add backwards compatibility test for annotate_device_placement. This enables exporting functions that use memory kinds to place data in different memories. jax-fixit PiperOrigin-RevId: 745008959 --- jax/_src/export/_export.py | 1 + .../annotate_data_placement.py | 73 +++++++++++++++++++ .../export_back_compat_test_util.py | 21 ++++-- tests/export_back_compat_test.py | 32 +++++++- 4 files changed, 118 insertions(+), 9 deletions(-) create mode 100644 jax/_src/internal_test_util/export_back_compat_test_data/annotate_data_placement.py diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index 90cc0c186ad1..4315c948bb5c 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -1082,6 +1082,7 @@ def _check_lowering(lowering) -> None: *_CPU_FFI_KERNELS, *_GPU_FFI_KERNELS, "Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape", + "annotate_device_placement", "cu_threefry2x32_ffi", # Triton IR does not guarantee stability. # "__gpu$xla.gpu.triton", diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/annotate_data_placement.py b/jax/_src/internal_test_util/export_back_compat_test_data/annotate_data_placement.py new file mode 100644 index 000000000000..bf70df2cdb3a --- /dev/null +++ b/jax/_src/internal_test_util/export_back_compat_test_data/annotate_data_placement.py @@ -0,0 +1,73 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa + +import datetime +from numpy import array, float32, int32 + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2025_04_07_tpu = dict( + testdata_version=1, + platform='tpu', + custom_call_targets=['annotate_device_placement'], + serialized_date=datetime.date(2025, 4, 7), + inputs=(array([0.], dtype=float32), array([0.], dtype=float32)), + expected_outputs=(array([0.], dtype=float32),), + mlir_module_text=r""" +#loc1 = loc("x") +#loc2 = loc("y") +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<1xf32> {mhlo.memory_kind = "device", mhlo.sharding = "{maximal device=0}"} loc("x"), %arg1: tensor<1xf32> {mhlo.memory_kind = "pinned_host", mhlo.sharding = "{maximal device=0}"} loc("y")) -> (tensor<1xf32> {jax.result_info = "result", mhlo.memory_kind = "pinned_host", mhlo.sharding = "{maximal device=0}"}) { + %0 = stablehlo.add %arg0, %arg1 : tensor<1xf32> loc(#loc4) + %1 = stablehlo.custom_call @annotate_device_placement(%0) {has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "pinned_host"}} : (tensor<1xf32>) -> tensor<1xf32> loc(#loc) + return %1 : tensor<1xf32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":878:13) +#loc4 = loc("jit(func)/jit(main)/add"(#loc3)) +""", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.9.3\x00\x01\x1b\x05\x01\x05\x0b\x01\x03\x0b\x03\t\x0f\x13\x17\x1b\x03oQ\x0b\x01%\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0b\x13\x0b\x03-\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0f#\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x01\x05\x0b\x0f\x03\x07\x13\x1b\x07\x02\n\x02\x1f\x11\x03\x05\x03\x07\x07\t\x0b\x03\r\x03\x05\x0f\x11\x01\x00\x05\x11\x05\x13\x05\x15\x1d\x13\x01\x05\x17\x1d\x17\x01\x05\x19\x1d\x1b\x1d\x05\x1b\x17\x1f\xba\r\x1b\x05\x1d\x03\x03#E\x05\x1f\x03\x01\x1d!\x1d#\x1d%\x1d'\x03\x0515\r\x05'3)+\x1d)\r\x05'-)+#\x07\x03\x03;\r\x07=?'-)+\x1d+\x1d-\x1d/\x1d1\r\x03G-\x1d3\x0b\x03\x1d5\x1d7\x05\x03\x01\t\x01\x02\x02)\x03\x05\t\x11\x05\x05\x05\x03\x05\t\x04e\x05\x01Q\x01\x05\x01\x07\x04S\x03\x01\x05\x03P\x01\x03\x07\x04?\x03\t\x0f\x05\x0b\x11\x0b\x15\x00\x05\x06\x19\x03\x05\x05\x01\x03\x07G\x01!\x05\x03\x05\x03\x05\t\x04\x01\x03\x07\x06\x03\x01\x05\x01\x00\x9a\x0695\x03-\x0f\x0b\x0f!\x0f\x19'\x1d#3i1\x05\x05\x13%)9\x15\x1f\x0f\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00add_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00x\x00y\x00jit(func)/jit(main)/add\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.frontend_attributes\x00mhlo.memory_kind\x00mhlo.sharding\x00{maximal device=0}\x00pinned_host\x00device\x00jax.result_info\x00result\x00main\x00public\x00_xla_buffer_placement\x00\x00annotate_device_placement\x00\x08'\x07\x05\x1f\x01\x0b/79AC\x11IKM%O%%%", + xla_call_module_version=9, + nr_devices=1, +) # End paste + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2025_04_07_cuda = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['annotate_device_placement'], + serialized_date=datetime.date(2025, 4, 7), + inputs=(array([0.], dtype=float32), array([0.], dtype=float32)), + expected_outputs=(array([0.], dtype=float32),), + mlir_module_text=r""" +#loc1 = loc("x") +#loc2 = loc("y") +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<1xf32> {mhlo.memory_kind = "device", mhlo.sharding = "{maximal device=0}"} loc("x"), %arg1: tensor<1xf32> {mhlo.memory_kind = "pinned_host", mhlo.sharding = "{maximal device=0}"} loc("y")) -> (tensor<1xf32> {jax.result_info = "result", mhlo.memory_kind = "pinned_host", mhlo.sharding = "{maximal device=0}"}) { + %0 = stablehlo.add %arg0, %arg1 : tensor<1xf32> loc(#loc4) + %1 = stablehlo.custom_call @annotate_device_placement(%0) {has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "pinned_host"}} : (tensor<1xf32>) -> tensor<1xf32> loc(#loc) + return %1 : tensor<1xf32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":878:13) +#loc4 = loc("jit(func)/jit(main)/add"(#loc3)) +""", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.9.3\x00\x01\x1b\x05\x01\x05\x0b\x01\x03\x0b\x03\t\x0f\x13\x17\x1b\x03oQ\x0b\x01%\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0b\x13\x0b\x03-\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0f#\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x01\x05\x0b\x0f\x03\x07\x13\x1b\x07\x02\n\x02\x1f\x11\x03\x05\x03\x07\x07\t\x0b\x03\r\x03\x05\x0f\x11\x01\x00\x05\x11\x05\x13\x05\x15\x1d\x13\x01\x05\x17\x1d\x17\x01\x05\x19\x1d\x1b\x1d\x05\x1b\x17\x1f\xba\r\x1b\x05\x1d\x03\x03#E\x05\x1f\x03\x01\x1d!\x1d#\x1d%\x1d'\x03\x0515\r\x05'3)+\x1d)\r\x05'-)+#\x07\x03\x03;\r\x07=?'-)+\x1d+\x1d-\x1d/\x1d1\r\x03G-\x1d3\x0b\x03\x1d5\x1d7\x05\x03\x01\t\x01\x02\x02)\x03\x05\t\x11\x05\x05\x05\x03\x05\t\x04e\x05\x01Q\x01\x05\x01\x07\x04S\x03\x01\x05\x03P\x01\x03\x07\x04?\x03\t\x0f\x05\x0b\x11\x0b\x15\x00\x05\x06\x19\x03\x05\x05\x01\x03\x07G\x01!\x05\x03\x05\x03\x05\t\x04\x01\x03\x07\x06\x03\x01\x05\x01\x00\x9a\x0695\x03-\x0f\x0b\x0f!\x0f\x19'\x1d#3i1\x05\x05\x13%)9\x15\x1f\x0f\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00add_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00x\x00y\x00jit(func)/jit(main)/add\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.frontend_attributes\x00mhlo.memory_kind\x00mhlo.sharding\x00{maximal device=0}\x00pinned_host\x00device\x00jax.result_info\x00result\x00main\x00public\x00_xla_buffer_placement\x00\x00annotate_device_placement\x00\x08'\x07\x05\x1f\x01\x0b/79AC\x11IKM%O%%%", + xla_call_module_version=9, + nr_devices=1, +) # End paste diff --git a/jax/_src/internal_test_util/export_back_compat_test_util.py b/jax/_src/internal_test_util/export_back_compat_test_util.py index 5d5e95b5cb9a..b86b24e2b4fc 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_util.py +++ b/jax/_src/internal_test_util/export_back_compat_test_util.py @@ -90,6 +90,7 @@ def func(...): ... from jax.experimental import pjit from jax._src import core +from jax._src import stages from jax._src import test_util as jtu from jax._src import xla_bridge as xb @@ -165,7 +166,8 @@ def load_testdata_nested(self, testdata_nest) -> Iterable[CompatTestData]: else: assert False, testdata_nest - def run_one_test(self, func: Callable[..., jax.Array], + def run_one_test(self, + func: Callable[..., jax.Array] | stages.Wrapped, data: CompatTestData, polymorphic_shapes: Sequence[str] | None = None, rtol: float | None = None, @@ -176,7 +178,8 @@ def run_one_test(self, func: Callable[..., jax.Array], """Run one compatibility test. Args: - func: the JAX function to serialize and run + func: the JAX function to serialize and run, either as a Python Callable + or as a `jax.jit(callable)`. data: the test data polymorphic_shapes: when using shape polymorphism, the specification for each argument of `func`. @@ -269,19 +272,22 @@ def run_one_test(self, func: Callable[..., jax.Array], expect_current_custom_calls = data.custom_call_targets self.assertItemsEqual(expect_current_custom_calls, current_custom_call_targets) - def run_current(self, func: Callable, data: CompatTestData): + def run_current(self, + func: Callable | stages.Wrapped, + data: CompatTestData): """Lowers and runs the test function at the current JAX version.""" - return jax.jit(func)(*data.inputs) + jit_func = func if isinstance(func, stages.Wrapped) else jax.jit(func) + return jit_func(*data.inputs) def serialize(self, - func: Callable, data: CompatTestData, *, + func: Callable | stages.Wrapped, data: CompatTestData, *, polymorphic_shapes: Sequence[str] | None = None, allow_unstable_custom_call_targets: Sequence[str] = () ) -> tuple[bytes, str, int, int]: """Serializes the test function. Args: - func: the function to serialize + func: the function to serialize. polymorphic_shapes: the polymorphic_shapes to use for serialization allow_unstable_custom_call_targets: whether to allow additional custom call targets besides those known as stable. @@ -292,8 +298,9 @@ def serialize(self, """ # Use the native exporter, to make sure we get the proper serialization. args_specs = export.symbolic_args_specs(data.inputs, polymorphic_shapes) + jit_func = func if isinstance(func, stages.Wrapped) else jax.jit(func) exported = export.export( - jax.jit(func), + jit_func, platforms=(self.default_jax_backend(),), disabled_checks=tuple( export.DisabledSafetyCheck.custom_call(target) diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py index 789838f99d14..fd2b349f6c95 100644 --- a/tests/export_back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -31,6 +31,7 @@ from jax._src.internal_test_util import export_back_compat_test_util as bctu +from jax._src.internal_test_util.export_back_compat_test_data import annotate_data_placement from jax._src.internal_test_util.export_back_compat_test_data import cpu_cholesky_lapack_potrf from jax._src.internal_test_util.export_back_compat_test_data import cpu_eig_lapack_geev from jax._src.internal_test_util.export_back_compat_test_data import cuda_eigh_cusolver_syev @@ -161,6 +162,8 @@ def test_custom_call_coverage(self): stablehlo_dynamic_top_k.data_2023_07_16, stablehlo_dynamic_top_k.data_2023_08_11, # with shape_assertion stablehlo_dynamic_approx_top_k.data_2024_05_30, + annotate_data_placement.data_2025_04_07_tpu, + annotate_data_placement.data_2025_04_07_cuda, ] # Some of the above are nested structures. covering_testdatas = itertools.chain( @@ -817,7 +820,7 @@ def func(x): ) self.run_one_test(func, data, rtol=rtol, atol=atol) - def test_approx_top_k(self): + def test_tpu_approx_top_k(self): def func(): x = np.array([3.0, 1.0, 4.0, 2.0, 5.0, 6.0, 7.0]) y = lax.approx_max_k(x, 3) @@ -834,7 +837,7 @@ def func(x): data = self.load_testdata(cuda_threefry2x32.data_2024_07_30) self.run_one_test(func, data) - def test_sharding(self): + def test_tpu_sharding(self): # Tests "Sharding", "SPMDShardToFullShape", "SPMDFullToShardShape" on TPU if not jtu.test_device_matches(["tpu"]) or len(jax.devices()) < 2: self.skipTest("Test runs only on TPU with at least 2 devices") @@ -856,6 +859,31 @@ def func(x): # b: f32[2, 4] with mesh: self.run_one_test(func, data) + @parameterized.named_parameters( + dict(testcase_name=f"_platform={platform}", platform=platform) + for platform in ("tpu", "gpu")) + def test_annotate_device_placement(self, platform): + if not jtu.test_device_matches([platform]): + self.skipTest(f"Test enabled only for {platform}") + + mesh = Mesh(jax.local_devices()[0:1], axis_names=("a")) + + dev_sharding = NS(mesh, P("a")) + host_sharding = NS(mesh, P("a"), memory_kind="pinned_host") + + @partial(jax.jit, + in_shardings=(dev_sharding, host_sharding), + out_shardings=host_sharding) + def func(x, y): + return x + y + + if platform == "tpu": + data = self.load_testdata(annotate_data_placement.data_2025_04_07_tpu) + else: + data = self.load_testdata(annotate_data_placement.data_2025_04_07_cuda) + + self.run_one_test(func, data) + def test_tpu_stablehlo_dynamic_reduce_window_unary(self): # stablehlo.dynamic_reduce_window is used temporarily on TPU for a # reduce window with dynamic shapes. From 19fcae12078b7a8823524203bd3e48cee9c254f5 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Tue, 8 Apr 2025 00:33:16 -0700 Subject: [PATCH 453/483] [Mosaic GPU] Add support for replicated warp_dim parsing and a dedicated test for parsing all canonical layouts. PiperOrigin-RevId: 745015431 --- jax/experimental/mosaic/gpu/layouts.py | 26 ++++++++++++++++--------- jaxlib/mosaic/dialect/gpu/mosaic_gpu.td | 2 +- tests/mosaic/gpu_dialect_test.py | 12 ++++++++++++ 3 files changed, 30 insertions(+), 10 deletions(-) diff --git a/jax/experimental/mosaic/gpu/layouts.py b/jax/experimental/mosaic/gpu/layouts.py index d9b1a01a24b5..0a4f3ed09116 100644 --- a/jax/experimental/mosaic/gpu/layouts.py +++ b/jax/experimental/mosaic/gpu/layouts.py @@ -96,7 +96,7 @@ def is_strided_fragmented_layout(attr: ir.Attribute) -> bool: _tiled_layout_attr_pattern = re.compile( r"^#mosaic_gpu.TiledLayout<\[(?P.*)\]," - r" warp_dim\s*=\s*(?P[-\d]+)," + r" warp_dim\s*=\s*(?P.+)," r" lane_dims\s*=\s*\[(?P.*)\]," r" vector_dim\s*=\s*(?P[-\d]+)>$" ) @@ -107,22 +107,26 @@ def to_tiled_layout_attr( ) -> ir.Attribute: """Constructs a #mosaic_gpu.TiledLayout attribute from a TiledLayout.""" - def _lane_dim_str(d: int | fa.Replicated) -> str: + def _int_or_replicated(d: int | fa.Replicated) -> str: if isinstance(d, fa.Replicated): return f"#mosaic_gpu.Replicated" return str(d) tile_str = lambda tile: "[" + ", ".join(str(d) for d in tile) + "]" tiling = "[" + ", ".join(tile_str(tile) for tile in layout.tiling.tiles) + "]" - lane_dims = "[" + ",".join(_lane_dim_str(d) for d in layout.lane_dims) + "]" + lane_dims = ( + "[" + ",".join(_int_or_replicated(d) for d in layout.lane_dims) + "]" + ) return ir.Attribute.parse( - f"#mosaic_gpu.TiledLayout<{tiling}, warp_dim={layout.warp_dim}," + f"#mosaic_gpu.TiledLayout<{tiling}," + f" warp_dim={_int_or_replicated(layout.warp_dim)}," f" lane_dims={lane_dims}, vector_dim={layout.vector_dim}>" ) _list_of_lists_delimiter = re.compile(r"\]\s*,\s*\[") +_int_pattern = re.compile(r"^(?P[-\d]+)(\s*:\s*\w+)?$") _replicated_pattern = re.compile( r"^#mosaic_gpu.Replicated<\s*times\s*=\s*(?P\d+)\s*>\s*$" ) @@ -143,11 +147,14 @@ def from_tiled_layout_attr( f"Expected a #mosaic_gpu.TiledLayout attribute, got {attr}" ) - def _lane_dim(lane_dim_str: str) -> int | fa.Replicated: - match = _replicated_pattern.fullmatch(lane_dim_str) + def _int_or_replicated(replicated_dim: str) -> int | fa.Replicated: + match = _replicated_pattern.fullmatch(replicated_dim) if match: return fa.Replicated(int(match.group("times"))) - return int(lane_dim_str) + match = _int_pattern.fullmatch(replicated_dim) + if match: + return int(match.group("num")) + raise ValueError(f"Unexpected format for replicated dim {replicated_dim}") tiling_str = match.group("tiling") tile_strings = [] @@ -156,9 +163,10 @@ def _lane_dim(lane_dim_str: str) -> int | fa.Replicated: tiles = tuple(tuple(map(int, ts.split(","))) for ts in tile_strings) return fa.TiledLayout( tiling=fa.Tiling(tiles), - warp_dim=int(match.group("warp_dim")), + warp_dim=_int_or_replicated(match.group("warp_dim")), lane_dims=tuple( - _lane_dim(s) for s in match.group("lane_dims").split(",") + _int_or_replicated(s.strip()) + for s in match.group("lane_dims").split(",") ), vector_dim=int(match.group("vector_dim")), ) diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td index 36f9f6f374e5..86219dbc87ac 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td @@ -161,7 +161,7 @@ def MosaicGPU_TiledLayout : AttrDef { let parameters = (ins "::mlir::ArrayAttr":$tiling, - "int":$warp_dim, + "::mlir::Attribute":$warp_dim, "::mlir::ArrayAttr":$lane_dims, "int":$vector_dim ); diff --git a/tests/mosaic/gpu_dialect_test.py b/tests/mosaic/gpu_dialect_test.py index 7e211abb955a..2d75c42424ef 100644 --- a/tests/mosaic/gpu_dialect_test.py +++ b/tests/mosaic/gpu_dialect_test.py @@ -593,6 +593,18 @@ def test_wgmma_b_n_dim_not_equal_to_acc_n_dim(self): ): self.module.operation.verify() + def test_tiled_layout_attr_parsing(self): + with ir.InsertionPoint(self.module.body): + for layout in ( + mgpu.WGMMA_LAYOUT, + mgpu.WGMMA_ROW_LAYOUT, + mgpu.WGMMA_COL_LAYOUT, + mgpu.WGMMA_TRANSPOSED_LAYOUT, + ): + attr = layouts.to_tiled_layout_attr(layout) + parsed_layout = layouts.from_tiled_layout_attr(attr) + self.assertEqual(layout, parsed_layout) + class DialectLoweringTest(MosaicGpuTest): From bc11a63113f543779c1ed8b794b00c3e747d17f7 Mon Sep 17 00:00:00 2001 From: Jamie Townsend Date: Tue, 8 Apr 2025 09:50:31 +0200 Subject: [PATCH 454/483] Clarify jax.make_jaxpr docstring --- jax/_src/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 0055f6466dae..89dfe74acd5b 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2142,7 +2142,7 @@ def make_jaxpr( return_shape: bool = False, abstracted_axes: Any | None = None, ) -> Callable[..., core.ClosedJaxpr | tuple[core.ClosedJaxpr, Any]]: - """Creates a function that produces its jaxpr given example args. + """Create a function that returns the jaxpr of ``fun`` given example args. Args: fun: The function whose ``jaxpr`` is to be computed. Its positional From 8ed59d8b5d99619e35b3a7ab595e11fe1668ada2 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 8 Apr 2025 02:05:52 -0700 Subject: [PATCH 455/483] Removed `jax._src.raise_to_shaped` It is just an identity after the "stackless" rewrite. PiperOrigin-RevId: 745042532 --- jax/_src/core.py | 5 ----- jax/_src/pallas/fuser/fusable.py | 6 +----- jax/_src/pallas/fuser/jaxpr_fusion.py | 6 +----- tests/api_test.py | 2 +- tests/state_test.py | 4 ++-- 5 files changed, 5 insertions(+), 18 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index 8d32b9370091..6c5b7a08a0e9 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -2229,11 +2229,6 @@ def block_until_ready(self): pytype_aval_mappings[Token] = lambda _: abstract_token -# TODO(dougalm): Deprecate these. They're just here for backwards compat. -def raise_to_shaped(aval): - return aval -raise_to_shaped_mappings: dict[type, Callable] = {} - ### Operations on shapes and dimension sizes. class InconclusiveDimensionOperation(Exception): diff --git a/jax/_src/pallas/fuser/fusable.py b/jax/_src/pallas/fuser/fusable.py index aa2ea0843c0a..d9d0ee0b4682 100644 --- a/jax/_src/pallas/fuser/fusable.py +++ b/jax/_src/pallas/fuser/fusable.py @@ -29,10 +29,6 @@ fusable_p.multiple_results = True -def _get_aval(x): - return jax_core.raise_to_shaped(jax_core.get_aval(x)) - - def _make_trivial_fusion(x: jax.Array) -> fusion_lib.Fusion: return fusion_lib.Fusion( func=lambda: x, @@ -53,7 +49,7 @@ def wrapped(*args): flat_fun, out_tree_thunk = api_util.flatten_fun_nokwargs( lu.wrap_init(wrapped, debug_info=debug_info), in_tree ) - flat_avals = [_get_aval(x) for x in flat_args] + flat_avals = [jax_core.get_aval(x) for x in flat_args] jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals) out_tree = out_tree_thunk() out = fusable_p.bind( diff --git a/jax/_src/pallas/fuser/jaxpr_fusion.py b/jax/_src/pallas/fuser/jaxpr_fusion.py index 649037e18092..3c3c2a3d7b66 100644 --- a/jax/_src/pallas/fuser/jaxpr_fusion.py +++ b/jax/_src/pallas/fuser/jaxpr_fusion.py @@ -28,10 +28,6 @@ from jax._src.pallas.fuser.fusable import fusable_p -def _get_aval(x): - return jax_core.raise_to_shaped(jax_core.get_aval(x)) - - def fuse(f=None, *, physicalize: bool = False, debug: bool = False): """Fuses a function into a single fusable. @@ -52,7 +48,7 @@ def wrapper(*args, **kwargs): flat_fun, out_tree_thunk = api_util.flatten_fun( lu.wrap_init(f, debug_info=debug_info), in_tree ) - flat_avals = [_get_aval(x) for x in flat_args] + flat_avals = [jax_core.get_aval(x) for x in flat_args] jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals) if debug: print("Jaxpr before fusion:") diff --git a/tests/api_test.py b/tests/api_test.py index 440fea1b059c..0e8cf2502540 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -5024,7 +5024,7 @@ def g(x): # Make sure that introducing constants in vmap works. constant_introducing_p = core.Primitive('introduce_constant') - constant_introducing_p.def_abstract_eval(core.raise_to_shaped) + constant_introducing_p.def_abstract_eval(lambda x: x) def _constant_introducing_batcher(xs, ds): (x,), (d,) = xs, ds return (x + np.arange(x.size, dtype=x.dtype).reshape(x.shape)), d diff --git a/tests/state_test.py b/tests/state_test.py index 60a7d8bc9f8a..03902687c40e 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -792,7 +792,7 @@ def body(i, st): lax.fori_loop(0, 5, body, init_val=()) return a_ref[...], b_ref[...] - ref = lambda x: AbstractRef(core.raise_to_shaped(core.get_aval(x))) + ref = lambda x: AbstractRef(core.get_aval(x)) f_jaxpr = jax.make_jaxpr(f)(ref(1.), ref(2.)) jaxpr, _ = discharge_state(f_jaxpr.jaxpr, (), should_discharge=[False, True]) # Effects on y_ref were discharged away but not the effects on x_ref @@ -1139,7 +1139,7 @@ def false_fun(): y_ref[...] = 2. lax.cond(pred, true_fun, false_fun) return x_ref[...], y_ref[...] - ref = lambda x: AbstractRef(core.raise_to_shaped(core.get_aval(x))) + ref = lambda x: AbstractRef(core.get_aval(x)) f_jaxpr = jax.make_jaxpr(f0)(False, ref(3.), ref(4.)) jaxpr, _ = discharge_state(f_jaxpr.jaxpr, (), should_discharge=[False, False, True]) # Effects on y_ref were discharged away but not the effects on x_ref From af072feb5a02bc75d4f9fec487ae59be60b0c01b Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 8 Apr 2025 02:37:36 -0700 Subject: [PATCH 456/483] Removed redundant `pass`es If a function or class has a docstring, it does not need a `pass`. PiperOrigin-RevId: 745052107 --- jax/_src/core.py | 1 - jax/_src/errors.py | 1 - jax/_src/pallas/fuser/fusable_dtype.py | 2 -- jax/_src/profiler.py | 1 - tests/api_test.py | 1 - 5 files changed, 6 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index 6c5b7a08a0e9..236781c16d27 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -2233,7 +2233,6 @@ def block_until_ready(self): class InconclusiveDimensionOperation(Exception): """Raised when we cannot conclusively compute with symbolic dimensions.""" - pass def is_symbolic_dim(v: Any) -> bool: """Checks if a value is a symbolic dimension used for shape polymorphism. diff --git a/jax/_src/errors.py b/jax/_src/errors.py index 6540fd1f5d41..b9831bfe3b1a 100644 --- a/jax/_src/errors.py +++ b/jax/_src/errors.py @@ -680,4 +680,3 @@ class KeyReuseError(JAXTypeError): must be manually split; For more information on this see `the Pseudorandom Numbers tutorial `_. """ - pass diff --git a/jax/_src/pallas/fuser/fusable_dtype.py b/jax/_src/pallas/fuser/fusable_dtype.py index e5bc9ab683ab..99c80e652791 100644 --- a/jax/_src/pallas/fuser/fusable_dtype.py +++ b/jax/_src/pallas/fuser/fusable_dtype.py @@ -83,8 +83,6 @@ def unpack(x): class FusableElementDType(dtypes.extended): """Scalar dtype for fusable dtypes.""" - pass - class FusableTyRules: allow_conversion: bool = False diff --git a/jax/_src/profiler.py b/jax/_src/profiler.py index 96e742f33904..0e9949f27f55 100644 --- a/jax/_src/profiler.py +++ b/jax/_src/profiler.py @@ -272,7 +272,6 @@ class TraceAnnotation(xla_client.profiler.TraceMe): This will cause a "my_label" event to show up on the trace timeline if the event occurs while the process is being traced. """ - pass class StepTraceAnnotation(TraceAnnotation): diff --git a/tests/api_test.py b/tests/api_test.py index 0e8cf2502540..2d9fcd1ff554 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -3120,7 +3120,6 @@ def test_error_for_invalid_dtype(self): def test_vmap_preserves_docstr(self): def superfun(a): """Does things with stuff.""" - pass self.assertRegex(api.vmap(superfun).__doc__, "\n".join([ "Vectorized version of superfun.*", From d12cbffd4912980f290d676ee1b606cb9d1c9ad2 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Tue, 8 Apr 2025 03:04:47 -0700 Subject: [PATCH 457/483] [Mosaic GPU] Refactor and generalize code in `optimization_barrier`. The change in `utils.py` is to enable the use of `bitwidth` when the mlir dialect is not registered. PiperOrigin-RevId: 745060221 --- jax/experimental/mosaic/gpu/core.py | 1 + .../mosaic/gpu/fragmented_array.py | 40 ++++++++++--------- 2 files changed, 22 insertions(+), 19 deletions(-) diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index e822ea5f3ebf..860b41e7e8e3 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -479,6 +479,7 @@ def _shape_to_ref_ty(shape: jax.ShapeDtypeStruct) -> ir.MemRefType: out_ref_tys.append(prof_spec.mlir_buffer_type(grid, block)) module = ir.Module.create() + dialect.register_dialect(module.context) attrs = module.operation.attributes attrs["sym_name"] = ir.StringAttr.get(module_name) if kernel_name is None: diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index ecd51f79eab0..9ab27927791a 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -2434,10 +2434,23 @@ def optimization_barrier(*arrays: mgpu.FragmentedArray): index = ir.IndexType.get() i32 = ir.IntegerType.get_signless(32) + def _repack(regs_it, reg_ty): + if not ir.VectorType.isinstance(reg_ty): + result_reg = next(regs_it) + assert result_reg.type == reg_ty + return result_reg + + num_i32_regs = utils.bitwidth(reg_ty) // 32 + i32_reg_ty = ir.VectorType.get((num_i32_regs,), i32) + reg = llvm.mlir_undef(i32_reg_ty) + for i_elem in range(num_i32_regs): + val = llvm.bitcast(i32, next(regs_it)) + reg = llvm.insertelement(reg, val, arith.constant(i32, i_elem)) + return vector.bitcast(reg_ty, reg) + regs = [] reg_dtypes = [] reg_constraints = [] - repack_fns = [] # We unpack each array into a flat list of registers, and prepare the # functions that invert the transform in repack_fns. for array in arrays: @@ -2451,36 +2464,25 @@ def optimization_barrier(*arrays: mgpu.FragmentedArray): for reg in array.registers.flat for pos in range(vec_len) ] - def _repack(regs, reg_ty=reg_ty): - reg = llvm.mlir_undef(reg_ty) - [vec_len] = ir.VectorType(reg_ty).shape - for i_elem in range(vec_len): - reg = llvm.insertelement( - reg, next(regs), arith.constant(i32, i_elem) - ) - return reg - repack_fns.append(_repack) else: array_regs = list(array.registers.flat) - repack_fns.append(lambda regs: next(regs)) reg_constraint = "f" elif ir.BF16Type.isinstance(dtype) or ir.F16Type.isinstance(dtype): if not ir.VectorType.isinstance(reg_ty): raise NotImplementedError(array.mlir_dtype) [vec_len] = ir.VectorType(reg_ty).shape - if vec_len != 2: + if vec_len % 2: raise NotImplementedError(vec_len) - i32_reg_ty = ir.VectorType.get((1,), i32) + num_i32_regs = vec_len // 2 + i32_reg_ty = ir.VectorType.get((num_i32_regs,), i32) array_regs = [ vector.extractelement( - vector.bitcast(i32_reg_ty, reg), position=c(0, index) + vector.bitcast(i32_reg_ty, reg), position=c(i, index) ) + for i in range(num_i32_regs) for reg in array.registers.flat ] reg_constraint = "r" - def _repack(regs, reg_ty=reg_ty, i32_reg_ty=i32_reg_ty): - return vector.bitcast(reg_ty, vector.splat(i32_reg_ty, next(regs))) - repack_fns.append(_repack) else: raise NotImplementedError(array.mlir_dtype) regs += array_regs @@ -2508,14 +2510,14 @@ def _repack(regs, reg_ty=reg_ty, i32_reg_ty=i32_reg_ty): i32 = ir.IntegerType.get_signless(32) results = [] regs_it = iter(regs) - for array, repack_fn in zip(arrays, repack_fns, strict=True): + for array in arrays: num_regs = array.registers.size reg_ty = array.registers.flat[0].type if ir.VectorType.isinstance(reg_ty): reg_ty = ir.VectorType(reg_ty) new_registers = np.empty((num_regs,), dtype=object) for i_vreg in range(num_regs): - reg = repack_fn(regs_it) + reg = _repack(regs_it, reg_ty) assert reg.type == reg_ty, (reg.type, reg_ty) new_registers[i_vreg] = reg results.append( From c4cc94a10cde3e480b7a4b6c76d304d782292895 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Tue, 8 Apr 2025 03:22:30 -0700 Subject: [PATCH 458/483] [Mosaic GPU] Add warpgroup lowering for `RunState` in Pallas. After this change we no longer skip tests that required 'RunState`. This necessitated a small fix in the pallas lowering of `while` and also enabling multiple i32 register bundling in the `optimization_barrier` lowering. PiperOrigin-RevId: 745065173 --- jax/_src/pallas/mosaic_gpu/lowering.py | 31 ++++++++++++++++---------- tests/pallas/mosaic_gpu_test.py | 22 +++++++++--------- 2 files changed, 31 insertions(+), 22 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index f7bdbccc1ad6..3fc2362decfc 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -2034,6 +2034,7 @@ def _run_scoped_lowering_rule( @register_lowering_rule(discharge.run_state_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(discharge.run_state_p, mgpu.LoweringSemantics.Warpgroup) def _run_state_lowering_rule( ctx: LoweringRuleContext, *args, @@ -2051,7 +2052,12 @@ def _run_state_lowering_rule( for arg, v, out_aval in zip(args, jaxpr.invars, ctx.avals_out): aval = v.aval if isinstance(aval, gpu_core.WGMMAAbstractAccumulatorRef): - new_input_vals.append(mgpu.WGMMAAccumulator.from_registers(arg)) + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Warpgroup: + arg = mgpu.dialect.optimization_barrier([arg]) + nvvm_dialect.wgmma_fence_aligned() + new_input_vals.append(arg) + else: + new_input_vals.append(mgpu.WGMMAAccumulator.from_registers(arg)) should_discharge.append(True) assert isinstance(out_aval, jax_core.ShapedArray) else: @@ -2273,18 +2279,19 @@ def _while_lowering_rule( ctx.module_ctx, ctx.launch_ctx, body_jaxpr.jaxpr, body_args ) loop_out = [*map(_ensure, loop_out, carry_avals)] - for idx, (carry_fa, out_fa) in enumerate(zip(carry, loop_out)): - if _is_acc(carry_fa) != _is_acc(out_fa): - raise ValueError( - f"The loop body output has unexpected accumulator type: output[{idx}]" - f" is {out_fa}, when it should be {carry_fa}." - ) + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: + for idx, (carry_fa, out_fa) in enumerate(zip(carry, loop_out)): + if _is_acc(carry_fa) != _is_acc(out_fa): + raise ValueError( + f"The loop body output has unexpected accumulator type:" + f" output[{idx}] is {out_fa}, when it should be {carry_fa}." + ) - if not _is_acc(out_fa) and carry_fa.layout != out_fa.layout: - raise ValueError( - f"The loop body output has unexpected layout: output[{idx}] has" - f" layout {out_fa.layout}, when it should be {carry_fa.layout}." - ) + if not _is_acc(out_fa) and carry_fa.layout != out_fa.layout: + raise ValueError( + f"The loop body output has unexpected layout: output[{idx}] has" + f" layout {out_fa.layout}, when it should be {carry_fa.layout}." + ) scf_dialect.yield_( carry_treedef.flatten_up_to(loop_out) if loop_out else [] ) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index f0f3bdf41c32..a73c4f82c31d 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -32,7 +32,6 @@ from jax._src.pallas.mosaic_gpu import lowering as mgpu_lowering from jax._src.pallas.mosaic_gpu import pipeline as mgpu_pipeline from jax._src.pallas.mosaic_gpu import primitives as mgpu_primitives -from jax._src.state import discharge from jax.experimental import pallas as pl import jax.experimental.mosaic.gpu as mgpu from jax.experimental.pallas import mosaic_gpu as plgpu @@ -1528,7 +1527,6 @@ def test_missing_primitive_lowerings_are_tracked(self): mgpu_primitives.layout_cast_p, mgpu_primitives.load_p, lax.slice_p, - discharge.run_state_p, } self.assertSetEqual(actual_missing_primitives, expected_missing_primitives) @@ -1538,10 +1536,14 @@ class PallasCallSm90ATest(PallasSm90ATest): @parameterized.parameters(False, True) def test_fori_loop_accumulator(self, force_while): - # ``pl.run_state`` is not supported in WG semantics. - self.skip_if_wg_semantics() - - transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) + if force_while: + # Layout inference and lowering for 'while' are not yet implemented for + # warpgroup semantics. + self.skip_if_wg_semantics() + if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Lane: + transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) + else: + transforms = () @functools.partial( self.pallas_call, in_specs=[plgpu.GPUBlockSpec((64, 64), transforms=transforms)], @@ -1733,9 +1735,6 @@ def scope(acc_ref): np.testing.assert_allclose(res, a @ b, rtol=1e-3) def test_wgmma_registers_init(self): - # ``pl.run_state`` is not supported in WG semantics. - self.skip_if_wg_semantics() - def kernel(a_ref, b_ref, i_ref, o_ref): def scope(acc_ref): plgpu.wgmma(acc_ref, a_ref[...], b_ref) @@ -1746,7 +1745,10 @@ def scope(acc_ref): b = jax.random.uniform(key2, shape=(128, 192), dtype=jnp.float16) i = jax.random.uniform(key3, shape=(64, 192), dtype=jnp.float16) * 10 - transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) + if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Lane: + transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) + else: + transforms = () res = self.pallas_call( kernel, in_specs=[ From 12811f08a8fc5fec7c39d17e3e48d14a8e339f06 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 8 Apr 2025 03:29:50 -0700 Subject: [PATCH 459/483] Removed `eager_pmap` config option It defaults to True and is not flipped to False by any internal JAX users. PiperOrigin-RevId: 745067361 --- CHANGELOG.md | 1 + jax/_src/config.py | 7 ------- jax/_src/interpreters/pxla.py | 4 ++-- tests/pmap_test.py | 2 +- 4 files changed, 4 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3aae0f432121..e744cad902de 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * Removed the `config.jax_data_dependent_tracing_fallback` config option, which was added temporarily in v0.4.36 to allow users to opt out of the new "stackless" tracing machinery. + * Removed the `config.jax_eager_pmap` config option. * Changes * The minimum CuDNN version is v9.8. diff --git a/jax/_src/config.py b/jax/_src/config.py index 8aa4ee343664..1fbb401afb61 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1514,13 +1514,6 @@ def _update_disable_jit_thread_local(val): 'compute when encountering OOM errors. However, you are ' 'likely to get better results manually with jax.checkpoint')) -# TODO(sharadmv,mattjj): set default to True, then remove -eager_pmap = bool_state( - name='jax_eager_pmap', - default=True, - upgrade=True, - help='Enable eager-mode pmap when jax_disable_jit is activated.') - no_tracing = bool_state( name='jax_no_tracing', default=False, diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 51854b457b37..45bdd4e17e8e 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -338,8 +338,8 @@ def xla_pmap_impl_lazy( donated_invars: Sequence[bool], is_explicit_global_axis_size: bool, ) -> Callable: - if (config.disable_jit.value and config.eager_pmap.value and - not is_explicit_global_axis_size and not any(d for d in donated_invars)): + if (config.disable_jit.value and + not is_explicit_global_axis_size and not any(donated_invars)): def _emap_apply_fn(*args): return _emap_impl(fun, *args, backend=backend, axis_name=axis_name, axis_size=axis_size, global_axis_size=global_axis_size, diff --git a/tests/pmap_test.py b/tests/pmap_test.py index af2d03e2945d..a07a9e271907 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -3189,7 +3189,7 @@ class EagerPmapMixin: def setUp(self): super().setUp() stack = contextlib.ExitStack() - stack.enter_context(jtu.thread_local_config_context(jax_disable_jit=True, jax_eager_pmap=True)) + stack.enter_context(jtu.thread_local_config_context(jax_disable_jit=True)) stack.enter_context(jtu.ignore_warning( message="Some donated buffers were not usable", category=UserWarning)) self.addCleanup(stack.close) From 5f33280dedb50e72abc3613461bbbe8a67b97f70 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 8 Apr 2025 04:54:30 -0700 Subject: [PATCH 460/483] [pallas:mosaic_gpu] `emit_pipeline*` now allows the grid to be dynamic PiperOrigin-RevId: 745091128 --- jax/_src/pallas/mosaic_gpu/lowering.py | 7 ++-- jax/_src/pallas/mosaic_gpu/pipeline.py | 42 +++++++++++++----------- jax/_src/pallas/mosaic_gpu/primitives.py | 5 +-- tests/pallas/mosaic_gpu_test.py | 38 +++++++++++++-------- 4 files changed, 54 insertions(+), 38 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 3fc2362decfc..b7aa01dbbfcf 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -897,8 +897,11 @@ def write_env(var: jax_core.Var, val, require_value: bool = True): if val.type != mlir_dtype: raise AssertionError(f"Scalar type must match ShapedArray dtype, got: {val.type} != {mlir_dtype}") - foreach(write_env, jaxpr.constvars, consts) - foreach(lambda v, a: write_env(v, a, require_value=False), jaxpr.invars, args) + foreach( + functools.partial(write_env, require_value=False), jaxpr.constvars, consts + ) + foreach(functools.partial(write_env, require_value=False), jaxpr.invars, args) + # TODO(justinfu): Handle transform scopes. last_local_name_stack: list[str] = [] named_regions = [] diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index df9c6668a51d..ecd7e792afbe 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -114,7 +114,7 @@ def _uses_arguments( def _is_index_invariant( - spec: pallas_core.BlockSpec, grid: pallas_core.StaticGrid + spec: pallas_core.BlockSpec, grid: pallas_core.TupleGrid ) -> bool: if (index_map := spec.index_map) is None: return True @@ -122,7 +122,7 @@ def _is_index_invariant( def _inc_grid_by_1( - indices: tuple[jax.Array, ...], grid: Sequence[int] + indices: tuple[jax.Array, ...], grid: pallas_core.TupleGrid ) -> tuple[jax.Array, ...]: next_indices = [] carry: bool | jax.Array = True @@ -161,7 +161,7 @@ def __eq__(self, other: _Slice) -> jax.Array: # type: ignore def emit_pipeline( body: Callable[..., None], *, - grid: pallas_core.StaticGrid, + grid: pallas_core.TupleGrid, in_specs: Sequence[pallas_core.BlockSpec] = (), out_specs: Sequence[pallas_core.BlockSpec] = (), max_concurrent_steps: int = 1, @@ -182,19 +182,19 @@ def emit_pipeline( ``max_concurrent_steps``. Generally, you'll want to set it to 1 if you don't await the WGMMA in the body. """ - num_steps = math.prod(grid) - if max_concurrent_steps <= delay_release: raise ValueError( "max_concurrent_steps must be greater than delay_release, but" f" {max_concurrent_steps=}, {delay_release=}" ) + num_steps = math.prod(grid) + has_dynamic_grid = not isinstance(num_steps, int) + # Shrink ``max_concurrent_steps`` if the total number of steps is lower to # reduce the size of the refs allocated in SMEM. - if max_concurrent_steps > num_steps: + if not has_dynamic_grid and max_concurrent_steps > num_steps: max_concurrent_steps = num_steps - delay_release = 0 # No need to delay anything. def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef): in_gmem_refs, out_gmem_refs = util.split_list(gmem_refs, [len(in_specs)]) @@ -244,12 +244,14 @@ def scoped_pipeline( ) ] - for step, indices in enumerate( - it.islice(it.product(*map(range, grid)), max_concurrent_steps) - ): - indices = tuple(map(lambda i: jnp.asarray(i, dtype=jnp.int32), indices)) + # Initialize the pipeline. + indices = (jnp.asarray(0, dtype=jnp.int32),) * len(grid) + fetch_indices = indices + for step in range(max_concurrent_steps): for bref in in_brefs: - bref.copy_in(step, indices, barrier_ref) + bref.copy_in(step, fetch_indices, barrier_ref) + fetch_indices = _inc_grid_by_1(fetch_indices, grid) + del fetch_indices # This is true if any of the outputs need to be transferred inside the loop. copies_out_in_loop = not all(bref.is_index_invariant for bref in out_brefs) @@ -327,7 +329,6 @@ def do_fetch(): # Invariant: ``indices`` and ``fetch_indices`` are always # ``max_concurrent_steps-delay_release`` apart. - indices = (jnp.asarray(0, dtype=jnp.int32),) * len(grid) fetch_indices = indices for _ in range(max_concurrent_steps-delay_release): fetch_indices = _inc_grid_by_1(fetch_indices, grid) @@ -362,7 +363,7 @@ def do_fetch(): def emit_pipeline_warp_specialized( body: Callable[..., None], *, - grid: pallas_core.StaticGrid, + grid: pallas_core.TupleGrid, memory_registers: int, in_specs: Sequence[pl.BlockSpec] = (), out_specs: Sequence[pl.BlockSpec] = (), @@ -434,7 +435,8 @@ def body( not _is_index_invariant(spec, grid) for spec in out_specs] spec_has_seq_axis = [*in_spec_has_seq_axis, *out_spec_has_seq_axis] - num_pipeline_steps = math.prod(grid) + num_steps = math.prod(grid) + has_dynamic_grid = not isinstance(num_steps, int) def _get_slot(step, has_seq_dim): """Returns the buffer slot given the pipeline step.""" @@ -445,8 +447,8 @@ def _get_slot(step, has_seq_dim): # Shrink ``max_concurrent_steps`` if the total number of steps is lower to # reduce the size of the refs allocated in SMEM. - if max_concurrent_steps > num_pipeline_steps: - max_concurrent_steps = num_pipeline_steps + if not has_dynamic_grid and max_concurrent_steps > num_steps: + max_concurrent_steps = num_steps def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef): in_gmem_refs, out_gmem_refs = util.split_list(gmem_refs, [len(in_specs)]) @@ -612,7 +614,7 @@ def compute_loop_body(step, carry): carry_init = None init_loop_carry = (init_indices, last_store_slices, carry_init) last_indices, _, final_body_carry = lax.fori_loop(0, - num_pipeline_steps, + num_steps, compute_loop_body, init_loop_carry) if has_carry: @@ -626,7 +628,7 @@ def compute_loop_body(step, carry): # written in the main pipeline loop. if not copies_out_in_loop: gpu_primitives.commit_smem() - last_slot = lax.rem(num_pipeline_steps - 1, max_concurrent_steps) + last_slot = lax.rem(num_steps - 1, max_concurrent_steps) for bref in out_brefs: if bref.is_index_invariant: bref.copy_out(_get_slot(last_slot, has_seq_dim=False), @@ -671,7 +673,7 @@ def memory_loop_body(step, carry): _get_slot(fetch_slot, not bref.is_index_invariant), indices, barrier) next_indices = _inc_grid_by_1(indices, grid) return (next_indices,) - lax.fori_loop(0, num_pipeline_steps - max_concurrent_steps, + lax.fori_loop(0, num_steps - max_concurrent_steps, memory_loop_body, (indices,)) wg_idx = lax.axis_index(wg_axis) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index a37b018860d7..c41a36da94e8 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -192,8 +192,9 @@ def _copy_smem_to_gmem_pp_eqn( pp_params = {} if not (commit_group := eqn.params["commit_group"]): pp_params["commit_group"] = commit_group - if has_user_predicate := eqn.params["has_user_predicate"]: - pp_params["has_user_predicate"] = has_user_predicate + if eqn.params["has_user_predicate"]: + flat_args, user_predicate = flat_args[:-1], flat_args[-1] + pp_params["user_predicate"] = jax_core.pp_var(user_predicate, context) if reduction_op := eqn.params["reduction_op"]: pp_params["reduction_op"] = reduction_op flat_src_transforms, flat_dst_transforms = util.split_list( diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index a73c4f82c31d..e32222775f94 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -2096,16 +2096,21 @@ def kernel_body(_, x_smem, o_smem): y = x + 1.0 np.testing.assert_array_equal(kernel_fn(x), y) - def test_emit_with_2d_grid(self): + @parameterized.product(static=[False, True]) + def test_emit_with_2d_grid(self, static): num_steps1 = 4 num_steps2 = 5 def kernel(x_gmem, o_gmem): + grid = (num_steps1, num_steps2) + if static: + grid = jax.tree.map(jnp.asarray, grid) + plgpu.emit_pipeline( kernel_body, in_specs=[pl.BlockSpec((32, 16, 8), lambda i, j: (0, i, j))], out_specs=[pl.BlockSpec((32, 16, 8), lambda i, j: (0, i, j))], - grid=(num_steps1, num_steps2), + grid=grid, max_concurrent_steps=2, )(x_gmem, o_gmem) @@ -2258,8 +2263,8 @@ def copy_kernel(_, x_smem, o_smem, o_last_block_smem, *consumed_barriers): np.testing.assert_array_equal(out, x) np.testing.assert_array_equal(out_last_block, x[-blk_m:, -blk_n:]) - @parameterized.product(m=[256], n=[256], num_compute_wgs=[1, 2]) - def test_elementwise_add(self, m, n, num_compute_wgs): + @parameterized.product(m=[256], n=[256], num_compute_wgs=[1, 2], static=[False, True]) + def test_elementwise_add(self, m, n, num_compute_wgs, static): self.skip_if_wg_semantics() # Crashes! blk_m = blk_n = 64 @@ -2273,16 +2278,21 @@ def tiled_add_kernel(_, x_smem, y_smem, o_smem): # This is currently a race, but the values written are the same. o_smem[...] = x_smem[...] + y_smem[...] - pipeline = mgpu_pipeline.emit_pipeline_warp_specialized( - tiled_add_kernel, - grid=(m // blk_m, n // blk_n), - max_concurrent_steps=2, - num_compute_wgs=num_compute_wgs, - memory_registers=40, - wg_axis="wg", - in_specs=[spec, spec], - out_specs=[spec], - ) + def pipeline(*gmem_refs): + grid = (m // blk_m, n // blk_n) + if static: + grid = jax.tree.map(jnp.asarray, grid) + return mgpu_pipeline.emit_pipeline_warp_specialized( + tiled_add_kernel, + grid=grid, + max_concurrent_steps=2, + num_compute_wgs=num_compute_wgs, + memory_registers=40, + wg_axis="wg", + in_specs=[spec, spec], + out_specs=[spec], + )(*gmem_refs) + kernel = self.kernel( pipeline, out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32), From 73ecf0bb483eb8239670c1c7a07349519bcf70ac Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 8 Apr 2025 05:24:34 -0700 Subject: [PATCH 461/483] Remove unused `return wrapper` in annotate_function that creates a self reference cycle loop in python. PiperOrigin-RevId: 745099538 --- jax/_src/profiler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/jax/_src/profiler.py b/jax/_src/profiler.py index 0e9949f27f55..912c90182977 100644 --- a/jax/_src/profiler.py +++ b/jax/_src/profiler.py @@ -332,7 +332,6 @@ def annotate_function(func: Callable, name: str | None = None, def wrapper(*args, **kwargs): with TraceAnnotation(name, **decorator_kwargs): return func(*args, **kwargs) - return wrapper return wrapper From 511f78202ff94c7fb88eb2f2ea7427a043c52962 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 8 Apr 2025 11:52:53 +0000 Subject: [PATCH 462/483] Add a skeleton for Pallas:Mosaic GPU documentation --- docs/_static/pallas/gpu/nvidia_sm.svg | 99 +++++++++++++++++ docs/pallas/gpu/index.rst | 14 +++ docs/pallas/gpu/reference.md | 150 ++++++++++++++++++++++++++ docs/pallas/index.rst | 8 +- 4 files changed, 270 insertions(+), 1 deletion(-) create mode 100644 docs/_static/pallas/gpu/nvidia_sm.svg create mode 100644 docs/pallas/gpu/index.rst create mode 100644 docs/pallas/gpu/reference.md diff --git a/docs/_static/pallas/gpu/nvidia_sm.svg b/docs/_static/pallas/gpu/nvidia_sm.svg new file mode 100644 index 000000000000..76b4edb2afad --- /dev/null +++ b/docs/_static/pallas/gpu/nvidia_sm.svg @@ -0,0 +1,99 @@ + + + + + Streaming Multiprocessor + + + + + + Warp Scheduler + + TensorCore + + ALU + (Float/Int) + + Load/Store + + + Special + Functions + + + + + + + Warp Scheduler + + TensorCore + + ALU + (Float/Int) + + Load/Store + + + Special + Functions + + + + + + + Warp Scheduler + + TensorCore + + ALU + (Float/Int) + + Load/Store + + + Special + Functions + + + + + + + Warp Scheduler + + TensorCore + + ALU + (Float/Int) + + Load/Store + + + Special + Functions + + + + + + Shared Memory / L1 Cache + + + diff --git a/docs/pallas/gpu/index.rst b/docs/pallas/gpu/index.rst new file mode 100644 index 000000000000..2d95d5c928c4 --- /dev/null +++ b/docs/pallas/gpu/index.rst @@ -0,0 +1,14 @@ +Pallas:Mosaic GPU +================= +Backend specific documentation for the Mosaic GPU backend. + +.. toctree:: + :caption: Reference documentation + :maxdepth: 2 + + reference + +.. toctree:: + :caption: Guides + :maxdepth: 2 + diff --git a/docs/pallas/gpu/reference.md b/docs/pallas/gpu/reference.md new file mode 100644 index 000000000000..416679d9654c --- /dev/null +++ b/docs/pallas/gpu/reference.md @@ -0,0 +1,150 @@ +# Writing Mosaic GPU kernels with Pallas + +This page is a reference for the most important features of the Pallas:MGPU backend. +It's not a tutorial and as such we do not expect everyone to read it top to bottom. +Still, it is worth going over +just to familiarise yourself with some patterns you can find in other tutorials. + +In the following examples, we're going to assume the following imports are in scope: +```python +import jax.experimental.pallas as pl +import jax.experimental.pallas.mosaic_gpu as plgpu +``` + +## What is a GPU? + +Technically, the NVIDIA GPU architecture looks as follows: the GPU is partitioned into +_streaming multiprocessors_ (SMs). The way this manifests in the CUDA programming model +is that each _CUDA thread block_ (or CTA) is scheduled on exactly one SM, but multiple +blocks can be scheduled onto a single SM at a time. + +Each SM contains a chunk of fast memory called _shared memory_ (SMEM) and 4 subdivisions, +each containing a _warp scheduler_ and compute units (ALU, TensorCore, ...). +This is also reflected in the CUDA programs: each _warp_ (a group of consecutive 32 CUDA +threads in a block) is assigned to one of those subdivisions in a round-robin fashion. +Similarly to blocks, each warp is assigned to exactly one subdivision (it never migrates), +but multiple warps can be assigned to the same SM subdivision. At each clock cycle, the +warp scheduler from each subdivision tries to select one of its resident warps to execute +the next instruction. + +![A diagram of one SM](../../_static/pallas/gpu/nvidia_sm.svg) + +Going further, recent CUDA versions also outline the concept of a _warpgroup_, which are +4 consecutive warps. Knowing how the hardware looks like, we can see where this is comming +from: 4 consecutive warps occupy the 4 quarters of an SM and let us issue instructions +that utilize the whole SM. + +> A GPU can be viewed in many different ways and in here we want to focus on a slightly + simplified model that is very TensorCore-centric. This should help you navigate the + complexities of writing kernels involving the TensorCore, but keep in mind that the + real picture is more complicated. + +For our purposes, TensorCore operations have grown so big that it no longer makes much +sense to follow the CUDA model. As such, to us, a GPU is a collection of single-threaded cores +(SMs) with one thread of Pallas:MGPU corresponding to a CUDA warpgroup. In this model, each +operation you perform in the kernel occupies the whole CUDA warpgroup, and its constituent +warps always run in lockstep (modulo the jitter from hardware scheduling) and never take +different paths through control flow (with the small exception of `core_map` that we will +discuss later). One notable addition here is that we still allow you to co-schedule multiple +of those Pallas-level threads on the same SM so that they can cooperate and communicate +through shared memory (we relize that by putting them in the same CUDA block). + +> This is very similar to a programming model popularized by [Triton](https://triton-lang.org/), + but as you will see there are a few differences. Mosaic GPU tends to be more low level, + which usually means you will have to put in more work, but it also puts you more in control. + In our view both approaches have their merits and we encourage you to pick the backend that + suits your needs the best! Pallas supports and will continue to support Triton as an alternative + GPU backend. + +### In-order execution & using multiple hardware units + +Unlike more complicated CPU architectures GPU only support in-order execution. That, however, +does not mean that at any given time only a single instruction is running! Each SM quarter +has multiple independent functional units: TensorCore, Arithmetic logic unit (ALU), +Load/Store (LSU), Special function unit (SFU). If the first instruction targets one of the +units and is followed by another one (that does not use the result of the first one), then the +warp scheduler can issue the second one before the first one completes. This is often referred +to as instruction-level parallelism (ILP) and is a common theme in modern TensorCore kernels: +TensorCore operations are so big and take so many cycles to complete, that it is a waste to not +try to use other units in the meantime. + +To extend this even further, we can take advantage of this hardware-unit-level parallelism by +allowing multiple Pallas threads (warpgroups) to run concurrently. If one of the threads primarily +occupies the ALU, while another one primarily issues TensorCore related instructions, we can +take advantage of the efficient context switching built into the warp schedulers to keep both +units busy. This is one of the core idea behind algorithms such as [FlashAttention 3](https://arxiv.org/abs/2407.08608) +or [CUTLASS ping-pong matmul kernels](https://pytorch.org/blog/cutlass-ping-pong-gemm-kernel/). + +For more information on how warp scheduling and instruction issue works, we recommend reading +[Analyzing Modern NVIDIA GPU cores](https://arxiv.org/abs/2503.20481). + +## Array layouts and reference transforms + +TODO + +## MMA (TensorCore) + +In this section, we focus on how Pallas:MGPU kernels can utilize the TensorCore unit. +NVIDIA continues to change the programming interface of the TensorCore significantly +between different hardware generations, which is why the lowest-level interfaces +differ in Pallas:MGPU as well. + +### Hopper (`wgmma`) + +TODO + +### Blackwell (`tcgen05`) + +TODO + +## Using `core_map` + +TODO + +## Synchronization structures and primitives + +### `commit_smem` + +TODO + +### `Barrier` + +This is essentially a thin wrapper around an array of PTX `mbarrier` types and is +passed in as a reference. All functions involving barriers expect to only get a single +barrier argument, and so if the reference contains multiple, you have to extract one +of them explicitly using `barriers.at[index]`. + +`Barrier`s are always allocated in SMEM and as such have relatively low overheads. +There are three primary use cases that require the use of `Barrier`s: + +1. Awaiting asynchronous GMEM-to-SMEM copies + +TODO + +2. Cross-warpgroup synchronization + +TODO + +3. Awaiting `tcgen05` TensorCore instructions + +TODO + +### `ClusterBarrier` + +TODO + +### `Semaphore` + +TODO + +## Asynchronous copies + +TODO + +## Inline Mosaic GPU + +TODO + +## Compiler parameters + +TODO diff --git a/docs/pallas/index.rst b/docs/pallas/index.rst index b2e2fca6c82e..6c1a048298c1 100644 --- a/docs/pallas/index.rst +++ b/docs/pallas/index.rst @@ -26,11 +26,17 @@ See also the :class:`jax.experimental.pallas` module API documentation. .. toctree:: - :caption: Platform Features + :caption: TPU backend guide :maxdepth: 2 tpu/index +.. toctree:: + :caption: Mosaic GPU backend guide + :maxdepth: 2 + + gpu/index + .. toctree:: :caption: Design Notes :maxdepth: 2 From d6524dc4616409808d8b1b0b9cd477d09fb0d818 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 8 Apr 2025 07:10:59 -0700 Subject: [PATCH 463/483] Update XLA dependency to use revision http://github.com/openxla/xla/commit/3764aee831189bd32a9c7dea56926b8f31ae86bf. PiperOrigin-RevId: 745130406 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index d4df9ee38034..0b9751ead471 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "77635006f6a898f71f19db360e9b4485aa5106da" -XLA_SHA256 = "d2a63a3cd2f354cd07699f30e7b5c16c7513e686e498b8ad712fb577ab677121" +XLA_COMMIT = "3764aee831189bd32a9c7dea56926b8f31ae86bf" +XLA_SHA256 = "845ce079537b7c25ca236d9910e460803b4148564f5c9c5440b6dab479919e68" def repo(): tf_http_archive( From b926fac66e7f80d4869a4da35a3630c00e050c54 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 8 Apr 2025 07:39:09 -0700 Subject: [PATCH 464/483] [Mosaic GPU] Simplify load/store methods now that we have fewer layouts PiperOrigin-RevId: 745139008 --- jax/_src/pallas/mosaic_gpu/lowering.py | 2 +- .../mosaic/gpu/fragmented_array.py | 72 +++-------------- tests/mosaic/gpu_test.py | 78 +++++++++---------- 3 files changed, 50 insertions(+), 102 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index b7aa01dbbfcf..9b44a1165cdf 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1226,7 +1226,7 @@ def _swap_lowering_rule( is_signed=mgpu_utils.is_signed(x_aval.dtype), optimized=False, ) - value.store_untiled(x_smem) + value.store_untiled(x_smem, optimized=False) return old_value case _: old_value = mgpu.FragmentedArray.load_strided( diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 9ab27927791a..df1e03627f94 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -1788,25 +1788,23 @@ def _(val, idx): fmt_str = fmt.format(f"[{idx_fmt}]: {{}}") utils.debug_print(fmt_str, *idx, val, uniform=False) - def store_untiled(self, ref: ir.Value, *, vector_store: bool = True): + def store_untiled( + self, ref: ir.Value, *, swizzle: int = 16, optimized: bool = True + ): if not ir.MemRefType.isinstance(ref.type): raise ValueError(ref) - - def vs_unsupported(): - if not vector_store: - raise NotImplementedError( - f"Can't use non-vector stores with layout {self.layout}" - ) - match self.layout: case WGSplatFragLayout(): - vs_unsupported() + # All values are the same so swizzle does not affect anything here. self._store_untiled_splat(ref) case WGStridedFragLayout(): - vs_unsupported() + if swizzle != 16: + raise NotImplementedError self._store_untiled_wg_strided(ref) case TiledLayout(): - self._store_untiled_tiled(ref, vector_store=vector_store) + ref_shape = ir.MemRefType(ref.type).shape + ref = utils.memref_reshape(ref, (*(1 for _ in ref_shape), *ref_shape)) + self.store_tiled(ref, swizzle=swizzle, optimized=optimized) case _: raise NotImplementedError(self.layout) @@ -1861,61 +1859,15 @@ def _store_untiled_wg_strided(self, ref: ir.Value): for idx, reg in zip(idxs, self.registers.flat): vector.store(reg, ref_, idx) - def _store_untiled_tiled(self, ref: ir.Value, *, vector_store: bool = True): - """Stores an array with a tiled layout. Not optimized at the moment.""" - if utils.bitwidth(self.mlir_dtype) < 8: - raise NotImplementedError(f"Can't store sub-byte types ({self.mlir_dtype=})") - i32 = ir.IntegerType.get_signless(32) - layout = self.layout - assert isinstance(layout, TiledLayout) - ref_strides, _ = ir.MemRefType(ref.type).get_strides_and_offset() - if vector_store and ref_strides[layout.vector_dim] != 1: - raise NotImplementedError( - "Can't use vector stores with non-unit minormost stride" - ) - strides = layout.tiling.tile_strides(ref_strides) - smem_space = ir.Attribute.parse("#gpu.address_space") - ref_space = ir.MemRefType(ref.type).memory_space - memory_space = None - if str(ref_space) == str(smem_space): - memory_space = 3 - elif ref_space: - raise NotImplementedError(f"Unexpected ref space {ref_space}") - ptr = utils.memref_ptr(ref, memory_space=memory_space) - # Fold warp and lane offsets into the pointer once, since they are dynamic. - dyn_strides = [ - arith.constant(i32, s) for s in strides[-layout.tiled_tiling_rank :] - ] - warp_offset = utils.dyn_dot(layout.warp_indices(), dyn_strides) - lane_offset = utils.dyn_dot(layout.lane_indices(), dyn_strides) - dyn_offset = arith.addi(warp_offset, lane_offset) - ptr = utils.getelementptr(ptr, [dyn_offset], self.mlir_dtype) - # All warp tile offsets are static and can be fused into the store. - for tile_idx, reg in np.ndenumerate(self.registers): - if vector_store: - elems = [reg] - else: - index = ir.IndexType.get() - elems = [ - vector.extractelement(reg, position=c(i, index)) - for i in range(ir.VectorType(reg.type).shape[0]) - ] - for i, e in enumerate(elems): - tile_idx_local = list(tile_idx) - tile_idx_local[layout.vector_dim] += i - tile_idx_local = list(tile_idx_local) - lin_idx = sum(i * s for i, s in zip(tile_idx_local, strides, strict=True)) - reg_ptr = utils.getelementptr(ptr, [lin_idx], self.mlir_dtype) - llvm.store(e, reg_ptr) - - def store_tiled(self, ref, swizzle: int | None): + def store_tiled(self, ref, swizzle: int | None, optimized: bool = True): if not isinstance(self.layout, TiledLayout): raise NotImplementedError(self.layout) layout, shape = self.layout, self.shape # Note that the loop below will "race" for layouts that replicate data. # However, in that case all of the racing writes store the same data, which # is ok in the CUDA memory model. - for get, _, ptr in self.transfer_tiled2(ref, swizzle, layout, shape): + stores = self.transfer_tiled2(ref, swizzle, layout, shape, optimized) + for get, _, ptr in stores: llvm.store(get(self.registers), ptr) @classmethod diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index f0930f5de8cc..b19dee9c065c 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -489,19 +489,12 @@ def get_packed_shape(strides, shape): class WGMMALayoutTest(TestCase): - @parameterized.product(dtype=[jnp.float16, jnp.float32], - transposed_smem=[False, True]) - def test_store_untiled(self, dtype, transposed_smem): + @parameterized.product(dtype=[jnp.float16, jnp.float32]) + def test_store_untiled(self, dtype): def kernel(ctx, out, _): del ctx - if transposed_smem: - out = memref_transpose(out, (1, 0)) - iota_tensor(64, 64, dtype).store_untiled( - out, vector_store=not transposed_smem - ) + iota_tensor(64, 64, dtype).store_untiled(out, optimized=False) expected = np.arange(64 * 64, dtype=dtype).reshape(64, 64) - if transposed_smem: - expected = expected.T iota = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), expected, () )() @@ -749,7 +742,7 @@ def kernel(ctx, lhs, rhs, out, scratch): acc = mgpu.wgmma(init_acc, lhs_smem, rhs_smem, swizzle=swizzle) nvvm.wgmma_commit_group_sync_aligned() nvvm.wgmma_wait_group_sync_aligned(0) - acc.value.store_untiled(out) + acc.value.store_untiled(out, optimized=False) def quantize(x): # Quantize the input to avoid rounding when feeding the WGMMA @@ -821,7 +814,7 @@ def kernel(ctx, rhs, out, rhs_smem): acc = mgpu.wgmma(init_acc, lhs_regs, rhs_smem, swizzle=swizzle) nvvm.wgmma_commit_group_sync_aligned() nvvm.wgmma_wait_group_sync_aligned(0) - acc.value.store_untiled(out) + acc.value.store_untiled(out, optimized=False) y_shape = (n, k) if rhs_transpose else (k, n) y = self.prng.uniform(-1, 1, y_shape).astype(dtype) @@ -881,7 +874,7 @@ def kernel(ctx, rhs, out, smem): acc = mgpu.wgmma(init_acc, lhs_regs, rhs_smem, swizzle=swizzle) nvvm.wgmma_commit_group_sync_aligned() nvvm.wgmma_wait_group_sync_aligned(0) - acc.value.store_untiled(out) + acc.value.store_untiled(out, optimized=False) jax_dtype = jnp.float16 y_shape = (n, k) if rhs_transpose else (k, n) @@ -1042,7 +1035,7 @@ def kernel(ctx, lhs, rhs, out, scratch): ) tcgen05.commit_arrive(barriers[2]) barriers[2].wait(for_tensor_core=True) - acc[:].store_untiled(out) + acc[:].store_untiled(out, optimized=False) x_shape = (k, m) if lhs_transpose else (m, k) x = self.prng.uniform(-1, 1, x_shape).astype(in_jax_dtype) @@ -1145,7 +1138,7 @@ def kernel(ctx, lhs, rhs, out, scratch): tcgen05.commit_arrive(barriers[2], collective=True, ctx=ctx) barriers[2].wait(for_tensor_core=True) m_slice = ds(arith.muli(block_id, c(m_block_tile, index)), m_block_tile) - acc[:].store_untiled(memref_slice(out, m_slice)) + acc[:].store_untiled(memref_slice(out, m_slice), optimized=False) in_finfo = jnp.finfo(in_jax_dtype) exponent_bits, mantissa_bits = in_finfo.nexp, in_finfo.nmant @@ -1198,7 +1191,7 @@ def kernel(ctx, dst, scratch): final_arr = arr + mgpu.FragmentedArray.load_strided( tmp, is_signed=False ) - final_arr.store_untiled(memref_slice(dst, 0)) + final_arr.store_untiled(memref_slice(dst, 0), optimized=False) scf.yield_([]) with ir.InsertionPoint(scf.IfOp(is_second_wg).then_block): barriers[0].wait() @@ -1209,7 +1202,7 @@ def kernel(ctx, dst, scratch): barriers[2].wait() # Synchronize this warpgroup before we overwrite tmp. arr.store_untiled(tmp) barriers[1].arrive() # Signal that tmp is ready. - final_arr.store_untiled(memref_slice(dst, 1)) + final_arr.store_untiled(memref_slice(dst, 1), optimized=False) scf.yield_([]) out_shape = jax.ShapeDtypeStruct((2, 128), jnp.int32) y = mgpu.as_gpu_kernel( @@ -1670,7 +1663,7 @@ def kernel(ctx, dst, _): mlir_dtype = utils.dtype_to_ir_type(dtype) iota = iota_tensor(m, n, dtype) rhs = iota if scalar_rhs is None else c(scalar_rhs, mlir_dtype) - op(iota, rhs).store_untiled(dst) + op(iota, rhs).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((m, n), dtype) result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () @@ -1716,7 +1709,7 @@ def test_division(self, op, dtype, m=64, n=32): def kernel(ctx, dst, _): iota = iota_tensor(m, n, dtype) - op(dtype(4.2).item() * iota, iota + 1).store_untiled(dst) + op(dtype(4.2).item() * iota, iota + 1).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((m, n), dtype) result = mgpu.as_gpu_kernel( @@ -1746,14 +1739,14 @@ def kernel(ctx, dst, _): rhs = 0 if rhs_is_literal else iota + 1 res = op(iota, rhs) assert not res.is_signed - res.astype(i8, is_signed=False).store_untiled(dst) + res.astype(i8, is_signed=False).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((m, n), jnp.int8) result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () )() iota = np.arange(m * n, dtype=dtype).reshape(m, n) - rhs = rhs = 0 if rhs_is_literal else iota + 1 + rhs = 0 if rhs_is_literal else iota + 1 np.testing.assert_array_equal(result, op(iota, rhs).astype(jnp.int8)) def test_foreach_wgmma_row_array(self): @@ -1784,9 +1777,8 @@ def _(v, idx): def test_foreach(self): dtype = jnp.int32 swizzle = 128 - tile = 64, swizzle // jnp.dtype(dtype).itemsize + tiling = (8, swizzle // jnp.dtype(dtype).itemsize) shape = 128, 192 - tiled_shape = mgpu.tile_shape(shape, tile) mlir_dtype = utils.dtype_to_ir_type(dtype) cst = 9999 def causal(val, idx): @@ -1794,12 +1786,16 @@ def causal(val, idx): mask = arith.cmpi(arith.CmpIPredicate.uge, row, col) return arith.select(mask, val, c(cst, mlir_dtype)) - tiling = mgpu.TileTransform(tile) def kernel(ctx, dst, smem): x = iota_tensor(shape[0], shape[1], dtype) - x.foreach(causal, create_array=True, is_signed=False).store_untiled(smem) + x.foreach(causal, create_array=True, is_signed=False).store_tiled(smem, swizzle=128) mgpu.commit_shared() - ctx.async_copy(src_ref=smem, dst_ref=dst) + ctx.async_copy( + src_ref=smem, + dst_ref=dst, + gmem_transform=mgpu.TileTransform(tiling), + swizzle=128, + ) ctx.await_async_copy(0) iota = np.arange(np.prod(shape), dtype=dtype).reshape(*shape) @@ -1809,7 +1805,7 @@ def kernel(ctx, dst, smem): (128, 1, 1), (), jax.ShapeDtypeStruct(shape=shape, dtype=dtype), - jax.ShapeDtypeStruct(shape=shape, dtype=dtype), + jax.ShapeDtypeStruct(shape=mgpu.tile_shape(shape, tiling), dtype=dtype), )() expected = jnp.tril(iota) + jnp.triu(jnp.ones(shape), k=1) * cst np.testing.assert_array_equal(result, expected) @@ -1821,7 +1817,7 @@ def kernel(ctx, dst, smem): def test_bitwise(self, op, dtype, m=64, n=8): def kernel(ctx, dst, _): iota = iota_tensor(m, n, dtype) - op(iota, iota + 1).store_untiled(dst) + op(iota, iota + 1).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((m, n), dtype) result = mgpu.as_gpu_kernel( @@ -1845,7 +1841,7 @@ def test_unary(self, ops, dtype, m=64, n=32): def kernel(ctx, dst, _): iota = iota_tensor(m, n, dtype) - op(iota).store_untiled(dst) + op(iota).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((m, n), dtype) result = mgpu.as_gpu_kernel( @@ -1858,7 +1854,7 @@ def test_select(self, m=64, n=32): def kernel(ctx, dst, _): iota = iota_tensor(m, n, jnp.int32) - (iota < 16).select(iota * 2, iota * 3).store_untiled(dst) + (iota < 16).select(iota * 2, iota * 3).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((m, n), jnp.int32) result = mgpu.as_gpu_kernel( @@ -1881,7 +1877,7 @@ def test_math(self, ops, approx, m=64, n=32): op, np_op = ops def kernel(ctx, dst, _): iota = iota_tensor(m, n, jnp.float32) - op(iota).store_untiled(dst) + op(iota).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32) result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () @@ -1902,7 +1898,7 @@ def kernel(ctx, src, dst, scratch): src, is_signed=utils.is_signed(dtype) ) acc = src.reduce_sum(scratch).broadcast((m,)) - acc.store_untiled(dst) + acc.store_untiled(dst, optimized=False) in_shape = jax.ShapeDtypeStruct((m, n), dtype) out_shape = jax.ShapeDtypeStruct((m,), dtype) @@ -1930,7 +1926,7 @@ def kernel(ctx, dst, _): is_signed=utils.is_signed(dtype), ) acc = src.reduce_sum().broadcast((m,)) - acc.store_untiled(dst) + acc.store_untiled(dst, optimized=False) kernel_fn = mgpu.as_gpu_kernel( kernel, @@ -1950,7 +1946,7 @@ def kernel(ctx, dst, _): def test_reduce(self, op, m=64, n=32): def kernel(ctx, dst, _): iota = iota_tensor(m, n, jnp.float32) - iota.reduce(op, axis=1).broadcast_minor(n).store_untiled(dst) + iota.reduce(op, axis=1).broadcast_minor(n).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32) result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () @@ -1971,7 +1967,7 @@ def kernel(ctx, dst, _): cte = c(1, iota.mlir_dtype) cte_arr = mgpu.FragmentedArray.splat(cte, ()) cte_arr = cte_arr.reshape((1, 1)).broadcast((m, n)) - (iota + cte_arr).store_untiled(dst) + (iota + cte_arr).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32) result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () @@ -1986,7 +1982,7 @@ def kernel(ctx, dst, _): t = mgpu.FragmentedArray.splat( v, (128,), mgpu.WGMMA_ROW_LAYOUT ) - t.broadcast_minor(32).store_untiled(dst) + t.broadcast_minor(32).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((128, 32), jnp.float32) result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () @@ -2005,7 +2001,7 @@ def kernel(ctx, src, dst, _): assert isinstance(pi_arr_sq.layout, mgpu.WGStridedFragLayout) pi_arr_cube = pi_splat.broadcast(pi_arr.shape) * pi_arr_sq assert isinstance(pi_arr_cube.layout, mgpu.WGStridedFragLayout) - (pi_arr == pi_arr).select(pi_splat, pi_arr_cube).store_untiled(dst) + (pi_arr == pi_arr).select(pi_splat, pi_arr_cube).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((128, 32), jnp.float32) inp = jnp.ones_like(out_shape) * 3.14 @@ -2077,7 +2073,7 @@ def kernel(ctx, gmem_input, gmem_output, _): t = mgpu.FragmentedArray.load_untiled( gmem_input, layout=mgpu.WGMMA_COL_LAYOUT, optimized=False ) - t.broadcast_major(m).store_untiled(gmem_output) + t.broadcast_major(m).store_untiled(gmem_output, optimized=False) inp = self.prng.uniform(-1, 1, (n,)).astype(jnp.float16) out_shape = jax.ShapeDtypeStruct((m, n), jnp.float16) @@ -2114,7 +2110,7 @@ def kernel(ctx, inp, out, smem): del ctx, smem arr = mgpu.FragmentedArray.load_strided(inp, is_signed=True) assert ir.VectorType(arr.registers.flat[0].type).shape == [reg_length] - arr.astype(mlir_dtype_to).store_untiled(out) + arr.astype(mlir_dtype_to).store_untiled(out, optimized=False) x = jnp.arange(-128, 128, dtype=jax_dtype_from) x = jnp.tile(x, reg_length // 2) @@ -2190,7 +2186,7 @@ def test_convert_bool_to_u8(self): def kernel(ctx, dst, _): i8 = ir.IntegerType.get_signless(8) iota = iota_tensor(m, n, jnp.uint8) - (iota > 10).astype(i8, is_signed=False).store_untiled(dst) + (iota > 10).astype(i8, is_signed=False).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((m, n), jnp.int8) result = mgpu.as_gpu_kernel( @@ -2318,7 +2314,7 @@ def kernel(ctx, dst, _): ) self.assertEqual(tiled.shape, shape) self.assertEqual(tiled.mlir_dtype, iota.mlir_dtype) - tiled.store_untiled(dst) + tiled.store_untiled(dst, optimized=False) ty = jax.ShapeDtypeStruct(shape, dtype) f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), (), ty, ()) expected = np.arange(math.prod(shape), dtype=dtype).reshape(shape) From f5d73b89ca8dc2a2d862154dff3f56362d33fc82 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 8 Apr 2025 07:58:52 -0700 Subject: [PATCH 465/483] [pallas:mosaic_gpu] Added test for custom pretty-printing rules PiperOrigin-RevId: 745145207 --- jax/_src/pallas/mosaic_gpu/primitives.py | 5 +- tests/pallas/mosaic_gpu_test.py | 74 ++++++++++++++++++++++++ 2 files changed, 77 insertions(+), 2 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index c41a36da94e8..f996d620af8f 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -857,13 +857,14 @@ def _wgmma_ref_pp_eqn( acc, a, b, *leaves = eqn.invars a_transforms_treedef = eqn.params["a_transforms_tree"] b_transforms_treedef = eqn.params["b_transforms_tree"] + split = getattr(a_transforms_treedef, "num_leaves", 0) a_transforms = ( - a_transforms_treedef.unflatten(leaves[: a_transforms_treedef.num_leaves]) + a_transforms_treedef.unflatten(leaves[:split]) if a_transforms_treedef is not None else [] ) b_transforms = ( - b_transforms_treedef.unflatten(leaves[a_transforms_treedef.num_leaves :]) + b_transforms_treedef.unflatten(leaves[split:]) if b_transforms_treedef is not None else [] ) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index e32222775f94..cd4f2f8ab602 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -2634,6 +2634,80 @@ class CoreMapWGTest( ... +class PrettyPrintingTest(PallasTest): + + def test_load(self): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct([2, 128], jnp.float32), + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.SMEM), + ) + def kernel(x_ref, o_ref): + for i in range(2): + x = plgpu.load(x_ref, (i,)) + o_ref[i, ...] = x + + _ = str(jax.make_jaxpr(kernel)(jax.ShapeDtypeStruct((2, 128), jnp.float32))) + + def test_copy_primitives(self): + num_steps = 4 + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((64, 64), jnp.float32), + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + ) + def kernel(x_gmem, o_gmem): + # ``plgpu.emit_pipeline`` is implemented in terms of async copy and + # synchronization primitives. + plgpu.emit_pipeline( + kernel_body, + in_specs=[pl.BlockSpec((64, 64), lambda i: (0, i))], + out_specs=[ + pl.BlockSpec( + (64, 64), + lambda i: (0, i), + ) + ], + grid=(num_steps,), + max_concurrent_steps=2, + )(x_gmem, o_gmem) + + def kernel_body(_, x_smem, o_smem): + o_smem[...] = x_smem[...] + 1.0 + + _ = str(jax.make_jaxpr(kernel)(jax.ShapeDtypeStruct((64, 64), jnp.float32))) + + def test_wgmma(self): + transforms = () + if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Lane: + transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float32), + in_specs=[ + plgpu.GPUBlockSpec(transforms=transforms), + plgpu.GPUBlockSpec(transforms=transforms), + ], + ) + def kernel(a_ref, b_ref, o_ref): + def scope(acc_ref): + plgpu.wgmma(acc_ref, a_ref[...], b_ref) + return acc_ref[...] + + o_ref[...] = pl.run_scoped(scope, plgpu.ACC((64, 192), jnp.float32)) + + _ = str( + jax.make_jaxpr(kernel)( + jax.ShapeDtypeStruct((64, 128), jnp.float16), + jax.ShapeDtypeStruct((128, 192), jnp.float16), + ) + ) + + class ExamplesTest(PallasTest): # Basic From b8353d1b903b57e3a86e666847c126b6d4bb8f7e Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 8 Apr 2025 08:15:39 -0700 Subject: [PATCH 466/483] [Mosaic TPU] Add support for non-32bit types in vector.extract At least for as long as the extracted value is not a scalar. PiperOrigin-RevId: 745151577 --- .../mosaic/dialect/tpu/transforms/apply_vector_layout.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index e68d5da466eb..25aebefa4506 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -3740,10 +3740,6 @@ LogicalResult vector_extract_rule(RewriteContext &ctx, Operation &op, TPU_ASSERT_EQ_OP(layouts_out.size(), 1); TPU_ASSERT_OP(layouts_in.front().has_value()); const VectorLayout &layout_in = *layouts_in.front(); - if (layout_in.bitwidth() != 32) { - return op.emitOpError( - "Not implemented: Only 32-bit vector.extract supported"); - } const VectorType res_vty = dyn_cast(extract_op.getResult().getType()); if (res_vty != nullptr) { @@ -3772,6 +3768,10 @@ LogicalResult vector_extract_rule(RewriteContext &ctx, Operation &op, op.erase(); return success(); } else { + if (layout_in.bitwidth() != 32) { + return op.emitOpError( + "Not implemented: Only 32-bit vector.extract supported"); + } // TODO(b/367459476): Support non-zero offsets. if (layout_in.offsets() != LayoutOffsets{0, 0}) { return op.emitOpError("Not implemented: Unsupported layout"); From e02faabfb2ed7eacd82b7c438a119fde9e362739 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 8 Apr 2025 08:32:59 -0700 Subject: [PATCH 467/483] Replace references to jax.readthedocs.io with docs.jax.dev. PiperOrigin-RevId: 745156931 --- CHANGELOG.md | 68 +++++++++---------- CONTRIBUTING.md | 2 +- README.md | 60 ++++++++-------- cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb | 2 +- cloud_tpu_colabs/JAX_demo.ipynb | 2 +- cloud_tpu_colabs/Pmap_Cookbook.ipynb | 4 +- cloud_tpu_colabs/README.md | 2 +- docs/README.md | 2 +- docs/about.md | 16 ++--- docs/advanced-autodiff.md | 4 +- docs/aot.md | 2 +- docs/api_compatibility.md | 2 +- docs/autodidax.ipynb | 4 +- docs/autodidax.md | 4 +- docs/autodidax.py | 4 +- docs/building_on_jax.md | 4 +- docs/contributing.md | 2 +- docs/control-flow.md | 10 +-- docs/developer.md | 6 +- docs/export/export.md | 2 +- docs/export/shape_poly.md | 4 +- docs/faq.rst | 18 ++--- docs/ffi.ipynb | 4 +- docs/ffi.md | 4 +- docs/gpu_memory_allocation.rst | 2 +- docs/installation.md | 2 +- docs/jax-primitives.md | 2 +- docs/jax_array_migration.md | 2 +- docs/jep/10657-sequencing-effects.md | 2 +- docs/jep/12049-type-annotations.md | 2 +- docs/jep/14273-shard-map.md | 4 +- docs/jep/15856-jex.md | 14 ++-- docs/jep/17111-shmap-transpose.md | 2 +- docs/jep/2026-custom-derivatives.md | 2 +- docs/jep/4008-custom-vjp-update.md | 2 +- docs/jep/4410-omnistaging.md | 2 +- docs/jep/9407-type-promotion.ipynb | 8 +-- docs/jep/9407-type-promotion.md | 8 +-- docs/jit-compilation.md | 2 +- docs/notebooks/Common_Gotchas_in_JAX.ipynb | 12 ++-- docs/notebooks/Common_Gotchas_in_JAX.md | 8 +-- ...tom_derivative_rules_for_Python_code.ipynb | 6 +- ...Custom_derivative_rules_for_Python_code.md | 6 +- ...arrays_and_automatic_parallelization.ipynb | 6 +- ...ed_arrays_and_automatic_parallelization.md | 6 +- docs/notebooks/README.md | 2 +- .../Writing_custom_interpreters_in_Jax.ipynb | 2 +- .../Writing_custom_interpreters_in_Jax.md | 2 +- docs/notebooks/autodiff_remat.ipynb | 2 +- docs/notebooks/autodiff_remat.md | 2 +- .../neural_network_with_tfds_data.ipynb | 2 +- .../neural_network_with_tfds_data.md | 2 +- docs/notebooks/shard_map.ipynb | 10 +-- docs/notebooks/shard_map.md | 10 +-- docs/notebooks/thinking_in_jax.ipynb | 12 ++-- docs/notebooks/thinking_in_jax.md | 8 +-- docs/pallas/CHANGELOG.md | 2 +- docs/quickstart.md | 2 +- docs/stateful-computations.md | 2 +- docs/type_promotion.rst | 2 +- docs/xla_flags.md | 2 +- examples/ffi/README.md | 2 +- examples/ffi/src/jax_ffi_example/rms_norm.py | 2 +- jax/BUILD | 2 +- jax/_src/ad_checkpoint.py | 4 +- jax/_src/api.py | 8 +-- jax/_src/basearray.py | 2 +- jax/_src/callback.py | 4 +- jax/_src/compilation_cache.py | 2 +- jax/_src/config.py | 8 +-- jax/_src/custom_derivatives.py | 4 +- jax/_src/debugging.py | 2 +- jax/_src/effects.py | 2 +- jax/_src/errors.py | 8 +-- jax/_src/export/_export.py | 20 +++--- jax/_src/export/shape_poly.py | 28 ++++---- jax/_src/flatten_util.py | 2 +- jax/_src/interpreters/mlir.py | 4 +- jax/_src/lax/lax.py | 4 +- jax/_src/mesh.py | 2 +- jax/_src/named_sharding.py | 2 +- jax/_src/numpy/array_methods.py | 2 +- jax/_src/numpy/lax_numpy.py | 12 ++-- jax/_src/numpy/util.py | 4 +- jax/_src/numpy/vectorize.py | 2 +- jax/_src/pallas/mosaic/lowering.py | 2 +- jax/_src/pallas/mosaic_gpu/lowering.py | 2 +- jax/_src/pallas/pallas_call.py | 4 +- jax/_src/pjit.py | 6 +- jax/_src/random.py | 2 +- jax/_src/xla_bridge.py | 6 +- jax/core.py | 16 ++--- jax/experimental/host_callback.py | 2 +- jax/experimental/jax2tf/README.md | 10 +-- .../jax2tf/g3doc/no_xla_limitations.md | 8 +-- jax/experimental/jax2tf/jax2tf.py | 2 +- .../jax2tf/tests/shape_poly_test.py | 8 +-- jax/experimental/pallas/__init__.py | 2 +- jax/experimental/pallas/ops/tpu/matmul.py | 2 +- jax/experimental/shard_map.py | 2 +- jax/extend/__init__.py | 10 +-- jax/lib/xla_client.py | 2 +- jax/random.py | 4 +- jax/stages.py | 2 +- jax/typing.py | 4 +- jaxlib/xla/pytree.h | 2 +- tests/api_test.py | 2 +- tests/debug_info_test.py | 2 +- tests/errors_test.py | 2 +- tests/export_test.py | 8 +-- tests/lax_test.py | 4 +- tests/random_test.py | 2 +- 112 files changed, 323 insertions(+), 323 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e744cad902de..86d1c82f6401 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ # Change log -Best viewed [here](https://jax.readthedocs.io/en/latest/changelog.html). +Best viewed [here](https://docs.jax.dev/en/latest/changelog.html). For the changes specific to the experimental Pallas APIs, see {ref}`pallas-changelog`. @@ -126,7 +126,7 @@ Patch release of 0.5.1 ## jax 0.5.0 (Jan 17, 2025) As of this release, JAX now uses -[effort-based versioning](https://jax.readthedocs.io/en/latest/jep/25516-effver.html). +[effort-based versioning](https://docs.jax.dev/en/latest/jep/25516-effver.html). Since this release makes a breaking change to PRNG key semantics that may require users to update their code, we are bumping the "meso" version of JAX to signify this. @@ -217,7 +217,7 @@ to signify this. * New Features * {func}`jax.export.export` can be used for device-polymorphic export with shardings constructed with {func}`jax.sharding.AbstractMesh`. - See the [jax.export documentation](https://jax.readthedocs.io/en/latest/export/export.html#device-polymorphic-export). + See the [jax.export documentation](https://docs.jax.dev/en/latest/export/export.html#device-polymorphic-export). * Added {func}`jax.lax.split`. This is a primitive version of {func}`jax.numpy.split`, added because it yields a more compact transpose during automatic differentiation. @@ -259,7 +259,7 @@ This is a patch release of jax 0.4.36. Only "jax" was released at this version. after being deprecated in JAX v0.4.31. Instead use `xb = jax.lib.xla_bridge`, `xc = jax.lib.xla_client`, and `xe = jax.lib.xla_extension`. * The deprecated module `jax.experimental.export` has been removed. It was replaced - by {mod}`jax.export` in JAX v0.4.30. See the [migration guide](https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export) + by {mod}`jax.export` in JAX v0.4.30. See the [migration guide](https://docs.jax.dev/en/latest/export/export.html#migration-guide-from-jax-experimental-export) for information on migrating to the new API. * The `initial` argument to {func}`jax.nn.softmax` and {func}`jax.nn.log_softmax` has been removed, after being deprecated in v0.4.27. @@ -297,7 +297,7 @@ This is a patch release of jax 0.4.36. Only "jax" was released at this version. call that we guarantee export stability. This is because this custom call relies on Triton IR, which is not guaranteed to be stable. If you need to export code that uses this custom call, you can use the `disabled_checks` - parameter. See more details in the [documentation](https://jax.readthedocs.io/en/latest/export/export.html#compatibility-guarantees-for-custom-calls). + parameter. See more details in the [documentation](https://docs.jax.dev/en/latest/export/export.html#compatibility-guarantees-for-custom-calls). * New Features * {func}`jax.jit` got a new `compiler_options: dict[str, Any]` argument, for @@ -577,7 +577,7 @@ See the 0.4.33 release notes for more details. * Added an API for exporting and serializing JAX functions. This used to exist in `jax.experimental.export` (which is being deprecated), and will now live in `jax.export`. - See the [documentation](https://jax.readthedocs.io/en/latest/export/index.html). + See the [documentation](https://docs.jax.dev/en/latest/export/index.html). * Deprecations * Internal pretty-printing tools `jax.core.pp_*` are deprecated, and will be removed @@ -586,7 +586,7 @@ See the 0.4.33 release notes for more details. release. This previously was the case, but there was an inadvertent regression in the last several JAX releases. * `jax.experimental.export` is deprecated. Use {mod}`jax.export` instead. - See the [migration guide](https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export). + See the [migration guide](https://docs.jax.dev/en/latest/export/export.html#migration-guide-from-jax-experimental-export). * Passing an array in place of a dtype is now deprecated in most cases; e.g. for arrays `x` and `y`, `x.astype(y)` will raise a warning. To silence it use `x.astype(y.dtype)`. * `jax.xla_computation` is deprecated and will be removed in a future release. @@ -798,7 +798,7 @@ See the 0.4.33 release notes for more details. deprecated. Use `jax.experimental.shard_map` or `jax.vmap` with the `spmd_axis_name` argument for expressing SPMD device-parallel computations. * The `jax.experimental.host_callback` module is deprecated. - Use instead the [new JAX external callbacks](https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html). + Use instead the [new JAX external callbacks](https://docs.jax.dev/en/latest/notebooks/external_callbacks.html). Added `JAX_HOST_CALLBACK_LEGACY` flag to assist in the transition to the new callbacks. See {jax-issue}`#20385` for a discussion. * Passing arguments to {func}`jax.numpy.array_equal` and {func}`jax.numpy.array_equiv` @@ -1270,9 +1270,9 @@ See the 0.4.33 release notes for more details. * Deprecations * Python 3.8 support has been dropped as per - https://jax.readthedocs.io/en/latest/deprecation.html + https://docs.jax.dev/en/latest/deprecation.html * JAX now requires NumPy 1.22 or newer as per - https://jax.readthedocs.io/en/latest/deprecation.html + https://docs.jax.dev/en/latest/deprecation.html * Passing optional arguments to {func}`jax.numpy.ndarray.at` by position is no longer supported, after being deprecated in JAX version 0.4.7. For example, instead of `x.at[i].get(True)`, use `x.at[i].get(indices_are_sorted=True)` @@ -1317,7 +1317,7 @@ See the 0.4.33 release notes for more details. * Deprecations * Python 3.8 support has been dropped as per - https://jax.readthedocs.io/en/latest/deprecation.html + https://docs.jax.dev/en/latest/deprecation.html ## jax 0.4.13 (June 22, 2023) @@ -1496,7 +1496,7 @@ See the 0.4.33 release notes for more details. ## jax 0.4.7 (March 27, 2023) * Changes - * As per https://jax.readthedocs.io/en/latest/jax_array_migration.html#jax-array-migration + * As per https://docs.jax.dev/en/latest/jax_array_migration.html#jax-array-migration `jax.config.jax_array` cannot be disabled anymore. * `jax.config.jax_jit_pjit_api_merge` cannot be disabled anymore. * {func}`jax.experimental.jax2tf.convert` now supports the `native_serialization` @@ -1580,7 +1580,7 @@ Changes: on top of each other. With the `jit`-`pjit` implementation merge, `jit` becomes an initial style primitive which means that we trace to jaxpr as early as possible. For more information see - [this section in autodidax](https://jax.readthedocs.io/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing). + [this section in autodidax](https://docs.jax.dev/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing). Moving to initial style should simplify JAX's internals and make development of features like dynamic shapes, etc easier. You can disable it only via the environment variable i.e. @@ -1665,9 +1665,9 @@ Changes: simplifies and unifies JAX internals, and allows us to unify `jit` and `pjit`. `jax.Array` has been enabled by default in JAX 0.4 and makes some breaking change to the `pjit` API. The [jax.Array migration - guide](https://jax.readthedocs.io/en/latest/jax_array_migration.html) can + guide](https://docs.jax.dev/en/latest/jax_array_migration.html) can help you migrate your codebase to `jax.Array`. You can also look at the - [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) + [Distributed arrays and automatic parallelization](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) tutorial to understand the new concepts. * `PartitionSpec` and `Mesh` are now out of experimental. The new API endpoints are `jax.sharding.PartitionSpec` and `jax.sharding.Mesh`. @@ -1696,7 +1696,7 @@ Changes: * The behavior of `XLA_PYTHON_CLIENT_MEM_FRACTION=.XX` has been changed to allocate XX% of the total GPU memory instead of the previous behavior of using currently available GPU memory to calculate preallocation. Please refer to - [GPU memory allocation](https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html) for + [GPU memory allocation](https://docs.jax.dev/en/latest/gpu_memory_allocation.html) for more details. * The deprecated method `.block_host_until_ready()` has been removed. Use `.block_until_ready()` instead. @@ -1810,7 +1810,7 @@ Changes: * Changes * Ahead-of-time lowering and compilation functionality (tracked in {jax-issue}`#7733`) is stable and public. See [the - overview](https://jax.readthedocs.io/en/latest/aot.html) and the API docs + overview](https://docs.jax.dev/en/latest/aot.html) and the API docs for {mod}`jax.stages`. * Introduced {class}`jax.Array`, intended to be used for both `isinstance` checks and type annotations for array types in JAX. Notice that this included some subtle @@ -1831,7 +1831,7 @@ Changes: * Breaking changes * {func}`jax.checkpoint`, also known as {func}`jax.remat`, no longer supports the `concrete` option, following the previous version's deprecation; see - [JEP 11830](https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html). + [JEP 11830](https://docs.jax.dev/en/latest/jep/11830-new-remat-checkpoint.html). * Changes * Added {func}`jax.pure_callback` that enables calling back to pure Python functions from compiled functions (e.g. functions decorated with `jax.jit` or `jax.pmap`). * Deprecations: @@ -1843,7 +1843,7 @@ Changes: * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.15...main). * Breaking changes * Support for NumPy 1.19 has been dropped, per the - [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). + [deprecation policy](https://docs.jax.dev/en/latest/deprecation.html). Please upgrade to NumPy 1.20 or newer. * Changes * Added {mod}`jax.debug` that includes utilities for runtime value debugging such at {func}`jax.debug.print` and {func}`jax.debug.breakpoint`. @@ -1861,7 +1861,7 @@ Changes: {mod}`jax.example_libraries.optimizers`. * {func}`jax.checkpoint`, also known as {func}`jax.remat`, has a new implementation switched on by default, meaning the old implementation is - deprecated; see [JEP 11830](https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html). + deprecated; see [JEP 11830](https://docs.jax.dev/en/latest/jep/11830-new-remat-checkpoint.html). ## jax 0.3.15 (July 22, 2022) * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.14...jax-v0.3.15). @@ -1993,7 +1993,7 @@ Changes: * {func}`jax.numpy.linalg.matrix_rank` on TPUs now accepts complex input. * {func}`jax.scipy.cluster.vq.vq` has been added. * `jax.experimental.maps.mesh` has been deleted. - Please use `jax.experimental.maps.Mesh`. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh + Please use `jax.experimental.maps.Mesh`. Please see https://docs.jax.dev/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh for more information. * {func}`jax.scipy.linalg.qr` now returns a length-1 tuple rather than the raw array when `mode='r'`, in order to match the behavior of `scipy.linalg.qr` ({jax-issue}`#10452`) @@ -2109,7 +2109,7 @@ Changes: * Changes: * The functions `jax.ops.index_update`, `jax.ops.index_add`, which were deprecated in 0.2.22, have been removed. Please use - [the `.at` property on JAX arrays](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html) + [the `.at` property on JAX arrays](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html) instead, e.g., `x.at[idx].set(y)`. * Moved `jax.experimental.ann.approx_*_k` into `jax.lax`. These functions are optimized alternatives to `jax.lax.top_k`. @@ -2155,13 +2155,13 @@ Changes: commits](https://github.com/jax-ml/jax/compare/jax-v0.2.28...jax-v0.3.0). * Changes - * jax version has been bumped to 0.3.0. Please see the [design doc](https://jax.readthedocs.io/en/latest/design_notes/jax_versioning.html) + * jax version has been bumped to 0.3.0. Please see the [design doc](https://docs.jax.dev/en/latest/design_notes/jax_versioning.html) for the explanation. ## jaxlib 0.3.0 (Feb 10, 2022) * Changes * Bazel 5.0.0 is now required to build jaxlib. - * jaxlib version has been bumped to 0.3.0. Please see the [design doc](https://jax.readthedocs.io/en/latest/design_notes/jax_versioning.html) + * jaxlib version has been bumped to 0.3.0. Please see the [design doc](https://docs.jax.dev/en/latest/design_notes/jax_versioning.html) for the explanation. ## jax 0.2.28 (Feb 1, 2022) @@ -2183,7 +2183,7 @@ Changes: by default. * Breaking changes * Support for NumPy 1.18 has been dropped, per the - [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). + [deprecation policy](https://docs.jax.dev/en/latest/deprecation.html). Please upgrade to a supported NumPy version. * Bug fixes * Fixed a bug where apparently identical pytreedef objects constructed by different routes @@ -2195,7 +2195,7 @@ Changes: * Breaking changes: * Support for NumPy 1.18 has been dropped, per the - [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). + [deprecation policy](https://docs.jax.dev/en/latest/deprecation.html). Please upgrade to a supported NumPy version. * The host_callback primitives have been simplified to drop the special autodiff handling for hcb.id_tap and id_print. @@ -2322,7 +2322,7 @@ Changes: * Deprecations * The functions `jax.ops.index_update`, `jax.ops.index_add` etc. are deprecated and will be removed in a future JAX release. Please use - [the `.at` property on JAX arrays](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html) + [the `.at` property on JAX arrays](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html) instead, e.g., `x.at[idx].set(y)`. For now, these functions produce a `DeprecationWarning`. * New features: @@ -2386,7 +2386,7 @@ Changes: commits](https://github.com/jax-ml/jax/compare/jax-v0.2.18...jax-v0.2.19). * Breaking changes: * Support for NumPy 1.17 has been dropped, per the - [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). + [deprecation policy](https://docs.jax.dev/en/latest/deprecation.html). Please upgrade to a supported NumPy version. * The `jit` decorator has been added around the implementation of a number of operators on JAX arrays. This speeds up dispatch times for common @@ -2407,10 +2407,10 @@ Changes: ## jaxlib 0.1.70 (Aug 9, 2021) * Breaking changes: * Support for Python 3.6 has been dropped, per the - [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). + [deprecation policy](https://docs.jax.dev/en/latest/deprecation.html). Please upgrade to a supported Python version. * Support for NumPy 1.17 has been dropped, per the - [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). + [deprecation policy](https://docs.jax.dev/en/latest/deprecation.html). Please upgrade to a supported NumPy version. * The host_callback mechanism now uses one thread per local device for @@ -2424,7 +2424,7 @@ Changes: * Breaking changes: * Support for Python 3.6 has been dropped, per the - [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). + [deprecation policy](https://docs.jax.dev/en/latest/deprecation.html). Please upgrade to a supported Python version. * The minimum jaxlib version is now 0.1.69. * The `backend` argument to {py:func}`jax.dlpack.from_dlpack` has been @@ -2473,7 +2473,7 @@ Changes: * Breaking changes: * Support for NumPy 1.16 has been dropped, per the - [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). + [deprecation policy](https://docs.jax.dev/en/latest/deprecation.html). * Bug fixes: * Fixed bug that prevented round-tripping from JAX to TF and back: @@ -3013,7 +3013,7 @@ Changes: * Support for reduction over subsets of a pmapped axis using `axis_index_groups` {jax-issue}`#2382`. * Experimental support for printing and calling host-side Python function from - compiled code. See [id_print and id_tap](https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html) + compiled code. See [id_print and id_tap](https://docs.jax.dev/en/latest/jax.experimental.host_callback.html) ({jax-issue}`#3006`). * Notable changes: * The visibility of names exported from {mod}`jax.numpy` has been @@ -3085,7 +3085,7 @@ Changes: ## jax 0.1.63 (April 12, 2020) * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.62...jax-v0.1.63). -* Added `jax.custom_jvp` and `jax.custom_vjp` from {jax-issue}`#2026`, see the [tutorial notebook](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html). Deprecated `jax.custom_transforms` and removed it from the docs (though it still works). +* Added `jax.custom_jvp` and `jax.custom_vjp` from {jax-issue}`#2026`, see the [tutorial notebook](https://docs.jax.dev/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html). Deprecated `jax.custom_transforms` and removed it from the docs (though it still works). * Add `scipy.sparse.linalg.cg` {jax-issue}`#2566`. * Changed how Tracers are printed to show more useful information for debugging {jax-issue}`#2591`. * Made `jax.numpy.isclose` handle `nan` and `inf` correctly {jax-issue}`#2501`. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 314d4387a044..046d3df3195c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,4 +1,4 @@ # Contributing to JAX For information on how to contribute to JAX, see -[Contributing to JAX](https://jax.readthedocs.io/en/latest/contributing.html) +[Contributing to JAX](https://docs.jax.dev/en/latest/contributing.html) diff --git a/README.md b/README.md index 0aca7cf58e6e..00391f314044 100644 --- a/README.md +++ b/README.md @@ -11,8 +11,8 @@ | [**Transformations**](#transformations) | [**Install guide**](#installation) | [**Neural net libraries**](#neural-network-libraries) -| [**Change logs**](https://jax.readthedocs.io/en/latest/changelog.html) -| [**Reference docs**](https://jax.readthedocs.io/en/latest/) +| [**Change logs**](https://docs.jax.dev/en/latest/changelog.html) +| [**Reference docs**](https://docs.jax.dev/en/latest/) ## What is JAX? @@ -48,7 +48,7 @@ are instances of such transformations. Others are parallel programming of multiple accelerators, with more to come. This is a research project, not an official Google product. Expect -[sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html). +[sharp edges](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html). Please help by trying it out, [reporting bugs](https://github.com/jax-ml/jax/issues), and letting us know what you think! @@ -83,15 +83,15 @@ perex_grads = jit(vmap(grad_loss, in_axes=(None, 0, 0))) # fast per-example gra ## Quickstart: Colab in the Cloud Jump right in using a notebook in your browser, connected to a Google Cloud GPU. Here are some starter notebooks: -- [The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://jax.readthedocs.io/en/latest/quickstart.html) +- [The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://docs.jax.dev/en/latest/quickstart.html) - [Training a Simple Neural Network, with TensorFlow Dataset Data Loading](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) **JAX now runs on Cloud TPUs.** To try out the preview, see the [Cloud TPU Colabs](https://github.com/jax-ml/jax/tree/main/cloud_tpu_colabs). For a deeper dive into JAX: -- [The Autodiff Cookbook, Part 1: easy and powerful automatic differentiation in JAX](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html) -- [Common gotchas and sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) +- [The Autodiff Cookbook, Part 1: easy and powerful automatic differentiation in JAX](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html) +- [Common gotchas and sharp edges](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html) - See the [full list of notebooks](https://github.com/jax-ml/jax/tree/main/docs/notebooks). @@ -105,7 +105,7 @@ Here are four transformations of primary interest: `grad`, `jit`, `vmap`, and JAX has roughly the same API as [Autograd](https://github.com/hips/autograd). The most popular function is -[`grad`](https://jax.readthedocs.io/en/latest/jax.html#jax.grad) +[`grad`](https://docs.jax.dev/en/latest/jax.html#jax.grad) for reverse-mode gradients: ```python @@ -129,13 +129,13 @@ print(grad(grad(grad(tanh)))(1.0)) ``` For more advanced autodiff, you can use -[`jax.vjp`](https://jax.readthedocs.io/en/latest/jax.html#jax.vjp) for +[`jax.vjp`](https://docs.jax.dev/en/latest/jax.html#jax.vjp) for reverse-mode vector-Jacobian products and -[`jax.jvp`](https://jax.readthedocs.io/en/latest/jax.html#jax.jvp) for +[`jax.jvp`](https://docs.jax.dev/en/latest/jax.html#jax.jvp) for forward-mode Jacobian-vector products. The two can be composed arbitrarily with one another, and with other JAX transformations. Here's one way to compose those to make a function that efficiently computes [full Hessian -matrices](https://jax.readthedocs.io/en/latest/_autosummary/jax.hessian.html#jax.hessian): +matrices](https://docs.jax.dev/en/latest/_autosummary/jax.hessian.html#jax.hessian): ```python from jax import jit, jacfwd, jacrev @@ -160,15 +160,15 @@ print(abs_val_grad(-1.0)) # prints -1.0 (abs_val is re-evaluated) ``` See the [reference docs on automatic -differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) +differentiation](https://docs.jax.dev/en/latest/jax.html#automatic-differentiation) and the [JAX Autodiff -Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html) +Cookbook](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html) for more. ### Compilation with `jit` You can use XLA to compile your functions end-to-end with -[`jit`](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit), +[`jit`](https://docs.jax.dev/en/latest/jax.html#just-in-time-compilation-jit), used either as an `@jit` decorator or as a higher-order function. ```python @@ -189,12 +189,12 @@ You can mix `jit` and `grad` and any other JAX transformation however you like. Using `jit` puts constraints on the kind of Python control flow the function can use; see -the tutorial on [Control Flow and Logical Operators with JIT](https://jax.readthedocs.io/en/latest/control-flow.html) +the tutorial on [Control Flow and Logical Operators with JIT](https://docs.jax.dev/en/latest/control-flow.html) for more. ### Auto-vectorization with `vmap` -[`vmap`](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) is +[`vmap`](https://docs.jax.dev/en/latest/jax.html#vectorization-vmap) is the vectorizing map. It has the familiar semantics of mapping a function along array axes, but instead of keeping the loop on the outside, it pushes the loop down into a @@ -259,7 +259,7 @@ differentiation for fast Jacobian and Hessian matrix calculations in ### SPMD programming with `pmap` For parallel programming of multiple accelerators, like multiple GPUs, use -[`pmap`](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap). +[`pmap`](https://docs.jax.dev/en/latest/jax.html#parallelization-pmap). With `pmap` you write single-program multiple-data (SPMD) programs, including fast parallel collective communication operations. Applying `pmap` will mean that the function you write is compiled by XLA (similarly to `jit`), then @@ -284,7 +284,7 @@ print(pmap(jnp.mean)(result)) ``` In addition to expressing pure maps, you can use fast [collective communication -operations](https://jax.readthedocs.io/en/latest/jax.lax.html#parallel-operators) +operations](https://docs.jax.dev/en/latest/jax.lax.html#parallel-operators) between devices: ```python @@ -341,20 +341,20 @@ for more. For a more thorough survey of current gotchas, with examples and explanations, we highly recommend reading the [Gotchas -Notebook](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html). +Notebook](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html). Some standouts: 1. JAX transformations only work on [pure functions](https://en.wikipedia.org/wiki/Pure_function), which don't have side-effects and respect [referential transparency](https://en.wikipedia.org/wiki/Referential_transparency) (i.e. object identity testing with `is` isn't preserved). If you use a JAX transformation on an impure Python function, you might see an error like `Exception: Can't lift Traced...` or `Exception: Different traces at same level`. 1. [In-place mutating updates of - arrays](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates), like `x[i] += y`, aren't supported, but [there are functional alternatives](https://jax.readthedocs.io/en/latest/jax.ops.html). Under a `jit`, those functional alternatives will reuse buffers in-place automatically. + arrays](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates), like `x[i] += y`, aren't supported, but [there are functional alternatives](https://docs.jax.dev/en/latest/jax.ops.html). Under a `jit`, those functional alternatives will reuse buffers in-place automatically. 1. [Random numbers are - different](https://jax.readthedocs.io/en/latest/random-numbers.html), but for [good reasons](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md). + different](https://docs.jax.dev/en/latest/random-numbers.html), but for [good reasons](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md). 1. If you're looking for [convolution - operators](https://jax.readthedocs.io/en/latest/notebooks/convolutions.html), + operators](https://docs.jax.dev/en/latest/notebooks/convolutions.html), they're in the `jax.lax` package. 1. JAX enforces single-precision (32-bit, e.g. `float32`) values by default, and [to enable - double-precision](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision) + double-precision](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision) (64-bit, e.g. `float64`) one needs to set the `jax_enable_x64` variable at startup (or set the environment variable `JAX_ENABLE_X64=True`). On TPU, JAX uses 32-bit values by default for everything _except_ internal @@ -368,14 +368,14 @@ Some standouts: and NumPy types aren't preserved, namely `np.add(1, np.array([2], np.float32)).dtype` is `float64` rather than `float32`. 1. Some transformations, like `jit`, [constrain how you can use Python control - flow](https://jax.readthedocs.io/en/latest/control-flow.html). + flow](https://docs.jax.dev/en/latest/control-flow.html). You'll always get loud errors if something goes wrong. You might have to use [`jit`'s `static_argnums` - parameter](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit), + parameter](https://docs.jax.dev/en/latest/jax.html#just-in-time-compilation-jit), [structured control flow - primitives](https://jax.readthedocs.io/en/latest/jax.lax.html#control-flow-operators) + primitives](https://docs.jax.dev/en/latest/jax.lax.html#control-flow-operators) like - [`lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html#jax.lax.scan), + [`lax.scan`](https://docs.jax.dev/en/latest/_autosummary/jax.lax.scan.html#jax.lax.scan), or just use `jit` on smaller subfunctions. ## Installation @@ -403,7 +403,7 @@ Some standouts: | Mac GPU | Follow [Apple's instructions](https://developer.apple.com/metal/jax/). | | Intel GPU | Follow [Intel's instructions](https://github.com/intel/intel-extension-for-openxla/blob/main/docs/acc_jax.md). | -See [the documentation](https://jax.readthedocs.io/en/latest/installation.html) +See [the documentation](https://docs.jax.dev/en/latest/installation.html) for information on alternative installation strategies. These include compiling from source, installing with Docker, using other versions of CUDA, a community-supported conda build, and answers to some frequently-asked questions. @@ -417,7 +417,7 @@ for training neural networks in JAX. If you want a fully featured library for ne training with examples and how-to guides, try [Flax](https://github.com/google/flax) and its [documentation site](https://flax.readthedocs.io/en/latest/nnx/index.html). -Check out the [JAX Ecosystem section](https://jax.readthedocs.io/en/latest/#ecosystem) +Check out the [JAX Ecosystem section](https://docs.jax.dev/en/latest/#ecosystem) on the JAX documentation site for a list of JAX-based network libraries, which includes [Optax](https://github.com/deepmind/optax) for gradient processing and optimization, [chex](https://github.com/deepmind/chex) for reliable code and testing, and @@ -452,7 +452,7 @@ paper. ## Reference documentation For details about the JAX API, see the -[reference documentation](https://jax.readthedocs.io/). +[reference documentation](https://docs.jax.dev/). For getting started as a JAX developer, see the -[developer documentation](https://jax.readthedocs.io/en/latest/developer.html). +[developer documentation](https://docs.jax.dev/en/latest/developer.html). diff --git a/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb b/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb index edaa71b93e85..5bc045d0f606 100644 --- a/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb +++ b/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb @@ -225,7 +225,7 @@ "* Jacobian pre-accumulation for elementwise operations (like `gelu`)\n", "\n", "\n", - "For much more, see the [JAX Autodiff Cookbook (Part 1)](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)." + "For much more, see the [JAX Autodiff Cookbook (Part 1)](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html)." ] }, { diff --git a/cloud_tpu_colabs/JAX_demo.ipynb b/cloud_tpu_colabs/JAX_demo.ipynb index d7ba5ed334f4..b69246c57e0b 100644 --- a/cloud_tpu_colabs/JAX_demo.ipynb +++ b/cloud_tpu_colabs/JAX_demo.ipynb @@ -315,7 +315,7 @@ "* Jacobian pre-accumulation for elementwise operations (like `gelu`)\n", "\n", "\n", - "For much more, see the [JAX Autodiff Cookbook (Part 1)](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)." + "For much more, see the [JAX Autodiff Cookbook (Part 1)](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html)." ] }, { diff --git a/cloud_tpu_colabs/Pmap_Cookbook.ipynb b/cloud_tpu_colabs/Pmap_Cookbook.ipynb index ea126ac4f1e7..8b16cd7694eb 100644 --- a/cloud_tpu_colabs/Pmap_Cookbook.ipynb +++ b/cloud_tpu_colabs/Pmap_Cookbook.ipynb @@ -59,7 +59,7 @@ "id": "2e_06-OAJNyi" }, "source": [ - "A basic starting point is expressing parallel maps with [`pmap`](https://jax.readthedocs.io/en/latest/jax.html#jax.pmap):" + "A basic starting point is expressing parallel maps with [`pmap`](https://docs.jax.dev/en/latest/jax.html#jax.pmap):" ] }, { @@ -407,7 +407,7 @@ "source": [ "When writing nested `pmap` functions in the decorator style, axis names are resolved according to lexical scoping.\n", "\n", - "Check [the JAX reference documentation](https://jax.readthedocs.io/en/latest/jax.lax.html#parallel-operators) for a complete list of the parallel operators. More are being added!\n", + "Check [the JAX reference documentation](https://docs.jax.dev/en/latest/jax.lax.html#parallel-operators) for a complete list of the parallel operators. More are being added!\n", "\n", "Here's how to use `lax.ppermute` to implement a simple halo exchange for a [Rule 30](https://en.wikipedia.org/wiki/Rule_30) simulation:" ] diff --git a/cloud_tpu_colabs/README.md b/cloud_tpu_colabs/README.md index db3dc5f30814..6e5501584da0 100644 --- a/cloud_tpu_colabs/README.md +++ b/cloud_tpu_colabs/README.md @@ -4,7 +4,7 @@ The same JAX code that runs on CPU and GPU can also be run on TPU. Cloud TPUs have the advantage of quickly giving you access to multiple TPU accelerators, including in [Colab](https://research.google.com/colaboratory/). All of the example notebooks here use -[`jax.pmap`](https://jax.readthedocs.io/en/latest/jax.html#jax.pmap) to run JAX +[`jax.pmap`](https://docs.jax.dev/en/latest/jax.html#jax.pmap) to run JAX computation across multiple TPU cores from Colab. You can also run the same code directly on a [Cloud TPU VM](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm). diff --git a/docs/README.md b/docs/README.md index 12e00425592f..54b8a67477b0 100644 --- a/docs/README.md +++ b/docs/README.md @@ -1,2 +1,2 @@ To rebuild the documentation, -see [Update Documentation](https://jax.readthedocs.io/en/latest/developer.html#update-documentation). +see [Update Documentation](https://docs.jax.dev/en/latest/developer.html#update-documentation). diff --git a/docs/about.md b/docs/about.md index 58e1703842b9..baeed941c8c3 100644 --- a/docs/about.md +++ b/docs/about.md @@ -19,7 +19,7 @@ technology stack](#components). First, we design the `jax` module to be [composable](https://github.com/jax-ml/jax?tab=readme-ov-file#transformations) and -[extensible](https://jax.readthedocs.io/en/latest/jax.extend.html), so +[extensible](https://docs.jax.dev/en/latest/jax.extend.html), so that a wide variety of domain-specific libraries can thrive outside of it in a decentralized manner. Second, we lean heavily on a modular backend stack (compiler and runtime) to target different @@ -42,10 +42,10 @@ scale. JAX's day-to-day development takes place in the open on GitHub, using pull requests, the issue tracker, discussions, and [JAX Enhancement Proposals -(JEPs)](https://jax.readthedocs.io/en/latest/jep/index.html). Reading +(JEPs)](https://docs.jax.dev/en/latest/jep/index.html). Reading and participating in these is a good way to get involved. We also maintain [developer -notes](https://jax.readthedocs.io/en/latest/contributor_guide.html) +notes](https://docs.jax.dev/en/latest/contributor_guide.html) that cover JAX's internal design. The JAX core team determines whether to accept changes and @@ -56,7 +56,7 @@ intricate decision structure over time (e.g. with designated area owners) if/when it becomes useful to do so. For more see [contributing to -JAX](https://jax.readthedocs.io/en/latest/contributing.html). +JAX](https://docs.jax.dev/en/latest/contributing.html). (components)= ## A modular stack @@ -71,7 +71,7 @@ and (b) an advancing hardware landscape, we lean heavily on While the JAX core library focuses on the fundamentals, we want to encourage domain-specific libraries and tools to be built on top of JAX. Indeed, [many -libraries](https://jax.readthedocs.io/en/latest/#ecosystem) have +libraries](https://docs.jax.dev/en/latest/#ecosystem) have emerged around JAX to offer higher-level features and extensions. How do we encourage such decentralized development? We guide it with @@ -80,11 +80,11 @@ building blocks (e.g. numerical primitives, NumPy operations, arrays, and transformations), encouraging auxiliary libraries to develop utilities as needed for their domain. In addition, JAX exposes a handful of more advanced APIs for -[customization](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) +[customization](https://docs.jax.dev/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) and -[extensibility](https://jax.readthedocs.io/en/latest/jax.extend.html). Libraries +[extensibility](https://docs.jax.dev/en/latest/jax.extend.html). Libraries can [lean on these -APIs](https://jax.readthedocs.io/en/latest/building_on_jax.html) in +APIs](https://docs.jax.dev/en/latest/building_on_jax.html) in order to use JAX as an internal means of implementation, to integrate more with its transformations like autodiff, and more. diff --git a/docs/advanced-autodiff.md b/docs/advanced-autodiff.md index eaa3bc7317c8..bef2fd088a3a 100644 --- a/docs/advanced-autodiff.md +++ b/docs/advanced-autodiff.md @@ -876,7 +876,7 @@ There are two ways to define differentiation rules in JAX: 1. Using {func}`jax.custom_jvp` and {func}`jax.custom_vjp` to define custom differentiation rules for Python functions that are already JAX-transformable; and 2. Defining new `core.Primitive` instances along with all their transformation rules, for example to call into functions from other systems like solvers, simulators, or general numerical computing systems. -This notebook is about #1. To read instead about #2, refer to the [notebook on adding primitives](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html). +This notebook is about #1. To read instead about #2, refer to the [notebook on adding primitives](https://docs.jax.dev/en/latest/notebooks/How_JAX_primitives_work.html). ### TL;DR: Custom JVPs with {func}`jax.custom_jvp` @@ -1608,7 +1608,7 @@ Array(-0.91113025, dtype=float32) #### Working with `list` / `tuple` / `dict` containers (and other pytrees) -You should expect standard Python containers like lists, tuples, namedtuples, and dicts to just work, along with nested versions of those. In general, any [pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) are permissible, so long as their structures are consistent according to the type constraints. +You should expect standard Python containers like lists, tuples, namedtuples, and dicts to just work, along with nested versions of those. In general, any [pytrees](https://docs.jax.dev/en/latest/pytrees.html) are permissible, so long as their structures are consistent according to the type constraints. Here's a contrived example with {func}`jax.custom_jvp`: diff --git a/docs/aot.md b/docs/aot.md index 1fcf11ab945d..8f68c2758148 100644 --- a/docs/aot.md +++ b/docs/aot.md @@ -26,7 +26,7 @@ are arrays, JAX does the following in order: carries out this specialization by a process that we call _tracing_. During tracing, JAX stages the specialization of `F` to a jaxpr, which is a function in the [Jaxpr intermediate - language](https://jax.readthedocs.io/en/latest/jaxpr.html). + language](https://docs.jax.dev/en/latest/jaxpr.html). 2. **Lower** this specialized, staged-out computation to the XLA compiler's input language, StableHLO. diff --git a/docs/api_compatibility.md b/docs/api_compatibility.md index 749c5907bc6b..9dca1fc08f50 100644 --- a/docs/api_compatibility.md +++ b/docs/api_compatibility.md @@ -91,7 +91,7 @@ guarantees of the main JAX package. If you have code that uses `jax.extend`, we would strongly recommend CI tests against JAX's nightly releases, so as to catch potential changes before they are released. -For details on `jax.extend`, see the [`jax.extend` module docuementation](https://jax.readthedocs.io/en/latest/jax.extend.html), or the design document, {ref}`jax-extend-jep`. +For details on `jax.extend`, see the [`jax.extend` module docuementation](https://docs.jax.dev/en/latest/jax.extend.html), or the design document, {ref}`jax-extend-jep`. ## Numerics and randomness diff --git a/docs/autodidax.ipynb b/docs/autodidax.ipynb index 7ec91affa05d..b6f12b624f8b 100644 --- a/docs/autodidax.ipynb +++ b/docs/autodidax.ipynb @@ -72,7 +72,7 @@ "outputs, we want to override primitive application and let different values\n", "flow through our program. For example, we might want to replace the\n", "application of every primitive with an application of [its JVP\n", - "rule](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html),\n", + "rule](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html),\n", "and let primal-tangent pairs flow through our program. Moreover, we want to be\n", "able to compose multiple transformations, leading to stacks of interpreters." ] @@ -3620,7 +3620,7 @@ "source": [ "Notice that we're not currently supporting the case where the predicate value\n", "itself is batched. In mainline JAX, we handle this case by transforming the\n", - "conditional to a [select primitive](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html).\n", + "conditional to a [select primitive](https://docs.jax.dev/en/latest/_autosummary/jax.lax.select.html).\n", "That transformation is semantically correct so long as `true_fun` and\n", "`false_fun` do not involve any side-effecting primitives.\n", "\n", diff --git a/docs/autodidax.md b/docs/autodidax.md index 2d4d6cd528af..1c375e21227c 100644 --- a/docs/autodidax.md +++ b/docs/autodidax.md @@ -72,7 +72,7 @@ where we apply primitive operations to numerical inputs to produce numerical outputs, we want to override primitive application and let different values flow through our program. For example, we might want to replace the application of every primitive with an application of [its JVP -rule](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html), +rule](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html), and let primal-tangent pairs flow through our program. Moreover, we want to be able to compose multiple transformations, leading to stacks of interpreters. @@ -2843,7 +2843,7 @@ print(out) Notice that we're not currently supporting the case where the predicate value itself is batched. In mainline JAX, we handle this case by transforming the -conditional to a [select primitive](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html). +conditional to a [select primitive](https://docs.jax.dev/en/latest/_autosummary/jax.lax.select.html). That transformation is semantically correct so long as `true_fun` and `false_fun` do not involve any side-effecting primitives. diff --git a/docs/autodidax.py b/docs/autodidax.py index f8c6372fe30d..6329234224cb 100644 --- a/docs/autodidax.py +++ b/docs/autodidax.py @@ -62,7 +62,7 @@ # outputs, we want to override primitive application and let different values # flow through our program. For example, we might want to replace the # application of every primitive with an application of [its JVP -# rule](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html), +# rule](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html), # and let primal-tangent pairs flow through our program. Moreover, we want to be # able to compose multiple transformations, leading to stacks of interpreters. @@ -2837,7 +2837,7 @@ def cond_vmap_rule(axis_size, vals_in, dims_in, *, true_jaxpr, false_jaxpr): # Notice that we're not currently supporting the case where the predicate value # itself is batched. In mainline JAX, we handle this case by transforming the -# conditional to a [select primitive](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html). +# conditional to a [select primitive](https://docs.jax.dev/en/latest/_autosummary/jax.lax.select.html). # That transformation is semantically correct so long as `true_fun` and # `false_fun` do not involve any side-effecting primitives. # diff --git a/docs/building_on_jax.md b/docs/building_on_jax.md index 9416b16cde10..6d13f517f50b 100644 --- a/docs/building_on_jax.md +++ b/docs/building_on_jax.md @@ -45,8 +45,8 @@ Here are more specific examples of each pattern. ### Direct usage Jax can be directly imported and utilized to build models “from scratch” as shown across this website, -for example in [JAX Tutorials](https://jax.readthedocs.io/en/latest/tutorials.html) -or [Neural Network with JAX](https://jax.readthedocs.io/en/latest/notebooks/neural_network_with_tfds_data.html). +for example in [JAX Tutorials](https://docs.jax.dev/en/latest/tutorials.html) +or [Neural Network with JAX](https://docs.jax.dev/en/latest/notebooks/neural_network_with_tfds_data.html). This may be the best option if you are unable to find prebuilt code for your particular challenge, or if you're looking to reduce the number of dependencies in your codebase. diff --git a/docs/contributing.md b/docs/contributing.md index 99d78453c436..53a863fdcd8c 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -6,7 +6,7 @@ Everyone can contribute to JAX, and we value everyone's contributions. There are ways to contribute, including: - Answering questions on JAX's [discussions page](https://github.com/jax-ml/jax/discussions) -- Improving or expanding JAX's [documentation](http://jax.readthedocs.io/) +- Improving or expanding JAX's [documentation](http://docs.jax.dev/) - Contributing to JAX's [code-base](http://github.com/jax-ml/jax/) - Contributing in any of the above ways to the broader ecosystem of [libraries built on JAX](https://github.com/jax-ml/jax#neural-network-libraries) diff --git a/docs/control-flow.md b/docs/control-flow.md index 7cb959f3e434..8f59bd92add7 100644 --- a/docs/control-flow.md +++ b/docs/control-flow.md @@ -244,19 +244,19 @@ lax.cond(False, lambda x: x+1, lambda x: x-1, operand) `jax.lax` provides two other functions that allow branching on dynamic predicates: -- [`lax.select`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html) is +- [`lax.select`](https://docs.jax.dev/en/latest/_autosummary/jax.lax.select.html) is like a batched version of `lax.cond`, with the choices expressed as pre-computed arrays rather than as functions. -- [`lax.switch`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.switch.html) is +- [`lax.switch`](https://docs.jax.dev/en/latest/_autosummary/jax.lax.switch.html) is like `lax.cond`, but allows switching between any number of callable choices. In addition, `jax.numpy` provides several numpy-style interfaces to these functions: -- [`jnp.where`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.where.html) with +- [`jnp.where`](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.where.html) with three arguments is the numpy-style wrapper of `lax.select`. -- [`jnp.piecewise`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.piecewise.html) +- [`jnp.piecewise`](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.piecewise.html) is a numpy-style wrapper of `lax.switch`, but switches on a list of boolean conditions rather than a single scalar index. -- [`jnp.select`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.select.html) has +- [`jnp.select`](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.select.html) has an API similar to `jnp.piecewise`, but the choices are given as pre-computed arrays rather than as functions. It is implemented in terms of multiple calls to `lax.select`. diff --git a/docs/developer.md b/docs/developer.md index b1a978ffd0d6..9edeaeac83f8 100644 --- a/docs/developer.md +++ b/docs/developer.md @@ -789,7 +789,7 @@ desired formats, and which the `jupytext --sync` command recognizes when invoked #### Notebooks within the Sphinx build Some of the notebooks are built automatically as part of the pre-submit checks and -as part of the [Read the docs](https://jax.readthedocs.io/en/latest) build. +as part of the [Read the docs](https://docs.jax.dev/en/latest) build. The build will fail if cells raise errors. If the errors are intentional, you can either catch them, or tag the cell with `raises-exceptions` metadata ([example PR](https://github.com/jax-ml/jax/pull/2402/files)). You have to add this metadata by hand in the `.ipynb` file. It will be preserved when somebody else @@ -800,7 +800,7 @@ See `exclude_patterns` in [conf.py](https://github.com/jax-ml/jax/blob/main/docs ### Documentation building on `readthedocs.io` -JAX's auto-generated documentation is at . +JAX's auto-generated documentation is at . The documentation building is controlled for the entire project by the [readthedocs JAX settings](https://readthedocs.org/dashboard/jax). The current settings @@ -813,7 +813,7 @@ For each automated documentation build you can see the If you want to test the documentation generation on Readthedocs, you can push code to the `test-docs` branch. That branch is also built automatically, and you can -see the generated documentation [here](https://jax.readthedocs.io/en/test-docs/). If the documentation build +see the generated documentation [here](https://docs.jax.dev/en/test-docs/). If the documentation build fails you may want to [wipe the build environment for test-docs](https://docs.readthedocs.io/en/stable/guides/wipe-environment.html). For a local test, I was able to do it in a fresh directory by replaying the commands diff --git a/docs/export/export.md b/docs/export/export.md index 18cdcc6c51d0..63c0db14f905 100644 --- a/docs/export/export.md +++ b/docs/export/export.md @@ -161,7 +161,7 @@ e.g., the inference system.) What **matters is when the exporting and consuming components were built**, not the time when the exporting and the compilation happen. For external JAX users, it is -[possible to run JAX and jaxlib at different versions](https://jax.readthedocs.io/en/latest/jep/9419-jax-versioning.html#how-are-jax-and-jaxlib-versioned); +[possible to run JAX and jaxlib at different versions](https://docs.jax.dev/en/latest/jep/9419-jax-versioning.html#how-are-jax-and-jaxlib-versioned); what matters is when the jaxlib release was built. To reduce chances of incompatibility, internal JAX users should: diff --git a/docs/export/shape_poly.md b/docs/export/shape_poly.md index 9254030a4e1c..6b63a536ab48 100644 --- a/docs/export/shape_poly.md +++ b/docs/export/shape_poly.md @@ -86,7 +86,7 @@ matching the structure of the arguments passed to it. The polymorphic shapes specification can be a pytree prefix in cases where one specification should apply to multiple arguments, as in the above example. -See [how optional parameters are matched to arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees). +See [how optional parameters are matched to arguments](https://docs.jax.dev/en/latest/pytrees.html#applying-optional-parameters-to-pytrees). A few examples of shape specifications: @@ -609,7 +609,7 @@ Division had remainder 1 when computing the value of 'd'. Using the following polymorphic shapes specifications: args[0].shape = (b, b, 2*d). Obtained dimension variables: 'b' = 3 from specification 'b' for dimension args[0].shape[0] (= 3), . -Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details. +Please see https://docs.jax.dev/en/latest/export/shape_poly.html#shape-assertion-errors for more details. ``` diff --git a/docs/faq.rst b/docs/faq.rst index 44267f6f5f7d..f5d43d25afb6 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -4,7 +4,7 @@ Frequently asked questions (FAQ) .. comment RST primer for Sphinx: https://thomas-cokelaer.info/tutorials/sphinx/rest_syntax.html .. comment Some links referenced here. Use `JAX - The Sharp Bits`_ (underscore at the end) to reference -.. _JAX - The Sharp Bits: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html +.. _JAX - The Sharp Bits: https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html We are collecting answers to frequently asked questions here. Contributions welcome! @@ -116,7 +116,7 @@ code in JAX's internal representation, typically because it makes heavy use of Python control flow such as ``for`` loops. For a handful of loop iterations, Python is OK, but if you need *many* loop iterations, you should rewrite your code to make use of JAX's -`structured control flow primitives `_ +`structured control flow primitives `_ (such as :func:`lax.scan`) or avoid wrapping the loop with ``jit`` (you can still use ``jit`` decorated functions *inside* the loop). @@ -454,8 +454,8 @@ performing matrix-matrix multiplication) to amortize the increased overhead of JAX/accelerators vs NumPy/CPU. For example, if we switch this example to use 10x10 input instead, JAX/GPU runs 10x slower than NumPy/CPU (100 µs vs 10 µs). -.. _To JIT or not to JIT: https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#to-jit-or-not-to-jit -.. _Double (64 bit) precision: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision +.. _To JIT or not to JIT: https://docs.jax.dev/en/latest/notebooks/thinking_in_jax.html#to-jit-or-not-to-jit +.. _Double (64 bit) precision: https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision .. _`%time and %timeit magics`: https://ipython.readthedocs.io/en/stable/interactive/magics.html#magic-time .. _Colab: https://colab.research.google.com/ @@ -841,12 +841,12 @@ reducing :code:`XLA_PYTHON_CLIENT_MEM_FRACTION` from the default of :code:`.75`, or setting :code:`XLA_PYTHON_CLIENT_PREALLOCATE=false`. For more details, please see the page on `JAX GPU memory allocation`_. -.. _JIT mechanics: https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#jit-mechanics-tracing-and-static-variables -.. _External callbacks in JAX: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html -.. _Pure callback example: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html#example-pure-callback-with-custom-jvp -.. _IO callback example: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html#exploring-jax-experimental-io-callback +.. _JIT mechanics: https://docs.jax.dev/en/latest/notebooks/thinking_in_jax.html#jit-mechanics-tracing-and-static-variables +.. _External callbacks in JAX: https://docs.jax.dev/en/latest/notebooks/external_callbacks.html +.. _Pure callback example: https://docs.jax.dev/en/latest/notebooks/external_callbacks.html#example-pure-callback-with-custom-jvp +.. _IO callback example: https://docs.jax.dev/en/latest/notebooks/external_callbacks.html#exploring-jax-experimental-io-callback .. _Heaviside Step Function: https://en.wikipedia.org/wiki/Heaviside_step_function .. _Sigmoid Function: https://en.wikipedia.org/wiki/Sigmoid_function .. _algebraic_simplifier.cc: https://github.com/openxla/xla/blob/33f815e190982dac4f20d1f35adb98497a382377/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc#L4851 -.. _JAX GPU memory allocation: https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html +.. _JAX GPU memory allocation: https://docs.jax.dev/en/latest/gpu_memory_allocation.html .. _dynamic linker search pattern: https://man7.org/linux/man-pages/man8/ld.so.8.html diff --git a/docs/ffi.ipynb b/docs/ffi.ipynb index b622fba9d5bc..f74ae9d58a78 100644 --- a/docs/ffi.ipynb +++ b/docs/ffi.ipynb @@ -439,7 +439,7 @@ "As far as JAX is concerned, the foreign function is a black box that can't be inspected to determine the appropriate behavior when differentiated.\n", "Therefore, it is the {func}`~jax.ffi.ffi_call` user's responsibility to define a custom derivative rule.\n", "\n", - "More details about custom derivative rules can be found in the [custom derivatives tutorial](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html), but the most common pattern used for implementing differentiation for foreign functions is to define a {func}`~jax.custom_vjp` which itself calls a foreign function.\n", + "More details about custom derivative rules can be found in the [custom derivatives tutorial](https://docs.jax.dev/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html), but the most common pattern used for implementing differentiation for foreign functions is to define a {func}`~jax.custom_vjp` which itself calls a foreign function.\n", "In this case, we actually define two new FFI calls:\n", "\n", "1. `rms_norm_fwd` returns two outputs: (a) the \"primal\" result, and (b) the \"residuals\" which are used in the backwards pass.\n", @@ -785,7 +785,7 @@ "{func}`~jax.experimental.custom_partitioning.custom_partitioning` works by adding Python callbacks into the XLA compiler's partitioning pass, which allows very flexible logic, but also comes with some rough edges.\n", "We won't go into too much detail on the caveats here, but the main issues that you should be aware of are:\n", "\n", - "1. `custom_partitioning` can cause unexpected cache misses when used with the JAX's [Persistent compilation cache](https://jax.readthedocs.io/en/latest/persistent_compilation_cache.html). This can be mitigated using the `jax_remove_custom_partitioning_ptr_from_cache_key` configuration flag, but that isn't always appropriate either.\n", + "1. `custom_partitioning` can cause unexpected cache misses when used with the JAX's [Persistent compilation cache](https://docs.jax.dev/en/latest/persistent_compilation_cache.html). This can be mitigated using the `jax_remove_custom_partitioning_ptr_from_cache_key` configuration flag, but that isn't always appropriate either.\n", "2. Debugging `custom_partitioning` logic can be tedious because Python errors don't always get propagated, instead causing your Python process to exit. That being said, any exceptions will show up in the process logs, so you should be able to track them down there.\n", "\n", "All that being said, here's how we can wrap our FFI implementation of `rms_norm` using {func}`~jax.experimental.custom_partitioning.custom_partitioning`:" diff --git a/docs/ffi.md b/docs/ffi.md index 4aa03c217855..97648c78e118 100644 --- a/docs/ffi.md +++ b/docs/ffi.md @@ -353,7 +353,7 @@ Unlike with batching, {func}`~jax.ffi.ffi_call` doesn't provide any default supp As far as JAX is concerned, the foreign function is a black box that can't be inspected to determine the appropriate behavior when differentiated. Therefore, it is the {func}`~jax.ffi.ffi_call` user's responsibility to define a custom derivative rule. -More details about custom derivative rules can be found in the [custom derivatives tutorial](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html), but the most common pattern used for implementing differentiation for foreign functions is to define a {func}`~jax.custom_vjp` which itself calls a foreign function. +More details about custom derivative rules can be found in the [custom derivatives tutorial](https://docs.jax.dev/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html), but the most common pattern used for implementing differentiation for foreign functions is to define a {func}`~jax.custom_vjp` which itself calls a foreign function. In this case, we actually define two new FFI calls: 1. `rms_norm_fwd` returns two outputs: (a) the "primal" result, and (b) the "residuals" which are used in the backwards pass. @@ -591,7 +591,7 @@ If you can't use {func}`~jax.experimental.shard_map.shard_map`, an alternative a {func}`~jax.experimental.custom_partitioning.custom_partitioning` works by adding Python callbacks into the XLA compiler's partitioning pass, which allows very flexible logic, but also comes with some rough edges. We won't go into too much detail on the caveats here, but the main issues that you should be aware of are: -1. `custom_partitioning` can cause unexpected cache misses when used with the JAX's [Persistent compilation cache](https://jax.readthedocs.io/en/latest/persistent_compilation_cache.html). This can be mitigated using the `jax_remove_custom_partitioning_ptr_from_cache_key` configuration flag, but that isn't always appropriate either. +1. `custom_partitioning` can cause unexpected cache misses when used with the JAX's [Persistent compilation cache](https://docs.jax.dev/en/latest/persistent_compilation_cache.html). This can be mitigated using the `jax_remove_custom_partitioning_ptr_from_cache_key` configuration flag, but that isn't always appropriate either. 2. Debugging `custom_partitioning` logic can be tedious because Python errors don't always get propagated, instead causing your Python process to exit. That being said, any exceptions will show up in the process logs, so you should be able to track them down there. All that being said, here's how we can wrap our FFI implementation of `rms_norm` using {func}`~jax.experimental.custom_partitioning.custom_partitioning`: diff --git a/docs/gpu_memory_allocation.rst b/docs/gpu_memory_allocation.rst index 6667589e7b72..be40dfc8004c 100644 --- a/docs/gpu_memory_allocation.rst +++ b/docs/gpu_memory_allocation.rst @@ -69,7 +69,7 @@ Common causes of OOM failures disabling the automatic remat pass produces different trade-offs between compute and memory. Note however, that the algorithm is basic and you can often get better trade-off between compute and memory by disabling the automatic remat pass and doing - it manually with `the jax.remat API `_ + it manually with `the jax.remat API `_ Experimental features diff --git a/docs/installation.md b/docs/installation.md index ee675dd1e586..34274d7596aa 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -229,7 +229,7 @@ refer to JAX has experimental ROCm support. There are two ways to install JAX: * Use [AMD's Docker container](https://hub.docker.com/r/rocm/jax-community/tags); or -* Build from source. Refer to the section [Additional notes for building a ROCm jaxlib for AMD GPUs](https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-a-rocm-jaxlib-for-amd-gpus). +* Build from source. Refer to the section [Additional notes for building a ROCm jaxlib for AMD GPUs](https://docs.jax.dev/en/latest/developer.html#additional-notes-for-building-a-rocm-jaxlib-for-amd-gpus). (install-intel-gpu)= ## Intel GPU diff --git a/docs/jax-primitives.md b/docs/jax-primitives.md index 38a45ef4823e..819d0418e894 100644 --- a/docs/jax-primitives.md +++ b/docs/jax-primitives.md @@ -300,7 +300,7 @@ def multiply_add_lowering(ctx, xc, yc, zc): return [hlo.AddOp(hlo.MulOp(xc, yc), zc).result] # Now, register the lowering rule with JAX. -# For GPU, refer to the https://jax.readthedocs.io/en/latest/Custom_Operation_for_GPUs.html +# For GPU, refer to the https://docs.jax.dev/en/latest/Custom_Operation_for_GPUs.html from jax.interpreters import mlir mlir.register_lowering(multiply_add_p, multiply_add_lowering, platform='cpu') diff --git a/docs/jax_array_migration.md b/docs/jax_array_migration.md index a557f4ae7efc..3cc1629b2068 100644 --- a/docs/jax_array_migration.md +++ b/docs/jax_array_migration.md @@ -27,7 +27,7 @@ the unified jax.Array After the migration is complete `jax.Array` will be the only type of array in JAX. -This doc explains how to migrate existing codebases to `jax.Array`. For more information on using `jax.Array` and JAX parallelism APIs, see the [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) tutorial. +This doc explains how to migrate existing codebases to `jax.Array`. For more information on using `jax.Array` and JAX parallelism APIs, see the [Distributed arrays and automatic parallelization](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) tutorial. ### How to enable jax.Array? diff --git a/docs/jep/10657-sequencing-effects.md b/docs/jep/10657-sequencing-effects.md index 5f7eb0da4c04..ac3024519101 100644 --- a/docs/jep/10657-sequencing-effects.md +++ b/docs/jep/10657-sequencing-effects.md @@ -47,7 +47,7 @@ g() In many cases, JAX will execute `f` and `g` *in parallel*, dispatching the computations onto different threads -- `g` might actually be executed before `f`. Parallel execution is a nice performance optimization, especially if copying -to and from a device is expensive (see the [asynchronous dispatch note](https://jax.readthedocs.io/en/latest/async_dispatch.html) for more details). +to and from a device is expensive (see the [asynchronous dispatch note](https://docs.jax.dev/en/latest/async_dispatch.html) for more details). In practice, however, we often don't need to think about asynchronous dispatch because we're writing pure functions and only care about the inputs and outputs of functions -- we'll naturally block on future diff --git a/docs/jep/12049-type-annotations.md b/docs/jep/12049-type-annotations.md index 5ed760dd6c5c..bf6123b2bc7f 100644 --- a/docs/jep/12049-type-annotations.md +++ b/docs/jep/12049-type-annotations.md @@ -35,7 +35,7 @@ def slice(operand: Array, start_indices: Sequence[int], For the purposes of static type checking, this use of `Array = Any` for array type annotations puts no constraint on the argument values (`Any` is equivalent to no annotation at all), but it does serve as a form of useful in-code documentation for the developer. -For the sake of generated documentation, the name of the alias gets lost (the [HTML docs](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.slice.html) for `jax.lax.slice` report operand as type `Any`), so the documentation benefit does not go beyond the source code (though we could enable some `sphinx-autodoc` options to improve this: See [autodoc_type_aliases](https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#confval-autodoc_type_aliases)). +For the sake of generated documentation, the name of the alias gets lost (the [HTML docs](https://docs.jax.dev/en/latest/_autosummary/jax.lax.slice.html) for `jax.lax.slice` report operand as type `Any`), so the documentation benefit does not go beyond the source code (though we could enable some `sphinx-autodoc` options to improve this: See [autodoc_type_aliases](https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#confval-autodoc_type_aliases)). A benefit of this level of type annotation is that it is never wrong to annotate a value with `Any`, so it will provide a concrete benefit to developers and users in the form of documentation, without added complexity of satisfying the stricter needs of any particular static type checker. diff --git a/docs/jep/14273-shard-map.md b/docs/jep/14273-shard-map.md index 63742bc852c6..fa6681551d17 100644 --- a/docs/jep/14273-shard-map.md +++ b/docs/jep/14273-shard-map.md @@ -4,7 +4,7 @@ *January 2023* **This was the design doc proposing `shard_map`. You may instead want -[the up-to-date user docs](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html).** +[the up-to-date user docs](https://docs.jax.dev/en/latest/notebooks/shard_map.html).** ## Motivation @@ -18,7 +18,7 @@ We need great APIs for both, and rather than being mutually exclusive alternatives, they need to compose with each other. With `pjit` (now just `jit`) we have [a next-gen -API](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) +API](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) for the first school. But we haven't quite leveled-up the second school. `pmap` follows the second school, but over time we found it has [fatal flaws](#why-dont-pmap-or-xmap-already-solve-this). `xmap` solved those flaws, diff --git a/docs/jep/15856-jex.md b/docs/jep/15856-jex.md index a5625abf8930..a821405c399e 100644 --- a/docs/jep/15856-jex.md +++ b/docs/jep/15856-jex.md @@ -14,13 +14,13 @@ import jax.extend as jex Several projects depend on JAX's codebase internals, often to use its core machinery (e.g. to write a -[transformation over its IR](https://jax.readthedocs.io/en/latest/notebooks/Writing_custom_interpreters_in_Jax.html)) +[transformation over its IR](https://docs.jax.dev/en/latest/notebooks/Writing_custom_interpreters_in_Jax.html)) or to extend it (e.g. to [define new primitives](https://github.com/dfm/extending-jax)). Two challenges for these dependencies are (a) that our internals aren't all solidly designed for external use, and (b) that circumventing JAX's public API is -[unsupported](https://jax.readthedocs.io/en/latest/api_compatibility.html). +[unsupported](https://docs.jax.dev/en/latest/api_compatibility.html). In other words, our internals are often used like a library, but are neither structured nor updated like one. @@ -50,12 +50,12 @@ removed altogether. To keep development overhead low, `jax.extend` would not follow the public -[API compatibility](https://jax.readthedocs.io/en/latest/api_compatibility.html) +[API compatibility](https://docs.jax.dev/en/latest/api_compatibility.html) policy. It would promise no deprecation windows nor backwards compatibility between releases. Every release may break existing callers without simple recourse (e.g. without a flag reintroducing prior behavior). We would rely on the -[changelog](https://jax.readthedocs.io/en/latest/changelog.html) +[changelog](https://docs.jax.dev/en/latest/changelog.html) to call out such changes. Callers of `jax.extend` that need to upgrade their code regularly @@ -108,7 +108,7 @@ to process the Jaxpr IR (the output of At initialization, this module will contain many more symbols than what's needed to define primitives and rules, including various names used in setting up -["final-style transformations"](https://jax.readthedocs.io/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing), +["final-style transformations"](https://docs.jax.dev/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing), such as the current `jax._src.core.Trace` and `Tracer` classes. We can revisit whether `jex.core` should also support final-style extensions alongside initial style approaches, and whether it can do so by a more @@ -137,7 +137,7 @@ tracer types from `jex`. This module plus `jex.core` ought to suffice for replicating today's custom primitive tutorials (e.g. -[ours](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html) +[ours](https://docs.jax.dev/en/latest/notebooks/How_JAX_primitives_work.html) and [dfm's](https://github.com/dfm/extending-jax)). For instance, defining a primitive and its behavior under `jax.jit` @@ -184,6 +184,6 @@ arrays. We have only one item in mind for now. The XLA compiler's array sharding format is more expressive than [those provided by -JAX](https://jax.readthedocs.io/en/latest/jax.sharding.html). We could +JAX](https://docs.jax.dev/en/latest/jax.sharding.html). We could provide this as `jex.sharding.XlaOpShardingProto`, corresponding to today's `jax._src.lib.xla_client.OpSharding` internally. diff --git a/docs/jep/17111-shmap-transpose.md b/docs/jep/17111-shmap-transpose.md index 2fdf5f822835..00d8a3f383fd 100644 --- a/docs/jep/17111-shmap-transpose.md +++ b/docs/jep/17111-shmap-transpose.md @@ -497,7 +497,7 @@ of every function instance along which the outputs are mapped, whereas for mesh axes over which the output is unmapped only one copy of the value is used. See [the `shmap` -JEP](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html) for examples +JEP](https://docs.jax.dev/en/latest/jep/14273-shard-map.html) for examples of unmapped inputs and outputs. For comparison, in `vmap` unmapped inputs/outputs are indicated by using `in_axes` / `out_axes` of `None` (rather than an `int`). diff --git a/docs/jep/2026-custom-derivatives.md b/docs/jep/2026-custom-derivatives.md index ce149fa6fb35..b09926425667 100644 --- a/docs/jep/2026-custom-derivatives.md +++ b/docs/jep/2026-custom-derivatives.md @@ -2,7 +2,7 @@ This is a design document, explaining some of the thinking behind the design and implementation of `jax.custom_jvp` and `jax.custom_vjp`. For user-oriented -documentation, see [the tutorial notebook](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html). +documentation, see [the tutorial notebook](https://docs.jax.dev/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html). There are two ways to define differentiation rules in JAX: 1. using `jax.custom_jvp` and `jax.custom_vjp` to define custom differentiation diff --git a/docs/jep/4008-custom-vjp-update.md b/docs/jep/4008-custom-vjp-update.md index 1e2270e052a6..c3f2be151ef7 100644 --- a/docs/jep/4008-custom-vjp-update.md +++ b/docs/jep/4008-custom-vjp-update.md @@ -4,7 +4,7 @@ _Oct 14 2020_ This doc assumes familiarity with `jax.custom_vjp`, as described in the [Custom derivative rules for JAX-transformable Python -functions](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) +functions](https://docs.jax.dev/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) notebook. ## What to update diff --git a/docs/jep/4410-omnistaging.md b/docs/jep/4410-omnistaging.md index f95c15f404b6..5b4536864ac2 100644 --- a/docs/jep/4410-omnistaging.md +++ b/docs/jep/4410-omnistaging.md @@ -266,7 +266,7 @@ While tracing the function ex1 at ex1.py:4, this value became a tracer due to JA You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions. -See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information. +See https://docs.jax.dev/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information. Encountered tracer value: Tracedwith ``` diff --git a/docs/jep/9407-type-promotion.ipynb b/docs/jep/9407-type-promotion.ipynb index a1ede3177a3a..5f12877c97a9 100644 --- a/docs/jep/9407-type-promotion.ipynb +++ b/docs/jep/9407-type-promotion.ipynb @@ -12,7 +12,7 @@ "\n", "*Jake VanderPlas, December 2021*\n", "\n", - "One of the challenges faced in the design of any numerical computing library is the choice of how to handle operations between values of different types. This document outlines the thought process behind the promotion semantics used by JAX, summarized in [JAX Type Promotion Semantics](https://jax.readthedocs.io/en/latest/type_promotion.html)." + "One of the challenges faced in the design of any numerical computing library is the choice of how to handle operations between values of different types. This document outlines the thought process behind the promotion semantics used by JAX, summarized in [JAX Type Promotion Semantics](https://docs.jax.dev/en/latest/type_promotion.html)." ] }, { @@ -1335,7 +1335,7 @@ "However, these advantages comes with a few tradeoffs:\n", "\n", "- mixed float/integer promotion is very prone to precision loss: for example, `int64` (with a maximum value of $9.2 \\times 10^{18}$) can be promoted to `float16` (with a maximum value of $6.5 \\times 10^4$), meaning most representable values will become `inf`.\n", - "- as mentioned above, `f*` can no longer be thought of as a \"scalar type\", but as a different flavor of float64. In JAX's parlance, this is referred to as a [*weak type*](https://jax.readthedocs.io/en/latest/type_promotion.html#weakly-typed-values-in-jax), in that it is represented as 64-bit, but only weakly holds to this bit width in promotion with other values.\n", + "- as mentioned above, `f*` can no longer be thought of as a \"scalar type\", but as a different flavor of float64. In JAX's parlance, this is referred to as a [*weak type*](https://docs.jax.dev/en/latest/type_promotion.html#weakly-typed-values-in-jax), in that it is represented as 64-bit, but only weakly holds to this bit width in promotion with other values.\n", "\n", "Note that also, this approach still leaves the `uint64` promotion question unanswered, although it is perhaps reasonable to close the lattice by connecting `u64` to `f*`." ] @@ -1413,7 +1413,7 @@ "id": "o0-E2KWjYEXO" }, "source": [ - "The behavior resulting from this choice is summarized in [JAX Type Promotion Semantics](https://jax.readthedocs.io/en/latest/type_promotion.html). Notably, aside from the inclusion of larger unsigned types (`u16`, `u32`, `u64`) and some details about the behavior of scalar/weak types (`i*`, `f*`, `c*`), this type promotion scheme turns out to be very close to that chosen by PyTorch.\n", + "The behavior resulting from this choice is summarized in [JAX Type Promotion Semantics](https://docs.jax.dev/en/latest/type_promotion.html). Notably, aside from the inclusion of larger unsigned types (`u16`, `u32`, `u64`) and some details about the behavior of scalar/weak types (`i*`, `f*`, `c*`), this type promotion scheme turns out to be very close to that chosen by PyTorch.\n", "\n", "For those interested, the appendix below prints the full promotion tables used by NumPy, Tensorflow, PyTorch, and JAX." ] @@ -2883,7 +2883,7 @@ "source": [ "### JAX Type Promotion: `jax.numpy`\n", "\n", - "`jax.numpy` follows type promotion rules laid out at https://jax.readthedocs.io/en/latest/type_promotion.html. Here we use `i*`, `f*`, `c*` to indicate both Python scalars and weakly-typed arrays." + "`jax.numpy` follows type promotion rules laid out at https://docs.jax.dev/en/latest/type_promotion.html. Here we use `i*`, `f*`, `c*` to indicate both Python scalars and weakly-typed arrays." ] }, { diff --git a/docs/jep/9407-type-promotion.md b/docs/jep/9407-type-promotion.md index ff67a8c21399..c047d76c1b18 100644 --- a/docs/jep/9407-type-promotion.md +++ b/docs/jep/9407-type-promotion.md @@ -20,7 +20,7 @@ kernelspec: *Jake VanderPlas, December 2021* -One of the challenges faced in the design of any numerical computing library is the choice of how to handle operations between values of different types. This document outlines the thought process behind the promotion semantics used by JAX, summarized in [JAX Type Promotion Semantics](https://jax.readthedocs.io/en/latest/type_promotion.html). +One of the challenges faced in the design of any numerical computing library is the choice of how to handle operations between values of different types. This document outlines the thought process behind the promotion semantics used by JAX, summarized in [JAX Type Promotion Semantics](https://docs.jax.dev/en/latest/type_promotion.html). +++ {"id": "Rod6OOyUVbQ8"} @@ -680,7 +680,7 @@ This is important because `f16` and `bf16` are not comparable because they utili However, these advantages comes with a few tradeoffs: - mixed float/integer promotion is very prone to precision loss: for example, `int64` (with a maximum value of $9.2 \times 10^{18}$) can be promoted to `float16` (with a maximum value of $6.5 \times 10^4$), meaning most representable values will become `inf`. -- as mentioned above, `f*` can no longer be thought of as a "scalar type", but as a different flavor of float64. In JAX's parlance, this is referred to as a [*weak type*](https://jax.readthedocs.io/en/latest/type_promotion.html#weakly-typed-values-in-jax), in that it is represented as 64-bit, but only weakly holds to this bit width in promotion with other values. +- as mentioned above, `f*` can no longer be thought of as a "scalar type", but as a different flavor of float64. In JAX's parlance, this is referred to as a [*weak type*](https://docs.jax.dev/en/latest/type_promotion.html#weakly-typed-values-in-jax), in that it is represented as 64-bit, but only weakly holds to this bit width in promotion with other values. Note that also, this approach still leaves the `uint64` promotion question unanswered, although it is perhaps reasonable to close the lattice by connecting `u64` to `f*`. @@ -730,7 +730,7 @@ nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos +++ {"id": "o0-E2KWjYEXO"} -The behavior resulting from this choice is summarized in [JAX Type Promotion Semantics](https://jax.readthedocs.io/en/latest/type_promotion.html). Notably, aside from the inclusion of larger unsigned types (`u16`, `u32`, `u64`) and some details about the behavior of scalar/weak types (`i*`, `f*`, `c*`), this type promotion scheme turns out to be very close to that chosen by PyTorch. +The behavior resulting from this choice is summarized in [JAX Type Promotion Semantics](https://docs.jax.dev/en/latest/type_promotion.html). Notably, aside from the inclusion of larger unsigned types (`u16`, `u32`, `u64`) and some details about the behavior of scalar/weak types (`i*`, `f*`, `c*`), this type promotion scheme turns out to be very close to that chosen by PyTorch. For those interested, the appendix below prints the full promotion tables used by NumPy, Tensorflow, PyTorch, and JAX. @@ -900,7 +900,7 @@ display.HTML(table.to_html()) ### JAX Type Promotion: `jax.numpy` -`jax.numpy` follows type promotion rules laid out at https://jax.readthedocs.io/en/latest/type_promotion.html. Here we use `i*`, `f*`, `c*` to indicate both Python scalars and weakly-typed arrays. +`jax.numpy` follows type promotion rules laid out at https://docs.jax.dev/en/latest/type_promotion.html. Here we use `i*`, `f*`, `c*` to indicate both Python scalars and weakly-typed arrays. ```{code-cell} :cellView: form diff --git a/docs/jit-compilation.md b/docs/jit-compilation.md index 5e5be308068a..093f5ec4ab72 100644 --- a/docs/jit-compilation.md +++ b/docs/jit-compilation.md @@ -55,7 +55,7 @@ The {ref}`jax-internals-jaxpr` section of the documentation provides more inform Importantly, notice that the jaxpr does not capture the side-effect present in the function: there is nothing in it corresponding to `global_list.append(x)`. This is a feature, not a bug: JAX transformations are designed to understand side-effect-free (a.k.a. functionally pure) code. -If *pure function* and *side-effect* are unfamiliar terms, this is explained in a little more detail in [🔪 JAX - The Sharp Bits 🔪: Pure Functions](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions). +If *pure function* and *side-effect* are unfamiliar terms, this is explained in a little more detail in [🔪 JAX - The Sharp Bits 🔪: Pure Functions](https:docs.jax.devio/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions). Impure functions are dangerous because under JAX transformations they are likely not to behave as intended; they might fail silently, or produce surprising downstream errors like leaked Tracers. Moreover, JAX often can't detect when side effects are present. diff --git a/docs/notebooks/Common_Gotchas_in_JAX.ipynb b/docs/notebooks/Common_Gotchas_in_JAX.ipynb index a1435c4e557e..de6da98b7d62 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.ipynb +++ b/docs/notebooks/Common_Gotchas_in_JAX.ipynb @@ -346,7 +346,7 @@ "evalue": "ignored", "output_type": "error", "traceback": [ - "\u001b[0;31mTypeError\u001b[0m\u001b[0;31m:\u001b[0m '' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html\n" + "\u001b[0;31mTypeError\u001b[0m\u001b[0;31m:\u001b[0m '' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html\n" ] } ], @@ -365,7 +365,7 @@ "source": [ "Allowing mutation of variables in-place makes program analysis and transformation difficult. JAX requires that programs are pure functions.\n", "\n", - "Instead, JAX offers a _functional_ array update using the [`.at` property on JAX arrays](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at)." + "Instead, JAX offers a _functional_ array update using the [`.at` property on JAX arrays](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at)." ] }, { @@ -521,7 +521,7 @@ "id": "sTjJ3WuaDyqU" }, "source": [ - "For more details on indexed array updates, see the [documentation for the `.at` property](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at)." + "For more details on indexed array updates, see the [documentation for the `.at` property](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at)." ] }, { @@ -604,7 +604,7 @@ "id": "NAcXJNAcDi_v" }, "source": [ - "If you would like finer-grained control over the behavior for out-of-bound indices, you can use the optional parameters of [`ndarray.at`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html); for example:" + "If you would like finer-grained control over the behavior for out-of-bound indices, you can use the optional parameters of [`ndarray.at`](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html); for example:" ] }, { @@ -971,7 +971,7 @@ "evalue": "ignored", "output_type": "error", "traceback": [ - "\u001b[0;31mNonConcreteBooleanIndexError\u001b[0m\u001b[0;31m:\u001b[0m Array boolean indices must be concrete; got ShapedArray(bool[5])\n\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError\n" + "\u001b[0;31mNonConcreteBooleanIndexError\u001b[0m\u001b[0;31m:\u001b[0m Array boolean indices must be concrete; got ShapedArray(bool[5])\n\nSee https://docs.jax.dev/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError\n" ] } ], @@ -1296,7 +1296,7 @@ "While `jax.numpy` makes every attempt to replicate the behavior of numpy's API, there do exist corner cases where the behaviors differ.\n", "Many such cases are discussed in detail in the sections above; here we list several other known places where the APIs diverge.\n", "\n", - "- For binary operations, JAX's type promotion rules differ somewhat from those used by NumPy. See [Type Promotion Semantics](https://jax.readthedocs.io/en/latest/type_promotion.html) for more details.\n", + "- For binary operations, JAX's type promotion rules differ somewhat from those used by NumPy. See [Type Promotion Semantics](https://docs.jax.dev/en/latest/type_promotion.html) for more details.\n", "- When performing unsafe type casts (i.e. casts in which the target dtype cannot represent the input value), JAX's behavior may be backend dependent, and in general may diverge from NumPy's behavior. Numpy allows control over the result in these scenarios via the `casting` argument (see [`np.ndarray.astype`](https://numpy.org/devdocs/reference/generated/numpy.ndarray.astype.html)); JAX does not provide any such configuration, instead directly inheriting the behavior of [XLA:ConvertElementType](https://www.tensorflow.org/xla/operation_semantics#convertelementtype).\n", "\n", " Here is an example of an unsafe cast with differing results between NumPy and JAX:\n", diff --git a/docs/notebooks/Common_Gotchas_in_JAX.md b/docs/notebooks/Common_Gotchas_in_JAX.md index 80ab69be1ed8..9fbc26a46c8f 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.md +++ b/docs/notebooks/Common_Gotchas_in_JAX.md @@ -201,7 +201,7 @@ jax_array[1, :] = 1.0 Allowing mutation of variables in-place makes program analysis and transformation difficult. JAX requires that programs are pure functions. -Instead, JAX offers a _functional_ array update using the [`.at` property on JAX arrays](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at). +Instead, JAX offers a _functional_ array update using the [`.at` property on JAX arrays](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at). +++ {"id": "hfloZ1QXCS_J"} @@ -261,7 +261,7 @@ print(new_jax_array) +++ {"id": "sTjJ3WuaDyqU"} -For more details on indexed array updates, see the [documentation for the `.at` property](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at). +For more details on indexed array updates, see the [documentation for the `.at` property](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at). +++ {"id": "oZ_jE2WAypdL"} @@ -292,7 +292,7 @@ jnp.arange(10)[11] +++ {"id": "NAcXJNAcDi_v"} -If you would like finer-grained control over the behavior for out-of-bound indices, you can use the optional parameters of [`ndarray.at`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html); for example: +If you would like finer-grained control over the behavior for out-of-bound indices, you can use the optional parameters of [`ndarray.at`](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html); for example: ```{code-cell} ipython3 :id: -0-MaFddO-xy @@ -664,7 +664,7 @@ x.dtype # --> dtype('float64') While `jax.numpy` makes every attempt to replicate the behavior of numpy's API, there do exist corner cases where the behaviors differ. Many such cases are discussed in detail in the sections above; here we list several other known places where the APIs diverge. -- For binary operations, JAX's type promotion rules differ somewhat from those used by NumPy. See [Type Promotion Semantics](https://jax.readthedocs.io/en/latest/type_promotion.html) for more details. +- For binary operations, JAX's type promotion rules differ somewhat from those used by NumPy. See [Type Promotion Semantics](https://docs.jax.dev/en/latest/type_promotion.html) for more details. - When performing unsafe type casts (i.e. casts in which the target dtype cannot represent the input value), JAX's behavior may be backend dependent, and in general may diverge from NumPy's behavior. Numpy allows control over the result in these scenarios via the `casting` argument (see [`np.ndarray.astype`](https://numpy.org/devdocs/reference/generated/numpy.ndarray.astype.html)); JAX does not provide any such configuration, instead directly inheriting the behavior of [XLA:ConvertElementType](https://www.tensorflow.org/xla/operation_semantics#convertelementtype). Here is an example of an unsafe cast with differing results between NumPy and JAX: diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb index e550cbf36da3..e80c7ae94687 100644 --- a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb +++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb @@ -17,9 +17,9 @@ "1. using `jax.custom_jvp` and `jax.custom_vjp` to define custom differentiation rules for Python functions that are already JAX-transformable; and\n", "2. defining new `core.Primitive` instances along with all their transformation rules, for example to call into functions from other systems like solvers, simulators, or general numerical computing systems.\n", "\n", - "This notebook is about #1. To read instead about #2, see the [notebook on adding primitives](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html).\n", + "This notebook is about #1. To read instead about #2, see the [notebook on adding primitives](https://docs.jax.dev/en/latest/notebooks/How_JAX_primitives_work.html).\n", "\n", - "For an introduction to JAX's automatic differentiation API, see [The Autodiff Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html). This notebook assumes some familiarity with [jax.jvp](https://jax.readthedocs.io/en/latest/jax.html#jax.jvp) and [jax.grad](https://jax.readthedocs.io/en/latest/jax.html#jax.grad), and the mathematical meaning of JVPs and VJPs." + "For an introduction to JAX's automatic differentiation API, see [The Autodiff Cookbook](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html). This notebook assumes some familiarity with [jax.jvp](https://docs.jax.dev/en/latest/jax.html#jax.jvp) and [jax.grad](https://docs.jax.dev/en/latest/jax.html#jax.grad), and the mathematical meaning of JVPs and VJPs." ] }, { @@ -2035,7 +2035,7 @@ "source": [ "### Working with `list` / `tuple` / `dict` containers (and other pytrees)\n", "\n", - "You should expect standard Python containers like lists, tuples, namedtuples, and dicts to just work, along with nested versions of those. In general, any [pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) are permissible, so long as their structures are consistent according to the type constraints. \n", + "You should expect standard Python containers like lists, tuples, namedtuples, and dicts to just work, along with nested versions of those. In general, any [pytrees](https://docs.jax.dev/en/latest/pytrees.html) are permissible, so long as their structures are consistent according to the type constraints. \n", "\n", "Here's a contrived example with `jax.custom_jvp`:" ] diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.md b/docs/notebooks/Custom_derivative_rules_for_Python_code.md index 8a63f142693e..82b97e195bd9 100644 --- a/docs/notebooks/Custom_derivative_rules_for_Python_code.md +++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.md @@ -24,9 +24,9 @@ There are two ways to define differentiation rules in JAX: 1. using `jax.custom_jvp` and `jax.custom_vjp` to define custom differentiation rules for Python functions that are already JAX-transformable; and 2. defining new `core.Primitive` instances along with all their transformation rules, for example to call into functions from other systems like solvers, simulators, or general numerical computing systems. -This notebook is about #1. To read instead about #2, see the [notebook on adding primitives](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html). +This notebook is about #1. To read instead about #2, see the [notebook on adding primitives](https://docs.jax.dev/en/latest/notebooks/How_JAX_primitives_work.html). -For an introduction to JAX's automatic differentiation API, see [The Autodiff Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html). This notebook assumes some familiarity with [jax.jvp](https://jax.readthedocs.io/en/latest/jax.html#jax.jvp) and [jax.grad](https://jax.readthedocs.io/en/latest/jax.html#jax.grad), and the mathematical meaning of JVPs and VJPs. +For an introduction to JAX's automatic differentiation API, see [The Autodiff Cookbook](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html). This notebook assumes some familiarity with [jax.jvp](https://docs.jax.dev/en/latest/jax.html#jax.jvp) and [jax.grad](https://docs.jax.dev/en/latest/jax.html#jax.grad), and the mathematical meaning of JVPs and VJPs. +++ {"id": "9Fg3NFNY-2RY"} @@ -1048,7 +1048,7 @@ Array(-0.91113025, dtype=float32) ### Working with `list` / `tuple` / `dict` containers (and other pytrees) -You should expect standard Python containers like lists, tuples, namedtuples, and dicts to just work, along with nested versions of those. In general, any [pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) are permissible, so long as their structures are consistent according to the type constraints. +You should expect standard Python containers like lists, tuples, namedtuples, and dicts to just work, along with nested versions of those. In general, any [pytrees](https://docs.jax.dev/en/latest/pytrees.html) are permissible, so long as their structures are consistent according to the type constraints. Here's a contrived example with `jax.custom_jvp`: diff --git a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb index 8abee469d552..90d92c4ea241 100644 --- a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb +++ b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb @@ -1276,7 +1276,7 @@ "id": "3qfPjJdhgerc" }, "source": [ - "So computation follows data placement: when we explicitly shard data with `jax.device_put`, and apply functions to that data, the compiler attempts to parallelize the computation and decide the output sharding. This policy for sharded data is a generalization of [JAX's policy of following explicit device placement](https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices)." + "So computation follows data placement: when we explicitly shard data with `jax.device_put`, and apply functions to that data, the compiler attempts to parallelize the computation and decide the output sharding. This policy for sharded data is a generalization of [JAX's policy of following explicit device placement](https://docs.jax.dev/en/latest/faq.html#controlling-data-and-computation-placement-on-devices)." ] }, { @@ -1382,7 +1382,7 @@ "id": "6ZYcK8eXrn0p" }, "source": [ - "We say arrays that have been explicitly placed or sharded with `jax.device_put` are _committed_ to their device(s), and so won't be automatically moved. See the [device placement FAQ](https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices) for more information.\n", + "We say arrays that have been explicitly placed or sharded with `jax.device_put` are _committed_ to their device(s), and so won't be automatically moved. See the [device placement FAQ](https://docs.jax.dev/en/latest/faq.html#controlling-data-and-computation-placement-on-devices) for more information.\n", "\n", "When arrays are _not_ explicitly placed or sharded with `jax.device_put`, they are placed _uncommitted_ on the default device.\n", "Unlike committed arrays, uncommitted arrays can be moved and resharded automatically: that is, uncommitted arrays can be arguments to a computation even if other arguments are explicitly placed on different devices.\n", @@ -2339,7 +2339,7 @@ "source": [ "### Generating random numbers\n", "\n", - "JAX comes with a functional, deterministic [random number generator](https://jax.readthedocs.io/en/latest/jep/263-prng.html). It underlies the various sampling functions in the [`jax.random` module](https://jax.readthedocs.io/en/latest/jax.random.html), such as `jax.random.uniform`.\n", + "JAX comes with a functional, deterministic [random number generator](https://docs.jax.dev/en/latest/jep/263-prng.html). It underlies the various sampling functions in the [`jax.random` module](https://docs.jax.dev/en/latest/jax.random.html), such as `jax.random.uniform`.\n", "\n", "JAX's random numbers are produced by a counter-based PRNG, so in principle, random number generation should be a pure map over counter values. A pure map is a trivially partitionable operation in principle. It should require no cross-device communication, nor any redundant computation across devices.\n", "\n", diff --git a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md index c207f0ae4a00..79990fefb95d 100644 --- a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md +++ b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md @@ -427,7 +427,7 @@ jax.debug.visualize_array_sharding(w_copy) +++ {"id": "3qfPjJdhgerc"} -So computation follows data placement: when we explicitly shard data with `jax.device_put`, and apply functions to that data, the compiler attempts to parallelize the computation and decide the output sharding. This policy for sharded data is a generalization of [JAX's policy of following explicit device placement](https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices). +So computation follows data placement: when we explicitly shard data with `jax.device_put`, and apply functions to that data, the compiler attempts to parallelize the computation and decide the output sharding. This policy for sharded data is a generalization of [JAX's policy of following explicit device placement](https://docs.jax.dev/en/latest/faq.html#controlling-data-and-computation-placement-on-devices). +++ {"id": "QRB95LaWuT80"} @@ -484,7 +484,7 @@ except ValueError as e: print_exception(e) +++ {"id": "6ZYcK8eXrn0p"} -We say arrays that have been explicitly placed or sharded with `jax.device_put` are _committed_ to their device(s), and so won't be automatically moved. See the [device placement FAQ](https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices) for more information. +We say arrays that have been explicitly placed or sharded with `jax.device_put` are _committed_ to their device(s), and so won't be automatically moved. See the [device placement FAQ](https://docs.jax.dev/en/latest/faq.html#controlling-data-and-computation-placement-on-devices) for more information. When arrays are _not_ explicitly placed or sharded with `jax.device_put`, they are placed _uncommitted_ on the default device. Unlike committed arrays, uncommitted arrays can be moved and resharded automatically: that is, uncommitted arrays can be arguments to a computation even if other arguments are explicitly placed on different devices. @@ -854,7 +854,7 @@ outputId: 479c4d81-cb0b-40a5-89ba-394c10dc3297 ### Generating random numbers -JAX comes with a functional, deterministic [random number generator](https://jax.readthedocs.io/en/latest/jep/263-prng.html). It underlies the various sampling functions in the [`jax.random` module](https://jax.readthedocs.io/en/latest/jax.random.html), such as `jax.random.uniform`. +JAX comes with a functional, deterministic [random number generator](https://docs.jax.dev/en/latest/jep/263-prng.html). It underlies the various sampling functions in the [`jax.random` module](https://docs.jax.dev/en/latest/jax.random.html), such as `jax.random.uniform`. JAX's random numbers are produced by a counter-based PRNG, so in principle, random number generation should be a pure map over counter values. A pure map is a trivially partitionable operation in principle. It should require no cross-device communication, nor any redundant computation across devices. diff --git a/docs/notebooks/README.md b/docs/notebooks/README.md index 07be4441ade8..c945c197ad19 100644 --- a/docs/notebooks/README.md +++ b/docs/notebooks/README.md @@ -1,2 +1,2 @@ For instructions on how to change and test notebooks, see -[Update Documentation](https://jax.readthedocs.io/en/latest/developer.html#update-documentation). +[Update Documentation](https://docs.jax.dev/en/latest/developer.html#update-documentation). diff --git a/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb b/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb index 56b2d80fc58e..d22457c5d718 100644 --- a/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb +++ b/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb @@ -24,7 +24,7 @@ "\n", "Here we show how to add your own function transformations to the system, by writing a custom Jaxpr interpreter. And we'll get composability with all the other transformations for free.\n", "\n", - "**This example uses internal JAX APIs, which may break at any time. Anything not in [the API Documentation](https://jax.readthedocs.io/en/latest/jax.html) should be assumed internal.**" + "**This example uses internal JAX APIs, which may break at any time. Anything not in [the API Documentation](https://docs.jax.dev/en/latest/jax.html) should be assumed internal.**" ] }, { diff --git a/docs/notebooks/Writing_custom_interpreters_in_Jax.md b/docs/notebooks/Writing_custom_interpreters_in_Jax.md index 6b993a630e93..ad707a9746fc 100644 --- a/docs/notebooks/Writing_custom_interpreters_in_Jax.md +++ b/docs/notebooks/Writing_custom_interpreters_in_Jax.md @@ -27,7 +27,7 @@ etc.) that enable writing concise, accelerated code. Here we show how to add your own function transformations to the system, by writing a custom Jaxpr interpreter. And we'll get composability with all the other transformations for free. -**This example uses internal JAX APIs, which may break at any time. Anything not in [the API Documentation](https://jax.readthedocs.io/en/latest/jax.html) should be assumed internal.** +**This example uses internal JAX APIs, which may break at any time. Anything not in [the API Documentation](https://docs.jax.dev/en/latest/jax.html) should be assumed internal.** ```{code-cell} ipython3 :id: s27RDKvKXFL8 diff --git a/docs/notebooks/autodiff_remat.ipynb b/docs/notebooks/autodiff_remat.ipynb index feb906546341..d8a74e4b15fd 100644 --- a/docs/notebooks/autodiff_remat.ipynb +++ b/docs/notebooks/autodiff_remat.ipynb @@ -348,7 +348,7 @@ "source": [ "### Let's think step by step\n", "\n", - "You might want to first (re)read [the Autodiff Cookbook Part 1](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)." + "You might want to first (re)read [the Autodiff Cookbook Part 1](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html)." ] }, { diff --git a/docs/notebooks/autodiff_remat.md b/docs/notebooks/autodiff_remat.md index 8ba87dcfee18..12564bd91f30 100644 --- a/docs/notebooks/autodiff_remat.md +++ b/docs/notebooks/autodiff_remat.md @@ -156,7 +156,7 @@ print_fwd_bwd(f3, W1, W2, W3, x) ### Let's think step by step -You might want to first (re)read [the Autodiff Cookbook Part 1](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html). +You might want to first (re)read [the Autodiff Cookbook Part 1](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html). +++ {"id": "VMfwm_yinvoZ"} diff --git a/docs/notebooks/neural_network_with_tfds_data.ipynb b/docs/notebooks/neural_network_with_tfds_data.ipynb index c31a99746866..a909d9329e24 100644 --- a/docs/notebooks/neural_network_with_tfds_data.ipynb +++ b/docs/notebooks/neural_network_with_tfds_data.ipynb @@ -46,7 +46,7 @@ "\n", "![JAX](https://raw.githubusercontent.com/jax-ml/jax/main/images/jax_logo_250px.png)\n", "\n", - "Let's combine everything we showed in the [quickstart](https://jax.readthedocs.io/en/latest/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P).\n", + "Let's combine everything we showed in the [quickstart](https://docs.jax.dev/en/latest/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P).\n", "\n", "Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for building our model." ] diff --git a/docs/notebooks/neural_network_with_tfds_data.md b/docs/notebooks/neural_network_with_tfds_data.md index 53b7d47358c2..9c153d704763 100644 --- a/docs/notebooks/neural_network_with_tfds_data.md +++ b/docs/notebooks/neural_network_with_tfds_data.md @@ -44,7 +44,7 @@ _Forked from_ `neural_network_and_data_loading.ipynb` ![JAX](https://raw.githubusercontent.com/jax-ml/jax/main/images/jax_logo_250px.png) -Let's combine everything we showed in the [quickstart](https://jax.readthedocs.io/en/latest/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P). +Let's combine everything we showed in the [quickstart](https://docs.jax.dev/en/latest/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P). Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for building our model. diff --git a/docs/notebooks/shard_map.ipynb b/docs/notebooks/shard_map.ipynb index d73b0d4c0f3e..ecfa199c6b52 100644 --- a/docs/notebooks/shard_map.ipynb +++ b/docs/notebooks/shard_map.ipynb @@ -13,9 +13,9 @@ "\n", "`shard_map` is a single-program multiple-data (SPMD) multi-device parallelism API to map a function over shards of data. Mapped function applications, or _instances_, communicate with each other via explicit collective communication operations.\n", "\n", - "`shard_map` is complementary to, and composable with, the automatic compiler-based parallelization built into `jit`. With `jit` you write code as if for a single device, and [the compiler can automatically partition computation over multiple devices](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), generating per-device code and communication collectives behind the scenes. With `shard_map` you take control, writing your own partitioned code and explicit collectives. Or you can do a bit of both: take manual control across groups of devices while leaving within-group device partitioning up to the compiler. The two approaches can be mixed, matched, and composed as needed.\n", + "`shard_map` is complementary to, and composable with, the automatic compiler-based parallelization built into `jit`. With `jit` you write code as if for a single device, and [the compiler can automatically partition computation over multiple devices](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), generating per-device code and communication collectives behind the scenes. With `shard_map` you take control, writing your own partitioned code and explicit collectives. Or you can do a bit of both: take manual control across groups of devices while leaving within-group device partitioning up to the compiler. The two approaches can be mixed, matched, and composed as needed.\n", "\n", - "If you're familiar with `pmap`, think of `shard_map` as an evolution. It's more expressive, performant, and composable with other JAX APIs. It even works eagerly, for easier debugging! (For more, see [a detailed comparison to `pmap`.](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html#why-don-t-pmap-or-xmap-already-solve-this))\n", + "If you're familiar with `pmap`, think of `shard_map` as an evolution. It's more expressive, performant, and composable with other JAX APIs. It even works eagerly, for easier debugging! (For more, see [a detailed comparison to `pmap`.](https://docs.jax.dev/en/latest/jep/14273-shard-map.html#why-don-t-pmap-or-xmap-already-solve-this))\n", "\n", "By reading this tutorial, you'll learn how to use `shard_map` to get full control over your multi-device code. You'll see in detail how it composes with `jax.jit`'s automatic parallelization and `jax.grad`'s automatic differentiation. We'll also give some basic examples of neural network parallelization strategies.\n", "\n", @@ -499,7 +499,7 @@ "* `mesh` encodes devices arranged in an array and with associated axis names, just like it does for `sharding.NamedSharding`;\n", "* `in_specs` and `out_specs` are `PartitionSpec`s which can affinely mention axis names from `mesh` to express slicing/unconcatenation and concatenation of inputs and outputs, respectively, with unmentioned names corresponding to replication and untiling (assert-replicated-so-give-me-one-copy), respectively;\n", "* `auto` is an optional set of axis names corresponding to the subset of names of `mesh` to treat automatically in the body, as in the caller, rather than manually;\n", - "* `check_rep` is an optional boolean indicating whether to check statically for any replication errors in `out_specs`, and also whether to enable a related automatic differentiation optimization (see [JEP](https://jax.readthedocs.io/en/latest/jep/17111-shmap-transpose.html)).\n", + "* `check_rep` is an optional boolean indicating whether to check statically for any replication errors in `out_specs`, and also whether to enable a related automatic differentiation optimization (see [JEP](https://docs.jax.dev/en/latest/jep/17111-shmap-transpose.html)).\n", "\n", "The shapes of the arguments passed to `f` have the same ranks as the arguments\n", "passed to `shard_map`-of-`f`, and the shape of an argument to `f` is computed\n", @@ -1520,7 +1520,7 @@ "source": [ "Compare these examples with the purely [automatic partitioning examples in the\n", "\"Distributed arrays and automatic partitioning\"\n", - "doc](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html).\n", + "doc](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html).\n", "While in those automatic partitioning examples we don't need to edit the model\n", "functions to use different parallelization strategies, with `shard_map` we\n", "often do.\n", @@ -1626,7 +1626,7 @@ "parameters from the forward pass for use on the backward pass. Instead, we want\n", "to gather them again on the backward pass. We can express that by using\n", "`jax.remat` with a [custom\n", - "policy](https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#custom-policies-for-what-s-saveable)\n", + "policy](https://docs.jax.dev/en/latest/notebooks/autodiff_remat.html#custom-policies-for-what-s-saveable)\n", "(or a `custom_vjp`), though XLA typically does that rematerialization\n", "automatically.\n", "\n", diff --git a/docs/notebooks/shard_map.md b/docs/notebooks/shard_map.md index c52cf0e6d22b..095f37d0dde1 100644 --- a/docs/notebooks/shard_map.md +++ b/docs/notebooks/shard_map.md @@ -22,9 +22,9 @@ kernelspec: `shard_map` is a single-program multiple-data (SPMD) multi-device parallelism API to map a function over shards of data. Mapped function applications, or _instances_, communicate with each other via explicit collective communication operations. -`shard_map` is complementary to, and composable with, the automatic compiler-based parallelization built into `jit`. With `jit` you write code as if for a single device, and [the compiler can automatically partition computation over multiple devices](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), generating per-device code and communication collectives behind the scenes. With `shard_map` you take control, writing your own partitioned code and explicit collectives. Or you can do a bit of both: take manual control across groups of devices while leaving within-group device partitioning up to the compiler. The two approaches can be mixed, matched, and composed as needed. +`shard_map` is complementary to, and composable with, the automatic compiler-based parallelization built into `jit`. With `jit` you write code as if for a single device, and [the compiler can automatically partition computation over multiple devices](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), generating per-device code and communication collectives behind the scenes. With `shard_map` you take control, writing your own partitioned code and explicit collectives. Or you can do a bit of both: take manual control across groups of devices while leaving within-group device partitioning up to the compiler. The two approaches can be mixed, matched, and composed as needed. -If you're familiar with `pmap`, think of `shard_map` as an evolution. It's more expressive, performant, and composable with other JAX APIs. It even works eagerly, for easier debugging! (For more, see [a detailed comparison to `pmap`.](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html#why-don-t-pmap-or-xmap-already-solve-this)) +If you're familiar with `pmap`, think of `shard_map` as an evolution. It's more expressive, performant, and composable with other JAX APIs. It even works eagerly, for easier debugging! (For more, see [a detailed comparison to `pmap`.](https://docs.jax.dev/en/latest/jep/14273-shard-map.html#why-don-t-pmap-or-xmap-already-solve-this)) By reading this tutorial, you'll learn how to use `shard_map` to get full control over your multi-device code. You'll see in detail how it composes with `jax.jit`'s automatic parallelization and `jax.grad`'s automatic differentiation. We'll also give some basic examples of neural network parallelization strategies. @@ -346,7 +346,7 @@ where: * `mesh` encodes devices arranged in an array and with associated axis names, just like it does for `sharding.NamedSharding`; * `in_specs` and `out_specs` are `PartitionSpec`s which can affinely mention axis names from `mesh` to express slicing/unconcatenation and concatenation of inputs and outputs, respectively, with unmentioned names corresponding to replication and untiling (assert-replicated-so-give-me-one-copy), respectively; * `auto` is an optional set of axis names corresponding to the subset of names of `mesh` to treat automatically in the body, as in the caller, rather than manually; -* `check_rep` is an optional boolean indicating whether to check statically for any replication errors in `out_specs`, and also whether to enable a related automatic differentiation optimization (see [JEP](https://jax.readthedocs.io/en/latest/jep/17111-shmap-transpose.html)). +* `check_rep` is an optional boolean indicating whether to check statically for any replication errors in `out_specs`, and also whether to enable a related automatic differentiation optimization (see [JEP](https://docs.jax.dev/en/latest/jep/17111-shmap-transpose.html)). The shapes of the arguments passed to `f` have the same ranks as the arguments passed to `shard_map`-of-`f`, and the shape of an argument to `f` is computed @@ -1061,7 +1061,7 @@ params, batch = init(jax.random.key(0), layer_sizes, batch_size) Compare these examples with the purely [automatic partitioning examples in the "Distributed arrays and automatic partitioning" -doc](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html). +doc](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html). While in those automatic partitioning examples we don't need to edit the model functions to use different parallelization strategies, with `shard_map` we often do. @@ -1137,7 +1137,7 @@ There's one other ingredient we need: we don't want to store the fully gathered parameters from the forward pass for use on the backward pass. Instead, we want to gather them again on the backward pass. We can express that by using `jax.remat` with a [custom -policy](https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#custom-policies-for-what-s-saveable) +policy](https://docs.jax.dev/en/latest/notebooks/autodiff_remat.html#custom-policies-for-what-s-saveable) (or a `custom_vjp`), though XLA typically does that rematerialization automatically. diff --git a/docs/notebooks/thinking_in_jax.ipynb b/docs/notebooks/thinking_in_jax.ipynb index 5ddcdd32e2b4..d6cbf6e02198 100644 --- a/docs/notebooks/thinking_in_jax.ipynb +++ b/docs/notebooks/thinking_in_jax.ipynb @@ -248,7 +248,7 @@ "id": "yRYF0YgO3F4H" }, "source": [ - "For updating individual elements, JAX provides an [indexed update syntax](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax-numpy-ndarray-at) that returns an updated copy:" + "For updating individual elements, JAX provides an [indexed update syntax](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax-numpy-ndarray-at) that returns an updated copy:" ] }, { @@ -423,7 +423,7 @@ "id": "0GPqgT7S0q8r" }, "source": [ - "Under the hood, this NumPy operation is translated to a much more general convolution implemented by [`lax.conv_general_dilated`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv_general_dilated.html):" + "Under the hood, this NumPy operation is translated to a much more general convolution implemented by [`lax.conv_general_dilated`](https://docs.jax.dev/en/latest/_autosummary/jax.lax.conv_general_dilated.html):" ] }, { @@ -461,7 +461,7 @@ "id": "7mdo6ycczlbd" }, "source": [ - "This is a batched convolution operation designed to be efficient for the types of convolutions often used in deep neural nets. It requires much more boilerplate, but is far more flexible and scalable than the convolution provided by NumPy (See [Convolutions in JAX](https://jax.readthedocs.io/en/latest/notebooks/convolutions.html) for more detail on JAX convolutions).\n", + "This is a batched convolution operation designed to be efficient for the types of convolutions often used in deep neural nets. It requires much more boilerplate, but is far more flexible and scalable than the convolution provided by NumPy (See [Convolutions in JAX](https://docs.jax.dev/en/latest/notebooks/convolutions.html) for more detail on JAX convolutions).\n", "\n", "At their heart, all `jax.lax` operations are Python wrappers for operations in XLA; here, for example, the convolution implementation is provided by [XLA:ConvWithGeneralPadding](https://www.tensorflow.org/xla/operation_semantics#convwithgeneralpadding_convolution).\n", "Every JAX operation is eventually expressed in terms of these fundamental XLA operations, which is what enables just-in-time (JIT) compilation." @@ -562,7 +562,7 @@ "id": "3GvisB-CA9M8" }, "source": [ - "But due to the compilation (which includes fusing of operations, avoidance of allocating temporary arrays, and a host of other tricks), execution times can be orders of magnitude faster in the JIT-compiled case (note the use of `block_until_ready()` to account for JAX's [asynchronous dispatch](https://jax.readthedocs.io/en/latest/async_dispatch.html)):" + "But due to the compilation (which includes fusing of operations, avoidance of allocating temporary arrays, and a host of other tricks), execution times can be orders of magnitude faster in the JIT-compiled case (note the use of `block_until_ready()` to account for JAX's [asynchronous dispatch](https://docs.jax.dev/en/latest/async_dispatch.html)):" ] }, { @@ -650,7 +650,7 @@ "evalue": "ignored", "output_type": "error", "traceback": [ - "\u001b[0;31mNonConcreteBooleanIndexError\u001b[0m\u001b[0;31m:\u001b[0m Array boolean indices must be concrete; got ShapedArray(bool[10])\n\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError\n" + "\u001b[0;31mNonConcreteBooleanIndexError\u001b[0m\u001b[0;31m:\u001b[0m Array boolean indices must be concrete; got ShapedArray(bool[10])\n\nSee https://docs.jax.dev/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError\n" ] } ], @@ -835,7 +835,7 @@ "evalue": "ignored", "output_type": "error", "traceback": [ - "\u001b[0;31mTracerBoolConversionError\u001b[0m\u001b[0;31m:\u001b[0m Attempted boolean conversion of traced array with shape bool[]..\nThe error occurred while tracing the function f at :1 for jit. This concrete value was not available in Python because it depends on the value of the argument neg.\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError\n" + "\u001b[0;31mTracerBoolConversionError\u001b[0m\u001b[0;31m:\u001b[0m Attempted boolean conversion of traced array with shape bool[]..\nThe error occurred while tracing the function f at :1 for jit. This concrete value was not available in Python because it depends on the value of the argument neg.\nSee https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError\n" ] } ], diff --git a/docs/notebooks/thinking_in_jax.md b/docs/notebooks/thinking_in_jax.md index 0693f6ba8579..7b0bb0d9b8ce 100644 --- a/docs/notebooks/thinking_in_jax.md +++ b/docs/notebooks/thinking_in_jax.md @@ -117,7 +117,7 @@ x[0] = 10 +++ {"id": "yRYF0YgO3F4H"} -For updating individual elements, JAX provides an [indexed update syntax](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax-numpy-ndarray-at) that returns an updated copy: +For updating individual elements, JAX provides an [indexed update syntax](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax-numpy-ndarray-at) that returns an updated copy: ```{code-cell} ipython3 :id: 8zqPEAeP3UK5 @@ -189,7 +189,7 @@ jnp.convolve(x, y) +++ {"id": "0GPqgT7S0q8r"} -Under the hood, this NumPy operation is translated to a much more general convolution implemented by [`lax.conv_general_dilated`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv_general_dilated.html): +Under the hood, this NumPy operation is translated to a much more general convolution implemented by [`lax.conv_general_dilated`](https://docs.jax.dev/en/latest/_autosummary/jax.lax.conv_general_dilated.html): ```{code-cell} ipython3 :id: pi4f6ikjzc3l @@ -206,7 +206,7 @@ result[0, 0] +++ {"id": "7mdo6ycczlbd"} -This is a batched convolution operation designed to be efficient for the types of convolutions often used in deep neural nets. It requires much more boilerplate, but is far more flexible and scalable than the convolution provided by NumPy (See [Convolutions in JAX](https://jax.readthedocs.io/en/latest/notebooks/convolutions.html) for more detail on JAX convolutions). +This is a batched convolution operation designed to be efficient for the types of convolutions often used in deep neural nets. It requires much more boilerplate, but is far more flexible and scalable than the convolution provided by NumPy (See [Convolutions in JAX](https://docs.jax.dev/en/latest/notebooks/convolutions.html) for more detail on JAX convolutions). At their heart, all `jax.lax` operations are Python wrappers for operations in XLA; here, for example, the convolution implementation is provided by [XLA:ConvWithGeneralPadding](https://www.tensorflow.org/xla/operation_semantics#convwithgeneralpadding_convolution). Every JAX operation is eventually expressed in terms of these fundamental XLA operations, which is what enables just-in-time (JIT) compilation. @@ -261,7 +261,7 @@ np.allclose(norm(X), norm_compiled(X), atol=1E-6) +++ {"id": "3GvisB-CA9M8"} -But due to the compilation (which includes fusing of operations, avoidance of allocating temporary arrays, and a host of other tricks), execution times can be orders of magnitude faster in the JIT-compiled case (note the use of `block_until_ready()` to account for JAX's [asynchronous dispatch](https://jax.readthedocs.io/en/latest/async_dispatch.html)): +But due to the compilation (which includes fusing of operations, avoidance of allocating temporary arrays, and a host of other tricks), execution times can be orders of magnitude faster in the JIT-compiled case (note the use of `block_until_ready()` to account for JAX's [asynchronous dispatch](https://docs.jax.dev/en/latest/async_dispatch.html)): ```{code-cell} ipython3 :id: 6mUB6VdDAEIY diff --git a/docs/pallas/CHANGELOG.md b/docs/pallas/CHANGELOG.md index 2b1cad7c9a66..7533e6eda053 100644 --- a/docs/pallas/CHANGELOG.md +++ b/docs/pallas/CHANGELOG.md @@ -5,7 +5,7 @@ This is the list of changes specific to {class}`jax.experimental.pallas`. -For the overall JAX change log see [here](https://jax.readthedocs.io/en/latest/changelog.html). +For the overall JAX change log see [here](https://docs.jax.dev/en/latest/changelog.html).