Skip to content

Commit f26b03d

Browse files
cotaGoogle-ML-Automation
authored andcommitted
[xla:cpu] add scatter fusion emitter
Note that the emitter is disabled for now. We will enable it in a follow-up CL. PiperOrigin-RevId: 736516740
1 parent 2bcd3f0 commit f26b03d

12 files changed

+1594
-14
lines changed
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
load("//xla:xla.bzl", "xla_cc_test")
2+
load("//xla/tsl/platform:rules_cc.bzl", "cc_library")
3+
4+
package(
5+
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
6+
default_visibility = ["//xla/backends/cpu:xla_backend_cpu_internal_access"],
7+
licenses = ["notice"],
8+
)
9+
10+
package_group(
11+
name = "friends",
12+
includes = [
13+
"//xla:friends",
14+
],
15+
)
16+
17+
cc_library(
18+
name = "cpu_fusion_emitter_config",
19+
hdrs = ["cpu_fusion_emitter_config.h"],
20+
)
21+
22+
cc_library(
23+
name = "cpu_fusion_emitters",
24+
srcs = [
25+
"cpu_fusion_emitter.cc",
26+
"cpu_scatter_emitter.cc",
27+
],
28+
hdrs = [
29+
"cpu_fusion_emitter.h",
30+
"cpu_scatter_emitter.h",
31+
],
32+
deps = [
33+
"//xla:shape_util",
34+
"//xla:status_macros",
35+
"//xla:util",
36+
"//xla:xla_data_proto_cc",
37+
"//xla/backends/cpu/codegen:kernel_api_ir_builder",
38+
"//xla/backends/cpu/codegen/emitters/ir:xla_cpu",
39+
"//xla/backends/cpu/codegen/emitters/transforms:passes",
40+
"//xla/codegen/emitters:computation_partitioner",
41+
"//xla/codegen/emitters:elemental_hlo_to_mlir",
42+
"//xla/codegen/emitters:type_util",
43+
"//xla/codegen/emitters/ir:xla",
44+
"//xla/codegen/emitters/transforms:passes",
45+
"//xla/hlo/analysis:indexing_analysis",
46+
"//xla/hlo/ir:hlo",
47+
"//xla/mlir/tools/mlir_replay/public:compiler_trace_proto_cc",
48+
"//xla/mlir_hlo",
49+
"//xla/mlir_hlo:mhlo_passes",
50+
"//xla/service:buffer_assignment",
51+
"//xla/service:dump",
52+
"//xla/service:scatter_simplifier",
53+
"//xla/service/cpu:backend_config_proto_cc",
54+
"//xla/service/llvm_ir:llvm_util",
55+
"//xla/tsl/framework/mlir:status_scoped_diagnostic_handler",
56+
"//xla/tsl/platform:errors",
57+
"//xla/tsl/platform:statusor",
58+
"@com_google_absl//absl/algorithm:container",
59+
"@com_google_absl//absl/container:flat_hash_map",
60+
"@com_google_absl//absl/container:flat_hash_set",
61+
"@com_google_absl//absl/log",
62+
"@com_google_absl//absl/log:check",
63+
"@com_google_absl//absl/status",
64+
"@com_google_absl//absl/status:statusor",
65+
"@com_google_absl//absl/strings",
66+
"@com_google_absl//absl/types:span",
67+
"@llvm-project//llvm:Linker",
68+
"@llvm-project//llvm:Support",
69+
"@llvm-project//llvm:ir_headers",
70+
"@llvm-project//mlir:AffineDialect",
71+
"@llvm-project//mlir:AffineToStandard",
72+
"@llvm-project//mlir:ArithDialect",
73+
"@llvm-project//mlir:BufferizationInterfaces",
74+
"@llvm-project//mlir:BuiltinToLLVMIRTranslation",
75+
"@llvm-project//mlir:ComplexToStandard",
76+
"@llvm-project//mlir:ControlFlowDialect",
77+
"@llvm-project//mlir:DLTIDialect",
78+
"@llvm-project//mlir:DataLayoutInterfaces",
79+
"@llvm-project//mlir:FuncDialect",
80+
"@llvm-project//mlir:FuncExtensions",
81+
"@llvm-project//mlir:IR",
82+
"@llvm-project//mlir:LLVMDialect",
83+
"@llvm-project//mlir:LLVMIRTransforms",
84+
"@llvm-project//mlir:LLVMToLLVMIRTranslation",
85+
"@llvm-project//mlir:MathDialect",
86+
"@llvm-project//mlir:MemRefTransforms",
87+
"@llvm-project//mlir:NVVMToLLVMIRTranslation",
88+
"@llvm-project//mlir:Pass",
89+
"@llvm-project//mlir:ROCDLToLLVMIRTranslation",
90+
"@llvm-project//mlir:ReconcileUnrealizedCasts",
91+
"@llvm-project//mlir:SCFDialect",
92+
"@llvm-project//mlir:SCFToControlFlow",
93+
"@llvm-project//mlir:Support",
94+
"@llvm-project//mlir:TensorDialect",
95+
"@llvm-project//mlir:ToLLVMIRTranslation",
96+
"@llvm-project//mlir:Transforms",
97+
"@llvm-project//mlir:VectorDialect",
98+
],
99+
)
100+
101+
xla_cc_test(
102+
name = "cpu_fusion_emitter_test",
103+
srcs = ["cpu_fusion_emitter_test.cc"],
104+
deps = [
105+
":cpu_fusion_emitters",
106+
"//xla/hlo/analysis:hlo_ordering",
107+
"//xla/hlo/ir:hlo",
108+
"//xla/hlo/testlib:filecheck",
109+
"//xla/mlir_hlo",
110+
"//xla/service:buffer_assignment",
111+
"//xla/service:logical_buffer",
112+
"//xla/tests:hlo_test_base",
113+
"//xla/tests:xla_internal_test_main",
114+
"//xla/tsl/platform:statusor",
115+
"@com_google_absl//absl/status:statusor",
116+
"@com_google_absl//absl/strings:string_view",
117+
"@com_google_googletest//:gtest",
118+
"@llvm-project//llvm:Support",
119+
"@llvm-project//llvm:ir_headers",
120+
"@llvm-project//mlir:AffineDialect",
121+
"@llvm-project//mlir:ArithDialect",
122+
"@llvm-project//mlir:BuiltinToLLVMIRTranslation",
123+
"@llvm-project//mlir:ComplexDialect",
124+
"@llvm-project//mlir:FuncDialect",
125+
"@llvm-project//mlir:FuncExtensions",
126+
"@llvm-project//mlir:IR",
127+
"@llvm-project//mlir:LLVMToLLVMIRTranslation",
128+
"@llvm-project//mlir:MathDialect",
129+
"@llvm-project//mlir:NVVMToLLVMIRTranslation",
130+
"@llvm-project//mlir:Pass",
131+
"@llvm-project//mlir:ROCDLToLLVMIRTranslation",
132+
"@llvm-project//mlir:SCFDialect",
133+
"@llvm-project//mlir:TensorDialect",
134+
],
135+
)

0 commit comments

Comments
 (0)