Skip to content

Commit aba14f7

Browse files
[XLA] Replace convert(constant) with constant
PiperOrigin-RevId: 652328778
1 parent 2718d5d commit aba14f7

File tree

2 files changed

+51
-0
lines changed

2 files changed

+51
-0
lines changed

third_party/xla/xla/service/algebraic_simplifier.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5200,6 +5200,19 @@ absl::Status AlgebraicSimplifierVisitor::HandleConvert(
52005200
convert->mutable_operand(0)->mutable_operand(0));
52015201
}
52025202

5203+
// Try to replace convert(constant) with a constant of the right type to begin
5204+
// with. Disallow moving int4 since it is not supported for many ops
5205+
HloInstruction* constant;
5206+
if (Match(convert, m::Convert(m::Constant(&constant))) &&
5207+
!primitive_util::IsSubByteNonPredType(src_type) &&
5208+
!primitive_util::IsSubByteNonPredType(dest_type)) {
5209+
TF_ASSIGN_OR_RETURN(Literal dest_literal,
5210+
constant->literal().Convert(dest_type));
5211+
VLOG(10) << "Replacing convert(constant) with constant";
5212+
return ReplaceWithNewInstruction(
5213+
convert, HloInstruction::CreateConstant(std::move(dest_literal)));
5214+
}
5215+
52035216
return TryRemoveUpcastAndDowncastSurroundingBinaryOp(convert);
52045217
}
52055218

third_party/xla/xla/service/algebraic_simplifier_test.cc

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11676,5 +11676,43 @@ ENTRY main.1 {
1167611676
HloOpcode::kParameter);
1167711677
}
1167811678

11679+
TEST_F(AlgebraicSimplifierTest, RemoveConvertConstant) {
11680+
const std::string hlo_string = R"(
11681+
HloModule module
11682+
11683+
add {
11684+
p0 = f32[] parameter(0)
11685+
p1 = f32[] parameter(1)
11686+
ROOT r = f32[] add(p0, p1)
11687+
}
11688+
11689+
ENTRY test {
11690+
a = f32[32,64] parameter(0)
11691+
b = s32[] constant(0)
11692+
c = f32[] convert(b)
11693+
ROOT reduce = f32[32] reduce(a, c),
11694+
dimensions={1}, to_apply=add
11695+
}
11696+
)";
11697+
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
11698+
EXPECT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value());
11699+
HloInstruction* root = m->entry_computation()->root_instruction();
11700+
EXPECT_THAT(root, GmockMatch(m::Reduce(m::Parameter(0),
11701+
m::Constant().WithShape(F32, {}))));
11702+
}
11703+
11704+
TEST_F(AlgebraicSimplifierTest, KeepInt4ConvertConstant) {
11705+
const std::string hlo_string = R"(
11706+
HloModule module
11707+
11708+
ENTRY test {
11709+
a = s4[] constant(0)
11710+
ROOT b = s8[] convert(a)
11711+
}
11712+
)";
11713+
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
11714+
ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).value());
11715+
}
11716+
1167911717
} // namespace
1168011718
} // namespace xla

0 commit comments

Comments
 (0)