Skip to content

Commit a249db8

Browse files
ICGogtensorflower-gardener
authored andcommitted
[IFRT] Add pass to convert a ifrt.reshard to an ifrt.call.
PiperOrigin-RevId: 688773441
1 parent b80d16b commit a249db8

File tree

7 files changed

+489
-0
lines changed

7 files changed

+489
-0
lines changed

third_party/xla/xla/python/ifrt/ir/constants.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,13 @@ inline constexpr llvm::StringLiteral kCalleeMainFuncName = "main";
5656
// Name of StringAttr used to store the HloSharding.
5757
inline constexpr llvm::StringLiteral kHloShardingAttrName = "mhlo.sharding";
5858

59+
inline constexpr llvm::StringLiteral kIfrtModuleTypeAttrName =
60+
"ifrt.module_type";
61+
62+
inline constexpr llvm::StringLiteral kIfrtModuleTypeXla = "xla";
63+
inline constexpr llvm::StringLiteral kIfrtModuleTypeMpmdReshard =
64+
"mpmd_reshard";
65+
5966
} // namespace ifrt
6067
} // namespace xla
6168

third_party/xla/xla/python/ifrt/ir/tests/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ lit_test_suite(
1111
srcs = enforce_glob(
1212
[
1313
"ifrt_duplicated_callee_elimination.mlir",
14+
"ifrt_lower_mpmd_reshard_to_call.mlir",
1415
"ifrt_lower_sharding_to_xla.mlir",
1516
"ifrt_merge_reshards.mlir",
1617
"ifrt_outline_atom_program_to_module.mlir",
Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
// RUN: ifrt-opt %s -ifrt-lower-mpmd-reshard-to-call -split-input-file -verify-diagnostics | FileCheck %s
2+
3+
!array0 = !ifrt.array<tensor<2x2xi32>,
4+
#ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]>
5+
!array1 = !ifrt.array<tensor<2x2xi32>,
6+
#ifrt.sharding_param<1x1 to [0] on 1>, [2]>
7+
// CHECK-LABEL: @reshard_without_donation
8+
module @reshard_without_donation {
9+
func.func public @main(%arg0: !array0) -> (!array1)
10+
attributes {ifrt.function} {
11+
// CHECK: ifrt.Call @reshard_4784300543980450571::@main(%arg0) on devices [0, 1, 2] {ifrt.module_type = "mpmd_reshard"}
12+
%0, %ctrl_0 = ifrt.Reshard(%arg0) : (!array0) -> !array1
13+
return %0 : !array1
14+
}
15+
16+
// CHECK: module @reshard_4784300543980450571
17+
// CHECK-SAME: attributes {
18+
// CHECK-DAG: ifrt.num_devices = 3
19+
// CHECK-DAG: sym_visibility = "private"
20+
// CHECK-SAME: }
21+
// CHECK: func.func @main(
22+
// CHECK: %arg0: !ifrt.array<tensor<2x2xi32>, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]>
23+
// CHECK: -> !ifrt.array<tensor<2x2xi32>, #ifrt.sharding_param<1x1 to [0] on 1>, [2]>
24+
}
25+
26+
// -----
27+
28+
!array0 = !ifrt.array<tensor<2x2xi32>,
29+
#ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]>
30+
!array1 = !ifrt.array<tensor<2x2xi32>,
31+
#ifrt.sharding_param<1x1 to [0] on 1>, [2]>
32+
// CHECK-LABEL: @reshard_with_donation
33+
module @reshard_with_donation {
34+
func.func public @main(%arg0: !array0) -> (!array1)
35+
attributes {ifrt.function} {
36+
// CHECK: ifrt.Call @reshard_4784300543980450571::@main(%arg0) on devices [0, 1, 2]
37+
// CHECK-SAME: {
38+
// CHECK-DAG: ifrt.module_type = "mpmd_reshard"
39+
// CHECK-DAG: donated_input_indices = array<i32: 0>
40+
// CHECK-SAME: }
41+
%0, %ctrl_0 = ifrt.Reshard(%arg0) {donated=true} : (!array0) -> !array1
42+
return %0 : !array1
43+
}
44+
45+
// CHECK: module @reshard_4784300543980450571
46+
// CHECK-SAME: attributes {
47+
// CHECK-DAG: ifrt.num_devices = 3
48+
// CHECK-DAG: sym_visibility = "private"
49+
// CHECK-SAME: }
50+
// CHECK: func.func @main(
51+
// CHECK: %arg0: !ifrt.array<tensor<2x2xi32>, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]>
52+
// CHECK: -> !ifrt.array<tensor<2x2xi32>, #ifrt.sharding_param<1x1 to [0] on 1>, [2]>
53+
}
54+
55+
// -----
56+
57+
!array0 = !ifrt.array<tensor<2x2xi32>,
58+
#ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]>
59+
!array1 = !ifrt.array<tensor<2x2xi32>,
60+
#ifrt.sharding_param<2x1 to [0] on 2>, [2, 3]>
61+
// ifrt.Reshard does not need to be converted to a MPMD reshard ifrt.Call
62+
// because the reshard is a 1:1 buffer copy between devices.
63+
module @reshard_is_not_converted_to_call {
64+
func.func public @main(%arg0: !array0) -> (!array1)
65+
attributes {ifrt.function} {
66+
// expected-error@+1 {{'ifrt.Reshard' op does not reshard any arrays. Use CopyArraysOp instead}}
67+
%0, %ctrl_0 = ifrt.Reshard(%arg0) : (!array0) -> !array1
68+
return %0 : !array1
69+
}
70+
}
71+
72+
// -----
73+
74+
!array0 = !ifrt.array<tensor<2x2xi32>,
75+
#ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]>
76+
!array1 = !ifrt.array<tensor<2x2xi32>,
77+
#ifrt.sharding_param<1x1 to [0] on 1>, [2]>
78+
// CHECK-LABEL: @reshard_after_call_to_module
79+
module @reshard_after_call_to_module {
80+
func.func public @main(%arg0: !array0) -> (!array1)
81+
attributes {ifrt.function} {
82+
// CHECK: %[[OUT_1:.*]], %[[CTRL_OUT:.*]] = ifrt.Call @add_one
83+
%0, %ctrl_0 = ifrt.Call @add_one(%arg0) on devices [0, 1]
84+
: (!array0) -> !array0
85+
// CHECK: %[[OUT_2:.*]], %{{.+}} = ifrt.Call @reshard_4784300543980450571::@main(%[[OUT_1]]) after %[[CTRL_OUT]]
86+
// CHECK: {ifrt.module_type = "mpmd_reshard"}
87+
%1, %ctrl_1 = ifrt.Reshard(%0) after %ctrl_0 : (!array0) -> !array1
88+
return %1 : !array1
89+
}
90+
91+
func.func private @add_one(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> {
92+
%0 = mhlo.constant dense<1> : tensor<2x2xi32>
93+
%1 = mhlo.add %arg0, %0 : tensor<2x2xi32>
94+
return %1 : tensor<2x2xi32>
95+
}
96+
97+
// CHECK: module @reshard_4784300543980450571
98+
// CHECK-SAME: attributes {
99+
// CHECK-DAG: ifrt.num_devices = 3
100+
// CHECK-DAG: sym_visibility = "private"
101+
// CHECK-SAME: }
102+
// CHECK: func.func @main(
103+
// CHECK: %arg0: !ifrt.array<tensor<2x2xi32>, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]>
104+
// CHECK: -> !ifrt.array<tensor<2x2xi32>, #ifrt.sharding_param<1x1 to [0] on 1>, [2]>
105+
}
106+
107+
// -----
108+
109+
!array0 = !ifrt.array<tensor<2x2xi32>,
110+
#ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]>
111+
!array1 = !ifrt.array<tensor<2x2xi32>,
112+
#ifrt.sharding_param<1x1 to [0] on 1>, [2]>
113+
// CHECK-LABEL: @reshard_before_call_to_module
114+
module @reshard_before_call_to_module {
115+
func.func public @main(%arg0: !array0) -> (!array1)
116+
attributes {ifrt.function} {
117+
// CHECK: ifrt.Call @reshard_4784300543980450571::@main(%arg0) on devices [0, 1, 2] {ifrt.module_type = "mpmd_reshard"}
118+
%0, %ctrl_0 = ifrt.Reshard(%arg0) : (!array0) -> !array1
119+
// CHECK: %[[OUT:.*]], %[[CTRL_OUT:.*]] = ifrt.Call @add_one
120+
%1, %ctrl_1 = ifrt.Call @add_one(%0) on devices [2]
121+
: (!array1) -> !array1
122+
return %1 : !array1
123+
}
124+
125+
func.func private @add_one(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> {
126+
%0 = mhlo.constant dense<1> : tensor<2x2xi32>
127+
%1 = mhlo.add %arg0, %0 : tensor<2x2xi32>
128+
return %1 : tensor<2x2xi32>
129+
}
130+
131+
// CHECK: module @reshard_4784300543980450571
132+
// CHECK-SAME: attributes {
133+
// CHECK-DAG: ifrt.num_devices = 3
134+
// CHECK-DAG: sym_visibility = "private"
135+
// CHECK-SAME: }
136+
// CHECK: func.func @main(
137+
// CHECK: %arg0: !ifrt.array<tensor<2x2xi32>, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]>
138+
// CHECK: -> !ifrt.array<tensor<2x2xi32>, #ifrt.sharding_param<1x1 to [0] on 1>, [2]>
139+
}
140+
141+
// -----
142+
143+
!array0 = !ifrt.array<tensor<2x2xi32>,
144+
#ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]>
145+
!array1 = !ifrt.array<tensor<2x2xi32>,
146+
#ifrt.sharding_param<1x1 to [0] on 1>, [2]>
147+
// CHECK-LABEL: @two_identical_reshards_single_module
148+
module @two_identical_reshards_single_module {
149+
func.func public @main(%arg0: !array0, %arg1: !array0) -> (!array1, !array1)
150+
attributes {ifrt.function} {
151+
// CHECK: ifrt.Call @reshard_4784300543980450571::@main(%arg0) on devices [0, 1, 2] {ifrt.module_type = "mpmd_reshard"}
152+
%0, %ctrl_0 = ifrt.Reshard(%arg0) : (!array0) -> !array1
153+
// CHECK: ifrt.Call @reshard_4784300543980450571::@main(%arg1) on devices [0, 1, 2] {ifrt.module_type = "mpmd_reshard"}
154+
%1, %ctrl_1 = ifrt.Reshard(%arg1) : (!array0) -> !array1
155+
return %0, %1 : !array1, !array1
156+
}
157+
158+
// CHECK: module @reshard_4784300543980450571
159+
// CHECK-SAME: attributes {
160+
// CHECK-DAG: ifrt.num_devices = 3
161+
// CHECK-DAG: sym_visibility = "private"
162+
// CHECK-SAME: }
163+
// CHECK: func.func @main(
164+
// CHECK: %arg0: !ifrt.array<tensor<2x2xi32>, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]>
165+
// CHECK: -> !ifrt.array<tensor<2x2xi32>, #ifrt.sharding_param<1x1 to [0] on 1>, [2]>
166+
}
167+
168+
// -----
169+
170+
!array0 = !ifrt.array<tensor<2x2xi32>,
171+
#ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]>
172+
!array1 = !ifrt.array<tensor<2x2xi32>,
173+
#ifrt.sharding_param<1x1 to [0] on 1>, [2]>
174+
// CHECK-LABEL: @two_reshards_two_modules
175+
module @two_reshards_two_modules {
176+
func.func public @main(%arg0: !array0) -> (!array0)
177+
attributes {ifrt.function} {
178+
// CHECK: %[[OUT:.+]], %{{.+}} = ifrt.Call @reshard_4784300543980450571::@main(%arg0) on devices [0, 1, 2] {ifrt.module_type = "mpmd_reshard"}
179+
%0, %ctrl_0 = ifrt.Reshard(%arg0) : (!array0) -> !array1
180+
// CHECK: ifrt.Call @reshard_17322361279023763284::@main(%[[OUT]]) on devices [0, 1, 2] {ifrt.module_type = "mpmd_reshard"}
181+
%1, %ctrl_1 = ifrt.Reshard(%0) : (!array1) -> !array0
182+
return %1 : !array0
183+
}
184+
185+
// CHECK: module @reshard_4784300543980450571
186+
// CHECK-SAME: attributes {
187+
// CHECK-DAG: ifrt.num_devices = 3
188+
// CHECK-DAG: sym_visibility = "private"
189+
// CHECK-SAME: }
190+
// CHECK: func.func @main(
191+
// CHECK: %arg0: !ifrt.array<tensor<2x2xi32>, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]>
192+
// CHECK: -> !ifrt.array<tensor<2x2xi32>, #ifrt.sharding_param<1x1 to [0] on 1>, [2]>
193+
194+
// CHECK: module @reshard_17322361279023763284
195+
// CHECK-SAME: attributes {
196+
// CHECK-DAG: ifrt.num_devices = 3
197+
// CHECK-DAG: sym_visibility = "private"
198+
// CHECK-SAME: }
199+
// CHECK: func.func @main(
200+
// CHECK: %arg0: !ifrt.array<tensor<2x2xi32>, #ifrt.sharding_param<1x1 to [0] on 1>, [2]>
201+
// CHECK: -> !ifrt.array<tensor<2x2xi32>, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]>
202+
}
203+
204+
// -----
205+
206+
!array0 = !ifrt.array<tensor<2x2xi32>,
207+
#ifrt.sharding_param<1x1 to [0] on 1>, [0]>
208+
!array1 = !ifrt.array<tensor<2x2xi32>,
209+
#ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]>
210+
// Tests if the module for the MPMD reshard has unique devices.
211+
// CHECK-LABEL: @check_reshard_module_has_unique_devices
212+
module @check_reshard_module_has_unique_devices {
213+
func.func @main(%arg0: !array0) -> !array1 attributes {ifrt.function} {
214+
// CHECK: ifrt.Call @reshard_6746659470058475136::@main(%arg0) on devices [0, 1] {ifrt.module_type = "mpmd_reshard"}
215+
%0, %ctrl_0 = ifrt.Reshard(%arg0) : (!array0) -> !array1
216+
return %0 : !array1
217+
}
218+
219+
// CHECK: module @reshard_6746659470058475136
220+
// CHECK-SAME: attributes {
221+
// CHECK-DAG: ifrt.num_devices = 2
222+
// CHECK-DAG: sym_visibility = "private"
223+
// CHECK-SAME: }
224+
// CHECK: func.func @main(
225+
// CHECK: %arg0: !ifrt.array<tensor<2x2xi32>, #ifrt.sharding_param<1x1 to [0] on 1>, [0]>
226+
// CHECK: -> !ifrt.array<tensor<2x2xi32>, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]>
227+
}

third_party/xla/xla/python/ifrt/ir/transforms/BUILD

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ cc_library(
3030
name = "passes",
3131
srcs = [
3232
"ifrt_duplicated_callee_elimination_pass.cc",
33+
"ifrt_lower_mpmd_reshard_to_call_pass.cc",
3334
"ifrt_lower_sharding_to_xla_pass.cc",
3435
"ifrt_merge_reshards_pass.cc",
3536
"ifrt_outline_atom_program_to_module_pass.cc",
@@ -52,17 +53,20 @@ cc_library(
5253
"//xla/mlir_hlo",
5354
"//xla/python/ifrt/ir",
5455
"//xla/python/ifrt/support:sharding_conversions",
56+
"@com_google_absl//absl/container:btree",
5557
"@com_google_absl//absl/container:flat_hash_set",
5658
"@com_google_absl//absl/log:check",
5759
"@com_google_absl//absl/status",
5860
"@com_google_absl//absl/status:statusor",
61+
"@com_google_absl//absl/strings",
5962
"@llvm-project//llvm:Support",
6063
"@llvm-project//mlir:FuncDialect",
6164
"@llvm-project//mlir:IR",
6265
"@llvm-project//mlir:Pass",
6366
"@llvm-project//mlir:Support",
6467
"@llvm-project//mlir:TransformUtils",
6568
"@llvm-project//mlir:Transforms",
69+
"@local_tsl//tsl/platform:fingerprint",
6670
"@stablehlo//:stablehlo_ops",
6771
],
6872
)

0 commit comments

Comments
 (0)