Skip to content

Commit 235865e

Browse files
allow to disable one broadcast optimization in algebraic_simplifier
PiperOrigin-RevId: 685914179
1 parent 4bf4d86 commit 235865e

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

third_party/xla/xla/service/algebraic_simplifier.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5122,7 +5122,8 @@ absl::Status AlgebraicSimplifierVisitor::HandleBroadcast(
51225122
if (options_.is_layout_sensitive()) {
51235123
return absl::OkStatus();
51245124
}
5125-
if (ShapeUtil::HasDegenerateDimensions(operand->shape())) {
5125+
if (options_.enable_broadcast_degenerate_dimension() &&
5126+
ShapeUtil::HasDegenerateDimensions(operand->shape())) {
51265127
auto new_operand = operand->AddInstruction(HloInstruction::CreateReshape(
51275128
ShapeUtil::DropDegenerateDimensions(operand->shape()), operand));
51285129
std::vector<int64_t> new_dims;

third_party/xla/xla/service/algebraic_simplifier.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,15 @@ class AlgebraicSimplifierOptions {
313313
}
314314
bool enable_fast_math() const { return enable_fast_math_; }
315315

316+
void set_enable_broadcast_degenerate_dimension(
317+
bool enable_broadcast_degenerate_dimension) {
318+
enable_broadcast_degenerate_dimension_ =
319+
enable_broadcast_degenerate_dimension;
320+
}
321+
bool enable_broadcast_degenerate_dimension() const {
322+
return enable_broadcast_degenerate_dimension_;
323+
}
324+
316325
private:
317326
// Metadata struct can be used to store any metadata information encapsulated
318327
// with the AlgebraicSimplifierOptions that can be later used in an
@@ -354,6 +363,7 @@ class AlgebraicSimplifierOptions {
354363
bool use_convert_constant_folding_{false};
355364
bool disable_dynamic_slice_to_slice_conversion_{false};
356365
bool enable_fast_math_{false};
366+
bool enable_broadcast_degenerate_dimension_{true};
357367
Metadata metadata_;
358368
};
359369

0 commit comments

Comments
 (0)