Skip to content

Commit db11efa

Browse files
hawkinspGoogle-ML-Automation
authored andcommitted
Migrate jaxlib to use a single common .so file for all C++ dependencies.
The idea is to move all of the jaxlib contents into a single .so file, and have all of the other Python extensions be tiny stubs that reexport part of the larger .so file. This has two main benefits: * it reduces the size of the jaxlib wheel, by about 70-80MB when installed. The benefit of the change is that it avoid duplication between the MLIR CAPI code and the copy of MLIR in XLA. * it gives us flexibility to split and merge Python extensions as we see fit. Issue #11225 PiperOrigin-RevId: 744855997
1 parent 74917ce commit db11efa

14 files changed

+281
-104
lines changed

.bazelrc

+1
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ build:windows --incompatible_strict_action_env=true
9898
# #############################################################################
9999
build:nonccl --define=no_nccl_support=true
100100

101+
build --repo_env USE_PYWRAP_RULES=1
101102
build:posix --copt=-fvisibility=hidden
102103
build:posix --copt=-Wno-sign-compare
103104
build:posix --cxxopt=-std=c++17

jaxlib/BUILD

+46-2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@ load(
2020
"py_library_providing_imports_info",
2121
"pytype_library",
2222
)
23+
load(
24+
"//jaxlib:pywrap.bzl",
25+
"nanobind_pywrap_extension",
26+
"pywrap_binaries",
27+
"pywrap_library",
28+
)
2329
load("//jaxlib:symlink_files.bzl", "symlink_files")
2430

2531
licenses(["notice"])
@@ -51,6 +57,7 @@ py_library_providing_imports_info(
5157
lib_rule = pytype_library,
5258
deps = [
5359
":cpu_feature_guard",
60+
":jax",
5461
":utils",
5562
"//jaxlib/cpu:_lapack",
5663
"//jaxlib/mlir",
@@ -98,6 +105,44 @@ exports_files([
98105
"setup.py",
99106
])
100107

108+
pywrap_library(
109+
name = "jax",
110+
common_lib_def_files_or_filters = {
111+
"jaxlib/jax_common": "jax_common.json",
112+
},
113+
common_lib_version_scripts = {
114+
"jaxlib/jax_common": select({
115+
"@bazel_tools//src/conditions:windows": None,
116+
"@bazel_tools//src/conditions:darwin": "libjax_common_darwin.lds",
117+
"//conditions:default": "libjax_common.lds",
118+
}),
119+
},
120+
deps = [
121+
":utils",
122+
"//jaxlib/mlir/_mlir_libs:_chlo",
123+
"//jaxlib/mlir/_mlir_libs:_mlir",
124+
"//jaxlib/mlir/_mlir_libs:_mlirDialectsGPU",
125+
"//jaxlib/mlir/_mlir_libs:_mlirDialectsLLVM",
126+
"//jaxlib/mlir/_mlir_libs:_mlirDialectsNVGPU",
127+
"//jaxlib/mlir/_mlir_libs:_mlirDialectsSparseTensor",
128+
"//jaxlib/mlir/_mlir_libs:_mlirGPUPasses",
129+
"//jaxlib/mlir/_mlir_libs:_mlirHlo",
130+
"//jaxlib/mlir/_mlir_libs:_mlirSparseTensorPasses",
131+
"//jaxlib/mlir/_mlir_libs:_mosaic_gpu_ext",
132+
"//jaxlib/mlir/_mlir_libs:_sdy",
133+
"//jaxlib/mlir/_mlir_libs:_stablehlo",
134+
"//jaxlib/mlir/_mlir_libs:_tpu_ext",
135+
"//jaxlib/mlir/_mlir_libs:_triton_ext",
136+
"//jaxlib/mlir/_mlir_libs:register_jax_dialects",
137+
"//jaxlib/xla:xla_extension",
138+
],
139+
)
140+
141+
pywrap_binaries(
142+
name = "jaxlib_binaries",
143+
dep = ":jax",
144+
)
145+
101146
cc_library(
102147
name = "absl_status_casters",
103148
hdrs = ["absl_status_casters.h"],
@@ -170,10 +215,9 @@ nanobind_extension(
170215
],
171216
)
172217

173-
nanobind_extension(
218+
nanobind_pywrap_extension(
174219
name = "utils",
175220
srcs = ["utils.cc"],
176-
module_name = "utils",
177221
deps = [
178222
"@com_google_absl//absl/cleanup",
179223
"@com_google_absl//absl/container:flat_hash_map",

jaxlib/jax_common.json

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{
2+
"global": [
3+
"Wrapped_PyInit_*"
4+
],
5+
"local": [
6+
"*"
7+
]
8+
}

jaxlib/libjax_common.lds

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
global:
3+
Wrapped_PyInit_*;
4+
5+
local:
6+
*;
7+
};

jaxlib/libjax_common_darwin.lds

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
*Wrapped_PyInit_*

0 commit comments

Comments
 (0)