Skip to content

Commit 405b74d

Browse files
authored
Cleanup bazel files (#51)
* cleanup baze * buildifier
1 parent 5c65020 commit 405b74d

File tree

3 files changed

+60
-53
lines changed

3 files changed

+60
-53
lines changed

BUILD

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,15 @@ package(
1010

1111
py_package(
1212
name = "enzyme_jax_data",
13+
# Only include these Python packages.
14+
packages = [
15+
"@//src/enzyme_ad/jax:enzyme_call.so",
16+
"@llvm-project//clang:builtin_headers_gen",
17+
],
1318
deps = [
1419
"//src/enzyme_ad/jax:enzyme_call.so",
1520
"@llvm-project//clang:builtin_headers_gen",
1621
],
17-
# Only include these Python packages.
18-
packages = ["@//src/enzyme_ad/jax:enzyme_call.so", "@llvm-project//clang:builtin_headers_gen"],
1922
)
2023

2124
cc_binary(
@@ -49,17 +52,11 @@ cc_binary(
4952

5053
py_wheel(
5154
name = "enzyme_ad",
55+
author = "Enzyme Authors",
56+
5257
distribution = "enzyme_ad",
53-
summary = "Enzyme automatic differentiation tool.",
5458
homepage = "https://enzyme.mit.edu/",
55-
project_urls = {
56-
"GitHub": "https://github.com/EnzymeAD/Enzyme-JAX/",
57-
},
58-
author="Enzyme Authors",
59-
license="LLVM",
60-
61-
python_tag = "py3",
62-
version = "0.0.6",
59+
license = "LLVM",
6360
platform = select({
6461
"@bazel_tools//src/conditions:windows_x64": "win_amd64",
6562
"@bazel_tools//src/conditions:darwin_arm64": "macosx_11_0_arm64",
@@ -68,11 +65,20 @@ py_wheel(
6865
"@bazel_tools//src/conditions:linux_x86_64": "manylinux2014_x86_64",
6966
"@bazel_tools//src/conditions:linux_ppc64le": "manylinux2014_ppc64le",
7067
}),
71-
deps = ["//src/enzyme_ad/jax:enzyme_jax_internal", ":enzyme_jax_data"],
72-
strip_path_prefixes = ["src/"],
68+
project_urls = {
69+
"GitHub": "https://github.com/EnzymeAD/Enzyme-JAX/",
70+
},
71+
python_tag = "py3",
7372
requires = [
7473
"absl_py >= 2.0.0",
7574
"jax >= 0.4.21",
7675
"jaxlib >= 0.4.21",
7776
],
77+
strip_path_prefixes = ["src/"],
78+
summary = "Enzyme automatic differentiation tool.",
79+
version = "0.0.6",
80+
deps = [
81+
":enzyme_jax_data",
82+
"//src/enzyme_ad/jax:enzyme_jax_internal",
83+
],
7884
)

src/enzyme_ad/jax/BUILD

Lines changed: 37 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
load("@jax//jaxlib:symlink_files.bzl", "symlink_inputs")
22
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension", "pybind_library")
33
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
4-
load("@llvm-project//llvm:tblgen.bzl", "gentbl")
54

65
exports_files(["enzymexlamlir-opt.cpp"])
6+
77
licenses(["notice"])
88

99
package(
@@ -30,9 +30,9 @@ pybind_library(
3030
"@llvm-project//llvm:AsmParser",
3131
"@llvm-project//llvm:CodeGen",
3232
"@llvm-project//llvm:Core",
33-
"@llvm-project//llvm:MC",
3433
"@llvm-project//llvm:IRReader",
3534
"@llvm-project//llvm:Linker",
35+
"@llvm-project//llvm:MC",
3636
"@llvm-project//llvm:OrcJIT",
3737
"@llvm-project//llvm:Support",
3838
"@llvm-project//llvm:TargetParser",
@@ -41,8 +41,11 @@ pybind_library(
4141

4242
py_library(
4343
name = "enzyme_jax_internal",
44-
srcs = ["primitives.py", "__init__.py"],
45-
visibility = ["//visibility:public"]
44+
srcs = [
45+
"__init__.py",
46+
"primitives.py",
47+
],
48+
visibility = ["//visibility:public"],
4649
)
4750

4851
symlink_inputs(
@@ -56,7 +59,7 @@ symlink_inputs(
5659
td_library(
5760
name = "ImplementationsCommonTdFiles",
5861
srcs = [
59-
":EnzymeImplementationsCommonTdFiles",
62+
":EnzymeImplementationsCommonTdFiles",
6063
],
6164
deps = [
6265
":EnzymeImplementationsCommonTdFiles",
@@ -76,8 +79,8 @@ gentbl_cc_library(
7679
"Implementations/HLODerivatives.td",
7780
],
7881
deps = [
79-
"@enzyme//:enzyme-tblgen",
8082
":ImplementationsCommonTdFiles",
83+
"@enzyme//:enzyme-tblgen",
8184
],
8285
)
8386

@@ -94,8 +97,8 @@ gentbl_cc_library(
9497
"Implementations/HLODerivatives.td",
9598
],
9699
deps = [
97-
"@enzyme//:enzyme-tblgen",
98100
":EnzymeImplementationsCommonTdFiles",
101+
"@enzyme//:enzyme-tblgen",
99102
],
100103
)
101104

@@ -146,27 +149,31 @@ cc_library(
146149
":EnzymeXLAPassesIncGen",
147150
":mhlo-derivatives",
148151
":stablehlo-derivatives",
149-
"@stablehlo//:stablehlo_ops",
150-
"@stablehlo//:stablehlo_passes",
151-
"@stablehlo//:reference_ops",
152+
"@enzyme//:EnzymeMLIR",
152153
"@llvm-project//mlir:ArithDialect",
154+
"@llvm-project//mlir:CommonFolders",
155+
"@llvm-project//mlir:ControlFlowInterfaces",
153156
"@llvm-project//mlir:FuncDialect",
154-
"@llvm-project//mlir:TensorDialect",
155-
"@llvm-project//mlir:IR",
156157
"@llvm-project//mlir:FunctionInterfaces",
157-
"@llvm-project//mlir:ControlFlowInterfaces",
158+
"@llvm-project//mlir:IR",
158159
"@llvm-project//mlir:Support",
159-
"@llvm-project//mlir:CommonFolders",
160+
"@llvm-project//mlir:TensorDialect",
160161
"@llvm-project//mlir:Transforms",
162+
"@stablehlo//:reference_ops",
163+
"@stablehlo//:stablehlo_ops",
164+
"@stablehlo//:stablehlo_passes",
161165
"@xla//xla/mlir_hlo",
162-
"@enzyme//:EnzymeMLIR",
163-
]
166+
],
164167
)
165168

166169
pybind_library(
167170
name = "compile_with_xla",
168171
srcs = ["compile_with_xla.cc"],
169-
hdrs = glob(["compile_with_xla.h", "Implementations/*.h", "Passes/*.h"]),
172+
hdrs = glob([
173+
"compile_with_xla.h",
174+
"Implementations/*.h",
175+
"Passes/*.h",
176+
]),
170177
deps = [
171178
":XLADerivatives",
172179
# This is similar to xla_binary rule and is needed to make XLA client compile.
@@ -193,7 +200,7 @@ pybind_library(
193200
"@xla//xla/client:client_library",
194201
"@xla//xla/client:executable_build_options",
195202
"@xla//xla/client:xla_computation",
196-
"@xla//xla/service:service",
203+
"@xla//xla/service",
197204
"@xla//xla/service:local_service",
198205
"@xla//xla/service:local_service_utils",
199206
"@xla//xla/service:buffer_assignment_proto_cc",
@@ -212,7 +219,6 @@ pybind_library(
212219
"@xla//xla:xla_data_proto_cc_impl",
213220
"@xla//xla:xla_proto_cc",
214221
"@xla//xla:xla_proto_cc_impl",
215-
216222
"@stablehlo//:stablehlo_ops",
217223

218224
# Make CPU target available to XLA.
@@ -221,7 +227,6 @@ pybind_library(
221227
# MHLO stuff.
222228
"@xla//xla/mlir_hlo",
223229
"@xla//xla/translate/mhlo_to_hlo:mlir_hlo_to_hlo",
224-
225230
"@xla//xla/hlo/ir:hlo",
226231

227232
# This is necessary for XLA protobufs to link
@@ -233,50 +238,46 @@ pybind_library(
233238
"@llvm-project//mlir:FuncDialect",
234239
"@llvm-project//mlir:FuncExtensions",
235240
"@llvm-project//mlir:TensorDialect",
236-
237241
"@llvm-project//mlir:Parser",
238242
"@llvm-project//mlir:Pass",
239-
240243
"@xla//xla/mlir_hlo:all_passes",
241244
"@xla//xla:printer",
242245

243-
# EnzymeMLIR
246+
# EnzymeMLIR
244247
"@enzyme//:EnzymeMLIR",
245-
246248
"@com_google_absl//absl/status:statusor",
247-
249+
248250
# Mosaic
249-
"@jax//jaxlib/mosaic:tpu_dialect",
251+
"@jax//jaxlib/mosaic:tpu_dialect",
250252
],
251253
)
252254

253255
pybind_extension(
254256
name = "enzyme_call",
255257
srcs = ["enzyme_call.cc"],
258+
visibility = ["//visibility:public"],
256259
deps = [
260+
":clang_compile",
261+
":compile_with_xla",
262+
"@com_google_absl//absl/status:statusor",
263+
"@enzyme//:EnzymeMLIR",
264+
"@enzyme//:EnzymeStatic",
257265
"@llvm-project//llvm:Core",
258266
"@llvm-project//llvm:ExecutionEngine",
259267
"@llvm-project//llvm:IRReader",
260268
"@llvm-project//llvm:OrcJIT",
261269
"@llvm-project//llvm:OrcTargetProcess",
262270
"@llvm-project//llvm:Support",
263271
"@llvm-project//mlir:AllPassesAndDialects",
264-
":clang_compile",
265-
":compile_with_xla",
266-
"@com_google_absl//absl/status:statusor",
267272
"@stablehlo//:stablehlo_passes",
268-
"@xla//xla/stream_executor:stream_executor_impl",
273+
"@xla//xla/hlo/ir:hlo",
269274
"@xla//xla/mlir/backends/cpu/transforms:passes",
270275
"@xla//xla/mlir/memref/transforms:passes",
271276
"@xla//xla/mlir/runtime/transforms:passes",
277+
"@xla//xla/mlir_hlo:all_passes",
272278
"@xla//xla/mlir_hlo:deallocation_passes",
273279
"@xla//xla/mlir_hlo:lhlo",
274-
"@xla//xla/mlir_hlo:all_passes",
275-
"@xla//xla/hlo/ir:hlo",
276280
"@xla//xla/service/cpu:cpu_executable",
277-
"@enzyme//:EnzymeStatic",
278-
"@enzyme//:EnzymeMLIR",
279-
281+
"@xla//xla/stream_executor:stream_executor_impl",
280282
],
281-
visibility = ["//visibility:public"],
282283
)

test/BUILD

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
load("@rules_python//python:py_test.bzl", "py_test")
2-
load("@llvm-project//llvm:lit_test.bzl", "package_path", "lit_test")
2+
load("@llvm-project//llvm:lit_test.bzl", "lit_test", "package_path")
33
load("@bazel_skylib//rules:expand_template.bzl", "expand_template")
44

55
expand_template(
@@ -30,8 +30,8 @@ exports_files(
3030
data = [
3131
":lit.cfg.py",
3232
":lit_site_cfg_py",
33-
"//src/enzyme_ad/jax:enzyme_jax_internal",
3433
"//:enzymexlamlir-opt",
34+
"//src/enzyme_ad/jax:enzyme_jax_internal",
3535
"@llvm-project//clang:builtin_headers_gen",
3636
"@llvm-project//llvm:FileCheck",
3737
"@llvm-project//llvm:count",
@@ -40,7 +40,8 @@ exports_files(
4040
)
4141
for src in glob(
4242
[
43-
"**/*.pyt", "**/*.mlir",
43+
"**/*.pyt",
44+
"**/*.mlir",
4445
],
4546
)
4647
]
@@ -65,7 +66,6 @@ py_test(
6566
],
6667
)
6768

68-
6969
py_test(
7070
name = "llama",
7171
srcs = [

0 commit comments

Comments
 (0)