Skip to content

Commit 0c08427

Browse files
jreifferstensorflower-gardener
authored andcommitted
Simplify more constraints and mods.
Additional mod simplification was a coincidence due to refactoring the sum-splitting logic. I can factor this out into a separate CL if desired. - enables additional vectorization when there is a constraint on the loop symbol that is redundant. - fixes some inconsistencies in div/mod simplification (should lead to better code in some cases, see loop emitter test). - replaces an ad-hoc simplification in fusion_emitter.cc with a more general one Also remove most of the change detector tests in reduction_test. PiperOrigin-RevId: 651294205
1 parent 19ad21f commit 0c08427

13 files changed

+273
-469
lines changed

third_party/xla/xla/service/gpu/fusions/concatenate_mlir.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ absl::Status MlirConcatenateFusion::EmitEntryFunction(
114114
auto thread_id_to_output_map = ComposeIndexingMaps(
115115
ComposeIndexingMaps(thread_id_to_input_map, input_to_output_map),
116116
epilogue_indexing);
117+
thread_id_to_output_map.Simplify();
117118

118119
auto loop_nest_body_builder =
119120
[&, operand_index = operand_index](

third_party/xla/xla/service/gpu/fusions/fusion_emitter.cc

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -188,13 +188,7 @@ IndexingMap KernelFusionInterface::GetDefaultThreadIdIndexingMap(
188188
mlir::AffineMap::get(/*dimCount=*/6,
189189
/*symbolCount=*/2, output_dims, ctx),
190190
dim_vars, range_vars, /*rt_vars=*/{});
191-
// Remove the unroll_elem_id symbol if unrolling divides num_elements.
192-
if (num_elements % unroll_factor == 0) {
193-
indexing_map.AddConstraint(linear_index.replace({{unroll_elem_id, c0}}),
194-
Interval{0, num_elements - unroll_factor});
195-
} else {
196-
indexing_map.AddConstraint(linear_index, Interval{0, num_elements - 1});
197-
}
191+
indexing_map.AddConstraint(linear_index, Interval{0, num_elements - 1});
198192
indexing_map.Simplify();
199193
return indexing_map;
200194
}

third_party/xla/xla/service/gpu/fusions/loop_mlir_test.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ TEST_F(MlirLoopFusionTest, ThreadId_IndexingUnrolled) {
5656
(th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (
5757
((bl_x * 128 + chunk_id * 129024 + th_x) floordiv 15000) mod 100,
5858
((bl_x * 128 + chunk_id * 129024 + th_x) floordiv 75) mod 200,
59-
(th_x * 4 + bl_x * 512 + chunk_id * 516096) mod 300 + unroll_id
59+
((bl_x * 128 + chunk_id * 129024 + th_x) mod 75) * 4 + unroll_id
6060
)
6161
domain:
6262
th_x in [0, 128)
@@ -67,7 +67,7 @@ TEST_F(MlirLoopFusionTest, ThreadId_IndexingUnrolled) {
6767
bl_z in [0, 1)
6868
chunk_id in [0, 12)
6969
unroll_id in [0, 4)
70-
(th_x + bl_x * 128) * 4 + chunk_id * 516096 in [0, 5999997)
70+
th_x + bl_x * 128 + chunk_id * 129024 in [0, 1500000)
7171
)"));
7272
}
7373

third_party/xla/xla/service/gpu/fusions/loop_test.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ TEST_F(LoopTest, ThreadIndexingUnrolled) {
9090
(th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (
9191
((bl_x * 128 + chunk_id * 129024 + th_x) floordiv 15000) mod 100,
9292
((bl_x * 128 + chunk_id * 129024 + th_x) floordiv 75) mod 200,
93-
(th_x * 4 + bl_x * 512 + chunk_id * 516096) mod 300 + unroll_id
93+
((bl_x * 128 + chunk_id * 129024 + th_x) mod 75) * 4 + unroll_id
9494
)
9595
domain:
9696
th_x in [0, 128)
@@ -101,7 +101,7 @@ TEST_F(LoopTest, ThreadIndexingUnrolled) {
101101
bl_z in [0, 1)
102102
chunk_id in [0, 12)
103103
unroll_id in [0, 4)
104-
(th_x + bl_x * 128) * 4 + chunk_id * 516096 in [0, 5999997)
104+
th_x + bl_x * 128 + chunk_id * 129024 in [0, 1500000)
105105
)"));
106106
}
107107

third_party/xla/xla/service/gpu/fusions/reduction.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1295,6 +1295,7 @@ std::optional<IndexingMap> ReductionInfo::ComputeThreadIdToOutputIndexing(
12951295
}();
12961296

12971297
AddGroupIdConstraint(map, root_index, groups_);
1298+
map.Simplify();
12981299
return map;
12991300
}
13001301

@@ -1321,6 +1322,7 @@ std::optional<IndexingMap> ReductionInfo::ComputeThreadIdToInputIndexing(
13211322
GetBitcastMap(tiling_.GetXlaShape(),
13221323
hero.operand(hero_operand_index)->shape(), ctx));
13231324
AddGroupIdConstraint(map, root_index, groups_);
1325+
map.Simplify();
13241326
return map;
13251327
}
13261328

third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -564,10 +564,13 @@ std::optional<IndexingMap> MlirReductionFusion::ComputeThreadIdToInputIndexing(
564564
.indexing_maps[hero_operand_index]
565565
.begin());
566566
}
567-
auto map = ComputeReductionInputIndexing(ctx);
568-
AddGroupIdConstraint(map, root_index, groups_);
569-
return map * GetBitcastMap(input_shape_,
570-
hero.operand(hero_operand_index)->shape(), ctx);
567+
auto projected_map = ComputeReductionInputIndexing(ctx);
568+
AddGroupIdConstraint(projected_map, root_index, groups_);
569+
auto map = projected_map *
570+
GetBitcastMap(input_shape_,
571+
hero.operand(hero_operand_index)->shape(), ctx);
572+
map.Simplify();
573+
return map;
571574
}
572575

