@@ -230,14 +230,36 @@ class TritonSupportTest : public TritonSupportTestBase {
230
230
std::vector<int64_t > output_tile_sizes,
231
231
se::GpuComputeCapability cc,
232
232
ExpectedFailMode failure_mode = ExpectedFailMode::kFail ) {
233
+ // output_tile_sizes is embedded in a vector of 1 element to share the logic
234
+ // with the multiple output tiles case.
235
+ RunSupportTestMultipleOutputTiles (
236
+ std::move (ti), {std::move (output_tile_sizes)}, cc, failure_mode);
237
+ }
238
+
239
+ void RunSupportTestMultipleOutputTiles (
240
+ TestedInstruction ti, std::vector<std::vector<int64_t >> output_tile_sizes,
241
+ se::GpuComputeCapability cc,
242
+ ExpectedFailMode failure_mode = ExpectedFailMode::kFail ) {
233
243
// Ensure that the caller provided the right number of output tile sizes.
234
244
// If that is not the case, codegen could fail for that reason---which
235
- // wouldn't give any valuable signal here. We skip the check for non-array
236
- // output shapes, since we have no meaningful way of providing tile sizes
237
- // for them at the moment.
245
+ // wouldn't give any valuable signal here. The check is only done for array
246
+ // and tuple shapes (only one layer of nesting is supported for tuples).
238
247
if (ti.Instruction ().shape ().IsArray ()) {
248
+ ASSERT_EQ (output_tile_sizes.size (), 1 );
249
+ ASSERT_EQ (output_tile_sizes[0 ].size (),
250
+ ti.Instruction ().shape ().dimensions ().size ());
251
+ } else if (ti.Instruction ().shape ().IsTuple ()) {
239
252
ASSERT_EQ (output_tile_sizes.size (),
240
- ti.Instruction ().shape ().dimensions_size ());
253
+ ti.Instruction ().shape ().tuple_shapes_size ());
254
+ for (int64_t i = 0 ; i < output_tile_sizes.size (); ++i) {
255
+ const auto & shape = ti.Instruction ().shape ().tuple_shapes (i);
256
+ if (shape.IsTuple ()) {
257
+ continue ; // No validation for nested tuples, as there is no way to
258
+ // specify output tile sizes for them.
259
+ }
260
+ ASSERT_TRUE (shape.IsArray ());
261
+ ASSERT_EQ (shape.dimensions ().size (), output_tile_sizes[i].size ());
262
+ }
241
263
}
242
264
BlockLevelParameters block_level_parameters =
243
265
FromOutputTileSizes (std::move (output_tile_sizes));
@@ -726,16 +748,16 @@ add {
726
748
ENTRY triton_computation {
727
749
parameter_0 = $$0[125,127] parameter(0)
728
750
constant_0 = $$0[] constant($0)
729
- tuple = ($$0[125], $$0[125]) reduce(
751
+ ROOT reduce = ($$0[125], $$0[125]) reduce(
730
752
parameter_0, parameter_0, constant_0, constant_0),
731
753
dimensions={1}, to_apply=add
732
- ROOT reduce = $$0[125] get-tuple-element(tuple), index=0
733
754
})" ,
734
755
init_value (data_type));
735
756
TF_ASSERT_OK_AND_ASSIGN (
736
757
TestedInstruction ti,
737
758
ParseTemplateAndGetInstruction (kHloTestTemplate , data_type, opcode));
738
- RunSupportTest (std::move (ti), /* output_tile_sizes=*/ {1 }, cc);
759
+ RunSupportTestMultipleOutputTiles (std::move (ti),
760
+ /* output_tile_sizes=*/ {{1 }, {1 }}, cc);
739
761
}
740
762
741
763
TEST_F (ReduceTest, ReduceWithNonConstReduceValueIsSupportedWithTriton) {
@@ -1025,7 +1047,8 @@ ENTRY triton_computation {
1025
1047
TestedInstruction ti,
1026
1048
ParseTemplateAndGetInstruction (kHloTestTemplate , data_type,
1027
1049
HloOpcode::kAllGatherStart ));
1028
- RunSupportTest (std::move (ti), /* output_tile_sizes=*/ {2 , 2 }, cc);
1050
+ RunSupportTestMultipleOutputTiles (std::move (ti),
1051
+ /* output_tile_sizes=*/ {{2 , 2 }, {2 , 2 }}, cc);
1029
1052
}
1030
1053
1031
1054
TEST_P (CollectiveTest, UnsupportedAllGatherDoneFailsGracefullyWithTriton) {
@@ -1142,7 +1165,8 @@ ENTRY triton_computation {
1142
1165
ParseTemplateAndGetInstruction (kHloTestTemplate , data_type,
1143
1166
HloOpcode::kCollectivePermuteDone ));
1144
1167
1145
- RunSupportTest (std::move (ti_start), /* output_tile_sizes=*/ {2 , 2 }, cc);
1168
+ RunSupportTestMultipleOutputTiles (std::move (ti_start),
1169
+ /* output_tile_sizes=*/ {{2 , 2 }, {2 , 2 }}, cc);
1146
1170
RunSupportTest (std::move (ti_done), /* output_tile_sizes=*/ {2 , 2 }, cc);
1147
1171
}
1148
1172
@@ -1197,8 +1221,10 @@ ENTRY triton_computation {
1197
1221
TestedInstruction ti_done,
1198
1222
ParseTemplateAndGetInstruction (kHloTestTemplate , data_type,
1199
1223
HloOpcode::kAsyncDone ));
1200
- RunSupportTest (std::move (ti_start), /* output_tile_sizes=*/ {1 }, cc);
1201
- RunSupportTest (std::move (ti_update), /* output_tile_sizes=*/ {1 }, cc);
1224
+ RunSupportTestMultipleOutputTiles (std::move (ti_start),
1225
+ /* output_tile_sizes=*/ {{1 }, {1 }}, cc);
1226
+ RunSupportTestMultipleOutputTiles (std::move (ti_update),
1227
+ /* output_tile_sizes=*/ {{1 }, {1 }}, cc);
1202
1228
RunSupportTest (std::move (ti_done), /* output_tile_sizes=*/ {1 }, cc);
1203
1229
}
1204
1230
@@ -1436,7 +1462,8 @@ ENTRY triton_computation {
1436
1462
TF_ASSERT_OK_AND_ASSIGN (
1437
1463
TestedInstruction ti,
1438
1464
ParseTemplateAndGetInstruction (kHloTestTemplate , data_type, opcode));
1439
- RunSupportTest (std::move (ti), /* output_tile_sizes=*/ {16 , 32 }, cc);
1465
+ RunSupportTestMultipleOutputTiles (std::move (ti),
1466
+ /* output_tile_sizes=*/ {{1 }, {16 , 32 }}, cc);
1440
1467
}
1441
1468
1442
1469
INSTANTIATE_TEST_SUITE_P (
@@ -1608,23 +1635,21 @@ INSTANTIATE_TEST_SUITE_P(
1608
1635
1609
1636
using BatchNormTrainingTest = TritonSupportTestWithTypeAndOpcodeAndDeviceParam;
1610
1637
1611
- // TODO: b/363981282 - Get rid of get-tuple-element by adding multiple output
1612
- // tikes support to RunSupportTest.
1613
1638
TEST_P (BatchNormTrainingTest, BatchNormTraining) {
1614
1639
auto [data_type, opcode, cc] = GetParam ();
1615
1640
const std::string kHloTestTemplate = R"(
1616
1641
ENTRY triton_computation {
1617
1642
operand = $0[4,8,16,32] parameter(0)
1618
1643
scale = $0[32] parameter(1)
1619
1644
offset = $0[32] parameter(2)
1620
- bn_train = ($0[4,8,16,32], $0[32], $0[32]) batch-norm-training(operand, scale, offset),
1645
+ ROOT bn_train = ($0[4,8,16,32], $0[32], $0[32]) batch-norm-training(operand, scale, offset),
1621
1646
epsilon=0.001, feature_index=3
1622
- ROOT gte = $0[4,8,16,32] get-tuple-element(bn_train), index=0
1623
1647
})" ;
1624
1648
TF_ASSERT_OK_AND_ASSIGN (
1625
1649
TestedInstruction ti,
1626
1650
ParseTemplateAndGetInstruction (kHloTestTemplate , data_type, opcode));
1627
- RunSupportTest (std::move (ti), /* output_tile_sizes=*/ {1 , 1 , 4 , 8 }, cc);
1651
+ RunSupportTestMultipleOutputTiles (
1652
+ std::move (ti), /* output_tile_sizes=*/ {{1 , 1 , 4 , 8 }, {1 }, {1 }}, cc);
1628
1653
}
1629
1654
1630
1655
INSTANTIATE_TEST_SUITE_P (
@@ -1634,8 +1659,6 @@ INSTANTIATE_TEST_SUITE_P(
1634
1659
1635
1660
using BatchNormGradTest = TritonSupportTestWithTypeAndOpcodeAndDeviceParam;
1636
1661
1637
- // TODO: b/363981282 - Get rid of get-tuple-element by adding multiple output
1638
- // tikes support to RunSupportTest.
1639
1662
TEST_P (BatchNormGradTest, BatchNormGrad) {
1640
1663
auto [data_type, opcode, cc] = GetParam ();
1641
1664
const std::string kHloTestTemplate = R"(
@@ -1645,14 +1668,14 @@ ENTRY triton_computation {
1645
1668
mean = $0[32] parameter(2)
1646
1669
variance = $0[32] parameter(3)
1647
1670
grad_output = $0[4,8,16,32] parameter(4)
1648
- bn_grad = ($0[4,8,16,32], $0[32], $0[32]) batch-norm-grad(operand, scale, mean, variance, grad_output),
1671
+ ROOT bn_grad = ($0[4,8,16,32], $0[32], $0[32]) batch-norm-grad(operand, scale, mean, variance, grad_output),
1649
1672
epsilon=0.001, feature_index=3
1650
- ROOT gte = $0[4,8,16,32] get-tuple-element(bn_grad), index=0
1651
1673
})" ;
1652
1674
TF_ASSERT_OK_AND_ASSIGN (
1653
1675
TestedInstruction ti,
1654
1676
ParseTemplateAndGetInstruction (kHloTestTemplate , data_type, opcode));
1655
- RunSupportTest (std::move (ti), /* output_tile_sizes=*/ {1 , 1 , 4 , 8 }, cc);
1677
+ RunSupportTestMultipleOutputTiles (
1678
+ std::move (ti), /* output_tile_sizes=*/ {{1 , 1 , 4 , 8 }, {1 }, {1 }}, cc);
1656
1679
}
1657
1680
1658
1681
INSTANTIATE_TEST_SUITE_P (
0 commit comments