Skip to content

Commit b1971cc

Browse files
thcmbsGoogle-ML-Automation
authored andcommitted
[XLA:GPU] Add support for multiple output tiles in triton_support_test
+ removes dependency on `get-tuple-element` for Reduce, BatchNormGrad & BatchNormTraining tests PiperOrigin-RevId: 742631101
1 parent 3f1d29b commit b1971cc

File tree

2 files changed

+47
-24
lines changed

2 files changed

+47
-24
lines changed

xla/backends/gpu/codegen/triton/support_test.cc

+45-22
Original file line numberDiff line numberDiff line change
@@ -230,14 +230,36 @@ class TritonSupportTest : public TritonSupportTestBase {
230230
std::vector<int64_t> output_tile_sizes,
231231
se::GpuComputeCapability cc,
232232
ExpectedFailMode failure_mode = ExpectedFailMode::kFail) {
233+
// output_tile_sizes is embedded in a vector of 1 element to share the logic
234+
// with the multiple output tiles case.
235+
RunSupportTestMultipleOutputTiles(
236+
std::move(ti), {std::move(output_tile_sizes)}, cc, failure_mode);
237+
}
238+
239+
void RunSupportTestMultipleOutputTiles(
240+
TestedInstruction ti, std::vector<std::vector<int64_t>> output_tile_sizes,
241+
se::GpuComputeCapability cc,
242+
ExpectedFailMode failure_mode = ExpectedFailMode::kFail) {
233243
// Ensure that the caller provided the right number of output tile sizes.
234244
// If that is not the case, codegen could fail for that reason---which
235-
// wouldn't give any valuable signal here. We skip the check for non-array
236-
// output shapes, since we have no meaningful way of providing tile sizes
237-
// for them at the moment.
245+
// wouldn't give any valuable signal here. The check is only done for array
246+
// and tuple shapes (only one layer of nesting is supported for tuples).
238247
if (ti.Instruction().shape().IsArray()) {
248+
ASSERT_EQ(output_tile_sizes.size(), 1);
249+
ASSERT_EQ(output_tile_sizes[0].size(),
250+
ti.Instruction().shape().dimensions().size());
251+
} else if (ti.Instruction().shape().IsTuple()) {
239252
ASSERT_EQ(output_tile_sizes.size(),
240-
ti.Instruction().shape().dimensions_size());
253+
ti.Instruction().shape().tuple_shapes_size());
254+
for (int64_t i = 0; i < output_tile_sizes.size(); ++i) {
255+
const auto& shape = ti.Instruction().shape().tuple_shapes(i);
256+
if (shape.IsTuple()) {
257+
continue; // No validation for nested tuples, as there is no way to
258+
// specify output tile sizes for them.
259+
}
260+
ASSERT_TRUE(shape.IsArray());
261+
ASSERT_EQ(shape.dimensions().size(), output_tile_sizes[i].size());
262+
}
241263
}
242264
BlockLevelParameters block_level_parameters =
243265
FromOutputTileSizes(std::move(output_tile_sizes));
@@ -726,16 +748,16 @@ add {
726748
ENTRY triton_computation {
727749
parameter_0 = $$0[125,127] parameter(0)
728750
constant_0 = $$0[] constant($0)
729-
tuple = ($$0[125], $$0[125]) reduce(
751+
ROOT reduce = ($$0[125], $$0[125]) reduce(
730752
parameter_0, parameter_0, constant_0, constant_0),
731753
dimensions={1}, to_apply=add
732-
ROOT reduce = $$0[125] get-tuple-element(tuple), index=0
733754
})",
734755
init_value(data_type));
735756
TF_ASSERT_OK_AND_ASSIGN(
736757
TestedInstruction ti,
737758
ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode));
738-
RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1}, cc);
759+
RunSupportTestMultipleOutputTiles(std::move(ti),
760+
/*output_tile_sizes=*/{{1}, {1}}, cc);
739761
}
740762

741763
TEST_F(ReduceTest, ReduceWithNonConstReduceValueIsSupportedWithTriton) {
@@ -1025,7 +1047,8 @@ ENTRY triton_computation {
10251047
TestedInstruction ti,
10261048
ParseTemplateAndGetInstruction(kHloTestTemplate, data_type,
10271049
HloOpcode::kAllGatherStart));
1028-
RunSupportTest(std::move(ti), /*output_tile_sizes=*/{2, 2}, cc);
1050+
RunSupportTestMultipleOutputTiles(std::move(ti),
1051+
/*output_tile_sizes=*/{{2, 2}, {2, 2}}, cc);
10291052
}
10301053

10311054
TEST_P(CollectiveTest, UnsupportedAllGatherDoneFailsGracefullyWithTriton) {
@@ -1142,7 +1165,8 @@ ENTRY triton_computation {
11421165
ParseTemplateAndGetInstruction(kHloTestTemplate, data_type,
11431166
HloOpcode::kCollectivePermuteDone));
11441167

1145-
RunSupportTest(std::move(ti_start), /*output_tile_sizes=*/{2, 2}, cc);
1168+
RunSupportTestMultipleOutputTiles(std::move(ti_start),
1169+
/*output_tile_sizes=*/{{2, 2}, {2, 2}}, cc);
11461170
RunSupportTest(std::move(ti_done), /*output_tile_sizes=*/{2, 2}, cc);
11471171
}
11481172

