Skip to content

Commit b833d6f

Browse files
toli-ytensorflower-gardener
authored andcommitted
Use RunAndCheckHloRewrite in collective_permute_cycle_decomposer_test.cc
PiperOrigin-RevId: 685025195
1 parent 40da335 commit b833d6f

File tree

3 files changed

+51
-91
lines changed

3 files changed

+51
-91
lines changed

third_party/xla/xla/service/gpu/transforms/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,7 @@ xla_cc_test(
464464
"//xla/tests:filecheck",
465465
"//xla/tests:hlo_test_base",
466466
"//xla/tests:test_utils",
467+
"//xla/tsl/lib/core:status_test_util",
467468
"@com_google_absl//absl/strings:string_view",
468469
"@com_google_googletest//:gtest",
469470
"@local_tsl//tsl/platform:statusor",

third_party/xla/xla/service/gpu/transforms/collective_permute_cycle_decomposer.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,6 @@ class CollectivePermuteCycleDecomposer : public HloModulePass {
5656
return "collective-permute-cycle-decomposer";
5757
}
5858

59-
using HloPassInterface::Run;
60-
// Runs CollectivePermuteCycleDecomposer pass on computations in 'module'.
61-
// Returns whether the 'module' was changed.
6259
absl::StatusOr<bool> Run(
6360
HloModule* module,
6461
const absl::flat_hash_set<absl::string_view>& execution_threads) override;

third_party/xla/xla/service/gpu/transforms/collective_permute_cycle_decomposer_test.cc

Lines changed: 50 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,27 @@ limitations under the License.
2525
#include "xla/hlo/ir/hlo_instruction.h"
2626
#include "xla/hlo/ir/hlo_instructions.h"
2727
#include "xla/hlo/ir/hlo_module.h"
28-
#include "xla/hlo/parser/hlo_parser.h"
2928
#include "xla/tests/filecheck.h"
3029
#include "xla/tests/hlo_test_base.h"
31-
#include "xla/tests/test_utils.h"
30+
#include "xla/tsl/lib/core/status_test_util.h"
3231
#include "tsl/platform/statusor.h"
3332

3433
namespace xla {
3534
namespace {
3635

3736
using ::testing::HasSubstr;
3837
using CollectivePermuteCycleDecomposerTest = HloTestBase;
38+
using Decomposer = CollectivePermuteCycleDecomposer;
3939

40-
TEST_F(CollectivePermuteCycleDecomposerTest, TrivialNotTransformed) {
41-
const absl::string_view kModuleStr = R"(
40+
HloPrintOptions PrintOptions() {
41+
HloPrintOptions options;
42+
options.set_print_operand_shape(false);
43+
options.set_include_layout_in_shapes(false);
44+
return options;
45+
}
46+
47+
TEST_F(CollectivePermuteCycleDecomposerTest, NoCycle_NotTransformed) {
48+
absl::string_view kHlo = R"(
4249
HloModule test
4350
ENTRY test_computation {
4451
p = u32[8,8] parameter(0)
@@ -47,17 +54,14 @@ TEST_F(CollectivePermuteCycleDecomposerTest, TrivialNotTransformed) {
4754
}
4855
)";
4956

50-
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
51-
ParseAndReturnVerifiedModule((kModuleStr)));
52-
CollectivePermuteCycleDecomposer decomposer(/*threshold_in_bytes=*/0);
53-
TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get()));
54-
EXPECT_FALSE(changed);
57+
TF_ASSERT_OK(RunAndCheckHloRewrite(kHlo, Decomposer(0), false));
5558
}
5659

57-
TEST_F(CollectivePermuteCycleDecomposerTest, BelowThresholdNotTransformed) {
60+
TEST_F(CollectivePermuteCycleDecomposerTest, HonorsThreshold) {
5861
// When `size of data` > `threshold`, then it is decomposed, otherwise it
5962
// stays as it is.
60-
const absl::string_view kModuleStr = R"(
63+
// u32[4,2] = 4*4*2 = 32 bytes
64+
absl::string_view hlo = R"(
6165
HloModule test
6266
ENTRY test_computation {
6367
p = u32[4,2] parameter(0)
@@ -66,16 +70,9 @@ TEST_F(CollectivePermuteCycleDecomposerTest, BelowThresholdNotTransformed) {
6670
}
6771
)";
6872

69-
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
70-
ParseAndReturnVerifiedModule((kModuleStr)));
71-
CollectivePermuteCycleDecomposer decomposer(/*threshold_in_bytes=*/33);
72-
TF_ASSERT_OK_AND_ASSIGN(
73-
bool changed,
74-
RunHloPass(CollectivePermuteCycleDecomposer(33), module.get()));
75-
EXPECT_FALSE(changed);
76-
TF_ASSERT_OK_AND_ASSIGN(
77-
changed, RunHloPass(CollectivePermuteCycleDecomposer(16), module.get()));
78-
EXPECT_TRUE(changed);
73+
TF_ASSERT_OK(RunAndCheckHloRewrite(hlo, Decomposer(33), false));
74+
TF_ASSERT_OK(RunAndCheckHloRewrite(hlo, Decomposer(32), true));
75+
TF_ASSERT_OK(RunAndCheckHloRewrite(hlo, Decomposer(16), true));
7976
}
8077

8178
TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycle) {
@@ -84,7 +81,7 @@ TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycle) {
8481
// 2. They should split over the value of partition-id.
8582
// 3. The metadata and frontend_attributes are propagated to split
8683
// collectives.
87-
const absl::string_view kModuleStr = R"(
84+
absl::string_view hlo = R"(
8885
HloModule test
8986
ENTRY test_computation {
9087
p = u32[8,8] parameter(0)
@@ -94,30 +91,21 @@ TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycle) {
9491
metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35}
9592
}
9693
)";
97-
98-
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
99-
ParseAndReturnVerifiedModule((kModuleStr)));
100-
CollectivePermuteCycleDecomposer decomposer(/*threshold_in_bytes=*/0);
101-
TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get()));
102-
EXPECT_TRUE(changed);
103-
104-
TF_CHECK_OK(VerifyHloModule(module.get(), false, true));
105-
HloPrintOptions options;
106-
options.set_print_operand_shape(false);
107-
options.set_include_layout_in_shapes(false);
108-
EXPECT_TRUE(*RunFileCheck(module->ToString(options), R"(
94+
TF_ASSERT_OK_AND_ASSIGN(auto module,
95+
RunAndCheckHloRewrite(hlo, Decomposer(0), true));
96+
EXPECT_TRUE(*RunFileCheck(module->ToString(PrintOptions()), R"(
10997
// CHECK: ENTRY %test_computation (p: u32[8,8]) -> u32[8,8] {
11098
// CHECK-DAG: %[[partition_id:.+]] = u32[] partition-id()
11199
// CHECK-DAG: %[[c0:.+]] = u32[] constant(0)
112100
// CHECK-DAG: %[[compare:.+]] = pred[] compare(%[[partition_id]], %[[c0]]), direction=EQ
113101
// CHECK-DAG: %{{.+}} = u32[8,8] parameter(0)
114-
115-
// CHECK-DAG: %[[cp1:.+]] = u32[8,8] collective-permute(%{{.+}}), channel_id=1,
102+
103+
// CHECK-DAG: %[[cp1:.+]] = u32[8,8] collective-permute(%{{.+}}), channel_id=1,
116104
// CHECK-SAME{LITERAL}: source_target_pairs={{3,0}}, frontend_attributes={_xla_send_recv_validation={{3,10}}}, metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35}
117-
118-
// CHECK-DAG: %[[cp2:.+]] = u32[8,8] collective-permute(%{{.+}}), channel_id=2,
105+
106+
// CHECK-DAG: %[[cp2:.+]] = u32[8,8] collective-permute(%{{.+}}), channel_id=2,
119107
// CHECK-SAME{LITERAL}: source_target_pairs={{0,1},{1,2},{2,3}}, frontend_attributes={_xla_send_recv_validation={{0,7},{1,8},{2,9}}}, metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35}
120-
108+
121109
// CHECK-DAG: ROOT %select = u32[8,8] select(%[[compare]], %[[cp1]], %[[cp2]])
122110
// CHECK-DAG: }
123111
)"));
@@ -127,7 +115,7 @@ TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycleNoChannel) {
127115
// For a forward cycle, this checks:
128116
// 1. Split collectives should not have channel-id
129117
// 2. Split collectives are combined based on replica-id.
130-
const absl::string_view kModuleStr = R"(
118+
absl::string_view hlo = R"(
131119
HloModule test
132120
ENTRY test_computation {
133121
p = u32[8,8] parameter(0)
@@ -136,17 +124,9 @@ TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycleNoChannel) {
136124
}
137125
)";
138126

139-
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
140-
ParseAndReturnVerifiedModule((kModuleStr)));
141-
CollectivePermuteCycleDecomposer decomposer(/*threshold_in_bytes=*/0);
142-
TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get()));
143-
EXPECT_TRUE(changed);
144-
TF_CHECK_OK(VerifyHloModule(module.get(), false, true));
145-
146-
HloPrintOptions options;
147-
options.set_print_operand_shape(false);
148-
options.set_include_layout_in_shapes(false);
149-
EXPECT_TRUE(*RunFileCheck(module->ToString(options), R"(
127+
TF_ASSERT_OK_AND_ASSIGN(auto module,
128+
RunAndCheckHloRewrite(hlo, Decomposer(0), true));
129+
EXPECT_TRUE(*RunFileCheck(module->ToString(PrintOptions()), R"(
150130
// CHECK: ENTRY %test_computation (p: u32[8,8]) -> u32[8,8] {
151131
// CHECK-DAG: %[[replica_id:.+]] = u32[] replica-id()
152132
// CHECK-DAG: %[[c0:.+]] = u32[] constant(0)
@@ -155,17 +135,17 @@ TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycleNoChannel) {
155135
156136
// CHECK-DAG: %[[cp1:.+]] = u32[8,8] collective-permute(%{{.+}}), source_target_pairs=
157137
// CHECK-SAME{LITERAL}: {{3,0}}
158-
138+
159139
// CHECK-DAG: %[[cp2:.+]] = u32[8,8] collective-permute(%{{.+}}), source_target_pairs=
160140
// CHECK-SAME{LITERAL}: {{0,1},{1,2},{2,3}}
161-
141+
162142
// CHECK-DAG: ROOT %select = u32[8,8] select(%[[compare]], %[[cp1]], %[[cp2]])
163143
// CHECK-DAG: }
164144
)"));
165145
}
166146

167147
TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycleWithMatmul) {
168-
const absl::string_view kModuleStr = R"(
148+
absl::string_view hlo = R"(
169149
HloModule test
170150
171151
while_cond {
@@ -198,11 +178,8 @@ TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycleWithMatmul) {
198178
while_res = (u32[], f32[2,2], f32[2,2]) while(input), condition=while_cond, body=while_body
199179
ROOT data_out = f32[2,2] get-tuple-element(while_res), index=1
200180
})";
201-
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
202-
ParseAndReturnVerifiedModule((kModuleStr)));
203-
CollectivePermuteCycleDecomposer decomposer(/*threshold_in_bytes=*/0);
204-
TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get()));
205-
EXPECT_TRUE(changed);
181+
TF_ASSERT_OK_AND_ASSIGN(auto module,
182+
RunAndCheckHloRewrite(hlo, Decomposer(0), true));
206183
HloCollectivePermuteInstruction* cp1 =
207184
DynCast<HloCollectivePermuteInstruction>(
208185
FindInstruction(module.get(), "cp.backward"));
@@ -222,7 +199,7 @@ TEST_F(CollectivePermuteCycleDecomposerTest, BackwardCycle) {
222199
// 1. Metadata is propagated to split collectives.
223200
// 2. Frontend attributes are accurately split.
224201
// 3. The split collectives have channel IDs.
225-
const absl::string_view kModuleStr = R"(
202+
absl::string_view hlo = R"(
226203
HloModule test
227204
ENTRY test_computation {
228205
p = u32[8,8] parameter(0)
@@ -232,29 +209,21 @@ TEST_F(CollectivePermuteCycleDecomposerTest, BackwardCycle) {
232209
metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35}
233210
})";
234211