573576
std::optional<IndexingMap> MlirReductionFusion::ComputeThreadIdToOutputIndexing(
@@ -578,6 +581,7 @@ std::optional<IndexingMap> MlirReductionFusion::ComputeThreadIdToOutputIndexing(
578581
GetBitcastMap(input_shape_, analysis_.fusion_root(root_index).shape(),
579582
ctx));
580583
AddGroupIdConstraint(map, root_index, groups_);
584+
map.Simplify();
581585
return map;
582586
}
583587

@@ -594,10 +598,12 @@ std::optional<IndexingMap> MlirReductionFusion::ComputeThreadIdToOutputIndexing(
594598
const auto& hero = analysis_.fusion_hero(root_index).instruction();
595599
auto physical_shape =
596600
ShapeUtil::DeleteDimensions(hero.dimensions(), hero.operand(0)->shape());
597-
return projected_indexing *
598-
GetBitcastMap(ShapeUtil::MakeShapeWithDescendingLayout(
599-
PrimitiveType::U8, output_shape),
600-
physical_shape, ctx);
601+
auto map = projected_indexing *
602+
GetBitcastMap(ShapeUtil::MakeShapeWithDescendingLayout(
603+
PrimitiveType::U8, output_shape),
604+
physical_shape, ctx);
605+
map.Simplify();
606+
return map;
601607
}
602608

603609
SmallVector<Value> MlirReductionFusion::EvaluateEpilogue(

third_party/xla/xla/service/gpu/fusions/reduction_mlir_test.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -670,7 +670,7 @@ TEST_F(MlirColumnReductionTest, ColumnReduction) {
670670
(d0, d1, d2, d3, d4, d5)[s0, s1] -> (
671671
d3 floordiv 11,
672672
d0 floordiv 32 + s0 * 32,
673-
(d3 mod 11) * 32 + d0 mod 32 + s1
673+
(d3 mod 11) * 32 + d0 mod 32
674674
)
675675
domain:
676676
d0 in [0, 1024)
@@ -681,24 +681,24 @@ TEST_F(MlirColumnReductionTest, ColumnReduction) {
681681
d5 in [0, 1)
682682
s0 in [0, 33)
683683
s1 in [0, 1)
684-
(d3 mod 11) * 32 + d0 mod 32 + s1 in [0, 321)
684+
(d3 mod 11) * 32 + d0 mod 32 in [0, 321)
685685
d0 floordiv 32 + s0 * 32 in [0, 1051)
686686
)"));
687687
EXPECT_THAT(
688688
fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(),
689689
MatchIndexingString(R"(
690690
(d0, d1, d2, d3, d4, d5)[s0] -> (
691-
d3 floordiv 11, (d3 mod 11) * 32 + d0 floordiv 32 + s0
691+
d3 floordiv 11, (d3 mod 11) * 32 + d0 floordiv 32
692692
)
693693
domain:
694-
d0 in [0, 1024)
694+
d0 in [0, 993)
695695
d1 in [0, 1)
696696
d2 in [0, 1)
697697
d3 in [0, 143)
698698
d4 in [0, 1)
699699
d5 in [0, 1)
700700
s0 in [0, 1)
701-
(d3 mod 11) * 32 + d0 floordiv 32 + s0 in [0, 321)
701+
(d3 mod 11) * 32 + d0 floordiv 32 in [0, 321)
702702
d0 mod 32 in [0, 1)
703703
)"));
704704
TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"(

0 commit comments

Comments
 (0)