File tree Expand file tree Collapse file tree 2 files changed +51
-0
lines changed
third_party/xla/xla/service Expand file tree Collapse file tree 2 files changed +51
-0
lines changed Original file line number Diff line number Diff line change @@ -5200,6 +5200,19 @@ absl::Status AlgebraicSimplifierVisitor::HandleConvert(
5200
5200
convert->mutable_operand (0 )->mutable_operand (0 ));
5201
5201
}
5202
5202
5203
+ // Try to replace convert(constant) with a constant of the right type to begin
5204
+ // with. Disallow moving int4 since it is not supported for many ops
5205
+ HloInstruction* constant;
5206
+ if (Match (convert, m::Convert (m::Constant (&constant))) &&
5207
+ !primitive_util::IsSubByteNonPredType (src_type) &&
5208
+ !primitive_util::IsSubByteNonPredType (dest_type)) {
5209
+ TF_ASSIGN_OR_RETURN (Literal dest_literal,
5210
+ constant->literal ().Convert (dest_type));
5211
+ VLOG (10 ) << " Replacing convert(constant) with constant" ;
5212
+ return ReplaceWithNewInstruction (
5213
+ convert, HloInstruction::CreateConstant (std::move (dest_literal)));
5214
+ }
5215
+
5203
5216
return TryRemoveUpcastAndDowncastSurroundingBinaryOp (convert);
5204
5217
}
5205
5218
Original file line number Diff line number Diff line change @@ -11676,5 +11676,43 @@ ENTRY main.1 {
11676
11676
HloOpcode::kParameter );
11677
11677
}
11678
11678
11679
+ TEST_F (AlgebraicSimplifierTest, RemoveConvertConstant) {
11680
+ const std::string hlo_string = R"(
11681
+ HloModule module
11682
+
11683
+ add {
11684
+ p0 = f32[] parameter(0)
11685
+ p1 = f32[] parameter(1)
11686
+ ROOT r = f32[] add(p0, p1)
11687
+ }
11688
+
11689
+ ENTRY test {
11690
+ a = f32[32,64] parameter(0)
11691
+ b = s32[] constant(0)
11692
+ c = f32[] convert(b)
11693
+ ROOT reduce = f32[32] reduce(a, c),
11694
+ dimensions={1}, to_apply=add
11695
+ }
11696
+ )" ;
11697
+ TF_ASSERT_OK_AND_ASSIGN (auto m, ParseAndReturnVerifiedModule (hlo_string));
11698
+ EXPECT_TRUE (AlgebraicSimplifier (default_options_).Run (m.get ()).value ());
11699
+ HloInstruction* root = m->entry_computation ()->root_instruction ();
11700
+ EXPECT_THAT (root, GmockMatch (m::Reduce (m::Parameter (0 ),
11701
+ m::Constant ().WithShape (F32, {}))));
11702
+ }
11703
+
11704
+ TEST_F (AlgebraicSimplifierTest, KeepInt4ConvertConstant) {
11705
+ const std::string hlo_string = R"(
11706
+ HloModule module
11707
+
11708
+ ENTRY test {
11709
+ a = s4[] constant(0)
11710
+ ROOT b = s8[] convert(a)
11711
+ }
11712
+ )" ;
11713
+ TF_ASSERT_OK_AND_ASSIGN (auto m, ParseAndReturnVerifiedModule (hlo_string));
11714
+ ASSERT_FALSE (AlgebraicSimplifier (default_options_).Run (m.get ()).value ());
11715
+ }
11716
+
11679
11717
} // namespace
11680
11718
} // namespace xla
You can’t perform that action at this time.
0 commit comments