Skip to content

Commit 233bc5b

Browse files
authored
MLIR Reverse Mode (#37)
* WIP jax reverse more * fix * fix enzyme commit * cleanup format * fixup * tmp * continue * fix * fix * fixup * fix bug * fixup * more fixup * sliceslice * fix pad opt * cleanup * fix
1 parent 95cc9f1 commit 233bc5b

16 files changed

+1950
-95
lines changed

BUILD

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,35 @@ py_package(
1818
packages = ["@//src/enzyme_ad/jax:enzyme_call.so", "@llvm-project//clang:builtin_headers_gen"],
1919
)
2020

21+
cc_binary(
22+
name = "enzymexlamlir-opt",
23+
srcs = ["//src/enzyme_ad/jax:enzymexlamlir-opt.cpp"],
24+
visibility = ["//visibility:public"],
25+
deps = [
26+
"//src/enzyme_ad/jax:XLADerivatives",
27+
"@enzyme//:EnzymeMLIR",
28+
"@llvm-project//mlir:AffineDialect",
29+
"@llvm-project//mlir:AllPassesAndDialects",
30+
"@llvm-project//mlir:ArithDialect",
31+
"@llvm-project//mlir:AsyncDialect",
32+
"@llvm-project//mlir:ControlFlowDialect",
33+
"@llvm-project//mlir:ConversionPasses",
34+
"@llvm-project//mlir:DLTIDialect",
35+
"@llvm-project//mlir:FuncDialect",
36+
"@llvm-project//mlir:GPUDialect",
37+
"@llvm-project//mlir:LLVMDialect",
38+
"@llvm-project//mlir:LinalgDialect",
39+
"@llvm-project//mlir:MathDialect",
40+
"@llvm-project//mlir:MemRefDialect",
41+
"@llvm-project//mlir:MlirOptLib",
42+
"@llvm-project//mlir:NVVMDialect",
43+
"@llvm-project//mlir:OpenMPDialect",
44+
"@llvm-project//mlir:Pass",
45+
"@llvm-project//mlir:SCFDialect",
46+
"@llvm-project//mlir:Transforms",
47+
],
48+
)
49+
2150
py_wheel(
2251
name = "enzyme_ad",
2352
distribution = "enzyme_ad",

WORKSPACE

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ http_archive(
3939
strip_prefix = "xla-" + XLA_COMMIT,
4040
urls = ["https://github.com/wsmoses/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)],
4141
patch_args = ["-p1"],
42-
patches = ["//:patches/xla.patch"],
42+
patches = ["//:patches/xla.patch", "//:patches/xla2.patch", ],
4343
)
4444

4545
PYRULES_COMMIT = "fe33a4582c37499f3caeb49a07a78fc7948a8949"
@@ -60,12 +60,8 @@ load("@rules_python//python/pip_install:repositories.bzl", "pip_install_dependen
6060

6161
pip_install_dependencies()
6262

63-
ENZYME_COMMIT = "0b621884bc531329095d202f042f6599a86614ec"
64-
ENZYME_SHA256 = "f9479530b08aeb3ecbf0c420d0e2f222fdf8bcf6c20a218271b365db3a3053ad"
65-
# local_repository(
66-
# name = "enzyme",
67-
# path = "../Enzyme/enzyme"
68-
# )
63+
ENZYME_COMMIT = "0a129ae7e45114a08f281e50632b9f967fae8396"
64+
ENZYME_SHA256 = "715982efd0a0ef8038e8ad35047e9c1941eb3f9cb038883342969b0bcc8915ad"
6965

7066
http_archive(
7167
name = "enzyme",

patches/xla.patch

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,38 @@
1616
)
1717

1818
cc_library(
19+
20+
--- a/xla/mlir/backends/cpu/transforms/BUILD
21+
+++ b/xla/mlir/backends/cpu/transforms/BUILD
22+
@@ -4,7 +4,7 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")
23+
24+
package(
25+
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
26+
- default_visibility = ["//xla:internal"],
27+
+ default_visibility = ["//xla:friends"],
28+
licenses = ["notice"],
29+
)
30+
31+
gentbl_cc_library(
32+
33+
--- a/xla/mlir/memref/BUILD
34+
+++ b/xla/mlir/memref/BUILD
35+
@@ -1,6 +1,7 @@
36+
package_group(
37+
name = "friends",
38+
packages = [
39+
+ "public",
40+
"//xla/mlir/...",
41+
# copybara:uncomment_begin(google-only)
42+
# # TODO(ezhulenev): Clean up dependencies that are leforvers from Autofusion project.
43+
44+
--- a/xla/mlir/math/BUILD
45+
+++ b/xla/mlir/math/BUILD
46+
@@ -1,6 +1,7 @@
47+
package_group(
48+
name = "friends",
49+
packages = [
50+
+ "public",
51+
"//xla/mlir/...",
52+
# copybara:uncomment_begin(google-only)
53+
# # TODO(ezhulenev): Clean up dependencies that are leforvers from Autofusion project.

patches/xla2.patch

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
--- a/xla/mlir/runtime/BUILD
2+
+++ b/xla/mlir/runtime/BUILD
3+
@@ -19,6 +19,7 @@ package_group(
4+
# TODO(ezhulenev): All targets depending on mlir must be under xla/mlir folder
5+
"//xla/service/cpu/...",
6+
"//xla/service/gpu/...",
7+
+ "public",
8+
],
9+
)
10+

src/enzyme_ad/jax/BUILD

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ load("@pybind11_bazel//:build_defs.bzl", "pybind_extension", "pybind_library")
33
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
44
load("@llvm-project//llvm:tblgen.bzl", "gentbl")
55

6+
exports_files(["enzymexlamlir-opt.cpp"])
67
licenses(["notice"])
78

89
package(
@@ -29,10 +30,12 @@ pybind_library(
2930
"@llvm-project//llvm:AsmParser",
3031
"@llvm-project//llvm:CodeGen",
3132
"@llvm-project//llvm:Core",
33+
"@llvm-project//llvm:MC",
3234
"@llvm-project//llvm:IRReader",
3335
"@llvm-project//llvm:Linker",
3436
"@llvm-project//llvm:OrcJIT",
3537
"@llvm-project//llvm:Support",
38+
"@llvm-project//llvm:TargetParser",
3639
],
3740
)
3841

@@ -139,6 +142,16 @@ cc_library(
139142
":stablehlo-derivatives",
140143
"@stablehlo//:stablehlo_ops",
141144
"@stablehlo//:stablehlo_passes",
145+
"@stablehlo//:reference_ops",
146+
"@llvm-project//mlir:ArithDialect",
147+
"@llvm-project//mlir:FuncDialect",
148+
"@llvm-project//mlir:TensorDialect",
149+
"@llvm-project//mlir:IR",
150+
"@llvm-project//mlir:FunctionInterfaces",
151+
"@llvm-project//mlir:ControlFlowInterfaces",
152+
"@llvm-project//mlir:Support",
153+
"@llvm-project//mlir:CommonFolders",
154+
"@llvm-project//mlir:Transforms",
142155
"@xla//xla/mlir_hlo",
143156
"@enzyme//:EnzymeMLIR",
144157
]
@@ -174,6 +187,9 @@ pybind_library(
174187
"@xla//xla/client:client_library",
175188
"@xla//xla/client:executable_build_options",
176189
"@xla//xla/client:xla_computation",
190+
"@xla//xla/service:service",
191+
"@xla//xla/service:local_service",
192+
"@xla//xla/service:local_service_utils",
177193
"@xla//xla/service:buffer_assignment_proto_cc",
178194
"@xla//xla/service:buffer_assignment_proto_cc_impl",
179195
"@xla//xla/service/cpu:cpu_executable",
@@ -191,23 +207,37 @@ pybind_library(
191207
"@xla//xla:xla_proto_cc",
192208
"@xla//xla:xla_proto_cc_impl",
193209

210+
"@stablehlo//:stablehlo_ops",
211+
194212
# Make CPU target available to XLA.
195213
"@xla//xla/service:cpu_plugin",
196214

197215
# MHLO stuff.
198216
"@xla//xla/mlir_hlo",
199217
"@xla//xla/translate/mhlo_to_hlo:mlir_hlo_to_hlo",
200218

219+
"@xla//xla/hlo/ir:hlo",
220+
201221
# This is necessary for XLA protobufs to link
202222
"@com_google_protobuf//:protobuf",
203223

204224
# MLIR dialects and parser.
225+
"@llvm-project//llvm:Support",
205226
"@llvm-project//mlir:ArithDialect",
206227
"@llvm-project//mlir:FuncDialect",
228+
"@llvm-project//mlir:FuncExtensions",
229+
"@llvm-project//mlir:TensorDialect",
230+
207231
"@llvm-project//mlir:Parser",
232+
"@llvm-project//mlir:Pass",
208233

234+
"@xla//xla/mlir_hlo:all_passes",
235+
"@xla//xla:printer",
236+
209237
# EnzymeMLIR
210238
"@enzyme//:EnzymeMLIR",
239+
240+
"@com_google_absl//absl/status:statusor",
211241

212242
# Mosaic
213243
"@jax//jaxlib/mosaic:tpu_dialect",
@@ -230,6 +260,19 @@ pybind_extension(
230260
"@com_google_absl//absl/status:statusor",
231261
"@stablehlo//:stablehlo_passes",
232262
"@xla//xla/stream_executor:stream_executor_impl",
263+
"@xla//xla/mlir/backends/cpu/transforms:passes",
264+
"@xla//xla/mlir/memref/transforms:passes",
265+
"@xla//xla/mlir/math/transforms:passes",
266+
"@xla//xla/mlir/runtime/transforms:passes",
267+
"@xla//xla/mlir_hlo:deallocation_passes",
268+
"@xla//xla/mlir_hlo:lhlo",
269+
"@xla//xla/mlir_hlo:lhlo_gpu",
270+
"@xla//xla/mlir_hlo:all_passes",
271+
"@xla//xla/hlo/ir:hlo",
272+
"@xla//xla/service/cpu:cpu_executable",
273+
"@enzyme//:EnzymeStatic",
274+
"@enzyme//:EnzymeMLIR",
275+
233276
],
234277
visibility = ["//visibility:public"],
235278
)

0 commit comments

Comments
 (0)