@@ -25,20 +25,27 @@ limitations under the License.
25
25
#include " xla/hlo/ir/hlo_instruction.h"
26
26
#include " xla/hlo/ir/hlo_instructions.h"
27
27
#include " xla/hlo/ir/hlo_module.h"
28
- #include " xla/hlo/parser/hlo_parser.h"
29
28
#include " xla/tests/filecheck.h"
30
29
#include " xla/tests/hlo_test_base.h"
31
- #include " xla/tests/test_utils .h"
30
+ #include " xla/tsl/lib/core/status_test_util .h"
32
31
#include " tsl/platform/statusor.h"
33
32
34
33
namespace xla {
35
34
namespace {
36
35
37
36
using ::testing::HasSubstr;
38
37
using CollectivePermuteCycleDecomposerTest = HloTestBase;
38
+ using Decomposer = CollectivePermuteCycleDecomposer;
39
39
40
- TEST_F (CollectivePermuteCycleDecomposerTest, TrivialNotTransformed) {
41
- const absl::string_view kModuleStr = R"(
40
+ HloPrintOptions PrintOptions () {
41
+ HloPrintOptions options;
42
+ options.set_print_operand_shape (false );
43
+ options.set_include_layout_in_shapes (false );
44
+ return options;
45
+ }
46
+
47
+ TEST_F (CollectivePermuteCycleDecomposerTest, NoCycle_NotTransformed) {
48
+ absl::string_view kHlo = R"(
42
49
HloModule test
43
50
ENTRY test_computation {
44
51
p = u32[8,8] parameter(0)
@@ -47,17 +54,14 @@ TEST_F(CollectivePermuteCycleDecomposerTest, TrivialNotTransformed) {
47
54
}
48
55
)" ;
49
56
50
- TF_ASSERT_OK_AND_ASSIGN (std::unique_ptr<HloModule> module ,
51
- ParseAndReturnVerifiedModule ((kModuleStr )));
52
- CollectivePermuteCycleDecomposer decomposer (/* threshold_in_bytes=*/ 0 );
53
- TF_ASSERT_OK_AND_ASSIGN (bool changed, decomposer.Run (module .get ()));
54
- EXPECT_FALSE (changed);
57
+ TF_ASSERT_OK (RunAndCheckHloRewrite (kHlo , Decomposer (0 ), false ));
55
58
}
56
59
57
- TEST_F (CollectivePermuteCycleDecomposerTest, BelowThresholdNotTransformed ) {
60
+ TEST_F (CollectivePermuteCycleDecomposerTest, HonorsThreshold ) {
58
61
// When `size of data` > `threshold`, then it is decomposed, otherwise it
59
62
// stays as it is.
60
- const absl::string_view kModuleStr = R"(
63
+ // u32[4,2] = 4*4*2 = 32 bytes
64
+ absl::string_view hlo = R"(
61
65
HloModule test
62
66
ENTRY test_computation {
63
67
p = u32[4,2] parameter(0)
@@ -66,16 +70,9 @@ TEST_F(CollectivePermuteCycleDecomposerTest, BelowThresholdNotTransformed) {
66
70
}
67
71
)" ;
68
72
69
- TF_ASSERT_OK_AND_ASSIGN (std::unique_ptr<HloModule> module ,
70
- ParseAndReturnVerifiedModule ((kModuleStr )));
71
- CollectivePermuteCycleDecomposer decomposer (/* threshold_in_bytes=*/ 33 );
72
- TF_ASSERT_OK_AND_ASSIGN (
73
- bool changed,
74
- RunHloPass (CollectivePermuteCycleDecomposer (33 ), module .get ()));
75
- EXPECT_FALSE (changed);
76
- TF_ASSERT_OK_AND_ASSIGN (
77
- changed, RunHloPass (CollectivePermuteCycleDecomposer (16 ), module .get ()));
78
- EXPECT_TRUE (changed);
73
+ TF_ASSERT_OK (RunAndCheckHloRewrite (hlo, Decomposer (33 ), false ));
74
+ TF_ASSERT_OK (RunAndCheckHloRewrite (hlo, Decomposer (32 ), true ));
75
+ TF_ASSERT_OK (RunAndCheckHloRewrite (hlo, Decomposer (16 ), true ));
79
76
}
80
77
81
78
TEST_F (CollectivePermuteCycleDecomposerTest, ForwardCycle) {
@@ -84,7 +81,7 @@ TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycle) {
84
81
// 2. They should split over the value of partition-id.
85
82
// 3. The metadata and frontend_attributes are propagated to split
86
83
// collectives.
87
- const absl::string_view kModuleStr = R"(
84
+ absl::string_view hlo = R"(
88
85
HloModule test
89
86
ENTRY test_computation {
90
87
p = u32[8,8] parameter(0)
@@ -94,30 +91,21 @@ TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycle) {
94
91
metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35}
95
92
}
96
93
)" ;
97
-
98
- TF_ASSERT_OK_AND_ASSIGN (std::unique_ptr<HloModule> module ,
99
- ParseAndReturnVerifiedModule ((kModuleStr )));
100
- CollectivePermuteCycleDecomposer decomposer (/* threshold_in_bytes=*/ 0 );
101
- TF_ASSERT_OK_AND_ASSIGN (bool changed, decomposer.Run (module .get ()));
102
- EXPECT_TRUE (changed);
103
-
104
- TF_CHECK_OK (VerifyHloModule (module .get (), false , true ));
105
- HloPrintOptions options;
106
- options.set_print_operand_shape (false );
107
- options.set_include_layout_in_shapes (false );
108
- EXPECT_TRUE (*RunFileCheck (module ->ToString (options), R"(
94
+ TF_ASSERT_OK_AND_ASSIGN (auto module ,
95
+ RunAndCheckHloRewrite (hlo, Decomposer (0 ), true ));
96
+ EXPECT_TRUE (*RunFileCheck (module ->ToString (PrintOptions ()), R"(
109
97
// CHECK: ENTRY %test_computation (p: u32[8,8]) -> u32[8,8] {
110
98
// CHECK-DAG: %[[partition_id:.+]] = u32[] partition-id()
111
99
// CHECK-DAG: %[[c0:.+]] = u32[] constant(0)
112
100
// CHECK-DAG: %[[compare:.+]] = pred[] compare(%[[partition_id]], %[[c0]]), direction=EQ
113
101
// CHECK-DAG: %{{.+}} = u32[8,8] parameter(0)
114
-
115
- // CHECK-DAG: %[[cp1:.+]] = u32[8,8] collective-permute(%{{.+}}), channel_id=1,
102
+
103
+ // CHECK-DAG: %[[cp1:.+]] = u32[8,8] collective-permute(%{{.+}}), channel_id=1,
116
104
// CHECK-SAME{LITERAL}: source_target_pairs={{3,0}}, frontend_attributes={_xla_send_recv_validation={{3,10}}}, metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35}
117
-
118
- // CHECK-DAG: %[[cp2:.+]] = u32[8,8] collective-permute(%{{.+}}), channel_id=2,
105
+
106
+ // CHECK-DAG: %[[cp2:.+]] = u32[8,8] collective-permute(%{{.+}}), channel_id=2,
119
107
// CHECK-SAME{LITERAL}: source_target_pairs={{0,1},{1,2},{2,3}}, frontend_attributes={_xla_send_recv_validation={{0,7},{1,8},{2,9}}}, metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35}
120
-
108
+
121
109
// CHECK-DAG: ROOT %select = u32[8,8] select(%[[compare]], %[[cp1]], %[[cp2]])
122
110
// CHECK-DAG: }
123
111
)" ));
@@ -127,7 +115,7 @@ TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycleNoChannel) {
127
115
// For a forward cycle, this checks:
128
116
// 1. Split collectives should not have channel-id
129
117
// 2. Split collectives are combined based on replica-id.
130
- const absl::string_view kModuleStr = R"(
118
+ absl::string_view hlo = R"(
131
119
HloModule test
132
120
ENTRY test_computation {
133
121
p = u32[8,8] parameter(0)
@@ -136,17 +124,9 @@ TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycleNoChannel) {
136
124
}
137
125
)" ;
138
126
139
- TF_ASSERT_OK_AND_ASSIGN (std::unique_ptr<HloModule> module ,
140
- ParseAndReturnVerifiedModule ((kModuleStr )));
141
- CollectivePermuteCycleDecomposer decomposer (/* threshold_in_bytes=*/ 0 );
142
- TF_ASSERT_OK_AND_ASSIGN (bool changed, decomposer.Run (module .get ()));
143
- EXPECT_TRUE (changed);
144
- TF_CHECK_OK (VerifyHloModule (module .get (), false , true ));
145
-
146
- HloPrintOptions options;
147
- options.set_print_operand_shape (false );
148
- options.set_include_layout_in_shapes (false );
149
- EXPECT_TRUE (*RunFileCheck (module ->ToString (options), R"(
127
+ TF_ASSERT_OK_AND_ASSIGN (auto module ,
128
+ RunAndCheckHloRewrite (hlo, Decomposer (0 ), true ));
129
+ EXPECT_TRUE (*RunFileCheck (module ->ToString (PrintOptions ()), R"(
150
130
// CHECK: ENTRY %test_computation (p: u32[8,8]) -> u32[8,8] {
151
131
// CHECK-DAG: %[[replica_id:.+]] = u32[] replica-id()
152
132
// CHECK-DAG: %[[c0:.+]] = u32[] constant(0)
@@ -155,17 +135,17 @@ TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycleNoChannel) {
155
135
156
136
// CHECK-DAG: %[[cp1:.+]] = u32[8,8] collective-permute(%{{.+}}), source_target_pairs=
157
137
// CHECK-SAME{LITERAL}: {{3,0}}
158
-
138
+
159
139
// CHECK-DAG: %[[cp2:.+]] = u32[8,8] collective-permute(%{{.+}}), source_target_pairs=
160
140
// CHECK-SAME{LITERAL}: {{0,1},{1,2},{2,3}}
161
-
141
+
162
142
// CHECK-DAG: ROOT %select = u32[8,8] select(%[[compare]], %[[cp1]], %[[cp2]])
163
143
// CHECK-DAG: }
164
144
)" ));
165
145
}
166
146
167
147
TEST_F (CollectivePermuteCycleDecomposerTest, ForwardCycleWithMatmul) {
168
- const absl::string_view kModuleStr = R"(
148
+ absl::string_view hlo = R"(
169
149
HloModule test
170
150
171
151
while_cond {
@@ -198,11 +178,8 @@ TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycleWithMatmul) {
198
178
while_res = (u32[], f32[2,2], f32[2,2]) while(input), condition=while_cond, body=while_body
199
179
ROOT data_out = f32[2,2] get-tuple-element(while_res), index=1
200
180
})" ;
201
- TF_ASSERT_OK_AND_ASSIGN (std::unique_ptr<HloModule> module ,
202
- ParseAndReturnVerifiedModule ((kModuleStr )));
203
- CollectivePermuteCycleDecomposer decomposer (/* threshold_in_bytes=*/ 0 );
204
- TF_ASSERT_OK_AND_ASSIGN (bool changed, decomposer.Run (module .get ()));
205
- EXPECT_TRUE (changed);
181
+ TF_ASSERT_OK_AND_ASSIGN (auto module ,
182
+ RunAndCheckHloRewrite (hlo, Decomposer (0 ), true ));
206
183
HloCollectivePermuteInstruction* cp1 =
207
184
DynCast<HloCollectivePermuteInstruction>(
208
185
FindInstruction (module .get (), " cp.backward" ));
@@ -222,7 +199,7 @@ TEST_F(CollectivePermuteCycleDecomposerTest, BackwardCycle) {
222
199
// 1. Metadata is propagated to split collectives.
223
200
// 2. Frontend attributes are accurately split.
224
201
// 3. The split collectives have channel IDs.
225
- const absl::string_view kModuleStr = R"(
202
+ absl::string_view hlo = R"(
226
203
HloModule test
227
204
ENTRY test_computation {
228
205
p = u32[8,8] parameter(0)
@@ -232,29 +209,21 @@ TEST_F(CollectivePermuteCycleDecomposerTest, BackwardCycle) {
232
209
metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35}
233
210
})" ;
234
211
235
- TF_ASSERT_OK_AND_ASSIGN (std::unique_ptr<HloModule> module ,
236
- ParseAndReturnVerifiedModule ((kModuleStr )));
237
- TF_ASSERT_OK_AND_ASSIGN (
238
- bool changed,
239
- RunHloPass (CollectivePermuteCycleDecomposer (0 ), module .get ()));
240
- EXPECT_TRUE (changed);
241
- TF_CHECK_OK (VerifyHloModule (module .get (), true , false ));
242
- HloPrintOptions options;
243
- options.set_print_operand_shape (false );
244
- options.set_include_layout_in_shapes (false );
245
- EXPECT_TRUE (*RunFileCheck (module ->ToString (options), R"(
212
+ TF_ASSERT_OK_AND_ASSIGN (auto module ,
213
+ RunAndCheckHloRewrite (hlo, Decomposer (0 ), true ));
214
+ EXPECT_TRUE (*RunFileCheck (module ->ToString (PrintOptions ()), R"(
246
215
// CHECK: ENTRY %test_computation (p: u32[8,8]) -> u32[8,8] {
247
216
// CHECK-DAG: %[[partition:.+]] = u32[] partition-id()
248
217
// CHECK-DAG: %[[three:.+]] = u32[] constant(3)
249
218
// CHECK-DAG: %[[compare:.+]] = pred[] compare(%[[partition]], %[[three]]), direction=EQ
250
219
// CHECK-DAG: %{{.+}} = u32[8,8] parameter(0)
251
-
220
+
252
221
// CHECK-DAG: %[[cp1:.+]] = u32[8,8] collective-permute(%{{.+}}), channel_id=1, source_target_pairs=
253
222
// CHECK-SAME{LITERAL}: {{0,3}}, frontend_attributes={_xla_send_recv_validation={{0,7}}}, metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35}
254
-
223
+
255
224
// CHECK-DAG: %[[cp2:.+]] = u32[8,8] collective-permute(%{{.+}}), channel_id=2, source_target_pairs=
256
225
// CHECK-SAME{LITERAL}: {{1,0},{2,1},{3,2}}, frontend_attributes={_xla_send_recv_validation={{1,8},{2,9},{3,10}}}, metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35}
257
-
226
+
258
227
// CHECK-DAG: ROOT %{{.+}} = u32[8,8] select(%[[compare]], %[[cp1]], %[[cp2]])
259
228
// CHECK-DAG: }
260
229
)" ));
@@ -264,7 +233,7 @@ TEST_F(CollectivePermuteCycleDecomposerTest, BackwardCycleNoChannel) {
264
233
// For backward cycle, this checks:
265
234
// 1. Split collectives do not have a channel-id
266
235
// 2. Split collectives are combined based on the value of replica-id.
267
- const absl::string_view kModuleStr = R"(
236
+ absl::string_view hlo = R"(
268
237
HloModule test
269
238
ENTRY test_computation {
270
239
p = u32[8,8] parameter(0)
@@ -273,28 +242,21 @@ TEST_F(CollectivePermuteCycleDecomposerTest, BackwardCycleNoChannel) {
273
242
frontend_attributes={_xla_send_recv_validation="{{0,7},{1,8},{2,9},{3,10}}"}
274
243
})" ;
275
244
276
- TF_ASSERT_OK_AND_ASSIGN (std::unique_ptr<HloModule> module ,
277
- ParseAndReturnVerifiedModule ((kModuleStr )));
278
- CollectivePermuteCycleDecomposer decomposer (/* threshold_in_bytes=*/ 0 );
279
- TF_ASSERT_OK_AND_ASSIGN (bool changed, decomposer.Run (module .get ()));
280
- EXPECT_TRUE (changed);
281
- HloPrintOptions options;
282
- options.set_print_operand_shape (false );
283
- options.set_include_layout_in_shapes (false );
284
- TF_CHECK_OK (VerifyHloModule (module .get (), false , true ));
285
- EXPECT_TRUE (*RunFileCheck (module ->ToString (options), R"(
245
+ TF_ASSERT_OK_AND_ASSIGN (auto module ,
246
+ RunAndCheckHloRewrite (hlo, Decomposer (0 ), true ));
247
+ EXPECT_TRUE (*RunFileCheck (module ->ToString (PrintOptions ()), R"(
286
248
// CHECK: ENTRY %test_computation (p: u32[8,8]) -> u32[8,8] {
287
249
// CHECK-DAG: %[[replica_id:.+]] = u32[] replica-id()
288
250
// CHECK-DAG: %[[three:.+]] = u32[] constant(3)
289
251
// CHECK-DAG: %[[compare:.+]] = pred[] compare(%[[replica_id]], %[[three]]), direction=EQ
290
252
// CHECK-DAG: %{{.+}} = u32[8,8] parameter(0)
291
-
253
+
292
254
// CHECK-DAG: %[[cp1:.+]] = u32[8,8] collective-permute(%{{.+}}), source_target_pairs=
293
255
// CHECK-SAME{LITERAL}: {{0,3}}, frontend_attributes={_xla_send_recv_validation={{0,7}}}
294
-
256
+
295
257
// CHECK-DAG: %[[cp2:.+]] = u32[8,8] collective-permute(%{{.+}}), source_target_pairs=
296
258
// CHECK-SAME{LITERAL}: {{1,0},{2,1},{3,2}}, frontend_attributes={_xla_send_recv_validation={{1,8},{2,9},{3,10}}}
297
-
259
+
298
260
// CHECK-DAG: ROOT %select = u32[8,8] select(%[[compare]], %[[cp1]], %[[cp2]])
299
261
// CHECK-DAG: }
300
262
)" ));
0 commit comments