@@ -1197,8 +1221,10 @@ ENTRY triton_computation {
11971221
TestedInstruction ti_done,
11981222
ParseTemplateAndGetInstruction(kHloTestTemplate, data_type,
11991223
HloOpcode::kAsyncDone));
1200-
RunSupportTest(std::move(ti_start), /*output_tile_sizes=*/{1}, cc);
1201-
RunSupportTest(std::move(ti_update), /*output_tile_sizes=*/{1}, cc);
1224+
RunSupportTestMultipleOutputTiles(std::move(ti_start),
1225+
/*output_tile_sizes=*/{{1}, {1}}, cc);
1226+
RunSupportTestMultipleOutputTiles(std::move(ti_update),
1227+
/*output_tile_sizes=*/{{1}, {1}}, cc);
12021228
RunSupportTest(std::move(ti_done), /*output_tile_sizes=*/{1}, cc);
12031229
}
12041230

@@ -1436,7 +1462,8 @@ ENTRY triton_computation {
14361462
TF_ASSERT_OK_AND_ASSIGN(
14371463
TestedInstruction ti,
14381464
ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode));
1439-
RunSupportTest(std::move(ti), /*output_tile_sizes=*/{16, 32}, cc);
1465+
RunSupportTestMultipleOutputTiles(std::move(ti),
1466+
/*output_tile_sizes=*/{{1}, {16, 32}}, cc);
14401467
}
14411468

14421469
INSTANTIATE_TEST_SUITE_P(
@@ -1608,23 +1635,21 @@ INSTANTIATE_TEST_SUITE_P(
16081635

16091636
using BatchNormTrainingTest = TritonSupportTestWithTypeAndOpcodeAndDeviceParam;
16101637

1611-
// TODO: b/363981282 - Get rid of get-tuple-element by adding multiple output
1612-
// tikes support to RunSupportTest.
16131638
TEST_P(BatchNormTrainingTest, BatchNormTraining) {
16141639
auto [data_type, opcode, cc] = GetParam();
16151640
const std::string kHloTestTemplate = R"(
16161641
ENTRY triton_computation {
16171642
operand = $0[4,8,16,32] parameter(0)
16181643
scale = $0[32] parameter(1)
16191644
offset = $0[32] parameter(2)
1620-
bn_train = ($0[4,8,16,32], $0[32], $0[32]) batch-norm-training(operand, scale, offset),
1645+
ROOT bn_train = ($0[4,8,16,32], $0[32], $0[32]) batch-norm-training(operand, scale, offset),
16211646
epsilon=0.001, feature_index=3
1622-
ROOT gte = $0[4,8,16,32] get-tuple-element(bn_train), index=0
16231647
})";
16241648
TF_ASSERT_OK_AND_ASSIGN(
16251649
TestedInstruction ti,
16261650
ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode));
1627-
RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1, 1, 4, 8}, cc);
1651+
RunSupportTestMultipleOutputTiles(
1652+
std::move(ti), /*output_tile_sizes=*/{{1, 1, 4, 8}, {1}, {1}}, cc);
16281653
}
16291654

16301655
INSTANTIATE_TEST_SUITE_P(
@@ -1634,8 +1659,6 @@ INSTANTIATE_TEST_SUITE_P(
16341659

16351660
using BatchNormGradTest = TritonSupportTestWithTypeAndOpcodeAndDeviceParam;
16361661

1637-
// TODO: b/363981282 - Get rid of get-tuple-element by adding multiple output
1638-
// tikes support to RunSupportTest.
16391662
TEST_P(BatchNormGradTest, BatchNormGrad) {
16401663
auto [data_type, opcode, cc] = GetParam();
16411664
const std::string kHloTestTemplate = R"(
@@ -1645,14 +1668,14 @@ ENTRY triton_computation {
16451668
mean = $0[32] parameter(2)
16461669
variance = $0[32] parameter(3)
16471670
grad_output = $0[4,8,16,32] parameter(4)
1648-
bn_grad = ($0[4,8,16,32], $0[32], $0[32]) batch-norm-grad(operand, scale, mean, variance, grad_output),
1671+
ROOT bn_grad = ($0[4,8,16,32], $0[32], $0[32]) batch-norm-grad(operand, scale, mean, variance, grad_output),
16491672
epsilon=0.001, feature_index=3
1650-
ROOT gte = $0[4,8,16,32] get-tuple-element(bn_grad), index=0
16511673
})";
16521674
TF_ASSERT_OK_AND_ASSIGN(
16531675
TestedInstruction ti,
16541676
ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode));
1655-
RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1, 1, 4, 8}, cc);
1677+
RunSupportTestMultipleOutputTiles(
1678+
std::move(ti), /*output_tile_sizes=*/{{1, 1, 4, 8}, {1}, {1}}, cc);
16561679
}
16571680

16581681
INSTANTIATE_TEST_SUITE_P(

xla/backends/gpu/codegen/triton/test_utils.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@ absl::Status CreateTritonIrAndFileCheckForDot(
6565
const HloComputation& computation, absl::string_view filecheck_pattern);
6666

6767
inline BlockLevelParameters FromOutputTileSizes(
68-
std::vector<int64_t> output_tile_sizes) {
68+
std::vector<std::vector<int64_t>> output_tile_sizes) {
6969
BlockLevelParameters block_level_parameters;
70-
block_level_parameters.output_tile_sizes.push_back(output_tile_sizes);
70+
block_level_parameters.output_tile_sizes = std::move(output_tile_sizes);
7171
return block_level_parameters;
7272
}
7373

0 commit comments

Comments
 (0)