235-
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
236-
ParseAndReturnVerifiedModule((kModuleStr)));
237-
TF_ASSERT_OK_AND_ASSIGN(
238-
bool changed,
239-
RunHloPass(CollectivePermuteCycleDecomposer(0), module.get()));
240-
EXPECT_TRUE(changed);
241-
TF_CHECK_OK(VerifyHloModule(module.get(), true, false));
242-
HloPrintOptions options;
243-
options.set_print_operand_shape(false);
244-
options.set_include_layout_in_shapes(false);
245-
EXPECT_TRUE(*RunFileCheck(module->ToString(options), R"(
212+
TF_ASSERT_OK_AND_ASSIGN(auto module,
213+
RunAndCheckHloRewrite(hlo, Decomposer(0), true));
214+
EXPECT_TRUE(*RunFileCheck(module->ToString(PrintOptions()), R"(
246215
// CHECK: ENTRY %test_computation (p: u32[8,8]) -> u32[8,8] {
247216
// CHECK-DAG: %[[partition:.+]] = u32[] partition-id()
248217
// CHECK-DAG: %[[three:.+]] = u32[] constant(3)
249218
// CHECK-DAG: %[[compare:.+]] = pred[] compare(%[[partition]], %[[three]]), direction=EQ
250219
// CHECK-DAG: %{{.+}} = u32[8,8] parameter(0)
251-
220+
252221
// CHECK-DAG: %[[cp1:.+]] = u32[8,8] collective-permute(%{{.+}}), channel_id=1, source_target_pairs=
253222
// CHECK-SAME{LITERAL}: {{0,3}}, frontend_attributes={_xla_send_recv_validation={{0,7}}}, metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35}
254-
223+
255224
// CHECK-DAG: %[[cp2:.+]] = u32[8,8] collective-permute(%{{.+}}), channel_id=2, source_target_pairs=
256225
// CHECK-SAME{LITERAL}: {{1,0},{2,1},{3,2}}, frontend_attributes={_xla_send_recv_validation={{1,8},{2,9},{3,10}}}, metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35}
257-
226+
258227
// CHECK-DAG: ROOT %{{.+}} = u32[8,8] select(%[[compare]], %[[cp1]], %[[cp2]])
259228
// CHECK-DAG: }
260229
)"));
@@ -264,7 +233,7 @@ TEST_F(CollectivePermuteCycleDecomposerTest, BackwardCycleNoChannel) {
264233
// For backward cycle, this checks:
265234
// 1. Split collectives do not have a channel-id
266235
// 2. Split collectives are combined based on the value of replica-id.
267-
const absl::string_view kModuleStr = R"(
236+
absl::string_view hlo = R"(
268237
HloModule test
269238
ENTRY test_computation {
270239
p = u32[8,8] parameter(0)
@@ -273,28 +242,21 @@ TEST_F(CollectivePermuteCycleDecomposerTest, BackwardCycleNoChannel) {
273242
frontend_attributes={_xla_send_recv_validation="{{0,7},{1,8},{2,9},{3,10}}"}
274243
})";
275244

276-
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
277-
ParseAndReturnVerifiedModule((kModuleStr)));
278-
CollectivePermuteCycleDecomposer decomposer(/*threshold_in_bytes=*/0);
279-
TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get()));
280-
EXPECT_TRUE(changed);
281-
HloPrintOptions options;
282-
options.set_print_operand_shape(false);
283-
options.set_include_layout_in_shapes(false);
284-
TF_CHECK_OK(VerifyHloModule(module.get(), false, true));
285-
EXPECT_TRUE(*RunFileCheck(module->ToString(options), R"(
245+
TF_ASSERT_OK_AND_ASSIGN(auto module,
246+
RunAndCheckHloRewrite(hlo, Decomposer(0), true));
247+
EXPECT_TRUE(*RunFileCheck(module->ToString(PrintOptions()), R"(
286248
// CHECK: ENTRY %test_computation (p: u32[8,8]) -> u32[8,8] {
287249
// CHECK-DAG: %[[replica_id:.+]] = u32[] replica-id()
288250
// CHECK-DAG: %[[three:.+]] = u32[] constant(3)
289251
// CHECK-DAG: %[[compare:.+]] = pred[] compare(%[[replica_id]], %[[three]]), direction=EQ
290252
// CHECK-DAG: %{{.+}} = u32[8,8] parameter(0)
291-
253+
292254
// CHECK-DAG: %[[cp1:.+]] = u32[8,8] collective-permute(%{{.+}}), source_target_pairs=
293255
// CHECK-SAME{LITERAL}: {{0,3}}, frontend_attributes={_xla_send_recv_validation={{0,7}}}
294-
256+
295257
// CHECK-DAG: %[[cp2:.+]] = u32[8,8] collective-permute(%{{.+}}), source_target_pairs=
296258
// CHECK-SAME{LITERAL}: {{1,0},{2,1},{3,2}}, frontend_attributes={_xla_send_recv_validation={{1,8},{2,9},{3,10}}}
297-
259+
298260
// CHECK-DAG: ROOT %select = u32[8,8] select(%[[compare]], %[[cp1]], %[[cp2]])
299261
// CHECK-DAG: }
300262
)"));

0 commit comments

Comments
 (0)