|
20 | 20 | "py_library_providing_imports_info",
|
21 | 21 | "pytype_library",
|
22 | 22 | )
|
| 23 | +load( |
| 24 | + "//jaxlib:pywrap.bzl", |
| 25 | + "nanobind_pywrap_extension", |
| 26 | + "pywrap_binaries", |
| 27 | + "pywrap_library", |
| 28 | +) |
23 | 29 | load("//jaxlib:symlink_files.bzl", "symlink_files")
|
24 | 30 |
|
25 | 31 | licenses(["notice"])
|
@@ -51,6 +57,7 @@ py_library_providing_imports_info(
|
51 | 57 | lib_rule = pytype_library,
|
52 | 58 | deps = [
|
53 | 59 | ":cpu_feature_guard",
|
| 60 | + ":jax", |
54 | 61 | ":utils",
|
55 | 62 | "//jaxlib/cpu:_lapack",
|
56 | 63 | "//jaxlib/mlir",
|
@@ -98,6 +105,44 @@ exports_files([
|
98 | 105 | "setup.py",
|
99 | 106 | ])
|
100 | 107 |
|
| 108 | +pywrap_library( |
| 109 | + name = "jax", |
| 110 | + common_lib_def_files_or_filters = { |
| 111 | + "jaxlib/jax_common": "jax_common.json", |
| 112 | + }, |
| 113 | + common_lib_version_scripts = { |
| 114 | + "jaxlib/jax_common": select({ |
| 115 | + "@bazel_tools//src/conditions:windows": None, |
| 116 | + "@bazel_tools//src/conditions:darwin": "libjax_common_darwin.lds", |
| 117 | + "//conditions:default": "libjax_common.lds", |
| 118 | + }), |
| 119 | + }, |
| 120 | + deps = [ |
| 121 | + ":utils", |
| 122 | + "//jaxlib/mlir/_mlir_libs:_chlo", |
| 123 | + "//jaxlib/mlir/_mlir_libs:_mlir", |
| 124 | + "//jaxlib/mlir/_mlir_libs:_mlirDialectsGPU", |
| 125 | + "//jaxlib/mlir/_mlir_libs:_mlirDialectsLLVM", |
| 126 | + "//jaxlib/mlir/_mlir_libs:_mlirDialectsNVGPU", |
| 127 | + "//jaxlib/mlir/_mlir_libs:_mlirDialectsSparseTensor", |
| 128 | + "//jaxlib/mlir/_mlir_libs:_mlirGPUPasses", |
| 129 | + "//jaxlib/mlir/_mlir_libs:_mlirHlo", |
| 130 | + "//jaxlib/mlir/_mlir_libs:_mlirSparseTensorPasses", |
| 131 | + "//jaxlib/mlir/_mlir_libs:_mosaic_gpu_ext", |
| 132 | + "//jaxlib/mlir/_mlir_libs:_sdy", |
| 133 | + "//jaxlib/mlir/_mlir_libs:_stablehlo", |
| 134 | + "//jaxlib/mlir/_mlir_libs:_tpu_ext", |
| 135 | + "//jaxlib/mlir/_mlir_libs:_triton_ext", |
| 136 | + "//jaxlib/mlir/_mlir_libs:register_jax_dialects", |
| 137 | + "//jaxlib/xla:xla_extension", |
| 138 | + ], |
| 139 | +) |
| 140 | + |
| 141 | +pywrap_binaries( |
| 142 | + name = "jaxlib_binaries", |
| 143 | + dep = ":jax", |
| 144 | +) |
| 145 | + |
101 | 146 | cc_library(
|
102 | 147 | name = "absl_status_casters",
|
103 | 148 | hdrs = ["absl_status_casters.h"],
|
@@ -170,10 +215,9 @@ nanobind_extension(
|
170 | 215 | ],
|
171 | 216 | )
|
172 | 217 |
|
173 |
| -nanobind_extension( |
| 218 | +nanobind_pywrap_extension( |
174 | 219 | name = "utils",
|
175 | 220 | srcs = ["utils.cc"],
|
176 |
| - module_name = "utils", |
177 | 221 | deps = [
|
178 | 222 | "@com_google_absl//absl/cleanup",
|
179 | 223 | "@com_google_absl//absl/container:flat_hash_map",
|
|
0 commit comments