1
1
load ("@jax//jaxlib:symlink_files.bzl" , "symlink_inputs" )
2
2
load ("@pybind11_bazel//:build_defs.bzl" , "pybind_extension" , "pybind_library" )
3
3
load ("@llvm-project//mlir:tblgen.bzl" , "gentbl_cc_library" , "td_library" )
4
- load ("@llvm-project//llvm:tblgen.bzl" , "gentbl" )
5
4
6
5
exports_files (["enzymexlamlir-opt.cpp" ])
6
+
7
7
licenses (["notice" ])
8
8
9
9
package (
@@ -30,9 +30,9 @@ pybind_library(
30
30
"@llvm-project//llvm:AsmParser" ,
31
31
"@llvm-project//llvm:CodeGen" ,
32
32
"@llvm-project//llvm:Core" ,
33
- "@llvm-project//llvm:MC" ,
34
33
"@llvm-project//llvm:IRReader" ,
35
34
"@llvm-project//llvm:Linker" ,
35
+ "@llvm-project//llvm:MC" ,
36
36
"@llvm-project//llvm:OrcJIT" ,
37
37
"@llvm-project//llvm:Support" ,
38
38
"@llvm-project//llvm:TargetParser" ,
@@ -41,8 +41,11 @@ pybind_library(
41
41
42
42
py_library (
43
43
name = "enzyme_jax_internal" ,
44
- srcs = ["primitives.py" , "__init__.py" ],
45
- visibility = ["//visibility:public" ]
44
+ srcs = [
45
+ "__init__.py" ,
46
+ "primitives.py" ,
47
+ ],
48
+ visibility = ["//visibility:public" ],
46
49
)
47
50
48
51
symlink_inputs (
@@ -56,7 +59,7 @@ symlink_inputs(
56
59
td_library (
57
60
name = "ImplementationsCommonTdFiles" ,
58
61
srcs = [
59
- ":EnzymeImplementationsCommonTdFiles" ,
62
+ ":EnzymeImplementationsCommonTdFiles" ,
60
63
],
61
64
deps = [
62
65
":EnzymeImplementationsCommonTdFiles" ,
@@ -76,8 +79,8 @@ gentbl_cc_library(
76
79
"Implementations/HLODerivatives.td" ,
77
80
],
78
81
deps = [
79
- "@enzyme//:enzyme-tblgen" ,
80
82
":ImplementationsCommonTdFiles" ,
83
+ "@enzyme//:enzyme-tblgen" ,
81
84
],
82
85
)
83
86
@@ -94,8 +97,8 @@ gentbl_cc_library(
94
97
"Implementations/HLODerivatives.td" ,
95
98
],
96
99
deps = [
97
- "@enzyme//:enzyme-tblgen" ,
98
100
":EnzymeImplementationsCommonTdFiles" ,
101
+ "@enzyme//:enzyme-tblgen" ,
99
102
],
100
103
)
101
104
@@ -146,27 +149,31 @@ cc_library(
146
149
":EnzymeXLAPassesIncGen" ,
147
150
":mhlo-derivatives" ,
148
151
":stablehlo-derivatives" ,
149
- "@stablehlo//:stablehlo_ops" ,
150
- "@stablehlo//:stablehlo_passes" ,
151
- "@stablehlo//:reference_ops" ,
152
+ "@enzyme//:EnzymeMLIR" ,
152
153
"@llvm-project//mlir:ArithDialect" ,
154
+ "@llvm-project//mlir:CommonFolders" ,
155
+ "@llvm-project//mlir:ControlFlowInterfaces" ,
153
156
"@llvm-project//mlir:FuncDialect" ,
154
- "@llvm-project//mlir:TensorDialect" ,
155
- "@llvm-project//mlir:IR" ,
156
157
"@llvm-project//mlir:FunctionInterfaces" ,
157
- "@llvm-project//mlir:ControlFlowInterfaces " ,
158
+ "@llvm-project//mlir:IR " ,
158
159
"@llvm-project//mlir:Support" ,
159
- "@llvm-project//mlir:CommonFolders " ,
160
+ "@llvm-project//mlir:TensorDialect " ,
160
161
"@llvm-project//mlir:Transforms" ,
162
+ "@stablehlo//:reference_ops" ,
163
+ "@stablehlo//:stablehlo_ops" ,
164
+ "@stablehlo//:stablehlo_passes" ,
161
165
"@xla//xla/mlir_hlo" ,
162
- "@enzyme//:EnzymeMLIR" ,
163
- ]
166
+ ],
164
167
)
165
168
166
169
pybind_library (
167
170
name = "compile_with_xla" ,
168
171
srcs = ["compile_with_xla.cc" ],
169
- hdrs = glob (["compile_with_xla.h" , "Implementations/*.h" , "Passes/*.h" ]),
172
+ hdrs = glob ([
173
+ "compile_with_xla.h" ,
174
+ "Implementations/*.h" ,
175
+ "Passes/*.h" ,
176
+ ]),
170
177
deps = [
171
178
":XLADerivatives" ,
172
179
# This is similar to xla_binary rule and is needed to make XLA client compile.
@@ -193,7 +200,7 @@ pybind_library(
193
200
"@xla//xla/client:client_library" ,
194
201
"@xla//xla/client:executable_build_options" ,
195
202
"@xla//xla/client:xla_computation" ,
196
- "@xla//xla/service:service " ,
203
+ "@xla//xla/service" ,
197
204
"@xla//xla/service:local_service" ,
198
205
"@xla//xla/service:local_service_utils" ,
199
206
"@xla//xla/service:buffer_assignment_proto_cc" ,
@@ -212,7 +219,6 @@ pybind_library(
212
219
"@xla//xla:xla_data_proto_cc_impl" ,
213
220
"@xla//xla:xla_proto_cc" ,
214
221
"@xla//xla:xla_proto_cc_impl" ,
215
-
216
222
"@stablehlo//:stablehlo_ops" ,
217
223
218
224
# Make CPU target available to XLA.
@@ -221,7 +227,6 @@ pybind_library(
221
227
# MHLO stuff.
222
228
"@xla//xla/mlir_hlo" ,
223
229
"@xla//xla/translate/mhlo_to_hlo:mlir_hlo_to_hlo" ,
224
-
225
230
"@xla//xla/hlo/ir:hlo" ,
226
231
227
232
# This is necessary for XLA protobufs to link
@@ -233,50 +238,46 @@ pybind_library(
233
238
"@llvm-project//mlir:FuncDialect" ,
234
239
"@llvm-project//mlir:FuncExtensions" ,
235
240
"@llvm-project//mlir:TensorDialect" ,
236
-
237
241
"@llvm-project//mlir:Parser" ,
238
242
"@llvm-project//mlir:Pass" ,
239
-
240
243
"@xla//xla/mlir_hlo:all_passes" ,
241
244
"@xla//xla:printer" ,
242
245
243
- # EnzymeMLIR
246
+ # EnzymeMLIR
244
247
"@enzyme//:EnzymeMLIR" ,
245
-
246
248
"@com_google_absl//absl/status:statusor" ,
247
-
249
+
248
250
# Mosaic
249
- "@jax//jaxlib/mosaic:tpu_dialect" ,
251
+ "@jax//jaxlib/mosaic:tpu_dialect" ,
250
252
],
251
253
)
252
254
253
255
pybind_extension (
254
256
name = "enzyme_call" ,
255
257
srcs = ["enzyme_call.cc" ],
258
+ visibility = ["//visibility:public" ],
256
259
deps = [
260
+ ":clang_compile" ,
261
+ ":compile_with_xla" ,
262
+ "@com_google_absl//absl/status:statusor" ,
263
+ "@enzyme//:EnzymeMLIR" ,
264
+ "@enzyme//:EnzymeStatic" ,
257
265
"@llvm-project//llvm:Core" ,
258
266
"@llvm-project//llvm:ExecutionEngine" ,
259
267
"@llvm-project//llvm:IRReader" ,
260
268
"@llvm-project//llvm:OrcJIT" ,
261
269
"@llvm-project//llvm:OrcTargetProcess" ,
262
270
"@llvm-project//llvm:Support" ,
263
271
"@llvm-project//mlir:AllPassesAndDialects" ,
264
- ":clang_compile" ,
265
- ":compile_with_xla" ,
266
- "@com_google_absl//absl/status:statusor" ,
267
272
"@stablehlo//:stablehlo_passes" ,
268
- "@xla//xla/stream_executor:stream_executor_impl " ,
273
+ "@xla//xla/hlo/ir:hlo " ,
269
274
"@xla//xla/mlir/backends/cpu/transforms:passes" ,
270
275
"@xla//xla/mlir/memref/transforms:passes" ,
271
276
"@xla//xla/mlir/runtime/transforms:passes" ,
277
+ "@xla//xla/mlir_hlo:all_passes" ,
272
278
"@xla//xla/mlir_hlo:deallocation_passes" ,
273
279
"@xla//xla/mlir_hlo:lhlo" ,
274
- "@xla//xla/mlir_hlo:all_passes" ,
275
- "@xla//xla/hlo/ir:hlo" ,
276
280
"@xla//xla/service/cpu:cpu_executable" ,
277
- "@enzyme//:EnzymeStatic" ,
278
- "@enzyme//:EnzymeMLIR" ,
279
-
281
+ "@xla//xla/stream_executor:stream_executor_impl" ,
280
282
],
281
- visibility = ["//visibility:public" ],
282
283
)
0 commit comments